From 4f2f40ca3cc55d45d090af8cfdbdc6d694750f71 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <frank@sauerburger.com> Date: Sat, 27 Aug 2022 23:49:15 +0200 Subject: [PATCH] Build prediction model --- movies.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/movies.py b/movies.py index a78639e..16f2480 100644 --- a/movies.py +++ b/movies.py @@ -281,6 +281,15 @@ def eval_model(model, test_ds): pprint(result) +def build_prediction_model(model, movies): + """Build an index to get movie suggestions""" + index = tfrs.layers.factorized_top_k.BruteForce(model.query_model) + index.index_from_dataset(tf.data.Dataset.zip( + (movies.batch(100), movies.batch(100).map(model.cand_model)) + )) + return index + + def train(args): """Load the dataset, train and evaluate the model""" train_ds, test_ds, _, movies = load_dataset(args.input) @@ -291,6 +300,7 @@ def train(args): eval_model(model, test_ds) + commands = { "prepare": prepare_dataset, "train": train, -- GitLab