From 80b23288e95f3fa65a1122db54ef1164a46547f5 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Sun, 28 Aug 2022 13:02:00 +0200
Subject: [PATCH] Make training parameters configurable

---
 movies.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/movies.py b/movies.py
index f212dfc..c218c7d 100644
--- a/movies.py
+++ b/movies.py
@@ -357,8 +357,10 @@ def train(args):
     train_ds, test_ds, _, movies = load_dataset(args.input)
 
     with tf.device("cpu:0"):
-        model = build_model(movies)
-        fit_model(model, train_ds)
+        model = build_model(movies,
+                            embedding_dimension=args.embedding_dimension,
+                            learning_rate=args.learning_rate)
+        fit_model(model, train_ds, epochs=args.epochs)
 
         if not args.debug:
             eval_model(model, test_ds)
@@ -384,6 +386,12 @@ def get_default_parser():
                         help="Path to output file(s)")
     parser.add_argument("--debug", default=False, action="store_true",
                         help="Limit test data size")
+    parser.add_argument("-e", "--epochs", default=3, type=int,
+                        help="Number of training epochs")
+    parser.add_argument("-l", "--learning-rate", default=0.1, type=float,
+                        help="Learning rate")
+    parser.add_argument("-d", "--embedding-dimension", default=32, type=int,
+                        help="Dimension of embedding space")
 
     return parser
 
-- 
GitLab