From 3a995bbec9d6ccdfe6a5250de3bb44eb20cdb6b8 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Sat, 27 Aug 2022 13:46:40 +0200
Subject: [PATCH] Dump movie indices and titles

---
 movies.py      | 36 ++++++++++++++++++++++++++++++++++--
 movies_test.py | 43 ++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 76 insertions(+), 3 deletions(-)

diff --git a/movies.py b/movies.py
index da83814..3574b99 100644
--- a/movies.py
+++ b/movies.py
@@ -5,8 +5,10 @@ Movie recommender system as a showcase project
 
 import argparse
 import collections
-import random
+import json
 import os
+import random
+import re
 import pandas as pd
 import tensorflow as tf
 
@@ -79,6 +81,23 @@ def create_user_examples(user_histories, frac=0.8, random_seed=20220827, **kwds)
     return train_examples, test_examples
 
 
+def parse_title_year(movie_name):
+    """Return (title, year) from full title"""
+    pattern = r"^\s*(.+\S)\s+\((\d+)\)\s*$"
+    rem = re.match(pattern, movie_name)
+    if rem:
+        return rem.group(1), int(rem.group(2))
+    return movie_name, None
+
+
+def index_titles(movies_df):
+    """Create dict index of titles"""
+    titles = {}
+    for movie_id, title, _ in movies_df.values:
+        titles[movie_id] = parse_title_year(title)
+    return titles
+
+
 def write_user_examples(examples, filename):
     """Write examples to rt Record"""
     progress_bar = tf.keras.utils.Progbar(len(examples))
@@ -88,9 +107,15 @@ def write_user_examples(examples, filename):
             progress_bar.add(1)
 
 
+def write_movie_titles(titles, filename):
+    """Write movie titles to json file"""
+    with open(filename, "w", encoding="utf-8") as fileobj:
+        json.dump(titles, fileobj)
+
+
 def prepare_dataset(args):
     """Read raw CSV files and write TF Record"""
-    ratings_df, _ = load_movielense()
+    ratings_df, movies_df = load_movielense()
     if args.debug:
         ratings_df = ratings_df[:10000]
 
@@ -108,6 +133,13 @@ def prepare_dataset(args):
     print(f"File {test_file} with {n_test:d} records created.")
 
 
+    movie_titles = index_titles(movies_df)
+    titles_file = os.path.join(args.output, "titles.json")
+    write_movie_titles(movie_titles, titles_file)
+    n_titles = len(titles_file)
+    print(f"File {titles_file} with {n_titles:d} titles created.")
+
+
 commands = {
     "prepare": prepare_dataset,
 }
diff --git a/movies_test.py b/movies_test.py
index 5e29e6f..90315e7 100644
--- a/movies_test.py
+++ b/movies_test.py
@@ -92,7 +92,7 @@ class LoadTests(unittest.TestCase):
 
 
     def test_user_examples_empty(self):
-        """Test the window sliding with constom args"""
+        """Test the window sliding with custom args"""
         history = [1, 2, 3]
         examples = movies.create_single_user_examples(history)
         self.assertEqual(len(examples), 0)
@@ -100,3 +100,44 @@ class LoadTests(unittest.TestCase):
         history = []
         examples = movies.create_single_user_examples(history)
         self.assertEqual(len(examples), 0)
+
+
+    def test_parse_title_year(self):
+        """Test year parsing"""
+        title, year = movies.parse_title_year("Hello (2022)")
+        self.assertEqual(title, "Hello")
+        self.assertEqual(year, 2022)
+
+
+    def test_parse_title_year_whitespace(self):
+        """Test year parsing with white space"""
+        title, year = movies.parse_title_year("  Hello   (2022) ")
+        self.assertEqual(title, "Hello")
+        self.assertEqual(year, 2022)
+
+
+    def test_parse_title_year_no_year(self):
+        """Test year parsing without year"""
+        title, year = movies.parse_title_year("No time for a year")
+        self.assertEqual(title, "No time for a year")
+        self.assertIsNone(year)
+
+
+    @staticmethod
+    def toy_movies():
+        """Return toy movie dataframe"""
+        return pd.DataFrame({
+            "movieId": [1, 2],
+            "title": ["Hello (2020)", "Hello again (2021)"],
+            "genres": ["Thriller", "Romcom"]
+        },
+        columns=["movieId", "title", "generes"])
+
+
+    def test_index_titles(self):
+        """Check indexing movie titles"""
+        movies_df = self.toy_movies()
+        titles = movies.index_titles(movies_df)
+
+        self.assertEqual(titles[1], ("Hello", 2020))
+        self.assertEqual(titles[2], ("Hello again", 2021))
-- 
GitLab