diff --git a/movies.py b/movies.py index f212dfc79f49edf2aaecfafcfaaf962c240a5e48..c218c7dc0988ae9e2b3a80e302b20cc426f9839a 100644 --- a/movies.py +++ b/movies.py @@ -357,8 +357,10 @@ def train(args): train_ds, test_ds, _, movies = load_dataset(args.input) with tf.device("cpu:0"): - model = build_model(movies) - fit_model(model, train_ds) + model = build_model(movies, + embedding_dimension=args.embedding_dimension, + learning_rate=args.learning_rate) + fit_model(model, train_ds, epochs=args.epochs) if not args.debug: eval_model(model, test_ds) @@ -384,6 +386,12 @@ def get_default_parser(): help="Path to output file(s)") parser.add_argument("--debug", default=False, action="store_true", help="Limit test data size") + parser.add_argument("-e", "--epochs", default=3, type=int, + help="Number of training epochs") + parser.add_argument("-l", "--learning-rate", default=0.1, type=float, + help="Learning rate") + parser.add_argument("-d", "--embedding-dimension", default=32, type=int, + help="Dimension of embedding space") return parser