Skip to content
Snippets Groups Projects
Verified Commit 4f2f40ca authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Build prediction model

parent 02db354b
No related branches found
No related tags found
No related merge requests found
......@@ -281,6 +281,15 @@ def eval_model(model, test_ds):
pprint(result)
def build_prediction_model(model, movies):
"""Build an index to get movie suggestions"""
index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
index.index_from_dataset(tf.data.Dataset.zip(
(movies.batch(100), movies.batch(100).map(model.cand_model))
))
return index
def train(args):
"""Load the dataset, train and evaluate the model"""
train_ds, test_ds, _, movies = load_dataset(args.input)
......@@ -291,6 +300,7 @@ def train(args):
eval_model(model, test_ds)
commands = {
"prepare": prepare_dataset,
"train": train,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment