From 35d90390e7ee120b737c1e6f6cbcee4c7054c475 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <frank@sauerburger.com> Date: Sat, 27 Aug 2022 11:50:41 +0200 Subject: [PATCH] Create tf.train.Examples --- movies.py | 38 ++++++++++++++++++++++----- movies_test.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/movies.py b/movies.py index fa2218e..636c7f0 100644 --- a/movies.py +++ b/movies.py @@ -5,7 +5,7 @@ Movie recommender system as a showcase project import collections import pandas as pd -# import tensorflow as tf +import tensorflow as tf def load_movielense(): """Load an return movies and ratings as dataframes""" @@ -17,14 +17,40 @@ def load_movielense(): return ratings_df, movies_df -def collect_user_context(ratings, min_rating=2.1): + +def collect_user_histories(ratings, min_rating=2.1): """Create a per-user rating list""" - user_movies = collections.defaultdict(lambda: []) # dict mapping ids to movies + user_histories = collections.defaultdict(lambda: []) # dict mapping ids to movies ratings = ratings.sort_values(by=['userId', 'timestamp']) for user_id, movie_id, rating, _ in ratings.values: if rating >= min_rating: - user_movies[user_id].append(movie_id) - - return user_movies + user_histories[user_id].append(movie_id) + + return user_histories + + +def create_user_examples(user_history, min_len=3, max_len=100): + """Create examples by sliding a 3-100 window over the history""" + examples = [] + for label_idx in range(min_len, len(user_history)): + start_idx = max(0, label_idx - max_len) + context_movie_ids = user_history[start_idx:label_idx] + if len(context_movie_ids) < min_len: + continue + + movie_id = user_history[label_idx] + + feature = { + "context_movie_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=context_movie_ids) + ), + "label_movie_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[movie_id]) + ) + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + examples.append(example) + + return examples diff --git a/movies_test.py b/movies_test.py index 8211f15..3ec5aac 100644 --- a/movies_test.py +++ b/movies_test.py @@ -21,19 +21,82 @@ class LoadTests(unittest.TestCase): }, columns=["userId", "movieId", "rating", "timestamp"]) - def test_collect_user_context(self): + + def test_collect_user_histories(self): """Check that ratings are correctly aggregated""" rating = self.toy_ratings() - user_movies = movies.collect_user_context(rating) + user_movies = movies.collect_user_histories(rating) self.assertEqual(user_movies[1], [1, 4]) self.assertEqual(user_movies[2], [3, 1]) self.assertEqual(set(user_movies.keys()), {1, 2}) - def test_collect_user_context_min_rating(self): + + def test_collect_user_histories_min_rating(self): """Check the min rating threshold""" rating = self.toy_ratings() - user_movies = movies.collect_user_context(rating, min_rating=5) + user_movies = movies.collect_user_histories(rating, min_rating=5) self.assertEqual(user_movies[1], [4]) self.assertEqual(set(user_movies.keys()), {1}) + + + def assert_feature(self, example, feature_name, value): + """Assert the int64_list value of a tf.train.Example feature""" + self.assertEqual( + example.features.feature[feature_name].int64_list.value, + value + ) + + + def test_user_examples(self): + """Test the window sliding with default args""" + history = [1, 2, 3, 4, 5, 6] + + examples = movies.create_user_examples(history) + + # First + self.assert_feature(examples[0], "context_movie_id", [1, 2, 3]) + self.assert_feature(examples[0], "label_movie_id", [4]) + + # Second + self.assert_feature(examples[1], "context_movie_id", [1, 2, 3, 4]) + self.assert_feature(examples[1], "label_movie_id", [5]) + + # Third + self.assert_feature(examples[2], "context_movie_id", [1, 2, 3, 4, 5]) + self.assert_feature(examples[2], "label_movie_id", [6]) + + self.assertEqual(len(examples), 3) + + + def test_user_examples_args(self): + """Test the window sliding with constom args""" + history = [1, 2, 3, 4, 5] + + examples = movies.create_user_examples(history, min_len=2, max_len=3) + + # First + self.assert_feature(examples[0], "context_movie_id", [1, 2]) + self.assert_feature(examples[0], "label_movie_id", [3]) + + # Second + self.assert_feature(examples[1], "context_movie_id", [1, 2, 3]) + self.assert_feature(examples[1], "label_movie_id", [4]) + + # Third + self.assert_feature(examples[2], "context_movie_id", [2, 3, 4]) + self.assert_feature(examples[2], "label_movie_id", [5]) + + self.assertEqual(len(examples), 3) + + + def test_user_examples_empty(self): + """Test the window sliding with constom args""" + history = [1, 2, 3] + examples = movies.create_user_examples(history) + self.assertEqual(len(examples), 0) + + history = [] + examples = movies.create_user_examples(history) + self.assertEqual(len(examples), 0) -- GitLab