diff --git a/tfx/components/testdata/module_file/trainer_module.py b/tfx/components/testdata/module_file/trainer_module.py index bf46404c88..456feb30ce 100644 --- a/tfx/components/testdata/module_file/trainer_module.py +++ b/tfx/components/testdata/module_file/trainer_module.py @@ -13,33 +13,29 @@ # limitations under the License. """Python source file include taxi pipeline functions and necesasry utils. -For a TFX pipeline to successfully run, a preprocessing_fn and a -_build_estimator function needs to be provided. This file contains both. - -This file is equivalent to examples/chicago_taxi/trainer/model.py and -examples/chicago_taxi/preprocess.py. +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. """ -import absl +from typing import Optional + +from absl import logging import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft -from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer import executor -from tfx.utils import io_utils -from tfx.utils import path_utils -from tfx_bsl.public.tfxio import TensorFlowDatasetOptions -from tensorflow_metadata.proto.v0 import schema_pb2 - +from tfx.components.trainer import fn_args_utils +from tfx_bsl.tfxio import dataset_options # Categorical features are assumed to each have a maximum value in the dataset. _MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] _CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' + 'trip_start_hour', + 'trip_start_day', + 'trip_start_month', + 'pickup_census_tract', + 'dropoff_census_tract', + 'pickup_community_area', + 'dropoff_community_area', ] _DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] @@ -48,8 +44,10 @@ _FEATURE_BUCKET_COUNT = 10 _BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' + 'pickup_latitude', + 'pickup_longitude', + 'dropoff_latitude', + 'dropoff_longitude', ] # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform @@ -76,276 +74,293 @@ def _transformed_names(keys): return [_transformed_name(key) for key in keys] -# Tf.Transform considers these features as "raw" -def _get_raw_feature_spec(schema): - return schema_utils.schema_as_feature_spec(schema).feature_spec - - -def _gzip_reader_fn(filenames): - """Small utility returning a record reader that can read gzip'ed files.""" - return tf.data.TFRecordDataset(filenames, compression_type='GZIP') +def _fill_in_missing(x): + """Replace missing values in a SparseTensor. - -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. + Fills in missing values of `x` with '' or 0, and converts to a dense tensor. Args: - config: tf.estimator.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + in the second dimension. Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + A rank 1 tensor where missing values of `x` have been filled in. """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _example_serving_receiver_fn(tf_transform_output, schema): - """Build the serving in inputs. + if not isinstance(x, tf.sparse.SparseTensor): + return x + + default_value = '' if x.dtype == tf.string else 0 + return tf.squeeze( + tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value, + ), + axis=1, + ) + + +def _get_tf_examples_serving_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + + # We need to track the layers in the model in order to save it. + # TODO(b/162357359): Revise once the bug is resolved. + model.tft_layer_inference = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def serve_tf_examples_fn(serialized_tf_example): + """Returns the output to be used in the serving signature.""" + raw_feature_spec = tf_transform_output.raw_feature_spec() + # Remove label feature since these will not be present at serving time. + raw_feature_spec.pop(_LABEL_KEY) + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_inference(raw_features) + logging.info('serve_transformed_features = %s', transformed_features) + + outputs = model(transformed_features) + # TODO(b/154085620): Convert the predicted labels from the model using a + # reverse-lookup (opposite of transform.py). + return {'outputs': outputs} + + return serve_tf_examples_fn + + +def _get_transform_features_signature(model, tf_transform_output): + """Returns a serving signature that applies tf.Transform to features.""" + + # We need to track the layers in the model in order to save it. + # TODO(b/162357359): Revise once the bug is resolved. + model.tft_layer_eval = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def transform_features_fn(serialized_tf_example): + """Returns the transformed_features to be fed as input to evaluator.""" + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_eval(raw_features) + logging.info('eval_transformed_features = %s', transformed_features) + return transformed_features + + return transform_features_fn - Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - Returns: - Tensorflow graph which parses examples, applying tf-transform to them. - """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(_LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - transformed_features = tf_transform_output.transform_raw_features( - serving_input_receiver.features) - - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) - - -def _eval_input_receiver_fn(tf_transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. +def _input_fn( + file_pattern: list[str], + data_accessor: fn_args_utils.DataAccessor, + tf_transform_output: tft.TFTransformOutput, + batch_size: int = 200, +) -> tf.data.Dataset: + """Generates features and label for tuning/training. Args: + file_pattern: List of paths or patterns of input tfrecord files. + data_accessor: fn_args_utils.DataAccessor for converting input to + RecordBatch. tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') + return data_accessor.tf_dataset_factory( + file_pattern, + dataset_options.TensorFlowDatasetOptions( + batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) + ), + tf_transform_output.transformed_metadata.schema, + ).repeat() - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - transformed_features = tf_transform_output.transform_raw_features( - features) - - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=features, - receiver_tensors=receiver_tensors, - labels=transformed_features[_transformed_name(_LABEL_KEY)]) - - -def _input_fn( - filenames, data_accessor, tf_transform_output, batch_size=200): - """Generates features and labels for training or evaluation. +def _build_keras_model( + hidden_units: Optional[list[int]] = None, +) -> tf.keras.Model: + """Creates a DNN Keras model for classifying taxi data. Args: - filenames: [str] list of CSV files to read data from. - data_accessor: fn_args_utils.DataAccessor. - tf_transform_output: A TFTransformOutput. - batch_size: int First dimension size of the Tensors returned by input_fn + hidden_units: [int], the layer sizes of the DNN (input layer first). Returns: - A (features, indices) tuple where features is a dictionary of - Tensors, and indices is a single Tensor of label indices. + A Wide and Deep keras Model. """ - dataset = data_accessor.tf_dataset_factory( - filenames, - TensorFlowDatasetOptions( - batch_size=batch_size, - label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - - return tf.compat.v1.data.make_one_shot_iterator( - dataset).get_next() + # Following values are hard coded for simplicity in this example, + # However prefarably they should be passsed in as hparams. + # Keras needs the feature definitions at compile time. + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(), dtype=tf.float32) + for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + } + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') + for colname in _transformed_names(_VOCAB_FEATURE_KEYS) + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') + for colname in _transformed_names(_BUCKET_FEATURE_KEYS) + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') + for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } -# TFX will call this function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator using the high level API. + # TODO(b/161952382): Replace with Keras premade models and + # Keras preprocessing layers. + deep = tf.keras.layers.Normalization()(deep_input) + for numnodes in (hidden_units or [100, 70, 50, 25]): + deep = tf.keras.layers.Dense(numnodes)(deep) + + wide_layers = [] + for key in _transformed_names(_VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)( + input_layers[key] + ) + ) + for key in _transformed_names(_BUCKET_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.squeeze(output, -1) + + model = tf.keras.Model(input_layers, output) + model.compile( + loss='binary_crossentropy', + optimizer=tf.keras.optimizers.Adam(lr=0.001), + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) + model.summary(print_fn=logging.info) + return model + + +# TFX Transform will call this function. +def preprocessing_fn(inputs): + """tf.transform's callback function for preprocessing inputs. Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. + inputs: map from feature keys to raw not-yet-transformed features. Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + Map from string feature key to transformed feature operations. """ - if trainer_fn_args.hyperparameters: - hp = trainer_fn_args.hyperparameters - first_dnn_layer_size = hp.get('first_dnn_layer_size') - num_dnn_layers = hp.get('num_dnn_layers') - dnn_decay_factor = hp.get('dnn_decay_factor') - else: - # Number of nodes in the first layer of the DNN - first_dnn_layer_size = 100 - num_dnn_layers = 4 - dnn_decay_factor = 0.7 - - train_batch_size = 40 - eval_batch_size = 40 - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=train_batch_size) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=eval_batch_size) - - train_spec = tf_estimator.TrainSpec( # pylint: disable=g-long-lambda - train_input_fn, - max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, - # keep_checkpoint_max must be more than the number of worker replicas - # nodes if training distributed, in order to avoid race condition. - keep_checkpoint_max=5) - - export_dir = path_utils.serving_model_dir(trainer_fn_args.model_run_dir) - run_config = run_config.replace(model_dir=export_dir) - warm_start_from = trainer_fn_args.base_model - - estimator = _build_estimator( - # Construct layers sizes with exponetial decay - hidden_units=[ - max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) - for i in range(num_dnn_layers) - ], - config=run_config, - warm_start_from=warm_start_from) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn - } - - -# TFX generic trainer will call this function -def run_fn(fn_args: executor.TrainerFnArgs): + outputs = {} + for key in _DENSE_FLOAT_FEATURE_KEYS: + # If sparse make it dense, setting nan's to 0 or '', and apply zscore. + outputs[_transformed_name(key)] = tft.scale_to_z_score( + _fill_in_missing(inputs[key]) + ) + + for key in _VOCAB_FEATURE_KEYS: + # Build a vocabulary for this feature. + outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( + _fill_in_missing(inputs[key]), + top_k=_VOCAB_SIZE, + num_oov_buckets=_OOV_SIZE, + ) + + for key in _BUCKET_FEATURE_KEYS: + outputs[_transformed_name(key)] = tft.bucketize( + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT + ) + + for key in _CATEGORICAL_FEATURE_KEYS: + outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) + + # Was this passenger a big tipper? + taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) + tips = _fill_in_missing(inputs[_LABEL_KEY]) + outputs[_transformed_name(_LABEL_KEY)] = tf.where( + tf.math.is_nan(taxi_fare), + tf.cast(tf.zeros_like(taxi_fare), tf.int64), + # Test if the tip was > 20% of the fare. + tf.cast( + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64 + ), + ) + + return outputs + + +# TFX Trainer will call this function. +def run_fn(fn_args: fn_args_utils.FnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ - schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) - - training_spec = trainer_fn(fn_args, schema) - - # Train the model - absl.logging.info('Training model.') - tf_estimator.train_and_evaluate(training_spec['estimator'], - training_spec['train_spec'], - training_spec['eval_spec']) - - # Export an eval savedmodel for TFMA - # NOTE: When trained in distributed training cluster, eval_savedmodel must be - # exported only by the chief worker. - absl.logging.info('Exporting eval_savedmodel for TFMA.') - tfma.export.export_eval_savedmodel( - estimator=training_spec['estimator'], - export_dir_base=path_utils.eval_model_dir(fn_args.model_run_dir), - eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) - - # TODO(b/160795287): Deprecate estimator based executor. - # Copy serving and eval model from model_run to model artifact directory. - serving_source = path_utils.serving_model_path(fn_args.model_run_dir) - io_utils.copy_dir(serving_source, fn_args.serving_model_dir) - - eval_source = path_utils.eval_model_path(fn_args.model_run_dir) - io_utils.copy_dir(eval_source, fn_args.eval_model_dir) - - absl.logging.info('Training complete. Model written to %s', - fn_args.serving_model_dir) - absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir) + # Number of nodes in the first layer of the DNN + first_dnn_layer_size = 100 + num_dnn_layers = 4 + dnn_decay_factor = 0.7 + + tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) + + train_dataset = _input_fn( + fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40 + ) + eval_dataset = _input_fn( + fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40 + ) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + model = _build_keras_model( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ] + ) + + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=fn_args.model_run_dir, update_freq='epoch' + ) + + model.fit( + train_dataset, + steps_per_epoch=fn_args.train_steps, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback], + ) + + signatures = { + 'serving_default': _get_tf_examples_serving_signature( + model, tf_transform_output + ), + 'transform_features': _get_transform_features_signature( + model, tf_transform_output + ), + } + model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)