From 24cfd48979c29a5f85d7bb20aca3f9879ff4fd30 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Sat, 27 Aug 2022 13:20:35 +0200
Subject: [PATCH] Add import preparation CLI

---
 movies.py      | 84 +++++++++++++++++++++++++++++++++++++++++++++++++-
 movies_test.py |  8 ++---
 2 files changed, 87 insertions(+), 5 deletions(-)

diff --git a/movies.py b/movies.py
index 636c7f0..da83814 100644
--- a/movies.py
+++ b/movies.py
@@ -3,7 +3,10 @@
 Movie recommender system as a showcase project
 """
 
+import argparse
 import collections
+import random
+import os
 import pandas as pd
 import tensorflow as tf
 
@@ -31,7 +34,7 @@ def collect_user_histories(ratings, min_rating=2.1):
     return user_histories
 
 
-def create_user_examples(user_history, min_len=3, max_len=100):
+def create_single_user_examples(user_history, min_len=3, max_len=100):
     """Create examples by sliding a 3-100 window over the history"""
     examples = []
     for label_idx in range(min_len, len(user_history)):
@@ -54,3 +57,82 @@ def create_user_examples(user_history, min_len=3, max_len=100):
         examples.append(example)
 
     return examples
+
+
+def create_user_examples(user_histories, frac=0.8, random_seed=20220827, **kwds):
+    """Create user examples for each user history and return train,test lists"""
+    examples = []
+    progress_bar = tf.keras.utils.Progbar(len(user_histories))
+    for user_history in user_histories.values():
+        single_user_examples = create_single_user_examples(
+            user_history, **kwds
+        )
+        examples.extend(single_user_examples)
+        progress_bar.add(1)
+
+    random.seed(random_seed)
+    random.shuffle(examples)
+
+    last_train_index = int(len(examples) * frac)
+    train_examples = examples[:last_train_index]
+    test_examples = examples[last_train_index:]
+    return train_examples, test_examples
+
+
+def write_user_examples(examples, filename):
+    """Write examples to rt Record"""
+    progress_bar = tf.keras.utils.Progbar(len(examples))
+    with tf.io.TFRecordWriter(filename) as record_writer:
+        for example in examples:
+            record_writer.write(example.SerializeToString())
+            progress_bar.add(1)
+
+
+def prepare_dataset(args):
+    """Read raw CSV files and write TF Record"""
+    ratings_df, _ = load_movielense()
+    if args.debug:
+        ratings_df = ratings_df[:10000]
+
+    user_histories = collect_user_histories(ratings_df)
+    train_examples, test_examples = create_user_examples(user_histories)
+
+    train_file = os.path.join(args.output, "train.tfrecord")
+    write_user_examples(train_examples, train_file)
+    n_train = len(train_examples)
+    print(f"File {train_file} with {n_train:d} records created.")
+
+    test_file = os.path.join(args.output, "test.tfrecord")
+    write_user_examples(test_examples, test_file)
+    n_test = len(test_examples)
+    print(f"File {test_file} with {n_test:d} records created.")
+
+
+commands = {
+    "prepare": prepare_dataset,
+}
+
+def get_default_parser():
+    """Return the default command-line args parser"""
+    parser = argparse.ArgumentParser()
+    parser.add_argument("command", metavar="CMD", choices=commands,
+                        help="Operation to execute")
+    parser.add_argument("-i", "--input", metavar="PATH",
+                        help="Path to input file")
+    parser.add_argument("-o", "--output", metavar="PATH",
+                        help="Path to output file(s)")
+    parser.add_argument("--debug", default=False, action="store_true",
+                        help="Limit test data size")
+
+    return parser
+
+
+def cli_main(args):
+    """Handle CLI args and call requested command"""
+    commands[args.command](args)
+
+
+if __name__ == "__main__":
+    parser_ = get_default_parser()
+    args_ = parser_.parse_args()
+    cli_main(args_)
diff --git a/movies_test.py b/movies_test.py
index 3ec5aac..5e29e6f 100644
--- a/movies_test.py
+++ b/movies_test.py
@@ -53,7 +53,7 @@ class LoadTests(unittest.TestCase):
         """Test the window sliding with default args"""
         history = [1, 2, 3, 4, 5, 6]
 
-        examples = movies.create_user_examples(history)
+        examples = movies.create_single_user_examples(history)
 
         # First
         self.assert_feature(examples[0], "context_movie_id", [1, 2, 3])
@@ -74,7 +74,7 @@ class LoadTests(unittest.TestCase):
         """Test the window sliding with constom args"""
         history = [1, 2, 3, 4, 5]
 
-        examples = movies.create_user_examples(history, min_len=2, max_len=3)
+        examples = movies.create_single_user_examples(history, min_len=2, max_len=3)
 
         # First
         self.assert_feature(examples[0], "context_movie_id", [1, 2])
@@ -94,9 +94,9 @@ class LoadTests(unittest.TestCase):
     def test_user_examples_empty(self):
         """Test the window sliding with constom args"""
         history = [1, 2, 3]
-        examples = movies.create_user_examples(history)
+        examples = movies.create_single_user_examples(history)
         self.assertEqual(len(examples), 0)
 
         history = []
-        examples = movies.create_user_examples(history)
+        examples = movies.create_single_user_examples(history)
         self.assertEqual(len(examples), 0)
-- 
GitLab