Skip to content
Snippets Groups Projects
Verified Commit 80b23288 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Make training parameters configurable

parent 6b93d906
No related branches found
No related tags found
1 merge request!1Resolve "Track training with mlflow"
Pipeline #9628 passed
......@@ -357,8 +357,10 @@ def train(args):
train_ds, test_ds, _, movies = load_dataset(args.input)
with tf.device("cpu:0"):
model = build_model(movies)
fit_model(model, train_ds)
model = build_model(movies,
embedding_dimension=args.embedding_dimension,
learning_rate=args.learning_rate)
fit_model(model, train_ds, epochs=args.epochs)
if not args.debug:
eval_model(model, test_ds)
......@@ -384,6 +386,12 @@ def get_default_parser():
help="Path to output file(s)")
parser.add_argument("--debug", default=False, action="store_true",
help="Limit test data size")
parser.add_argument("-e", "--epochs", default=3, type=int,
help="Number of training epochs")
parser.add_argument("-l", "--learning-rate", default=0.1, type=float,
help="Learning rate")
parser.add_argument("-d", "--embedding-dimension", default=32, type=int,
help="Dimension of embedding space")
return parser
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment