Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to RunInferenceCore #11

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

Conversation

hgarrereyn
Copy link

@hgarrereyn hgarrereyn commented Jul 30, 2020

Overview

This PR moves the internal implementation of RunInferenceImpl to _RunInferenceCore. This core component accepts query tuples, i.e. (InferenceSpecType, Example). This PR has no public facing changes but will allow for the implementation of a streaming model API.

What was changed:

  • RunInferenceImpl is now a wrapper around _RunInferenceCore
  • _RunInferenceCore and the internal PTransforms that previously took inference_spec_type arguments now accept an optional fixed_inference_spec_type that can be None.
    • When the model is known at pipeline creation time, this fixed_inference_spec_type is set and the internal PTransforms can take a "fast" path that includes collapsing down to a single sequence of PTransforms and loading the model during the setup method of the DoFn. (similar to the current implementation)
    • When the model is not known, _RunInferenceCore will build a graph containing all possible operations and types (local/remote) and queries will be batched and routed to the correct operation at runtime. In this case, models will be loaded during the process method of a DoFn (but caching is still possible).
  • Internal logic in both _BaseBatchSavedModelDoFn and _RemotePredictDoFn was restructured:
    • model loading is moved to _setup_model(self, inference_spec_type: model_spec_pb2.InferenceSpecType). This function includes code previously in both __init__ and setup. It will be called either in setup or process depending on whether or not the inference spec is available at pipeline construction time.
    • during local inference, operation subclasses (e.g. _BatchClassifyDoFn) could implement operation-specific model validation by overloading the setup method and optionally raising an error. This check occurs after the model signature is available but before it has been loaded. Since all this logic is now contained in _setup_model, there is a new _validate_model(self) method that is unimplemented in the base class and can be overloaded to perform validation logic.
  • Type simplification: to simplify type signatures and allow for easier future improvements, several type aliases were created:
    • ExampleType, QueryType, _QueryBatchType (the first two types will be public facing after the model streaming API is implemented)
  • Query batching: currently examples are batched with beam.BatchElements; when working with queries, it is necessary to also perform a grouping operation by model spec. beam.GroupIntoBatches is currently experimental but contains this functionality. Unfortunately, BEAM-2717 currently blocks RunInference in GCP Dataflow v2 runner and the v1 runner does not support stateful DoFn's which is required for GroupIntoBatches. Currently BatchElements is used as a temporary replacement with the understanding that the current implementation will not use more than one model at a time.
  • Tests
    • Added a test for _BatchQueries and a TODO test that addresses the comment above
    • Added a test for running _RunInferenceCore with raw queries

Benchmarks:

A test set of 1,000,000 examples (chicago taxi example) was run in a small Beam pipeline on Dataflow (v1 runner). These are the total wall times for 3 separate runs of the RunInference component:

(Current)

  • 208 s
  • 209 s
  • 190 s

(RunInferenceCore)

  • 217 s
  • 208 s
  • 209 s

@hgarrereyn
Copy link
Author

@rose-rong-liu @SherylLuo

def _BatchQueries(queries: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
"""Groups queries into batches."""

def _add_key(query: QueryType) -> Tuple[bytes, QueryType]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before streaming is enabled, the model will be the same during inference. Is this still needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reduced _BatchQueries to the minimum required to operate on identical inference specs. #13 can introduce keyed batching and inference spec serialization

Comment on lines 291 to 304
if fixed_inference_spec_type is None:
tagged = pcoll | 'Tag inference type' >> _TagUsingInProcessInference()
tagged['remote'] | 'NotImplemented' >> _NotImplementedTransform()
raw_predictions = (
tagged['local']
| 'Regress' >> beam.ParDo(_BatchRegressDoFn(shared.Shared())))
else:
raise NotImplementedError
if _using_in_process_inference(fixed_inference_spec_type):
raw_predictions = (
pcoll
| 'Regress' >> beam.ParDo(_BatchRegressDoFn(shared.Shared(),
fixed_inference_spec_type=fixed_inference_spec_type)))
else:
raise NotImplementedError

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this part is repeated several times, can this be extracted to a function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this section into a single operation constructor:

def _BuildInferenceOperation(
name: str,
in_process_dofn: _BaseBatchSavedModelDoFn,
remote_dofn: Optional[_BaseDoFn],
build_prediction_log_dofn: beam.DoFn
):
"""Construct an operation specific inference sub-pipeline.
Args:
name: name of the operation (e.g. "Classify")
in_process_dofn: a _BaseBatchSavedModelDoFn class to use for in-process
inference
remote_dofn: an optional DoFn that is used for remote inference
build_prediction_log_dofn: a DoFn that can build prediction logs from the
output of `in_process_dofn` and `remote_dofn`
Returns:
A PTransform of the type (_QueryBatchType -> PredictionLog)
"""
@beam.ptransform_fn
@beam.typehints.with_input_types(_QueryBatchType)
@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog)
def _Op(
pcoll: beam.pvalue.PCollection,
fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None
): # pylint: disable=invalid-name
raw_result = None
if fixed_inference_spec_type is None:
tagged = pcoll | 'TagInferenceType' >> _TagUsingInProcessInference()
in_process_result = (
tagged['in_process']
| ('InProcess%s' % name) >> beam.ParDo(
in_process_dofn(shared.Shared())))
if remote_dofn:
remote_result = (
tagged['remote']
| ('Remote%s' % name) >> beam.ParDo(
remote_dofn(pcoll.pipeline.options)))
raw_result = (
[in_process_result, remote_result]
| 'FlattenResult' >> beam.Flatten())
else:
raw_result = in_process_result
else:
if _using_in_process_inference(fixed_inference_spec_type):
raw_result = (
pcoll
| ('InProcess%s' % name) >> beam.ParDo(in_process_dofn(
shared.Shared(),
fixed_inference_spec_type=fixed_inference_spec_type)))
else:
raw_result = (
pcoll
| ('Remote%s' % name) >> beam.ParDo(remote_dofn(
pcoll.pipeline.options,
fixed_inference_spec_type=fixed_inference_spec_type)))
return (
raw_result
| ('BuildPredictionLogFor%s' % name) >> beam.ParDo(
build_prediction_log_dofn()))
return _Op
_Classify = _BuildInferenceOperation(
'Classify', _BatchClassifyDoFn, None,
_BuildPredictionLogForClassificationsDoFn)
_Regress = _BuildInferenceOperation(
'Regress', _BatchRegressDoFn, None,
_BuildPredictionLogForRegressionsDoFn)
_Predict = _BuildInferenceOperation(
'Predict', _BatchPredictDoFn, _RemotePredictDoFn,
_BuildPredictionLogForPredictionsDoFn)
_MultiInference = _BuildInferenceOperation(
'MultiInference', _BatchMultiInferenceDoFn, None,
_BuildMultiInferenceLogDoFn)

