Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize python type hints for apache_beam. #32872

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class _InMemoryResultRecorder(object):
"""

# Class-level value to survive pickling.
_ALL_RESULTS = {} # type: Dict[str, List[Any]]
_ALL_RESULTS: Dict[str, List[Any]] = {}

def __init__(self):
self._id = id(self)
Expand Down
33 changes: 15 additions & 18 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Session(object):
def __init__(self, bindings=None):
self._bindings = dict(bindings or {})

def evaluate(self, expr): # type: (Expression) -> Any
def evaluate(self, expr: 'Expression') -> Any:
if expr not in self._bindings:
self._bindings[expr] = expr.evaluate_at(self)
return self._bindings[expr]

def lookup(self, expr): # type: (Expression) -> Any
def lookup(self, expr: 'Expression') -> Any:
return self._bindings[expr]


Expand Down Expand Up @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning:
class PlaceholderExpression(Expression):
"""An expression whose value must be explicitly bound in the session."""
def __init__(
self, # type: PlaceholderExpression
proxy, # type: T
reference=None, # type: Any
self,
proxy: T,
reference: Any = None,
):
"""Initialize a placeholder expression.

Expand Down Expand Up @@ -282,11 +282,7 @@ def preserves_partition_by(self):

class ConstantExpression(Expression):
"""An expression whose value is known at pipeline construction time."""
def __init__(
self, # type: ConstantExpression
value, # type: T
proxy=None # type: Optional[T]
):
def __init__(self, value: T, proxy: Optional[T] = None):
"""Initialize a constant expression.

Args:
Expand Down Expand Up @@ -319,14 +315,15 @@ def preserves_partition_by(self):
class ComputedExpression(Expression):
"""An expression whose value must be computed at pipeline execution time."""
def __init__(
self, # type: ComputedExpression
name, # type: str
func, # type: Callable[...,T]
args, # type: Iterable[Expression]
proxy=None, # type: Optional[T]
_id=None, # type: Optional[str]
requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning
preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning
self,
name: str,
func: Callable[..., T],
args: Iterable[Expression],
proxy: Optional[T] = None,
_id: Optional[str] = None,
requires_partition_by: partitionings.Partitioning = partitionings.Index(),
preserves_partition_by: partitionings.Partitioning = partitionings.
Singleton(),
):
"""Initialize a computed expression.

Expand Down
24 changes: 10 additions & 14 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
from apache_beam.pvalue import PCollection
from apache_beam.utils import windowed_value

__all__ = [
'DataframeTransform',
]

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
from apache_beam.pvalue import PCollection

T = TypeVar('T')

Expand Down Expand Up @@ -108,15 +106,15 @@ def expand(self, input_pcolls):
from apache_beam.dataframe import convert

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection]
input_dict: Dict[Any, PCollection] = _flatten(input_pcolls)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict
}
input_frames = {
input_frames: Dict[Any, DeferredFrame] = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
} # type: Dict[Any, DeferredFrame] # noqa: F821
} # noqa: F821

# Apply the function.
frames_input = _substitute(input_pcolls, input_frames)
Expand Down Expand Up @@ -152,9 +150,9 @@ def expand(self, inputs):

def _apply_deferred_ops(
self,
inputs, # type: Dict[expressions.Expression, PCollection]
outputs, # type: Dict[Any, expressions.Expression]
): # -> Dict[Any, PCollection]
inputs: Dict[expressions.Expression, PCollection],
outputs: Dict[Any, expressions.Expression],
): # -> Dict[Any, PCollection]
"""Construct a Beam graph that evaluates a set of expressions on a set of
input PCollections.

Expand Down Expand Up @@ -585,11 +583,9 @@ def _concat(parts):


def _flatten(
valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]]
root=(), # type: Tuple[Any, ...]
):
# type: (...) -> Mapping[Tuple[Any, ...], T]

valueish: Union[T, List[T], Tuple[T], Dict[Any, T]],
root: Tuple[Any, ...] = (),
) -> Mapping[Tuple[Any, ...], T]:
"""Given a nested structure of dicts, tuples, and lists, return a flat
dictionary where the values are the leafs and the keys are the "paths" to
these leaves.
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

class AvroBase(object):

_temp_files = [] # type: List[str]
_temp_files: List[str] = []

def __init__(self, methodName='runTest'):
super().__init__(methodName)
Expand Down
36 changes: 12 additions & 24 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,13 @@
from apache_beam.options.value_provider import ValueProvider
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.window import BoundedWindow
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.transforms.window import BoundedWindow

__all__ = [
'EmptyMatchTreatment',
'MatchFiles',
Expand Down Expand Up @@ -382,8 +380,7 @@ def create_metadata(
mime_type="application/octet-stream",
compression_type=CompressionTypes.AUTO)

def open(self, fh):
# type: (BinaryIO) -> None
def open(self, fh: BinaryIO) -> None:
raise NotImplementedError

def write(self, record):
Expand Down Expand Up @@ -575,8 +572,7 @@ class signature or an instance of FileSink to this parameter. If none is
self._max_num_writers_per_bundle = max_writers_per_bundle

@staticmethod
def _get_sink_fn(input_sink):
# type: (...) -> Callable[[Any], FileSink]
def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]:
if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
return lambda x: input_sink()
elif isinstance(input_sink, FileSink):
Expand All @@ -588,8 +584,7 @@ def _get_sink_fn(input_sink):
return lambda x: TextSink()

@staticmethod
def _get_destination_fn(destination):
# type: (...) -> Callable[[Any], str]
def _get_destination_fn(destination) -> Callable[[Any], str]:
if isinstance(destination, ValueProvider):
return lambda elm: destination.get()
elif callable(destination):
Expand Down Expand Up @@ -757,12 +752,8 @@ def _check_orphaned_files(self, writer_key):


