From 35d90390e7ee120b737c1e6f6cbcee4c7054c475 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Sat, 27 Aug 2022 11:50:41 +0200
Subject: [PATCH] Create tf.train.Examples

---
 movies.py      | 38 ++++++++++++++++++++++-----
 movies_test.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++---
 2 files changed, 99 insertions(+), 10 deletions(-)

diff --git a/movies.py b/movies.py
index fa2218e..636c7f0 100644
--- a/movies.py
+++ b/movies.py
@@ -5,7 +5,7 @@ Movie recommender system as a showcase project
 
 import collections
 import pandas as pd
-# import tensorflow as tf
+import tensorflow as tf
 
 def load_movielense():
     """Load an return movies and ratings as dataframes"""
@@ -17,14 +17,40 @@ def load_movielense():
 
     return ratings_df, movies_df
 
-def collect_user_context(ratings, min_rating=2.1):
+
+def collect_user_histories(ratings, min_rating=2.1):
     """Create a per-user rating list"""
-    user_movies = collections.defaultdict(lambda: [])  # dict mapping ids to movies
+    user_histories = collections.defaultdict(lambda: [])  # dict mapping ids to movies
 
     ratings = ratings.sort_values(by=['userId', 'timestamp'])
 
     for user_id, movie_id, rating, _ in ratings.values:
         if rating >= min_rating:
-            user_movies[user_id].append(movie_id)
-
-    return user_movies
+            user_histories[user_id].append(movie_id)
+
+    return user_histories
+
+
+def create_user_examples(user_history, min_len=3, max_len=100):
+    """Create examples by sliding a 3-100 window over the history"""
+    examples = []
+    for label_idx in range(min_len, len(user_history)):
+        start_idx = max(0, label_idx - max_len)
+        context_movie_ids = user_history[start_idx:label_idx]
+        if len(context_movie_ids) < min_len:
+            continue
+
+        movie_id = user_history[label_idx]
+
+        feature = {
+            "context_movie_id": tf.train.Feature(
+                int64_list=tf.train.Int64List(value=context_movie_ids)
+             ),
+             "label_movie_id": tf.train.Feature(
+                int64_list=tf.train.Int64List(value=[movie_id])
+             )
+        }
+        example = tf.train.Example(features=tf.train.Features(feature=feature))
+        examples.append(example)
+
+    return examples
diff --git a/movies_test.py b/movies_test.py
index 8211f15..3ec5aac 100644
--- a/movies_test.py
+++ b/movies_test.py
@@ -21,19 +21,82 @@ class LoadTests(unittest.TestCase):
         },
         columns=["userId", "movieId", "rating", "timestamp"])
 
-    def test_collect_user_context(self):
+
+    def test_collect_user_histories(self):
         """Check that ratings are correctly aggregated"""
         rating = self.toy_ratings()
-        user_movies = movies.collect_user_context(rating)
+        user_movies = movies.collect_user_histories(rating)
 
         self.assertEqual(user_movies[1], [1, 4])
         self.assertEqual(user_movies[2], [3, 1])
         self.assertEqual(set(user_movies.keys()), {1, 2})
 
-    def test_collect_user_context_min_rating(self):
+
+    def test_collect_user_histories_min_rating(self):
         """Check the min rating threshold"""
         rating = self.toy_ratings()
-        user_movies = movies.collect_user_context(rating, min_rating=5)
+        user_movies = movies.collect_user_histories(rating, min_rating=5)
 
         self.assertEqual(user_movies[1], [4])
         self.assertEqual(set(user_movies.keys()), {1})
+
+
+    def assert_feature(self, example, feature_name, value):
+        """Assert the int64_list value of a tf.train.Example feature"""
+        self.assertEqual(
+            example.features.feature[feature_name].int64_list.value,
+            value
+        )
+
+
+    def test_user_examples(self):
+        """Test the window sliding with default args"""
+        history = [1, 2, 3, 4, 5, 6]
+
+        examples = movies.create_user_examples(history)
+
+        # First
+        self.assert_feature(examples[0], "context_movie_id", [1, 2, 3])
+        self.assert_feature(examples[0], "label_movie_id", [4])
+
+        # Second
+        self.assert_feature(examples[1], "context_movie_id", [1, 2, 3, 4])
+        self.assert_feature(examples[1], "label_movie_id", [5])
+
+        # Third
+        self.assert_feature(examples[2], "context_movie_id", [1, 2, 3, 4, 5])
+        self.assert_feature(examples[2], "label_movie_id", [6])
+
+        self.assertEqual(len(examples), 3)
+
+
+    def test_user_examples_args(self):
+        """Test the window sliding with constom args"""
+        history = [1, 2, 3, 4, 5]
+
+        examples = movies.create_user_examples(history, min_len=2, max_len=3)
+
+        # First
+        self.assert_feature(examples[0], "context_movie_id", [1, 2])
+        self.assert_feature(examples[0], "label_movie_id", [3])
+
+        # Second
+        self.assert_feature(examples[1], "context_movie_id", [1, 2, 3])
+        self.assert_feature(examples[1], "label_movie_id", [4])
+
+        # Third
+        self.assert_feature(examples[2], "context_movie_id", [2, 3, 4])
+        self.assert_feature(examples[2], "label_movie_id", [5])
+
+        self.assertEqual(len(examples), 3)
+
+
+    def test_user_examples_empty(self):
+        """Test the window sliding with constom args"""
+        history = [1, 2, 3]
+        examples = movies.create_user_examples(history)
+        self.assertEqual(len(examples), 0)
+
+        history = []
+        examples = movies.create_user_examples(history)
+        self.assertEqual(len(examples), 0)
-- 
GitLab