Skip to content

Commit

Permalink
Merge branch 'main' into pr/1699
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Oct 22, 2024
2 parents 0cd03f5 + 5ebe691 commit 842f927
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions pytorch_forecasting/data/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def swap_parameters(norm):
self.missing_ = self.norm_.median().to_dict()

if (
(self.scale_by_group and any([(self.norm_[group]["scale"] < 1e-7).any() for group in self.groups]))
(self.scale_by_group and any((self.norm_[group]["scale"] < 1e-7).any() for group in self.groups))
or (not self.scale_by_group and isinstance(self.norm_["scale"], float) and self.norm_["scale"] < 1e-7)
or (
not self.scale_by_group
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError as e:
attribute_exists = all([hasattr(norm, name) for norm in self.normalizers])
attribute_exists = all(hasattr(norm, name) for norm in self.normalizers)
if attribute_exists:
# check if to return callable or not and return function if yes
if callable(getattr(self.normalizers[0], name)):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/metrics/base_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError as e:
attribute_exists = all([hasattr(metric, name) for metric in self.metrics])
attribute_exists = all(hasattr(metric, name) for metric in self.metrics)
if attribute_exists:
# check if to return callable or not and return function if yes
if callable(getattr(self.metrics[0], name)):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _torch_cat_na(x: List[torch.Tensor]) -> torch.Tensor:

# check if remaining dimensions are all equal
if x[0].ndim > 2:
remaining_dimensions_equal = all([all([xi.size(i) == x[0].size(i) for xi in x]) for i in range(2, x[0].ndim)])
remaining_dimensions_equal = all(all(xi.size(i) == x[0].size(i) for xi in x) for i in range(2, x[0].ndim))
else:
remaining_dimensions_equal = True

Expand Down Expand Up @@ -796,7 +796,7 @@ def step(
[self.hparams.x_reals.index(name) for name in self.hparams.monotone_constaints.keys()]
)
monotonicity = torch.tensor(
[val for val in self.hparams.monotone_constaints.values()], dtype=gradient.dtype, device=gradient.device
list(self.hparams.monotone_constaints.values()), dtype=gradient.dtype, device=gradient.device
)
# add additionl loss if gradient points in wrong direction
gradient = gradient[..., indices] * monotonicity[None, None]
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def configure_optimizers(self):
)
from pytorch_optimizer import Ranger21

if any([isinstance(c, LearningRateFinder) for c in self.trainer.callbacks]):
if any(isinstance(c, LearningRateFinder) for c in self.trainer.callbacks):
# if finding learning rate, switch off warm up and cool down
optimizer_params.setdefault("num_warm_up_iterations", 0)
optimizer_params.setdefault("num_warm_down_iterations", 0)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_forecasting/models/deepar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def from_dataset(
new_kwargs.update(kwargs)
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all([not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer])
or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction
if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss):
assert (
Expand All @@ -230,7 +230,7 @@ def construct_input_vector(
# create input vector
if len(self.categoricals) > 0:
embeddings = self.embeddings(x_cat)
flat_embeddings = torch.cat([emb for emb in embeddings.values()], dim=-1)
flat_embeddings = torch.cat(list(embeddings.values()), dim=-1)
input_vector = flat_embeddings

if len(self.reals) > 0:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_forecasting/models/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def from_dataset(
new_kwargs.update(cls.deduce_default_output_parameters(dataset=dataset, kwargs=kwargs, default_loss=MAE()))
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all([not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer])
or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction
return super().from_dataset(
dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs
Expand All @@ -197,7 +197,7 @@ def construct_input_vector(
# create input vector
if len(self.categoricals) > 0:
embeddings = self.embeddings(x_cat)
flat_embeddings = torch.cat([emb for emb in embeddings.values()], dim=-1)
flat_embeddings = torch.cat(list(embeddings.values()), dim=-1)
input_vector = flat_embeddings

if len(self.reals) > 0:
Expand Down

0 comments on commit 842f927

Please sign in to comment.