From 24cfd48979c29a5f85d7bb20aca3f9879ff4fd30 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <frank@sauerburger.com> Date: Sat, 27 Aug 2022 13:20:35 +0200 Subject: [PATCH] Add import preparation CLI --- movies.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++++- movies_test.py | 8 ++--- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/movies.py b/movies.py index 636c7f0..da83814 100644 --- a/movies.py +++ b/movies.py @@ -3,7 +3,10 @@ Movie recommender system as a showcase project """ +import argparse import collections +import random +import os import pandas as pd import tensorflow as tf @@ -31,7 +34,7 @@ def collect_user_histories(ratings, min_rating=2.1): return user_histories -def create_user_examples(user_history, min_len=3, max_len=100): +def create_single_user_examples(user_history, min_len=3, max_len=100): """Create examples by sliding a 3-100 window over the history""" examples = [] for label_idx in range(min_len, len(user_history)): @@ -54,3 +57,82 @@ def create_user_examples(user_history, min_len=3, max_len=100): examples.append(example) return examples + + +def create_user_examples(user_histories, frac=0.8, random_seed=20220827, **kwds): + """Create user examples for each user history and return train,test lists""" + examples = [] + progress_bar = tf.keras.utils.Progbar(len(user_histories)) + for user_history in user_histories.values(): + single_user_examples = create_single_user_examples( + user_history, **kwds + ) + examples.extend(single_user_examples) + progress_bar.add(1) + + random.seed(random_seed) + random.shuffle(examples) + + last_train_index = int(len(examples) * frac) + train_examples = examples[:last_train_index] + test_examples = examples[last_train_index:] + return train_examples, test_examples + + +def write_user_examples(examples, filename): + """Write examples to rt Record""" + progress_bar = tf.keras.utils.Progbar(len(examples)) + with tf.io.TFRecordWriter(filename) as record_writer: + for example in examples: + record_writer.write(example.SerializeToString()) + progress_bar.add(1) + + +def prepare_dataset(args): + """Read raw CSV files and write TF Record""" + ratings_df, _ = load_movielense() + if args.debug: + ratings_df = ratings_df[:10000] + + user_histories = collect_user_histories(ratings_df) + train_examples, test_examples = create_user_examples(user_histories) + + train_file = os.path.join(args.output, "train.tfrecord") + write_user_examples(train_examples, train_file) + n_train = len(train_examples) + print(f"File {train_file} with {n_train:d} records created.") + + test_file = os.path.join(args.output, "test.tfrecord") + write_user_examples(test_examples, test_file) + n_test = len(test_examples) + print(f"File {test_file} with {n_test:d} records created.") + + +commands = { + "prepare": prepare_dataset, +} + +def get_default_parser(): + """Return the default command-line args parser""" + parser = argparse.ArgumentParser() + parser.add_argument("command", metavar="CMD", choices=commands, + help="Operation to execute") + parser.add_argument("-i", "--input", metavar="PATH", + help="Path to input file") + parser.add_argument("-o", "--output", metavar="PATH", + help="Path to output file(s)") + parser.add_argument("--debug", default=False, action="store_true", + help="Limit test data size") + + return parser + + +def cli_main(args): + """Handle CLI args and call requested command""" + commands[args.command](args) + + +if __name__ == "__main__": + parser_ = get_default_parser() + args_ = parser_.parse_args() + cli_main(args_) diff --git a/movies_test.py b/movies_test.py index 3ec5aac..5e29e6f 100644 --- a/movies_test.py +++ b/movies_test.py @@ -53,7 +53,7 @@ class LoadTests(unittest.TestCase): """Test the window sliding with default args""" history = [1, 2, 3, 4, 5, 6] - examples = movies.create_user_examples(history) + examples = movies.create_single_user_examples(history) # First self.assert_feature(examples[0], "context_movie_id", [1, 2, 3]) @@ -74,7 +74,7 @@ class LoadTests(unittest.TestCase): """Test the window sliding with constom args""" history = [1, 2, 3, 4, 5] - examples = movies.create_user_examples(history, min_len=2, max_len=3) + examples = movies.create_single_user_examples(history, min_len=2, max_len=3) # First self.assert_feature(examples[0], "context_movie_id", [1, 2]) @@ -94,9 +94,9 @@ class LoadTests(unittest.TestCase): def test_user_examples_empty(self): """Test the window sliding with constom args""" history = [1, 2, 3] - examples = movies.create_user_examples(history) + examples = movies.create_single_user_examples(history) self.assertEqual(len(examples), 0) history = [] - examples = movies.create_user_examples(history) + examples = movies.create_single_user_examples(history) self.assertEqual(len(examples), 0) -- GitLab