From be7a084e737a09a9296e091f3767ab6ae1828bc9 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Tue, 15 Oct 2024 04:16:46 -0700 Subject: [PATCH] Update some types to modern Python PiperOrigin-RevId: 686048315 --- tensorflow_datasets/core/logging/__init__.py | 12 +-- .../core/logging/base_logger.py | 90 +++++++++---------- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/tensorflow_datasets/core/logging/__init__.py b/tensorflow_datasets/core/logging/__init__.py index c89345da669..205bf98209e 100644 --- a/tensorflow_datasets/core/logging/__init__.py +++ b/tensorflow_datasets/core/logging/__init__.py @@ -20,7 +20,7 @@ import collections import functools import threading -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, TypeVar from absl import flags from absl import logging @@ -34,15 +34,15 @@ _LoggerMethod = Callable[..., None] -_registered_loggers: Optional[List[base_logger.Logger]] = None +_registered_loggers: list[base_logger.Logger] | None = None -_import_operations: List[Tuple[call_metadata.CallMetadata, int, int]] = [] +_import_operations: list[tuple[call_metadata.CallMetadata, int, int]] = [] _import_operations_lock = threading.Lock() _thread_id_to_builder_init_count = collections.Counter() -def _init_registered_loggers() -> List[base_logger.Logger]: +def _init_registered_loggers() -> list[base_logger.Logger]: """Initializes the registered loggers if they are not set yet.""" global _registered_loggers if _registered_loggers is None: @@ -65,7 +65,7 @@ def _log_import_operation(): _import_operations.clear() -def _get_registered_loggers() -> List[base_logger.Logger]: +def _get_registered_loggers() -> list[base_logger.Logger]: _log_import_operation() return _init_registered_loggers() @@ -188,7 +188,7 @@ class _DsbuilderMethodDecorator(_FunctionDecorator): IS_PROPERTY: bool = False @staticmethod - def _get_info(dsbuilder: Any) -> Tuple[str, str, str, str]: + def _get_info(dsbuilder: Any) -> tuple[str, str, str, str]: """Gets information about the builder. Args: diff --git a/tensorflow_datasets/core/logging/base_logger.py b/tensorflow_datasets/core/logging/base_logger.py index d3c1f021783..e003af1a1e4 100644 --- a/tensorflow_datasets/core/logging/base_logger.py +++ b/tensorflow_datasets/core/logging/base_logger.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import Any from etils import epy @@ -51,7 +51,7 @@ def tfds_import( metadata: call_metadata.CallMetadata, import_time_ms_tensorflow: int, import_time_ms_dataset_builders: int, - ): + ) -> None: """Callback called when user calls `import tensorflow_datasets`.""" pass @@ -60,11 +60,11 @@ def builder_init( *, metadata: call_metadata.CallMetadata, name: str, - data_dir: Optional[str], - config: Optional[str], - version: Optional[str], + data_dir: str | None, + config: str | None, + version: str | None, is_read_only_builder: bool, - ): + ) -> None: """Callback called when user calls `DatasetBuilder(...)`.""" pass @@ -73,10 +73,10 @@ def builder_info( *, metadata: call_metadata.CallMetadata, name: str, - config_name: Optional[str], + config_name: str | None, version: str, data_path: str, - ): + ) -> None: """Callback called when user calls `builder.info()`.""" pass @@ -85,16 +85,16 @@ def as_dataset( *, metadata: call_metadata.CallMetadata, name: str, - config_name: Optional[str], + config_name: str | None, version: str, data_path: str, - split: Optional[type_utils.Tree[splits_lib.SplitArg]], - batch_size: Optional[int], + split: type_utils.Tree[splits_lib.SplitArg] | None, + batch_size: int | None, shuffle_files: bool, read_config: read_config_lib.ReadConfig, as_supervised: bool, - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]], - ): + decoders: TreeDict[decode.partial_decode.DecoderArg] | None, + ) -> None: """Callback called when user calls `dataset_builder.as_dataset`. Callback is also triggered by `tfds.load`, which calls `as_dataset`. @@ -122,13 +122,13 @@ def download_and_prepare( *, metadata: call_metadata.CallMetadata, name: str, - config_name: Optional[str], + config_name: str | None, version: str, data_path: str, - download_dir: Optional[str], - download_config: Optional[download_lib.DownloadConfig], - file_format: Union[None, str, file_adapters.FileFormat], - ): + download_dir: str | None, + download_config: download_lib.DownloadConfig | None, + file_format: str | file_adapters.FileFormat | None, + ) -> None: """Callback called when user calls `dataset_builder.download_and_prepare`.""" pass @@ -141,8 +141,8 @@ def builder( *, metadata: call_metadata.CallMetadata, name: str, - try_gcs: Optional[bool], - ): + try_gcs: bool | None, + ) -> None: """Callback called when user calls `tfds.builder(...)`.""" pass @@ -150,8 +150,8 @@ def dataset_collection( self, metadata: call_metadata.CallMetadata, name: str, - loader_kwargs: Optional[Dict[str, Any]], - ): + loader_kwargs: dict[str, Any] | None, + ) -> None: """Callback called when user calls `tfds.dataset_collection(...)`.""" pass @@ -160,17 +160,17 @@ def load( *, metadata: call_metadata.CallMetadata, name: str, - split: Optional[type_utils.Tree[splits_lib.SplitArg]], - data_dir: Optional[str], - batch_size: Optional[int], - shuffle_files: Optional[bool], - download: Optional[bool], - as_supervised: Optional[bool], - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]], - read_config: Optional[read_config_lib.ReadConfig], - with_info: Optional[bool], - try_gcs: Optional[bool], - ): + split: type_utils.Tree[splits_lib.SplitArg] | None, + data_dir: str | None, + batch_size: int | None, + shuffle_files: bool | None, + download: bool | None, + as_supervised: bool | None, + decoders: TreeDict[decode.partial_decode.DecoderArg] | None, + read_config: read_config_lib.ReadConfig | None, + with_info: bool | None, + try_gcs: bool | None, + ) -> None: """Callback called when user calls `tfds.load(...)`.""" pass @@ -178,8 +178,8 @@ def list_builders( self, *, metadata: call_metadata.CallMetadata, - with_community_datasets: Optional[bool], - ): + with_community_datasets: bool | None, + ) -> None: """Callback called when user calls `tfds.list_builders(...)`.""" pass @@ -192,12 +192,12 @@ def data_source( *, metadata: call_metadata.CallMetadata, name: str, - split: Optional[type_utils.Tree[splits_lib.SplitArg]], - data_dir: Optional[str], - download: Optional[bool], - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]], - try_gcs: Optional[bool], - ): + split: type_utils.Tree[splits_lib.SplitArg] | None, + data_dir: str | None, + download: bool | None, + decoders: TreeDict[decode.partial_decode.DecoderArg] | None, + try_gcs: bool | None, + ) -> None: """Callback called when user calls `tfds.data_source(...)`.""" pass @@ -206,11 +206,11 @@ def as_data_source( *, metadata: call_metadata.CallMetadata, name: str, - config_name: Optional[str], + config_name: str | None, version: str, data_path: str, - split: Optional[type_utils.Tree[splits_lib.SplitArg]], - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]], - ): + split: type_utils.Tree[splits_lib.SplitArg] | None, + decoders: TreeDict[decode.partial_decode.DecoderArg] | None, + ) -> None: """Callback called when user calls `dataset_builder.as_data_source(...)`.""" pass