diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 128ab9206bfd4..f5183b49bc4b7 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -243,6 +243,7 @@ def __init__( self.contains_external_transforms = False self._display_data = display_data or {} + self._error_handlers = [] def display_data(self): # type: () -> Dict[str, Any] @@ -258,6 +259,9 @@ def allow_unsafe_triggers(self): # type: () -> bool return self._options.view_as(TypeOptions).allow_unsafe_triggers + def _register_error_handler(self, error_handler): + self._error_handlers.append(error_handler) + def _current_transform(self): # type: () -> AppliedPTransform @@ -531,6 +535,9 @@ def run(self, test_runner_api='AUTO'): """Runs the pipeline. Returns whatever our runner returns after running.""" + for error_handler in self._error_handlers: + error_handler.verify_closed() + # Records whether this pipeline contains any cross-language transforms. self.contains_external_transforms = ( ExternalTransformFinder.contains_external_transforms(self)) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 68c9eecd9f3fc..32ce05de62065 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1573,6 +1573,7 @@ def with_exception_handling( threshold=1, threshold_windowing=None, timeout=None, + error_handler=None, on_failure_callback: typing.Optional[typing.Callable[ [Exception, typing.Any], None]] = None): """Automatically provides a dead letter output for skipping bad records. @@ -1622,6 +1623,8 @@ def with_exception_handling( defaults to the windowing of the input. timeout: If the element has not finished processing in timeout seconds, raise a TimeoutError. Defaults to None, meaning no time limit. + error_handler: An ErrorHandler that should be used to consume the bad + records, rather than returning the good and bad records as a tuple. on_failure_callback: If an element fails or times out, on_failure_callback will be invoked. It will receive the exception and the element being processed in as args. In case of a timeout, @@ -1642,8 +1645,20 @@ def with_exception_handling( threshold, threshold_windowing, timeout, + error_handler, on_failure_callback) + def with_error_handler(self, error_handler, **exception_handling_kwargs): + """An alias for `with_exception_handling(error_handler=error_handler, ...)` + + This is provided to fit the general ErrorHandler conventions. + """ + if error_handler is None: + return self + else: + return self.with_exception_handling( + error_handler=error_handler, **exception_handling_kwargs) + def default_type_hints(self): return self.fn.get_type_hints() @@ -2242,6 +2257,7 @@ def __init__( threshold, threshold_windowing, timeout, + error_handler, on_failure_callback): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') @@ -2256,6 +2272,7 @@ def __init__( self._threshold = threshold self._threshold_windowing = threshold_windowing self._timeout = timeout + self._error_handler = error_handler self._on_failure_callback = on_failure_callback def expand(self, pcoll): @@ -2306,7 +2323,11 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): _ = bad_count_pcoll | Map( check_threshold, input_count_view, self._threshold) - return result + if self._error_handler: + self._error_handler.add_error_pcollection(result[self._dead_letter_tag]) + return result[self._main_tag] + else: + return result class _ExceptionHandlingWrapperDoFn(DoFn): diff --git a/sdks/python/apache_beam/transforms/error_handling.py b/sdks/python/apache_beam/transforms/error_handling.py new file mode 100644 index 0000000000000..8671c66a12e02 --- /dev/null +++ b/sdks/python/apache_beam/transforms/error_handling.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utilities for gracefully handling errors and excluding bad elements.""" + +import traceback + +from apache_beam import transforms + + +class ErrorHandler: + """ErrorHandlers are used to skip and otherwise process bad records. + + Error handlers allow one to implement the "dead letter queue" pattern in + a fluent manner, disaggregating the error processing specification from + the main processing chain. + + This is typically used as follows:: + + with error_handling.ErrorHandler(WriteToSomewhere(...)) as error_handler: + result = pcoll | SomeTransform().with_error_handler(error_handler) + + in which case errors encountered by `SomeTransform()`` in processing pcoll + will be written by the PTransform `WriteToSomewhere(...)` and excluded from + `result` rather than failing the pipeline. + + To implement `with_error_handling` on a PTransform, one caches the provided + error handler for use in `expand`. During `expand()` one can invoke + `error_handler.add_error_pcollection(...)` any number of times with + PCollections containing error records to be processed by the given error + handler, or (if applicable) simply invoke `with_error_handling(...)` on any + subtransforms. + + The `with_error_handling` should accept `None` to indicate that error handling + is not enabled (and make implementation-by-forwarding-error-handlers easier). + In this case, any non-recoverable errors should fail the pipeline (e.g. + propagate exceptions in `process` methods) rather than silently ignore errors. + """ + def __init__(self, consumer): + self._consumer = consumer + self._creation_traceback = traceback.format_stack()[-2] + self._error_pcolls = [] + self._closed = False + + def __enter__(self): + self._error_pcolls = [] + self._closed = False + return self + + def __exit__(self, *exec_info): + if exec_info[0] is None: + self.close() + + def close(self): + """Indicates all error-producing operations have reported any errors. + + Invokes the provided error consuming PTransform on any provided error + PCollections. + """ + self._output = ( + tuple(self._error_pcolls) | transforms.Flatten() | self._consumer) + self._closed = True + + def output(self): + """Returns result of applying the error consumer to the error pcollections. + """ + if not self._closed: + raise RuntimeError( + "Cannot access the output of an error handler " + "until it has been closed.") + return self._output + + def add_error_pcollection(self, pcoll): + """Called by a class implementing error handling on the error records. + """ + pcoll.pipeline._register_error_handler(self) + self._error_pcolls.append(pcoll) + + def verify_closed(self): + """Called at end of pipeline construction to ensure errors are not ignored. + """ + if not self._closed: + raise RuntimeError( + "Unclosed error handler initialized at %s" % self._creation_traceback) + + +class _IdentityPTransform(transforms.PTransform): + def expand(self, pcoll): + return pcoll + + +class CollectingErrorHandler(ErrorHandler): + """An ErrorHandler that simply collects all errors for further processing. + + This ErrorHandler requires the set of errors be retrieved via `output()` + and consumed (or explicitly discarded). + """ + def __init__(self): + super().__init__(_IdentityPTransform()) + self._creation_traceback = traceback.format_stack()[-2] + self._output_accessed = False + + def output(self): + self._output_accessed = True + return super().output() + + def verify_closed(self): + if not self._output_accessed: + raise RuntimeError( + "CollectingErrorHandler requires the output to be retrieved. " + "Initialized at %s" % self._creation_traceback) + return super().verify_closed() diff --git a/sdks/python/apache_beam/transforms/error_handling_test.py b/sdks/python/apache_beam/transforms/error_handling_test.py new file mode 100644 index 0000000000000..4d8c2d23dc149 --- /dev/null +++ b/sdks/python/apache_beam/transforms/error_handling_test.py @@ -0,0 +1,148 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms import error_handling + + +class PTransformWithErrors(beam.PTransform): + def __init__(self, limit): + self._limit = limit + self._error_handler = None + + def with_error_handler(self, error_handler): + self._error_handler = error_handler + return self + + def expand(self, pcoll): + limit = self._limit + + def process(element): + if len(element) < limit: + return element.title() + else: + return beam.pvalue.TaggedOutput('bad', element) + + def raise_on_everything(element): + raise ValueError(element) + + good, bad = pcoll | beam.Map(process).with_outputs('bad', main='good') + if self._error_handler: + self._error_handler.add_error_pcollection(bad) + else: + # Will throw an exception if there are any bad elements. + _ = bad | beam.Map(raise_on_everything) + return good + + +def exception_throwing_map(x, limit): + if len(x) > limit: + raise ValueError(x) + else: + return x.title() + + +class ErrorHandlingTest(unittest.TestCase): + def test_error_handling(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + with error_handling.ErrorHandler( + beam.Map(lambda x: "error: %s" % x)) as error_handler: + result = pcoll | PTransformWithErrors(3).with_error_handler( + error_handler) + error_pcoll = error_handler.output() + + assert_that(result, equal_to(['A', 'Bb']), label='CheckGood') + assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad') + + def test_error_handling_pardo(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + with error_handling.ErrorHandler( + beam.Map(lambda x: "error: %s" % x[0])) as error_handler: + result = pcoll | beam.Map( + exception_throwing_map, limit=3).with_error_handler(error_handler) + error_pcoll = error_handler.output() + + assert_that(result, equal_to(['A', 'Bb']), label='CheckGood') + assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad') + + def test_error_handling_pardo_with_exception_handling_kwargs(self): + def side_effect(*args): + beam._test_error_handling_pardo_with_exception_handling_kwargs_val = True + + def check_side_effect(): + return getattr( + beam, + '_test_error_handling_pardo_with_exception_handling_kwargs_val', + False) + + self.assertFalse(check_side_effect()) + + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + with error_handling.ErrorHandler( + beam.Map(lambda x: "error: %s" % x[0])) as error_handler: + result = pcoll | beam.Map( + exception_throwing_map, limit=3).with_error_handler( + error_handler, on_failure_callback=side_effect) + error_pcoll = error_handler.output() + + assert_that(result, equal_to(['A', 'Bb']), label='CheckGood') + assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad') + + self.assertTrue(check_side_effect()) + + def test_error_on_unclosed_error_handler(self): + with self.assertRaisesRegex(RuntimeError, r'.*Unclosed error handler.*'): + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + # Use this outside of a context to allow it to remain unclosed. + error_handler = error_handling.ErrorHandler(beam.Map(lambda x: x)) + _ = pcoll | PTransformWithErrors(3).with_error_handler(error_handler) + + def test_collecting_error_handler(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + with error_handling.CollectingErrorHandler() as error_handler: + result = pcoll | beam.Map( + exception_throwing_map, limit=3).with_error_handler(error_handler) + error_pcoll = error_handler.output() | beam.Map(lambda x: x[0]) + + assert_that(result, equal_to(['A', 'Bb']), label='CheckGood') + assert_that(error_pcoll, equal_to(['cccc']), label='CheckBad') + + def test_error_on_collecting_error_handler_without_output_retrieval(self): + with self.assertRaisesRegex( + RuntimeError, + r'.*CollectingErrorHandler requires the output to be retrieved.*'): + with beam.Pipeline() as p: + pcoll = p | beam.Create(['a', 'bb', 'cccc']) + with error_handling.CollectingErrorHandler() as error_handler: + _ = pcoll | beam.Map( + exception_throwing_map, + limit=3).with_error_handler(error_handler) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()