Comment on lines 542 to 543
if self._use_fixed_model:
self._setup_model(self._fixed_inference_spec_type)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in setup instead of init ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this to __init__ would require the api client to be serializable (as part of the DoFn). This may be possible, but may also lead to some strange issues if multiple DoFn's sharing an api client is a problem. The original code configured this in setup so this PR is just maintaining that convention.

saved_model_spec=model_spec_pb2.SavedModelSpec(
model_path=model_path))

def test_batch_queries_single_model(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intended to contrast the todo test below, but I think it is redundant so I've removed it

tfx_bsl/beam/run_inference.py Outdated Show resolved Hide resolved
tfx_bsl/beam/run_inference.py Outdated Show resolved Hide resolved
tfx_bsl/beam/run_inference.py Outdated Show resolved Hide resolved
tfx_bsl/beam/run_inference.py Show resolved Hide resolved
@hgarrereyn hgarrereyn requested a review from SherylLuo August 7, 2020 02:51
Benchmarks showed that TagByOperation was a performance bottleneck* as it
requires disc access per query batch. To mitigate this I implemented
operation caching inside the DoFn. For readability, I also renamed this
operation to "SplitByOperation" as that more accurately describes its
purpose.

On a dataset with 1m examples, TagByOperation took ~25% of the total wall
time. After implementing caching, this was reduced to ~2%.
super(_BaseDoFn, self).__init__()
self._clock = None
self._metrics_collector = self._MetricsCollector(inference_spec_type)
self._metrics_collector = self._MetricsCollector(fixed_inference_spec_type)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved to setup_model() ? So that we can know the proximity and operation_type for model streaming case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this is that we might run into issues initializing Beam metrics outside of __init__. That might be a question for someone with more experience with Beam/Dataflow.

For the purposes of metrics collection, I think the end goal here is to collect metrics per operation type (e.g. classify, regress, ...). The fixed_inference_spec_type is only used here to determine the operation type so we could get rid of that and use a different solution where we initialize unique metric collectors for each operation type (which does not require knowing the inference spec). Then at runtime we choose which metric collector to use based on the inference spec (which will be available).

e.g. here we would have something like:

self._classify_metrics = self._MetricsCollector(OperationType.CLASSIFICATION, _METRICS_DESCRIPTOR_IN_PROCESS)
self._regress_metrics = self._MetricsCollector(OperationType.REGRESSION, _METRICS_DESCRIPTOR_IN_PROCESS)
...

and maybe expose a new method:

class _BaseDoFn(beam.DoFn):
  ...
  def _metrics_collector_for_inference_spec(inference_spec_type: InferenceSpecType) -> _MetricsCollector:
    ...

internally, we could replace:

self._metrics_collector.update(...)

with

self._metrics_collector_for_inference_spec(inference_spec).update(...)

Does this sound reasonable?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds good. Also note that we need to toggle on _METRICS_DESCRIPTOR_IN_PROCESS and _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a slightly cleaner solution here: I refactored MetricsCollector to accept an operation type and proximity directly: https://github.com/hgarrereyn/tfx-bsl/blob/core/tfx_bsl/beam/run_inference.py#L299-L309

Each subclass DoFn is responsible for configuring the correct operation type and proximity (for this we don't need the inference spec).

I refactored the MetricCollector methods for readability so now there is:

  • update_inference for updating inference metrics
  • update_model_load + commit_cached_metrics for updating model loading metrics. The first function caches the metrics and the second function commits the metrics. In the fixed model case, update_model_load is called in DoFn.setup and commit_cached_metrics is called in DoFn.finish_bundle and I added some documentation explaining why. In the dynamic model case, both functions are called sequentially in DoFn.process.

@rose-rong-liu
Copy link

Thanks Harrison! It looks good in general. Just some minor comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants