From 4f2f40ca3cc55d45d090af8cfdbdc6d694750f71 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Sat, 27 Aug 2022 23:49:15 +0200
Subject: [PATCH] Build prediction model

---
 movies.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/movies.py b/movies.py
index a78639e..16f2480 100644
--- a/movies.py
+++ b/movies.py
@@ -281,6 +281,15 @@ def eval_model(model, test_ds):
     pprint(result)
 
 
+def build_prediction_model(model, movies):
+    """Build an index to get movie suggestions"""
+    index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
+    index.index_from_dataset(tf.data.Dataset.zip(
+        (movies.batch(100), movies.batch(100).map(model.cand_model))
+    ))
+    return index
+
+
 def train(args):
     """Load the dataset, train and evaluate the model"""
     train_ds, test_ds, _, movies = load_dataset(args.input)
@@ -291,6 +300,7 @@ def train(args):
         eval_model(model, test_ds)
 
 
+
 commands = {
     "prepare": prepare_dataset,
     "train": train,
-- 
GitLab