diff --git a/movies.py b/movies.py index a78639e84c9440a9e7269a88ff1af8a94f20f9e3..16f2480f4008e1d514f99757f87f8b7ae527fba3 100644 --- a/movies.py +++ b/movies.py @@ -281,6 +281,15 @@ def eval_model(model, test_ds): pprint(result) +def build_prediction_model(model, movies): + """Build an index to get movie suggestions""" + index = tfrs.layers.factorized_top_k.BruteForce(model.query_model) + index.index_from_dataset(tf.data.Dataset.zip( + (movies.batch(100), movies.batch(100).map(model.cand_model)) + )) + return index + + def train(args): """Load the dataset, train and evaluate the model""" train_ds, test_ds, _, movies = load_dataset(args.input) @@ -291,6 +300,7 @@ def train(args): eval_model(model, test_ds) + commands = { "prepare": prepare_dataset, "train": train,