class _WriteShardedRecordsFn(beam.DoFn):

def __init__(self,
base_path,
sink_fn, # type: Callable[[Any], FileSink]
shards # type: int
):
def __init__(
self, base_path, sink_fn: Callable[[Any], FileSink], shards: int):
self.base_path = base_path
self.sink_fn = sink_fn
self.shards = shards
Expand Down Expand Up @@ -805,17 +796,13 @@ def process(


class _AppendShardedDestination(beam.DoFn):
def __init__(
self,
destination, # type: Callable[[Any], str]
shards # type: int
):
def __init__(self, destination: Callable[[Any], str], shards: int):
self.destination_fn = destination
self.shards = shards

# We start the shards for a single destination at an arbitrary point.
self._shard_counter = collections.defaultdict(
lambda: random.randrange(self.shards)) # type: DefaultDict[str, int]
self._shard_counter: DefaultDict[str, int] = collections.defaultdict(
lambda: random.randrange(self.shards))

def _next_shard_for_destination(self, destination):
self._shard_counter[destination] = ((self._shard_counter[destination] + 1) %
Expand All @@ -835,8 +822,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
SPILLED_RECORDS = 'spilled_records'
WRITTEN_FILES = 'written_files'

_writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]]
_file_names = None # type: Dict[Tuple[str, BoundedWindow], str]
_writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO,
FileSink]] = None
_file_names: Dict[Tuple[str, BoundedWindow], str] = None

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _insert_load_job(

def _start_job(
self,
request, # type: bigquery.BigqueryJobsInsertRequest
request: 'bigquery.BigqueryJobsInsertRequest',
stream=None,
):
"""Inserts a BigQuery job.
Expand Down Expand Up @@ -1786,9 +1786,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None):


def check_schema_equal(
left, right, *, ignore_descriptions=False, ignore_field_order=False):
# type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool

left: Union[bigquery.TableSchema, bigquery.TableFieldSchema],
right: Union[bigquery.TableSchema, bigquery.TableFieldSchema],
*,
ignore_descriptions: bool = False,
ignore_field_order: bool = False) -> bool:
"""Check whether schemas are equivalent.

This comparison function differs from using == to compare TableSchema
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True):

class GcsIO(object):
"""Google Cloud Storage I/O client."""
def __init__(self, storage_client=None, pipeline_options=None):
# type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None
def __init__(
self,
storage_client: Optional[storage.Client] = None,
pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None:
if pipeline_options is None:
pipeline_options = PipelineOptions()
elif isinstance(pipeline_options, dict):
Expand Down
54 changes: 26 additions & 28 deletions sdks/python/apache_beam/metrics/monitoring_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def create_labels(ptransform=None, namespace=None, name=None, pcollection=None):
return labels


def int64_user_counter(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_counter(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.

Args:
Expand All @@ -199,9 +198,12 @@ def int64_user_counter(namespace, name, metric, ptransform=None):
USER_COUNTER_URN, SUM_INT64_TYPE, metric, labels)


def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_counter(
urn,
metric,
ptransform=None,
pcollection=None,
labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.

Args:
Expand All @@ -217,9 +219,8 @@ def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
return create_monitoring_info(urn, SUM_INT64_TYPE, metric, labels)


def int64_user_distribution(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_distribution(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the distribution monitoring info for the URN, metric and labels.

Args:
Expand All @@ -234,9 +235,11 @@ def int64_user_distribution(namespace, name, metric, ptransform=None):
USER_DISTRIBUTION_URN, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_distribution(urn, metric, ptransform=None, pcollection=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_distribution(
urn,
metric,
ptransform=None,
pcollection=None) -> metrics_pb2.MonitoringInfo:
"""Return a distribution monitoring info for the URN, metric and labels.

Args:
Expand All @@ -251,9 +254,8 @@ def int64_distribution(urn, metric, ptransform=None, pcollection=None):
return create_monitoring_info(urn, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_user_gauge(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_gauge(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.

Args:
Expand All @@ -276,9 +278,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None):
USER_GAUGE_URN, LATEST_INT64_TYPE, payload, labels)


def int64_gauge(urn, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_gauge(urn, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.

Args:
Expand Down Expand Up @@ -320,9 +320,8 @@ def user_set_string(namespace, name, metric, ptransform=None):
USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels)


def create_monitoring_info(urn, type_urn, payload, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def create_monitoring_info(
urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, type, metric and labels.

Args:
Expand Down Expand Up @@ -366,9 +365,9 @@ def is_user_monitoring_info(monitoring_info_proto):
return monitoring_info_proto.urn in USER_METRIC_URNS


def extract_metric_result_map_value(monitoring_info_proto):
# type: (...) -> Union[None, int, DistributionResult, GaugeResult, set]

def extract_metric_result_map_value(
monitoring_info_proto
) -> Union[None, int, DistributionResult, GaugeResult, set]:
"""Returns the relevant GaugeResult, DistributionResult or int value for
counter metric, set for StringSet metric.

Expand Down Expand Up @@ -408,14 +407,13 @@ def get_step_name(monitoring_info_proto):
return monitoring_info_proto.labels.get(PTRANSFORM_LABEL)


def to_key(monitoring_info_proto):
# type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]

def to_key(
monitoring_info_proto: metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]:
"""Returns a key based on the URN and labels.

This is useful in maps to prevent reporting the same MonitoringInfo twice.
"""
key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable]
key_items: List[Hashable] = list(monitoring_info_proto.labels.items())
key_items.append(monitoring_info_proto.urn)
return frozenset(key_items)

Expand Down
Loading
Loading