diff --git a/movies.py b/movies.py
index f212dfc79f49edf2aaecfafcfaaf962c240a5e48..c218c7dc0988ae9e2b3a80e302b20cc426f9839a 100644
--- a/movies.py
+++ b/movies.py
@@ -357,8 +357,10 @@ def train(args):
     train_ds, test_ds, _, movies = load_dataset(args.input)
 
     with tf.device("cpu:0"):
-        model = build_model(movies)
-        fit_model(model, train_ds)
+        model = build_model(movies,
+                            embedding_dimension=args.embedding_dimension,
+                            learning_rate=args.learning_rate)
+        fit_model(model, train_ds, epochs=args.epochs)
 
         if not args.debug:
             eval_model(model, test_ds)
@@ -384,6 +386,12 @@ def get_default_parser():
                         help="Path to output file(s)")
     parser.add_argument("--debug", default=False, action="store_true",
                         help="Limit test data size")
+    parser.add_argument("-e", "--epochs", default=3, type=int,
+                        help="Number of training epochs")
+    parser.add_argument("-l", "--learning-rate", default=0.1, type=float,
+                        help="Learning rate")
+    parser.add_argument("-d", "--embedding-dimension", default=32, type=int,
+                        help="Dimension of embedding space")
 
     return parser