Skip to content
Snippets Groups Projects
Verified Commit 35d90390 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Create tf.train.Examples

parent 5b83a8c1
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@ Movie recommender system as a showcase project
import collections
import pandas as pd
# import tensorflow as tf
import tensorflow as tf
def load_movielense():
"""Load an return movies and ratings as dataframes"""
......@@ -17,14 +17,40 @@ def load_movielense():
return ratings_df, movies_df
def collect_user_context(ratings, min_rating=2.1):
def collect_user_histories(ratings, min_rating=2.1):
"""Create a per-user rating list"""
user_movies = collections.defaultdict(lambda: []) # dict mapping ids to movies
user_histories = collections.defaultdict(lambda: []) # dict mapping ids to movies
ratings = ratings.sort_values(by=['userId', 'timestamp'])
for user_id, movie_id, rating, _ in ratings.values:
if rating >= min_rating:
user_movies[user_id].append(movie_id)
return user_movies
user_histories[user_id].append(movie_id)
return user_histories
def create_user_examples(user_history, min_len=3, max_len=100):
"""Create examples by sliding a 3-100 window over the history"""
examples = []
for label_idx in range(min_len, len(user_history)):
start_idx = max(0, label_idx - max_len)
context_movie_ids = user_history[start_idx:label_idx]
if len(context_movie_ids) < min_len:
continue
movie_id = user_history[label_idx]
feature = {
"context_movie_id": tf.train.Feature(
int64_list=tf.train.Int64List(value=context_movie_ids)
),
"label_movie_id": tf.train.Feature(
int64_list=tf.train.Int64List(value=[movie_id])
)
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
examples.append(example)
return examples
......@@ -21,19 +21,82 @@ class LoadTests(unittest.TestCase):
},
columns=["userId", "movieId", "rating", "timestamp"])
def test_collect_user_context(self):
def test_collect_user_histories(self):
"""Check that ratings are correctly aggregated"""
rating = self.toy_ratings()
user_movies = movies.collect_user_context(rating)
user_movies = movies.collect_user_histories(rating)
self.assertEqual(user_movies[1], [1, 4])
self.assertEqual(user_movies[2], [3, 1])
self.assertEqual(set(user_movies.keys()), {1, 2})
def test_collect_user_context_min_rating(self):
def test_collect_user_histories_min_rating(self):
"""Check the min rating threshold"""
rating = self.toy_ratings()
user_movies = movies.collect_user_context(rating, min_rating=5)
user_movies = movies.collect_user_histories(rating, min_rating=5)
self.assertEqual(user_movies[1], [4])
self.assertEqual(set(user_movies.keys()), {1})
def assert_feature(self, example, feature_name, value):
"""Assert the int64_list value of a tf.train.Example feature"""
self.assertEqual(
example.features.feature[feature_name].int64_list.value,
value
)
def test_user_examples(self):
"""Test the window sliding with default args"""
history = [1, 2, 3, 4, 5, 6]
examples = movies.create_user_examples(history)
# First
self.assert_feature(examples[0], "context_movie_id", [1, 2, 3])
self.assert_feature(examples[0], "label_movie_id", [4])
# Second
self.assert_feature(examples[1], "context_movie_id", [1, 2, 3, 4])
self.assert_feature(examples[1], "label_movie_id", [5])
# Third
self.assert_feature(examples[2], "context_movie_id", [1, 2, 3, 4, 5])
self.assert_feature(examples[2], "label_movie_id", [6])
self.assertEqual(len(examples), 3)
def test_user_examples_args(self):
"""Test the window sliding with constom args"""
history = [1, 2, 3, 4, 5]
examples = movies.create_user_examples(history, min_len=2, max_len=3)
# First
self.assert_feature(examples[0], "context_movie_id", [1, 2])
self.assert_feature(examples[0], "label_movie_id", [3])
# Second
self.assert_feature(examples[1], "context_movie_id", [1, 2, 3])
self.assert_feature(examples[1], "label_movie_id", [4])
# Third
self.assert_feature(examples[2], "context_movie_id", [2, 3, 4])
self.assert_feature(examples[2], "label_movie_id", [5])
self.assertEqual(len(examples), 3)
def test_user_examples_empty(self):
"""Test the window sliding with constom args"""
history = [1, 2, 3]
examples = movies.create_user_examples(history)
self.assertEqual(len(examples), 0)
history = []
examples = movies.create_user_examples(history)
self.assertEqual(len(examples), 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment