From 1656eb6e24675b07906f4f8bcd7716c5e6c9a736 Mon Sep 17 00:00:00 2001 From: Shauvik RC Date: Sat, 2 Mar 2024 13:34:48 -0800 Subject: [PATCH 01/22] Fixing broken links to point to github (#54) --- benchmarks/README.md | 2 +- ...troduction to Federated Learning with CIFAR10.ipynb | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 02aa242..611b612 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -53,7 +53,7 @@ There are multiple official benchmarks for `pfl` to simulate various scenarios, ## Run distributed simulations Each benchmark can run in distributed mode with multiple cores, GPUs and machines. -See the [distributed simulation guide](https://pages.github.apple.com/apple/pfl-research/tutorials/simulation_distributed.html) on how it works. +See the [distributed simulation guide](https://apple.github.io/pfl-research/guides/simulation_distributed.html) on how it works. In summary, to quickly get started running distributed simulations: 1. Install [Horovod](https://horovod.readthedocs.io/en/stable/install_include.html). We have a helper script [here](https://github.com/apple/pfl-research/blob/main/build_scripts/install_horovod.sh). 2. Invoke your Python script with the `horovodrun` command. E.g. to run the same CIFAR10 training as described above in the quickstart, but train with 2 processes on the same machine, the command will look like this: diff --git a/tutorials/Introduction to Federated Learning with CIFAR10.ipynb b/tutorials/Introduction to Federated Learning with CIFAR10.ipynb index 7243f55..655d464 100644 --- a/tutorials/Introduction to Federated Learning with CIFAR10.ipynb +++ b/tutorials/Introduction to Federated Learning with CIFAR10.ipynb @@ -156,7 +156,7 @@ "source": [ "As displayed above, the images are of dimensions `32x32` with `3` color channels and the labels are one-hot vectors `[0-9]`. Also, you can see that we exclude all but 2 classes, Airplane (0) and Bird (2), to make things simple for this tutorial.\n", "\n", - "As the data is not inherently split into real users, we need to use [ArtificialFederatedDataset](https://pages.github.apple.com/apple/pfl-research/reference/data.html#pfl.data.federated_dataset.ArtificialFederatedDataset) to sample datapoints to create artificial user datasets for federated learning, which in turn requires a function for sampling the length of a user dataset and the actual datapoints of this dataset.\n", + "As the data is not inherently split into real users, we need to use [ArtificialFederatedDataset](https://apple.github.io/pfl-research/reference/data.html#pfl.data.federated_dataset.ArtificialFederatedDataset) to sample datapoints to create artificial user datasets for federated learning, which in turn requires a function for sampling the length of a user dataset and the actual datapoints of this dataset.\n", "\n", "Firstly, we define a function which samples the length of a dataset from a Poisson distribution with mean 5:" ] @@ -360,7 +360,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To train a model using `pfl`, a [Model](https://pages.github.apple.com/apple/pfl-research/reference/model.html#models) object must be instantiated for the Deep Learning framework in which the model was defined. A `pfl` Model acts as an adapter between the model of the Deep Learning framework and `pfl`. We use Keras with TensorFlow for this example.\n", + "To train a model using `pfl`, a [Model](https://apple.github.io/pfl-research/reference/model.html#models) object must be instantiated for the Deep Learning framework in which the model was defined. A `pfl` Model acts as an adapter between the model of the Deep Learning framework and `pfl`. We use Keras with TensorFlow for this example.\n", "\n", "The central optimizer is an input parameter to the model object. This optimizer is used when applying the aggregated model update to the central model." ] @@ -401,9 +401,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`Backend` is a component in `pfl` used to collect and average statistics from user devices. [SimulatedBackend](https://pages.github.apple.com/apple/pfl-research/reference/aggregate.html#pfl.aggregate.SimulatedBackend) provides simulation. For larger datasets or models, it also [supports distributed simulations](https://pages.github.apple.com/apple/pfl-research/tutorials/simulation_distributed.html) with multiple processes, GPUs and machines.\n", + "`Backend` is a component in `pfl` used to collect and average statistics from user devices. [SimulatedBackend](https://apple.github.io/pfl-research/reference/aggregate.html#pfl.aggregate.SimulatedBackend) provides simulation. For larger datasets or models, it also [supports distributed simulations](https://apple.github.io/pfl-research/tutorials/simulation_distributed.html) with multiple processes, GPUs and machines.\n", "\n", - "`pfl` provides various privacy mechanisms in [pfl.privacy](https://pages.github.apple.com/apple/pfl-research/reference/privacy.html#module-privacy).\n", + "`pfl` provides various privacy mechanisms in [pfl.privacy](https://apple.github.io/pfl-research/reference/privacy.html#module-privacy).\n", "These mechanisms are used to guarantee local and/or central differential privacy and are parameterized by `SimulatedBackend` in simulations.\n", "\n", "Below, we initialize `SimulatedBackend` with Gaussian Moments Accountant as the central differential privacy mechanism, which provides the minimum privacy guarantee required by most use-cases." @@ -445,7 +445,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The main component of a `pfl` modeling setup is the algorithm. In this tutorial, we will be using the [Federated Averaging](https://arxiv.org/pdf/1602.05629.pdf) algorithm. `pfl` implements Federated Averaging using the class [FederatedAveraging](https://pages.github.apple.com/apple/pfl-research/reference/algorithm.html#pfl.algorithm.federated_averaging.FederatedAveraging).\n", + "The main component of a `pfl` modeling setup is the algorithm. In this tutorial, we will be using the [Federated Averaging](https://arxiv.org/pdf/1602.05629.pdf) algorithm. `pfl` implements Federated Averaging using the class [FederatedAveraging](https://apple.github.io/pfl-research/reference/algorithm.html#pfl.algorithm.federated_averaging.FederatedAveraging).\n", "\n", "Everything is tied together in the `run` method of an algorithm. This method requires the model, backend, optional callbacks, as well as hyperparameters, which depends on the model and algorithm combination. In this case, training a neural network with federated averaging requires a `NNTrainHyperParams` and `NNAlgorithmParams`.\n", "\n", From b876fdac755d39f136e9dfabac21ccc86c8e5df3 Mon Sep 17 00:00:00 2001 From: ac554 <47990575+ac554@users.noreply.github.com> Date: Mon, 4 Mar 2024 16:30:32 -0800 Subject: [PATCH 02/22] remove broken link (#57) --- pfl/privacy/privacy_accountant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pfl/privacy/privacy_accountant.py b/pfl/privacy/privacy_accountant.py index 368c440..54de581 100644 --- a/pfl/privacy/privacy_accountant.py +++ b/pfl/privacy/privacy_accountant.py @@ -105,7 +105,7 @@ class PLDPrivacyAccountant(PrivacyAccountant): """ Privacy Loss Distribution (PLD) privacy accountant, from dp-accounting package. - Code: https://github.com/google/differential-privacy/blob/main/python/dp_accounting/pld/pld_privacy_accountant.py # pylint: disable=line-too-long + The PLD algorithm is based on: “Tight on budget?: Tight bounds for r-fold approximate differential privacy.”, Meiser and Mohammadi, in CCS, pages 247-264, 2018, https://eprint.iacr.org/2017/1034.pdf From efc22f7989d863c5a827e808c125dc1468609b3f Mon Sep 17 00:00:00 2001 From: Rogier van Dalen Date: Tue, 5 Mar 2024 00:34:48 +0000 Subject: [PATCH 03/22] Rename EMMGMMHyperParams to EMGMMHyperParams (#55) --- pfl/algorithm/expectation_maximization_gmm.py | 6 +++--- tests/algorithm/test_expectation_maximization_gmm.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pfl/algorithm/expectation_maximization_gmm.py b/pfl/algorithm/expectation_maximization_gmm.py index f451229..d097156 100644 --- a/pfl/algorithm/expectation_maximization_gmm.py +++ b/pfl/algorithm/expectation_maximization_gmm.py @@ -18,7 +18,7 @@ @dataclass(frozen=True) -class EMMGMMHyperParams(AlgorithmHyperParams): +class EMGMMHyperParams(AlgorithmHyperParams): """ Parameters for EM GMM algorithms. @@ -113,7 +113,7 @@ def compute_new_num_components(iteration, num_iterations_since_last_mix_up, return compute_new_num_components -class ExpectationMaximizationGMM(FederatedAlgorithm[EMMGMMHyperParams, +class ExpectationMaximizationGMM(FederatedAlgorithm[EMGMMHyperParams, GMMHyperParams, GaussianMixtureModel, MappedVectorStatistics, @@ -132,7 +132,7 @@ def get_next_central_contexts( self, model: GaussianMixtureModel, iteration: int, - algorithm_params: EMMGMMHyperParams, + algorithm_params: EMGMMHyperParams, model_train_params: GMMHyperParams, model_eval_params: Optional[GMMHyperParams] = None, ) -> Tuple[Optional[Tuple[CentralContext, ...]], GaussianMixtureModel, diff --git a/tests/algorithm/test_expectation_maximization_gmm.py b/tests/algorithm/test_expectation_maximization_gmm.py index b4e546b..ed09964 100644 --- a/tests/algorithm/test_expectation_maximization_gmm.py +++ b/tests/algorithm/test_expectation_maximization_gmm.py @@ -6,7 +6,7 @@ from pfl.aggregate.base import Backend, get_total_weight_name from pfl.algorithm.expectation_maximization_gmm import ( - EMMGMMHyperParams, + EMGMMHyperParams, ExpectationMaximizationGMM, make_compute_new_num_components, ) @@ -144,7 +144,7 @@ def test_run(self, tmpdir, numpy_ops, mock_gmm_backend, max_num_components=64, step_components=2) - algorithm_params = EMMGMMHyperParams( + algorithm_params = EMGMMHyperParams( central_num_iterations=central_num_iterations, evaluation_frequency=1, val_cohort_size=fixed_val_cohort_size, From 22a2a6b68a3c71f52d01d7fb90e78e1a8c119121 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Tue, 12 Mar 2024 00:14:44 +0100 Subject: [PATCH 04/22] Link from readme to docs, docs to readme (#58) --- README.md | 4 ++++ docs/source/index.rst | 2 ++ 2 files changed, 6 insertions(+) diff --git a/README.md b/README.md index 6d47ff5..ee97d37 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # `pfl`: Python framework for Private Federated Learning simulations +**Documentation website:** https://apple.github.io/pfl-research + `pfl` is a Python framework developed at Apple to empower researchers to run efficient simulations with privacy-preserving federated learning (FL) and disseminate the results of their research in FL. We are a team comprising engineering and research expertise, and we encourage researchers to publish their papers, with this code, with confidence. The framework is `not` intended to be used for third-party FL deployments but the results of the simulations can be tremendously useful in actual FL deployments. @@ -32,6 +34,8 @@ pip install 'pfl[tf,pytorch,trees]' To try out `pfl` immediately without installation, we provide several colab notebooks for learning the different components in `pfl` hands-on. `` +Also available as Jupyter notebooks [here](https://github.com/apple/pfl-research/tree/develop/tutorials). + ## Getting started - benchmarks `pfl` aims to streamline the benchmarking process of testing hypotheses in the Federated Learning paradigm. The official benchmarks are available in the [benchmarks](./benchmarks) directory, using a variety of realistic dataset-model combinations with and without differential privacy (yes, we do also have CIFAR10). diff --git a/docs/source/index.rst b/docs/source/index.rst index 931099b..18e1a71 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,6 +6,8 @@ pfl: Python framework for Private Federated Learning simulations ================================================================ +**Github repo available at** https://github.com/apple/pfl-research. + ``pfl`` is a Python framework developed at Apple to enable researchers to `run efficient simulations with privacy-preserving federated learning (FL)` and `disseminate the results of their research in FL`. The framework is `not` intended to be used for third-party FL deployments but the results of the simulations can be tremendously useful in actual FL deployments. We hope that ``pfl`` will promote open research in FL and its effective dissemination. From 2c9d089ac8b0c8b82e6c98e49a5d0d7b083322e0 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Wed, 13 Mar 2024 22:48:11 +0000 Subject: [PATCH 05/22] license, CI, py version badges (#59) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index ee97d37..21f56c7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # `pfl`: Python framework for Private Federated Learning simulations +[![GitHub License](https://img.shields.io/github/license/apple/pfl-research)](https://github.com/apple/pfl-research/blob/main/LICENSE) +[![CircleCI](https://dl.circleci.com/status-badge/img/gh/apple/pfl-research/tree/main.svg?style=shield)](https://dl.circleci.com/status-badge/redirect/gh/apple/pfl-research/tree/main) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pfl)](https://github.com/apple/pfl-research/blob/main/pyproject.toml#L18) + **Documentation website:** https://apple.github.io/pfl-research `pfl` is a Python framework developed at Apple to empower researchers to run efficient simulations with privacy-preserving federated learning (FL) and disseminate the results of their research in FL. We are a team comprising engineering and research expertise, and we encourage researchers to publish their papers, with this code, with confidence. From 2d79b66ead81333b88b9b00f5e5e3e2db9f7baa1 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Fri, 15 Mar 2024 00:05:58 +0000 Subject: [PATCH 06/22] run ci on develop merge (#60) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1a90f60..51034f2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -210,7 +210,7 @@ workflows: matches: # Only on branches approved by Apple CircleCI policy: # https://app.circleci.com/settings/organization/github/apple/policies/baseline_apple - pattern: "^main|gh-readonly-queue/main/pr-\\d+-[0-9a-f]{40}.*$" + pattern: "^main$|^develop$|gh-readonly-queue/main/pr-\\d+-[0-9a-f]{40}.*$" value: << pipeline.git.branch >> jobs: - code-quality From 93f938f5d55a2a3f2bccb7157f4dd8dc4f2b86b7 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Thu, 28 Mar 2024 09:01:47 +0100 Subject: [PATCH 07/22] correctly trigger develop ci (#61) --- .circleci/config.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 51034f2..5902dd0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -207,11 +207,13 @@ jobs: workflows: build_and_test: when: - matches: - # Only on branches approved by Apple CircleCI policy: - # https://app.circleci.com/settings/organization/github/apple/policies/baseline_apple - pattern: "^main$|^develop$|gh-readonly-queue/main/pr-\\d+-[0-9a-f]{40}.*$" - value: << pipeline.git.branch >> + or: + - matches: + # Only on branches approved by Apple CircleCI policy: + # https://app.circleci.com/settings/organization/github/apple/policies/baseline_apple + pattern: "^main|gh-readonly-queue/main/pr-\\d+-[0-9a-f]{40}.*$" + value: << pipeline.git.branch >> + - equal: [ develop, << pipeline.git.branch >> ] jobs: - code-quality - build-documentation-wheel From 9be095ccbd371299392c1a3d2abe4e7438736d2a Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Mon, 1 Apr 2024 01:25:14 +0200 Subject: [PATCH 08/22] set seeds for tf examples (#62) --- benchmarks/image_classification/tf/train.py | 3 +++ benchmarks/lm/tf/train.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/benchmarks/image_classification/tf/train.py b/benchmarks/image_classification/tf/train.py index 9d3da23..6c80ebd 100644 --- a/benchmarks/image_classification/tf/train.py +++ b/benchmarks/image_classification/tf/train.py @@ -2,6 +2,7 @@ import argparse import logging import os +import random from uuid import uuid4 import numpy as np @@ -55,6 +56,8 @@ def main(): argument_parser = add_model_arguments(argument_parser) arguments = argument_parser.parse_args() + os.environ['PYTHONHASHSEED'] = str(arguments.seed) + random.seed(arguments.seed) tf.random.set_seed(arguments.seed) np.random.seed(arguments.seed) diff --git a/benchmarks/lm/tf/train.py b/benchmarks/lm/tf/train.py index 1456ffa..bbddaff 100644 --- a/benchmarks/lm/tf/train.py +++ b/benchmarks/lm/tf/train.py @@ -2,6 +2,7 @@ import argparse import logging import os +import random from uuid import uuid4 import numpy as np @@ -58,6 +59,8 @@ def main(): parser = add_weighting_arguments(parser) arguments = parser.parse_args() + os.environ['PYTHONHASHSEED'] = str(arguments.seed) + random.seed(arguments.seed) np.random.seed(arguments.seed) tf.random.set_seed(arguments.seed) From a55c301a4dc86b822f544d2f1f2e283e00253936 Mon Sep 17 00:00:00 2001 From: congzheng-song <90863343+congzheng-song@users.noreply.github.com> Date: Wed, 3 Apr 2024 02:12:12 -0700 Subject: [PATCH 09/22] Update LLM Benchmark Configs (#63) --- benchmarks/dataset/hugging_face/__init__.py | 40 ++++++++++++++--- benchmarks/dataset/hugging_face/alpaca.py | 10 ++--- benchmarks/dataset/hugging_face/oasst.py | 10 ++--- benchmarks/llm/README.md | 4 +- benchmarks/llm/argument_parsing.py | 15 ++++++- benchmarks/llm/configs/alpaca.yaml | 50 +++++++++++++++++++++ benchmarks/llm/configs/aya.yaml | 48 ++++++++++++++++++++ benchmarks/llm/configs/baseline.yaml | 38 ---------------- benchmarks/llm/configs/oasst.yaml | 48 ++++++++++++++++++++ benchmarks/llm/train.py | 15 ++++++- pfl/data/pytorch.py | 3 +- 11 files changed, 219 insertions(+), 62 deletions(-) create mode 100644 benchmarks/llm/configs/alpaca.yaml create mode 100644 benchmarks/llm/configs/aya.yaml delete mode 100644 benchmarks/llm/configs/baseline.yaml create mode 100644 benchmarks/llm/configs/oasst.yaml diff --git a/benchmarks/dataset/hugging_face/__init__.py b/benchmarks/dataset/hugging_face/__init__.py index c7616c7..4f1a4fb 100644 --- a/benchmarks/dataset/hugging_face/__init__.py +++ b/benchmarks/dataset/hugging_face/__init__.py @@ -1,23 +1,49 @@ # Copyright © 2023-2024 Apple Inc. -import logging -from typing import Dict, List, Union +from collections import namedtuple +from typing import Any, Callable, Dict, List, Union import torch IGNORE_INDEX = -100 -logger = logging.getLogger(__name__) + +IndexedData = namedtuple('IndexedData', ['data', 'index']) class GetItemDataset(torch.utils.data.Dataset): """ Wraps a dataset that has __getitem__. """ - def __init__(self, data: Union[Dict, List]): + def __init__(self, data: Union[Dict, List], return_index: bool = False): super().__init__() - self.data = data + self._data = data + self._return_index = return_index def __len__(self): - return len(self.data) + return len(self._data) def __getitem__(self, i): - return self.data[i] + if self._return_index: + # Return both data and index. The index can be used in the below + # `UserIDCollatorWrapper` to pass the `user_id` to + # `PyTorchFederatedDataset` + return IndexedData(data=self._data[i], index=i) + return self._data[i] + + +class UserIDCollatorWrapper: + """ + Wraps an existing collator to add user ID to the resulting collated data. + This is useful for `PyTorchFederatedDataset` which does not have `user_id` + by default. + """ + + def __init__(self, collator: Callable[[List], Dict[str, Any]]): + self._collator = collator + + def __call__(self, indexed_data: IndexedData) -> Dict[str, Any]: + assert isinstance(indexed_data, IndexedData), ( + "`UserIDCollatorWrapper` only supports `torch.utils.data.Dataset` " + "that returns `IndexedData`") + data = self._collator(indexed_data.data) + data["user_id"] = indexed_data.index + return data diff --git a/benchmarks/dataset/hugging_face/alpaca.py b/benchmarks/dataset/hugging_face/alpaca.py index 27d1365..1efecdd 100644 --- a/benchmarks/dataset/hugging_face/alpaca.py +++ b/benchmarks/dataset/hugging_face/alpaca.py @@ -15,10 +15,7 @@ from pfl.data.pytorch import PyTorchDataDataset, PyTorchFederatedDataset from pfl.data.sampling import get_user_sampler -from . import ( - IGNORE_INDEX, - GetItemDataset, -) +from . import IGNORE_INDEX, GetItemDataset, UserIDCollatorWrapper logger = logging.getLogger(__name__) @@ -134,11 +131,12 @@ def make_iid_federated_dataset(user_dataset: Dict[str, List[Dict]], """ Split the dataset into IID artificial users. """ user_sampler = get_user_sampler('random', list(user_dataset.keys())) user_id_to_weight = {k: len(v) for k, v in user_dataset.items()} - return PyTorchFederatedDataset(GetItemDataset(user_dataset), + collate_fn = UserIDCollatorWrapper(AlpacaDataCollator(tokenizer)) + return PyTorchFederatedDataset(GetItemDataset(user_dataset, True), user_sampler, user_id_to_weight=user_id_to_weight, batch_size=None, - collate_fn=AlpacaDataCollator(tokenizer), + collate_fn=collate_fn, **dataloader_kwargs) diff --git a/benchmarks/dataset/hugging_face/oasst.py b/benchmarks/dataset/hugging_face/oasst.py index 8169948..c513504 100644 --- a/benchmarks/dataset/hugging_face/oasst.py +++ b/benchmarks/dataset/hugging_face/oasst.py @@ -15,10 +15,7 @@ from pfl.data.pytorch import PyTorchDataDataset, PyTorchFederatedDataset from pfl.data.sampling import get_user_sampler -from . import ( - IGNORE_INDEX, - GetItemDataset, -) +from . import IGNORE_INDEX, GetItemDataset, UserIDCollatorWrapper logger = logging.getLogger(__name__) @@ -94,11 +91,12 @@ def make_federated_dataset(user_dataset: Dict[str, List[Dict]], """ user_sampler = get_user_sampler('random', list(user_dataset.keys())) user_id_to_weight = {k: len(v) for k, v in user_dataset.items()} - return PyTorchFederatedDataset(GetItemDataset(user_dataset), + collate_fn = UserIDCollatorWrapper(default_data_collator) + return PyTorchFederatedDataset(GetItemDataset(user_dataset, True), user_sampler, user_id_to_weight=user_id_to_weight, batch_size=None, - collate_fn=default_data_collator, + collate_fn=collate_fn, **dataloader_kwargs) diff --git a/benchmarks/llm/README.md b/benchmarks/llm/README.md index e9bf19a..b9ae583 100644 --- a/benchmarks/llm/README.md +++ b/benchmarks/llm/README.md @@ -21,10 +21,10 @@ dataset=alpaca LLM benchmark no DP: ``` -python -m llm.train --args_config llm/configs/{dataset}_baseline.yaml +python -m llm.train --args_config llm/configs/{dataset}.yaml ``` LLM benchmark Central DP: ``` -python -m llm.train --args_config llm/configs/{dataset}_baseline.yaml --central_privacy_mechanism gaussian_moments_accountant +python -m llm.train --args_config llm/configs/{dataset}.yaml --central_privacy_mechanism gaussian_moments_accountant ``` diff --git a/benchmarks/llm/argument_parsing.py b/benchmarks/llm/argument_parsing.py index 2e6c0e1..269c9b4 100644 --- a/benchmarks/llm/argument_parsing.py +++ b/benchmarks/llm/argument_parsing.py @@ -16,10 +16,15 @@ def parse_peft_config(args) -> Optional[PeftConfig]: if args.peft_type is None: return None assert args.peft_type == 'lora', "Currently only supports PEFT with LoRA." + target_modules = None + if args.lora_target_modules is not None: + target_modules = args.lora_target_modules.split(',') + return LoraConfig(task_type=args.peft_task_type, r=args.lora_r, lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout) + lora_dropout=args.lora_dropout, + target_modules=target_modules) def add_peft_arguments(argument_parser): @@ -53,6 +58,14 @@ def add_peft_arguments(argument_parser): default=0.0, help='LoRA dropout.') + argument_parser.add_argument('--lora_target_modules', + type=str, + default=None, + help='Modules in the model to apply LoRA.' + 'Use comma to add multiple modules, e.g.' + '`query,key,value` applies LoRA on all ' + '`query`, `key` and `value` modules.') + return argument_parser diff --git a/benchmarks/llm/configs/alpaca.yaml b/benchmarks/llm/configs/alpaca.yaml new file mode 100644 index 0000000..ea98a43 --- /dev/null +++ b/benchmarks/llm/configs/alpaca.yaml @@ -0,0 +1,50 @@ +dataset: alpaca +mean_datapoints_per_user: 16 +datapoints_per_user_distribution: poisson + +# Model config +hugging_face_model_name_or_path: TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T +model_max_length: 512 +use_fast_tokenizer: True +padding_side: right +amp_dtype: bfloat16 +model_dtype_same_as_amp: False +use_torch_compile: True + +# PEFT - LoRA +peft_type: lora +lora_r: 8 + +central_optimizer: adam +adaptivity_degree: 1.0e-4 +learning_rate: 0.01 +central_lr_scheduler: cosine +central_lr_num_warmup_iterations: 50 + +central_num_iterations: 1000 +evaluation_frequency: 10 +central_eval_batch_size: 8 +cohort_size: 100 +val_cohort_size: 0 +noise_cohort_size: 5000 + +local_batch_size: 4 +local_num_epochs: 1 +local_learning_rate: 0.01 +local_max_grad_norm: 1.0 + +central_privacy_mechanism: none # gaussian_moments_accountant +central_epsilon: 2.0 +central_delta: 1e-6 +central_privacy_clipping_bound: 0.1 +central_order: 2 +population: 1e6 + +use_tensorboard: False + +add_all_arguments: True +algorithm_name: fedavg +mu: 0.01 +scaffold_population: 2934 +adafedprox_metric_name: "Central val | loss" +adafedprox_adapt_frequency: 20 diff --git a/benchmarks/llm/configs/aya.yaml b/benchmarks/llm/configs/aya.yaml new file mode 100644 index 0000000..2eb9ad0 --- /dev/null +++ b/benchmarks/llm/configs/aya.yaml @@ -0,0 +1,48 @@ +dataset: aya + +# Model config +hugging_face_model_name_or_path: TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T +model_max_length: 512 +use_fast_tokenizer: True +padding_side: right +amp_dtype: bfloat16 +model_dtype_same_as_amp: False +use_torch_compile: True + +# PEFT - LoRA +peft_type: lora +lora_r: 8 + +central_optimizer: adam +adaptivity_degree: 1.0e-4 +learning_rate: 0.01 +central_lr_scheduler: cosine +central_lr_num_warmup_iterations: 50 + +central_num_iterations: 1000 +evaluation_frequency: 10 +central_eval_batch_size: 12 +cohort_size: 100 +val_cohort_size: 0 +noise_cohort_size: 5000 + +local_batch_size: 4 +local_num_epochs: 1 +local_learning_rate: 0.1 +local_max_grad_norm: 1.0 + +central_privacy_mechanism: none # gaussian_moments_accountant +central_epsilon: 2.0 +central_delta: 1e-6 +central_privacy_clipping_bound: 0.1 +central_order: 2 +population: 1e6 + +use_tensorboard: False + +add_all_arguments: True +algorithm_name: fedavg +mu: 0.01 +scaffold_population: 4089 +adafedprox_metric_name: "Central val | loss" +adafedprox_adapt_frequency: 20 diff --git a/benchmarks/llm/configs/baseline.yaml b/benchmarks/llm/configs/baseline.yaml deleted file mode 100644 index 016f216..0000000 --- a/benchmarks/llm/configs/baseline.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Dataset config -#dataset: alpaca -#mean_datapoints_per_user: 16 -#datapoints_per_user_distribution: poisson - -# Model config -hugging_face_model_name_or_path: facebook/opt-1.3b -model_max_length: 512 -use_fast_tokenizer: True -padding_side: right -amp_dtype: bfloat16 -model_dtype_same_as_amp: True - -central_optimizer: adam -adaptivity_degree: 0.01 - -central_privacy_mechanism: none -central_epsilon: 2.0 -central_delta: 1e-6 -central_privacy_clipping_bound: 1.0 -central_order: 2 -population: 1e6 - -evaluation_frequency: 10 -learning_rate: 0.1 -cohort_size: 100 -val_cohort_size: 0 - -central_num_iterations: 2000 -local_batch_size: 4 -central_eval_batch_size: 8 -local_num_epochs: 1 -local_learning_rate: 0.1 - -use_tensorboard: False - -# PEFT -peft_type: lora diff --git a/benchmarks/llm/configs/oasst.yaml b/benchmarks/llm/configs/oasst.yaml new file mode 100644 index 0000000..b5c8201 --- /dev/null +++ b/benchmarks/llm/configs/oasst.yaml @@ -0,0 +1,48 @@ +dataset: oasst + +# Model config +hugging_face_model_name_or_path: TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T +model_max_length: 512 +use_fast_tokenizer: True +padding_side: right +amp_dtype: bfloat16 +model_dtype_same_as_amp: False +use_torch_compile: True + +# PEFT - LoRA +peft_type: lora +lora_r: 8 + +central_optimizer: adam +adaptivity_degree: 1.0e-4 +learning_rate: 0.01 +central_lr_scheduler: cosine +central_lr_num_warmup_iterations: 50 + +central_num_iterations: 1000 +evaluation_frequency: 10 +central_eval_batch_size: 12 +cohort_size: 100 +val_cohort_size: 0 +noise_cohort_size: 5000 + +local_batch_size: 4 +local_num_epochs: 1 +local_learning_rate: 0.1 +local_max_grad_norm: 1.0 + +central_privacy_mechanism: none # gaussian_moments_accountant +central_epsilon: 2.0 +central_delta: 1e-6 +central_privacy_clipping_bound: 0.1 +central_order: 2 +population: 1e6 + +use_tensorboard: False + +add_all_arguments: True +algorithm_name: fedavg +mu: 0.01 +scaffold_population: 12870 +adafedprox_metric_name: "Central val | loss" +adafedprox_adapt_frequency: 20 diff --git a/benchmarks/llm/train.py b/benchmarks/llm/train.py index 70d1210..1e205fe 100644 --- a/benchmarks/llm/train.py +++ b/benchmarks/llm/train.py @@ -2,6 +2,8 @@ import argparse import logging +import os +from uuid import uuid4 import numpy as np import torch @@ -21,7 +23,7 @@ from llm.argument_parsing import add_llm_arguments, parse_central_lr_scheduler, parse_peft_config from pfl.aggregate.simulate import SimulatedBackend -from pfl.callback import AggregateMetricsToDisk, CentralEvaluationCallback, StopwatchCallback +from pfl.callback import AggregateMetricsToDisk, CentralEvaluationCallback, StopwatchCallback, WandbCallback from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams from pfl.model.pytorch import PyTorchModel @@ -138,6 +140,17 @@ def main(): algorithm, algorithm_params, algorithm_callbacks = get_algorithm(arguments) callbacks.extend(algorithm_callbacks) + if arguments.wandb_project_id: + callbacks.append( + WandbCallback( + wandb_project_id=arguments.wandb_project_id, + wandb_experiment_name=os.environ.get('WANDB_TASK_ID', + str(uuid4())), + # List of dicts to one dict. + wandb_config=dict(vars(arguments)), + tags=os.environ.get('WANDB_TAGS', 'empty-tag').split(','), + group=os.environ.get('WANDB_GROUP', None))) + logger.info("Starts federated learning.") model = algorithm.run(algorithm_params=algorithm_params, backend=backend, diff --git a/pfl/data/pytorch.py b/pfl/data/pytorch.py index fb9f470..5533535 100644 --- a/pfl/data/pytorch.py +++ b/pfl/data/pytorch.py @@ -277,7 +277,8 @@ def _tensors_to_pfl_dataset(self, tensors): if isinstance(tensors, Dict): # The tensors are from a dictionary, no need to do extra processing # on the tensors as below. - return self._dataset_cls(tensors, **self._dataset_kwargs) + user_id = tensors.pop("user_id") if "user_id" in tensors else None + return self._dataset_cls(tensors, user_id, **self._dataset_kwargs) def process_tensor(tensor): assert tensor.shape[0] == 1, ( From 1dcd9a2b479094bcbce1d3696e26b8343519704c Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Fri, 19 Apr 2024 16:59:32 +0200 Subject: [PATCH 10/22] fix error in error in GBDT config (#64) --- pfl/tree/federated_gbdt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pfl/tree/federated_gbdt.py b/pfl/tree/federated_gbdt.py index dfa8d03..d9c394b 100644 --- a/pfl/tree/federated_gbdt.py +++ b/pfl/tree/federated_gbdt.py @@ -176,8 +176,8 @@ def __post_init__(self): 'validation cohort size must be an integer >= 0') assert self.cohort_size_per_layer_modifier_fn in [ 'none', 'linear', 'power' - ], (f'{self.cohort_size_per_layer_modifier_function} is not a', - 'valid value for cohort_size_per_layer_modifier_function') + ], (f'{self.cohort_size_per_layer_modifier_fn} is not a', + 'valid value for cohort_size_per_layer_modifier_fn') assert isinstance( self.leaf_nodes_reduction_factor, int) and self.leaf_nodes_reduction_factor >= 1, ( From b1e7be618271e8f50ed698545d27f0152d3220ae Mon Sep 17 00:00:00 2001 From: Martin Pelikan <154003090+martin-pelikan-apple@users.noreply.github.com> Date: Fri, 10 May 2024 09:54:03 -0700 Subject: [PATCH 11/22] Fixes werkzeug vulnerability (#68) --- benchmarks/poetry.lock | 8 ++++---- poetry.lock | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/poetry.lock b/benchmarks/poetry.lock index 124d9ed..fb31b53 100644 --- a/benchmarks/poetry.lock +++ b/benchmarks/poetry.lock @@ -1696,7 +1696,7 @@ requests = ">=2.21.0,<3" setuptools = ">=41.0.0" six = ">1.9" tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" +werkzeug = ">=3.0.3" [[package]] name = "tensorboard-data-server" @@ -2202,13 +2202,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "werkzeug" -version = "3.0.0" +version = "3.0.3" description = "The comprehensive WSGI web application library." optional = true python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.0-py3-none-any.whl", hash = "sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962"}, - {file = "werkzeug-3.0.0.tar.gz", hash = "sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0"}, + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] [package.dependencies] diff --git a/poetry.lock b/poetry.lock index 547ea88..bbd03f3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2079,7 +2079,7 @@ requests = ">=2.21.0,<3" setuptools = ">=41.0.0" six = ">1.9" tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" +werkzeug = ">=3.0.3" [[package]] name = "tensorboard-data-server" @@ -2483,13 +2483,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "werkzeug" -version = "3.0.0" +version = "3.0.3" description = "The comprehensive WSGI web application library." optional = true python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.0-py3-none-any.whl", hash = "sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962"}, - {file = "werkzeug-3.0.0.tar.gz", hash = "sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0"}, + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] [package.dependencies] From 2df910ebbbf7e37141b0a9efac12e03a8bc4b847 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Wed, 29 May 2024 11:50:21 +0200 Subject: [PATCH 12/22] move noise cohort to privacy args (#70) --- benchmarks/utils/argument_parsing.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/benchmarks/utils/argument_parsing.py b/benchmarks/utils/argument_parsing.py index bd5cc2f..00885b5 100644 --- a/benchmarks/utils/argument_parsing.py +++ b/benchmarks/utils/argument_parsing.py @@ -184,15 +184,6 @@ def add_dnn_training_arguments(argument_parser): help='The target number of users for one iteration ' 'of training.') - argument_parser.add_argument( - '--noise_cohort_size', - type=int, - default=1000, - help=('The cohort size to use in calculating noise for DP. ' - 'If you run cohort_size=100 but noise_cohort_size=1000, ' - 'then your results will only be valid if running with ' - 'cohort_size=1000 outside simulation')) - argument_parser.add_argument( '--val_cohort_size', type=int, @@ -347,6 +338,16 @@ def __call__(self, parser, namespace, values, option_string=None): 'sampled to participate again. This parameter is used in ' '`BandedMatrixFactorizationMechanism`.') + argument_parser.add_argument( + '--noise_cohort_size', + type=int, + default=1000, + help=('The cohort size to use when calculating noise for DP. ' + 'If you run cohort_size=100 but noise_cohort_size=1000, ' + 'then the noise will be scaled down by a factor of 0.1 ' + 'and your results will only be valid if running with ' + 'cohort_size=1000 outside simulation')) + return argument_parser From 5b488286b3bb053c8ba0094bc7ed8ec0ac1c21b5 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Thu, 30 May 2024 13:21:36 +0200 Subject: [PATCH 13/22] FLAIR HF (#72) --- benchmarks/dataset/flair/download_dataset.py | 94 -- .../dataset/flair/download_preprocess.py | 213 +++++ benchmarks/dataset/flair/prepare_dataset.py | 236 ----- benchmarks/flair/README.md | 3 +- benchmarks/poetry.lock | 806 +++++++++++++++++- benchmarks/pyproject.toml | 1 + 6 files changed, 1014 insertions(+), 339 deletions(-) delete mode 100644 benchmarks/dataset/flair/download_dataset.py create mode 100644 benchmarks/dataset/flair/download_preprocess.py delete mode 100644 benchmarks/dataset/flair/prepare_dataset.py diff --git a/benchmarks/dataset/flair/download_dataset.py b/benchmarks/dataset/flair/download_dataset.py deleted file mode 100644 index cb4a82a..0000000 --- a/benchmarks/dataset/flair/download_dataset.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. -import argparse -import functools -import logging -import multiprocessing -import os -import subprocess -import sys -from urllib.parse import urljoin -from urllib.request import urlretrieve - -logger = logging.getLogger(name=__name__) - -DATA_URL = "https://docs-assets.developer.apple.com/ml-research/datasets/flair/" -NUM_IMAGE_BATCHES = 43 -SMALL_IMAGE_URLS = [ - urljoin(DATA_URL, f"images/small/small_images-{str(i).zfill(2)}.tar.gz") - for i in range(NUM_IMAGE_BATCHES) -] -RAW_IMAGE_URLS = [ - urljoin(DATA_URL, f"images/raw/images-{str(i).zfill(2)}.tar.gz") - for i in range(NUM_IMAGE_BATCHES) -] -LABELS_AND_METADATA_URL = urljoin(DATA_URL, "labels/labels_and_metadata.json") -LABEL_RELATIONSHIP_URL = urljoin(DATA_URL, "labels/label_relationship.txt") - - -def extract_tar(compressed_path: str, dataset_dir: str, - keep_archive_after_decompress: bool): - subprocess.run(f"tar -zxf {compressed_path} -C {dataset_dir}".split(), - check=True) - if not keep_archive_after_decompress: - os.remove(compressed_path) - - -def decompress_images(dataset_dir: str, keep_archive_after_decompress: bool): - compressed_paths = [ - os.path.join(dataset_dir, path) for path in os.listdir(dataset_dir) - if path.endswith(".tar.gz") - ] - decompress = functools.partial( - extract_tar, - dataset_dir=dataset_dir, - keep_archive_after_decompress=keep_archive_after_decompress) - with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: - pool.map(decompress, compressed_paths) - - -if __name__ == '__main__': - logging.basicConfig(stream=sys.stdout, - level=logging.INFO, - format='%(asctime)s %(levelname)s: %(message)s') - - parser = argparse.ArgumentParser( - description='Download the images and labels of FLAIR dataset.') - parser.add_argument("--dataset_dir", - required=True, - help="Path to directory of dataset to be downloaded") - parser.add_argument("--download_raw", - action="store_true", - help="Whether to download the raw images, " - "which need storage space ~1.2TB") - parser.add_argument("--keep_archive_after_decompress", - action="store_true", - help="Whether to keep the image tarball archives") - arguments = parser.parse_args() - os.makedirs(arguments.dataset_dir, exist_ok=True) - - # download labels and metadata - logger.info("Downloading labels...") - urlretrieve( - LABELS_AND_METADATA_URL, - os.path.join(arguments.dataset_dir, - os.path.basename(LABELS_AND_METADATA_URL))) - urlretrieve( - LABEL_RELATIONSHIP_URL, - os.path.join(arguments.dataset_dir, - os.path.basename(LABEL_RELATIONSHIP_URL))) - # download and decompress all images - for image_url in SMALL_IMAGE_URLS: - logger.info(f"Downloading small image: {image_url}") - urlretrieve( - image_url, - os.path.join(arguments.dataset_dir, os.path.basename(image_url))) - if arguments.download_raw: - for image_url in RAW_IMAGE_URLS: - logger.info(f"Downloading raw image: {image_url}") - urlretrieve( - image_url, - os.path.join(arguments.dataset_dir, - os.path.basename(image_url))) - logger.info("Decompressing images...") - decompress_images(arguments.dataset_dir, - arguments.keep_archive_after_decompress) diff --git a/benchmarks/dataset/flair/download_preprocess.py b/benchmarks/dataset/flair/download_preprocess.py new file mode 100644 index 0000000..f6e833e --- /dev/null +++ b/benchmarks/dataset/flair/download_preprocess.py @@ -0,0 +1,213 @@ +# Copyright © 2024 Apple Inc. +import argparse +import json +import logging +import sys +from collections import Counter, defaultdict +from typing import Counter as TCounter + +import h5py +import numpy as np +import tqdm +from datasets import load_dataset + +logger = logging.getLogger(name=__name__) + +LABEL_DELIMITER = '|' # Labels will be joined by delimiter and saved to hdf5 +LOG_INTERVAL = 100 # Log the preprocessing progress every interval steps + + +def load_image_from_huggingface(dataset, image_id): + """ + Load an image from the HuggingFace dataset by image_id. + + :param dataset: + The loaded HuggingFace dataset. + :param image_id: + The image_id of the image to be loaded. + :return: + The loaded image as a PIL Image object. + """ + # Assuming image_id is unique and using filter to get the specific image + image_data = dataset.filter(lambda x: x['image_id'] == image_id) + image_entry = next(iter(image_data)) # Get the first (and only) entry + return image_entry['image'] + + +def preprocess_federated_dataset(output_file: str): + """ + Process images and labels into a HDF5 federated dataset where data is + first split by train/test partitions and then split again by user ID. + + :param dataset: + The loaded HuggingFace dataset. + :param output_file: + Output path for HDF5 file. Use the postfix `.hdf5`. + """ + + # Load dataset from HuggingFace + # This is a Dict[str, Dataset] where key is the split. + dataset_splits = load_dataset('apple/flair', keep_in_memory=False) + logger.info( + f'Preprocessing federated dataset, sample record: {next(iter(dataset_splits["train"]))}' + ) + + label_counter: TCounter[str] = Counter() + fine_grained_label_counter: TCounter[str] = Counter() + + partition_to_user_to_ix: defaultdict = defaultdict( + lambda: defaultdict(list)) + for partition, ds in dataset_splits.items(): + for i, entry in tqdm.tqdm( + enumerate(ds), + total=len(ds), + desc=f'{partition} - Mapping users to datapoints.'): + # Make user to datapoints mapping. + partition_to_user_to_ix[partition][entry['user_id']].append(i) + label_counter.update(entry["labels"]) + fine_grained_label_counter.update(entry["fine_grained_labels"]) + + label_to_index = { + label: index + for index, label in enumerate(sorted(label_counter.keys())) + } + fine_grained_label_to_index = { + fine_grained_label: index + for index, fine_grained_label in enumerate( + sorted(fine_grained_label_counter.keys())) + } + + with h5py.File(output_file, 'w') as h5file: + for partition, user_to_ix in partition_to_user_to_ix.items(): + ds = dataset_splits[partition] + + # Iterate through users of each partition. + for i, (user_id, data_indices) in tqdm.tqdm( + enumerate(user_to_ix.items()), + total=len(user_to_ix), + desc=f'{partition} - constructing dataset'): + # Load and concatenate all images and labels of a user. + image_array, image_id_array = [], [] + labels_row, labels_col = [], [] + fine_grained_labels_row, fine_grained_labels_col = [], [] + for j, data_index in enumerate(data_indices): + metadata = ds[data_index] + image_id = metadata["image_id"] + image_array.append(np.asarray(metadata["image"])) + image_id_array.append(image_id) + # Encode labels as row indices and column indices + labels_row.extend([j] * len(metadata["labels"])) + labels_col.extend( + [label_to_index[l] for l in metadata["labels"]]) + fine_grained_labels_row.extend( + [j] * len(metadata["fine_grained_labels"])) + fine_grained_labels_col.extend([ + fine_grained_label_to_index[l] + for l in metadata["fine_grained_labels"] + ]) + # Update label counter + label_counter.update(metadata["labels"]) + fine_grained_label_counter.update( + metadata["fine_grained_labels"]) + + # Multiple variable-length labels. Needs to be stored as a string. + h5file[f'/{partition}/{user_id}/labels_row'] = np.asarray( + labels_row, dtype=np.uint16) + h5file[f'/{partition}/{user_id}/labels_col'] = np.asarray( + labels_col, dtype=np.uint8) + h5file[ + f'/{partition}/{user_id}/fine_grained_labels_row'] = np.asarray( + fine_grained_labels_row, dtype=np.uint16) + h5file[ + f'/{partition}/{user_id}/fine_grained_labels_col'] = np.asarray( + fine_grained_labels_col, dtype=np.uint16) + h5file[f'/{partition}/{user_id}/image_ids'] = np.asarray( + image_id_array, dtype='S') + # Tensor with dimensions [num_images,width,height,channels] + h5file.create_dataset(f'/{partition}/{user_id}/images', + data=np.stack(image_array)) + + if (i + 1) % LOG_INTERVAL == 0: + logger.info(f"Processed {i + 1}/{len(user_to_ix)} users") + + # Write metadata + h5file['/metadata/label_mapping'] = json.dumps(label_to_index) + h5file['/metadata/fine_grained_label_mapping'] = json.dumps( + fine_grained_label_to_index) + + logger.info('Finished preprocess federated dataset successfully!') + + +def preprocess_central_dataset(output_file: str): + """ + Process images and labels into a HDF5 (not federated) dataset where + data is split by train/val/test partitions. + + :param dataset: + The loaded HuggingFace dataset. + :param output_file: + Output path for HDF5 file. Use the postfix `.hdf5`. + """ + logger.info('Preprocessing central dataset.') + dataset_splits = load_dataset('apple/flair', keep_in_memory=False) + + label_counter: TCounter[str] = Counter() + fine_grained_label_counter: TCounter[str] = Counter() + with h5py.File(output_file, 'w') as h5file: + # Iterate through dataset. + for partition, dataset in dataset_splits.items(): + for i, entry in tqdm.tqdm(enumerate(dataset), total=len(dataset)): + image_id = entry["image_id"] + image = np.asarray( + entry["image"] + ) # Directly use the image array from the dataset + h5file.create_dataset(f'/{partition}/{image_id}/image', + data=image) + # Encode labels as a single string, separated by delimiter | + h5file[ + f'/{partition}/{image_id}/labels'] = LABEL_DELIMITER.join( + entry["labels"]) + h5file[f'/{partition}/{image_id}/fine_grained_labels'] = ( + LABEL_DELIMITER.join(entry["fine_grained_labels"])) + h5file[f'/{partition}/{image_id}/user_id'] = entry["user_id"] + # Update label counter + label_counter.update(entry["labels"]) + fine_grained_label_counter.update(entry["fine_grained_labels"]) + + if (i + 1) % LOG_INTERVAL == 0: + logger.info(f"Processed {i + 1}/{len(dataset)} entries") + + # Write metadata + h5file['/metadata/label_mapping'] = json.dumps(dict(label_counter)) + h5file['/metadata/fine_grained_label_mapping'] = json.dumps( + dict(fine_grained_label_counter)) + + logger.info('Finished preprocessing central dataset successfully!') + + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stdout, + level=logging.INFO, + format='%(asctime)s %(levelname)s: %(message)s') + + argument_parser = argparse.ArgumentParser( + description= + 'Download and preprocess the images and labels of FLAIR dataset into HDF5 files.' + ) + argument_parser.add_argument( + '--output_file', + required=True, + help='Path to output HDF5 file that will be constructed by this script' + ) + argument_parser.add_argument('--not_group_data_by_user', + action='store_true', + default=False, + help='If true, do not group data by user IDs.' + 'If false, group data by user IDs to ' + 'make suitable for federated learning.') + arguments = argument_parser.parse_args() + + if arguments.not_group_data_by_user: + preprocess_central_dataset(arguments.output_file) + else: + preprocess_federated_dataset(arguments.output_file) diff --git a/benchmarks/dataset/flair/prepare_dataset.py b/benchmarks/dataset/flair/prepare_dataset.py deleted file mode 100644 index 5a9ba50..0000000 --- a/benchmarks/dataset/flair/prepare_dataset.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. -import argparse -import json -import logging -import os -import sys -from collections import Counter, defaultdict -from typing import Dict, Tuple - -import h5py -import numpy as np -import tqdm -from PIL import Image - -logger = logging.getLogger(name=__name__) - -LABEL_DELIMITER = '|' # Labels will be joined by delimiter and saved to hdf5 -LOG_INTERVAL = 100 # Log the preprocessing progress every interval steps - - -def load_user_metadata_and_label_counters( - labels_file: str) -> Tuple[Dict, Counter, Counter]: - """ - Load labels and metadata keyed by `user_id`, and label counts. - - :param labels_file: - A .json file with a list of labels and metadata dictionaries. Each - dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. - * `image_id` is the ID of an image. - * `user_id` is the ID of the user `image_id` belongs to. - * `labels` is a list of 17 higher-order class labels. - * `fine_grained_labels` is a list of 1,628 fine-grained class labels. - :return: - Three dictionaries. First dictionary has key being `user_id` and value - being a list of labels and metadata for each image `user_id` owns. - Second and third dictionaries are counts for the labels for coarse-grained - and fine-grained taxonomies. - """ - user_metadata = defaultdict(list) - with open(labels_file) as f: - metadata_list = json.load(f) - - label_counter: Counter = Counter() - fine_grained_label_counter: Counter = Counter() - for metadata in metadata_list: - user_metadata[metadata["user_id"]].append(metadata) - label_counter.update(metadata["labels"]) - fine_grained_label_counter.update(metadata["fine_grained_labels"]) - return user_metadata, label_counter, fine_grained_label_counter - - -def preprocess_federated_dataset(image_dir: str, labels_file: str, - output_file: str): - """ - Process images and labels into a HDF5 federated dataset where data is - first split by train/test partitions and then split again by user ID. - - :param image_dir: - Path to directory of images output from the script - `download_dataset.sh`. - :param labels_file: - A .json file with a list of labels and metadata dictionaries. Each - dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. - * `image_id` is the ID of an image. - * `user_id` is the ID of the user `image_id` belongs to. - * `labels` is a list of 17 higher-order class labels. - * `fine_grained_labels` is a list of ~1,600 fine-grained class labels. - :param output_file: - Output path for HDF5 file. Use the postfix `.hdf5`. - """ - logger.info('Preprocessing federated dataset.') - (user_metadata, label_counter, fine_grained_label_counter - ) = load_user_metadata_and_label_counters(labels_file) - - label_to_index = { - label: index - for index, label in enumerate(sorted(label_counter.keys())) - } - fine_grained_label_to_index = { - fine_grained_label: index - for index, fine_grained_label in enumerate( - sorted(fine_grained_label_counter.keys())) - } - - label_counter = Counter() - fine_grained_label_counter = Counter() - with h5py.File(output_file, 'w') as h5file: - # Iterate through users of each partition. - for i, user_id in tqdm.tqdm(enumerate(user_metadata), - total=len(user_metadata)): - # This snippet was used to generate flair_federated_small.h5 - #if i > len(user_metadata)*0.01: - # break - #if len(user_metadata[user_id]) > 20: - # # Skip large users - # continue - # This snippet was used to generate flair_federated_ci.h5 - #if i > 12: - # break - #if i > 10: - # user_metadata[user_id][0]['partition'] = 'test' - - # Load and concatenate all images of a user. - image_array, image_id_array = [], [] - labels_row, labels_col = [], [] - fine_grained_labels_row, fine_grained_labels_col = [], [] - # Load and concatenate all images and labels of a user. - for j, metadata in enumerate(user_metadata[user_id]): - image_id = metadata["image_id"] - image = Image.open(os.path.join(image_dir, f"{image_id}.jpg")) - image_array.append(np.asarray(image)) - image_id_array.append(image_id) - # Encode labels as row indices and column indices - labels_row.extend([j] * len(metadata["labels"])) - labels_col.extend( - [label_to_index[l] for l in metadata["labels"]]) - fine_grained_labels_row.extend( - [j] * len(metadata["fine_grained_labels"])) - fine_grained_labels_col.extend([ - fine_grained_label_to_index[l] - for l in metadata["fine_grained_labels"] - ]) - # Update label counter - label_counter.update(metadata["labels"]) - fine_grained_label_counter.update( - metadata["fine_grained_labels"]) - - partition = user_metadata[user_id][0]["partition"] - # Multiple variable-length labels. Needs to be stored as a string. - h5file[f'/{partition}/{user_id}/labels_row'] = np.asarray( - labels_row, dtype=np.uint16) - h5file[f'/{partition}/{user_id}/labels_col'] = np.asarray( - labels_col, dtype=np.uint8) - h5file[ - f'/{partition}/{user_id}/fine_grained_labels_row'] = np.asarray( - fine_grained_labels_row, dtype=np.uint16) - h5file[ - f'/{partition}/{user_id}/fine_grained_labels_col'] = np.asarray( - fine_grained_labels_col, dtype=np.uint16) - h5file[f'/{partition}/{user_id}/image_ids'] = np.asarray( - image_id_array, dtype='S') - # Tensor with dimensions [num_images,width,height,channels] - h5file.create_dataset(f'/{partition}/{user_id}/images', - data=np.stack(image_array)) - - if (i + 1) % LOG_INTERVAL == 0: - logger.info(f"Processed {i + 1}/{len(user_metadata)} users") - - # Write metadata - h5file['/metadata/label_mapping'] = json.dumps(label_to_index) - h5file['/metadata/fine_grained_label_mapping'] = json.dumps( - fine_grained_label_to_index) - - logger.info('Finished preprocess federated dataset successfully!') - - -def preprocess_central_dataset(image_dir: str, labels_file: str, - output_file: str): - """ - Process images and labels into a HDF5 (not federated) dataset where - data is split by train/val/test partitions. - - Same parameters as `preprocess_federated_dataset`. - """ - logger.info('Preprocessing central dataset.') - (user_metadata, _, _) = load_user_metadata_and_label_counters(labels_file) - label_counter: Counter = Counter() - fine_grained_label_counter: Counter = Counter() - with h5py.File(output_file, 'w') as h5file: - # Iterate through users of each partition. - for i, user_id in enumerate(user_metadata): - # Load and concatenate all images of a user. - for metadata in user_metadata[user_id]: - image_id = metadata["image_id"] - image = Image.open(os.path.join(image_dir, f"{image_id}.jpg")) - partition = metadata["partition"] - h5file.create_dataset(f'/{partition}/{image_id}/image', - data=np.asarray(image)) - # Encode labels as a single string, separated by delimiter | - h5file[ - f'/{partition}/{image_id}/labels'] = LABEL_DELIMITER.join( - metadata["labels"]) - h5file[f'/{partition}/{image_id}/fine_grained_labels'] = ( - LABEL_DELIMITER.join(metadata["fine_grained_labels"])) - h5file[f'/{partition}/{image_id}/user_id'] = user_id - # Update label counter - label_counter.update(metadata["labels"]) - fine_grained_label_counter.update( - metadata["fine_grained_labels"]) - - if (i + 1) % LOG_INTERVAL == 0: - logger.info(f"Processed {i + 1}/{len(user_metadata)} users") - - # Write metadata - h5file['/metadata/label_mapping'] = json.dumps(label_counter) - h5file['/metadata/fine_grained_label_mapping'] = json.dumps( - fine_grained_label_counter) - - logger.info('Finished preprocessing central dataset successfully!') - - -if __name__ == '__main__': - logging.basicConfig(stream=sys.stdout, - level=logging.INFO, - format='%(asctime)s %(levelname)s: %(message)s') - - argument_parser = argparse.ArgumentParser( - description= - 'Preprocess the images and labels of FLAIR dataset into HDF5 files.') - argument_parser.add_argument( - '--dataset_dir', - required=True, - help='Path to directory of images and label file. ' - 'Can be downloaded using download_dataset.py') - argument_parser.add_argument( - '--output_file', - required=True, - help='Path to output HDF5 file that will be constructed by this script' - ) - argument_parser.add_argument('--not_group_data_by_user', - action='store_true', - default=False, - help='If true, do not group data by user IDs.' - 'If false, group data by user IDs to ' - 'make suitable for federated learning.') - arguments = argument_parser.parse_args() - - image_dir = os.path.join(arguments.dataset_dir, "small_images") - labels_file = os.path.join(arguments.dataset_dir, - "labels_and_metadata.json") - if arguments.not_group_data_by_user: - preprocess_central_dataset(image_dir, labels_file, - arguments.output_file) - else: - preprocess_federated_dataset(image_dir, labels_file, - arguments.output_file) diff --git a/benchmarks/flair/README.md b/benchmarks/flair/README.md index 49e177f..89e9fcf 100644 --- a/benchmarks/flair/README.md +++ b/benchmarks/flair/README.md @@ -13,8 +13,7 @@ Same as the [default setup](../README.md). ## Download and preprocess FLAIR dataset ``` -python -m dataset.flair.download_dataset --dataset_dir ./data/flair -python -m dataset.flair.prepare_dataset --dataset_dir data/flair/ --output_file data/flair/flair_federated.hdf5 +python -m dataset.flair.download_preprocess --output_file data/flair/flair_federated.hdf5 ``` ## Run benchmarks diff --git a/benchmarks/poetry.lock b/benchmarks/poetry.lock index fb31b53..83b6ce2 100644 --- a/benchmarks/poetry.lock +++ b/benchmarks/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -11,6 +11,116 @@ files = [ {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"}, ] +[[package]] +name = "aiohttp" +version = "3.9.5" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"}, + {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"}, + {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"}, + {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, + {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, + {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"}, + {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"}, + {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"}, + {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"}, + {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "brotlicffi"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + [[package]] name = "astunparse" version = "1.6.3" @@ -26,11 +136,22 @@ files = [ six = ">=1.6.1,<2.0" wheel = ">=0.23.0,<1.0" +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + [[package]] name = "attrs" version = "23.2.0" description = "Classes Without Boilerplate" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, @@ -98,7 +219,7 @@ files = [ name = "certifi" version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, @@ -120,7 +241,7 @@ files = [ name = "charset-normalizer" version = "3.3.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -optional = true +optional = false python-versions = ">=3.7.0" files = [ {file = "charset-normalizer-3.3.0.tar.gz", hash = "sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6"}, @@ -266,6 +387,50 @@ files = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] +[[package]] +name = "datasets" +version = "2.19.1" +description = "HuggingFace community-driven open-source library of datasets" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"}, + {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]} +huggingface-hub = ">=0.21.2" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=12.0.0" +pyarrow-hotfix = "*" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["ruff (>=0.3.0)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.6.0)"] +tensorflow-gpu = ["tensorflow (>=2.6.0)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + [[package]] name = "decorator" version = "5.1.1" @@ -435,6 +600,130 @@ files = [ {file = "flatbuffers-23.5.26.tar.gz", hash = "sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89"}, ] +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + +[[package]] +name = "fsspec" +version = "2024.3.1" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, +] + +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + [[package]] name = "gast" version = "0.5.4" @@ -605,6 +894,40 @@ files = [ [package.dependencies] numpy = ">=1.17.3" +[[package]] +name = "huggingface-hub" +version = "0.23.1" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, + {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] name = "identify" version = "2.5.30" @@ -623,7 +946,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, @@ -889,6 +1212,105 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + [[package]] name = "multiprocess" version = "0.70.15" @@ -1087,6 +1509,75 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] +[[package]] +name = "pandas" +version = "2.2.2" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, +] + +[package.dependencies] +numpy = {version = ">=1.22.4", markers = "python_version < \"3.11\""} +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + [[package]] name = "pfl" version = "0.1.0" @@ -1294,6 +1785,65 @@ scipy = "*" [package.extras] extra = ["flake8", "jupyter", "nbconvert", "pandas", "plotly", "pytest", "sympy", "tensorflow-privacy", "tqdm"] +[[package]] +name = "pyarrow" +version = "16.1.0" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, + {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, + {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, + {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, + {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, + {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, + {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = false +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pyasn1" version = "0.5.0" @@ -1389,6 +1939,17 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2024.1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1452,7 +2013,7 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, @@ -2163,6 +2724,17 @@ files = [ {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, ] +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "urllib3" version = "2.0.6" @@ -2314,6 +2886,123 @@ files = [ {file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, ] +[[package]] +name = "xxhash" +version = "3.4.1" +description = "Python binding for xxHash" +optional = false +python-versions = ">=3.7" +files = [ + {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"}, + {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"}, + {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"}, + {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"}, + {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"}, + {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"}, + {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"}, + {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"}, + {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"}, + {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"}, + {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"}, + {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"}, + {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"}, + {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"}, + {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"}, + {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"}, + {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"}, + {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"}, + {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"}, +] + [[package]] name = "yapf" version = "0.40.2" @@ -2330,6 +3019,109 @@ importlib-metadata = ">=6.6.0" platformdirs = ">=3.5.1" tomli = ">=2.0.1" +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] name = "zipp" version = "3.17.0" @@ -2352,4 +3144,4 @@ tf = ["pfl", "pfl", "tensorflow", "tensorflow_addons"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "46ec49debe16e94d0c654c385c8a0386938fd4beac106ffbf6eee551671cff1a" +content-hash = "041acca8627b3f11a5d041a9e99fb06a09dc1109424881d944e33c8be6d38476" diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml index e928a86..d7cf297 100644 --- a/benchmarks/pyproject.toml +++ b/benchmarks/pyproject.toml @@ -29,6 +29,7 @@ pfl = [ { path = "../", extras = ["pytorch"], markers="extra=='pytorch'", develop = true }, ] pillow = ">=10.2.0" +datasets = "^2.19.1" [tool.poetry.extras] tf = ["pfl", "tensorflow_addons", "tensorflow"] From d73c07bcfe0c0bafb085967200107919f36ca460 Mon Sep 17 00:00:00 2001 From: Mona Chitnis Date: Thu, 30 May 2024 19:05:14 +0200 Subject: [PATCH 14/22] option to lowercase metric keys on serialization. --- pfl/metrics.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pfl/metrics.py b/pfl/metrics.py index 12cfb3a..0664a06 100644 --- a/pfl/metrics.py +++ b/pfl/metrics.py @@ -575,9 +575,9 @@ def __or__(self, other) -> 'Metrics': return Metrics(itertools.chain(self, other)) def to_simple_dict( - self, - force_serialize_all_metrics: bool = False - ) -> Dict[str, Union[float, int]]: + self, + force_serialize_all_metrics: bool = False, + to_lowercase: bool = False) -> Dict[str, Union[float, int]]: """ Returns a python dictionary of name-value pairs of metrics and their values, e.g. {'Loss': 0.12, 'Accuracy': 0.45}. All metric names are @@ -586,6 +586,12 @@ def to_simple_dict( :param force_serialize_all_metrics: Default to False. Indicate whether or not to include metrics that are marked to be ignored on serialization. + :param to_lowercase: + Default to False. Indicate lowercasing entire metrics key name. + This uniformity is necessary for downstream analytics pipelines + such as Splunk and Delphi. + On the contrary, Proper-casing might be better for readability + in stdout in iPython notebooks and logs. """ def convert(metric_name, weighted_value): @@ -599,9 +605,10 @@ def convert(metric_name, weighted_value): metric_name = str(metric_name) - # Uppercase the first character. - name_uppercase = metric_name[0].upper() + metric_name[1:] - return (name_uppercase, get_overall_value(weighted_value)) + # conditionally Uppercase the first character + modified_name = metric_name.lower( + ) if to_lowercase else metric_name[0].upper() + metric_name[1:] + return (modified_name, get_overall_value(weighted_value)) return dict( convert(*value) for value in self._hash_to_keyvalue.values() From ad9cc5e0b5937621e7c460b4bcb844879a197424 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Fri, 31 May 2024 11:55:51 +0200 Subject: [PATCH 15/22] Relax Pytorch version to >2,<3 (#69) --- benchmarks/poetry.lock | 8 ++++---- benchmarks/pyproject.toml | 8 ++++---- poetry.lock | 4 ++-- pyproject.toml | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmarks/poetry.lock b/benchmarks/poetry.lock index 83b6ce2..2a27500 100644 --- a/benchmarks/poetry.lock +++ b/benchmarks/poetry.lock @@ -1598,13 +1598,13 @@ tensorflow = {version = ">=2.14,<3.0", optional = true, markers = "sys_platform tensorflow-macos = {version = "^2.14", optional = true, markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} tensorflow-probability = {version = "^0.22", optional = true} torch = [ - {version = "2.0.1+cu118", optional = true, markers = "sys_platform == \"linux\""}, - {version = "2.0.1", optional = true, markers = "sys_platform == \"darwin\""}, + {version = ">=2.0.1+cu118,<3.0.0", optional = true, markers = "sys_platform == \"linux\""}, + {version = ">=2.0.1,<3.0.0", optional = true, markers = "sys_platform == \"darwin\""}, ] wheel = "^0.41.2" [package.extras] -pytorch = ["cmake (>=3.27.5,<4.0.0)", "torch (==2.0.1)", "torch (==2.0.1+cu118)"] +pytorch = ["cmake (>=3.27.5,<4.0.0)", "torch (>=2.0.1+cu118,<3.0.0)", "torch (>=2.0.1,<3.0.0)"] tf = ["cmake (>=3.27.5,<4.0.0)", "tensorflow (>=2.14,<3.0)", "tensorflow (>=2.14,<3.0)", "tensorflow-macos (>=2.14,<3.0)", "tensorflow-probability (>=0.22,<0.23)"] trees = ["scikit-learn (>=1.0.2,<2.0.0)", "xgboost (>=1.4.2,<2.0.0)"] @@ -3144,4 +3144,4 @@ tf = ["pfl", "pfl", "tensorflow", "tensorflow_addons"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "041acca8627b3f11a5d041a9e99fb06a09dc1109424881d944e33c8be6d38476" +content-hash = "2c4a3e255f8fd453e4048dc5aaf861c95169cc74ad99f2b3e501f43cbdab1833" diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml index d7cf297..1ed3bc9 100644 --- a/benchmarks/pyproject.toml +++ b/benchmarks/pyproject.toml @@ -16,12 +16,12 @@ tensorflow = { version = "^2.14.0", optional = true } tensorflow_probability = { version = "^0.22", optional = true } tensorflow_addons = { version = ">=0.20.0,<1", optional = true } torch = [ - { version = "2.0.1+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, - { version = "2.0.1", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, + { version = "^2.0.1+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, + { version = "^2.0.1", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, ] torchvision = [ - { version = "0.15.2+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, - { version = "0.15.2", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, + { version = "^0.15.2+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, + { version = "^0.15.2", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, ] # Installs pfl from source. pfl = [ diff --git a/poetry.lock b/poetry.lock index bbd03f3..241f194 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -2661,4 +2661,4 @@ trees = ["scikit-learn", "xgboost"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "edf6985c2d793fdc44b121ae23867d960e3d884c089103aa2e53db8a45ddc498" +content-hash = "f9141fce1445e145c3fbe2c3f7bb9ada28750d8a10a819b3c6dd4f3b107f8013" diff --git a/pyproject.toml b/pyproject.toml index 73673fc..bb3b6dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,8 @@ dp-accounting = "^0.4" prv-accountant = "^0.2.0" #### Will be installed only with "pytorch" install extra torch = [ - { version = "2.0.1+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, - { version = "2.0.1", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, + { version = "^2.0.1+cu118", source = "torch_cu118", markers = "sys_platform == 'linux'", optional = true }, + { version = "^2.0.1", source = "PyPI", markers = "sys_platform == 'darwin'", optional = true }, ] #### Will be installed only with "tf" install extra tensorflow = [ From 82bf921bd00e5a595951198f8ed63cacc7fe3c92 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Fri, 31 May 2024 15:14:49 +0200 Subject: [PATCH 16/22] return local metadata from training (#71) --- pfl/context.py | 13 +++++++++++++ pfl/model/pytorch.py | 10 +++++++--- pfl/model/tensorflow.py | 9 +++++++-- tests/model/test_pytorch_model.py | 16 +++++++++++----- tests/model/test_tensorflow_model.py | 14 ++++++++++++++ 5 files changed, 52 insertions(+), 10 deletions(-) diff --git a/pfl/context.py b/pfl/context.py index a7658bf..c14d7db 100644 --- a/pfl/context.py +++ b/pfl/context.py @@ -7,6 +7,19 @@ from pfl.metrics import Metrics +@dataclass(frozen=True) +class LocalResultMetaData: + """ + Data that is typically returned by a model's local optimization procedure, + e.g. ``PyTorchModel.do_multiple_epochs_of``. Can have useful information + needed by the algorithm. + + :param num_steps: + The number of local steps taken during the local optimization procedure. + """ + num_steps: int + + @dataclass(frozen=True) class UserContext: """ diff --git a/pfl/model/pytorch.py b/pfl/model/pytorch.py index bae6ee4..19c77c9 100644 --- a/pfl/model/pytorch.py +++ b/pfl/model/pytorch.py @@ -6,6 +6,7 @@ import torch +from pfl.context import LocalResultMetaData from pfl.data.dataset import AbstractDatasetType from pfl.exception import CheckpointNotFoundError from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams @@ -314,7 +315,8 @@ def _get_local_num_steps(train_params: NNTrainHyperParams, def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, train_params: NNTrainHyperParams, - train_step_fn: Callable, **kwargs) -> None: + train_step_fn: Callable, + **kwargs) -> LocalResultMetaData: """ Perform multiple epochs of training. The customizable training function that will use a batch of data to update the local @@ -350,13 +352,14 @@ def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, local_optimizer.zero_grad() # Common arguments used in `train_step_fn` + local_num_steps = self._get_local_num_steps(train_params, + len(user_dataset)) train_step_args = pytorch_ops.PyTorchTrainStepArgs( amp_context=self._amp_context or contextlib.nullcontext(), grad_scaler=self._grad_scaler, max_grad_norm=train_params.local_max_grad_norm, grad_accumulation_state=pytorch_ops.GradAccumulationState( - self._get_local_num_steps(train_params, len(user_dataset)), - train_params.grad_accumulation_steps)) + local_num_steps, train_params.grad_accumulation_steps)) for _ in range(num_epochs): for batch_ix, batch in enumerate( user_dataset.iter(train_params.get('local_batch_size'))): @@ -366,6 +369,7 @@ def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, train_step_fn(self._model, local_optimizer, batch, user_dataset.train_kwargs, train_step_args, **kwargs) + return LocalResultMetaData(num_steps=local_num_steps) def evaluate(self, dataset: AbstractDatasetType, diff --git a/pfl/model/tensorflow.py b/pfl/model/tensorflow.py index f235c95..6819505 100644 --- a/pfl/model/tensorflow.py +++ b/pfl/model/tensorflow.py @@ -5,8 +5,9 @@ import uuid from typing import Callable, Dict, Optional, Tuple, Union -import tensorflow as tf # type: ignore +import tensorflow as tf +from pfl.context import LocalResultMetaData from pfl.data.dataset import AbstractDatasetType from pfl.exception import CheckpointNotFoundError from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams @@ -271,7 +272,8 @@ def _reset_local_optimizer(self, optimizer, learning_rate): def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, train_params: NNTrainHyperParams, - train_step_fn: Callable, **kwargs) -> None: + train_step_fn: Callable, + **kwargs) -> LocalResultMetaData: """ Perform multiple epochs of training. The customizable training function that will use a batch of data to update the local @@ -308,6 +310,7 @@ def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, assert train_params.grad_accumulation_steps == 1, ( "Gradient accumulation is not yet supported in TensorFlow") + num_steps_taken = 0 for _ in range(num_epochs): for batch_ix, batch in enumerate( user_dataset.iter(train_params.get('local_batch_size'))): @@ -319,6 +322,8 @@ def do_multiple_epochs_of(self, user_dataset: AbstractDatasetType, *batch, user_dataset.train_kwargs, **kwargs) + num_steps_taken += 1 + return LocalResultMetaData(num_steps=num_steps_taken) def evaluate(self, dataset: AbstractDatasetType, diff --git a/tests/model/test_pytorch_model.py b/tests/model/test_pytorch_model.py index 67066f8..ccd1b69 100644 --- a/tests/model/test_pytorch_model.py +++ b/tests/model/test_pytorch_model.py @@ -71,21 +71,27 @@ def step_side_effect(): def new_local_optimizer(*args, **kwargs): return mock_local_optimizer + local_batch_size = 1 pytorch_model_setup.model.new_local_optimizer = new_local_optimizer - bridges.sgd_bridge().do_sgd( - pytorch_model_setup.model, user_dataset, + # This is same as bridges.sgd_bridge().do_sgd, but we want + # to check the returned metadata as well. + from pfl.internal.bridge.pytorch.sgd import _sgd_train_step + train_metadata = pytorch_model_setup.model.do_multiple_epochs_of( + user_dataset, NNTrainHyperParams( local_learning_rate=local_learning_rate, local_num_epochs=local_num_epochs, - local_batch_size=1, - grad_accumulation_steps=grad_accumulation_steps)) + local_batch_size=local_batch_size, + grad_accumulation_steps=grad_accumulation_steps), + _sgd_train_step) # Check if optimizer step is called correct number of times - total_steps = 2 * local_num_epochs + total_steps = len(user_dataset) / local_batch_size * local_num_epochs expected_optimizer_calls = ( total_steps // grad_accumulation_steps + int(total_steps % grad_accumulation_steps != 0)) assert mock_local_optimizer.step.call_count == expected_optimizer_calls + assert train_metadata.num_steps == total_steps # Check if each step the gradient is accumulated correctly assert len(step_grads) == len(expected_step_grads) diff --git a/tests/model/test_tensorflow_model.py b/tests/model/test_tensorflow_model.py index 1a58248..dd95192 100644 --- a/tests/model/test_tensorflow_model.py +++ b/tests/model/test_tensorflow_model.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from pfl.hyperparam.base import NNTrainHyperParams from pfl.internal.ops import get_tf_major_version from pfl.internal.ops.selector import _internal_reset_framework_module @@ -155,3 +156,16 @@ def test_save_and_load_central_optimizer_impl( Test if central optimizer could be save and restored """ check_save_and_load_central_optimizer_impl(tensorflow_model_setup) + + def test_local_train_metadata(self, tensorflow_model_setup, user_dataset): + model = tensorflow_model_setup.model + from pfl.internal.bridge.tensorflow.sgd import _make_train_step + step_fn = _make_train_step(model) + train_metadata = model.do_multiple_epochs_of( + user_dataset, + NNTrainHyperParams(local_learning_rate=0.1, + local_num_epochs=3, + local_batch_size=1), + step_fn, + max_grad_norm=None) + assert train_metadata.num_steps == 6 From 9f1878d21c23b07bb01cdba102fa3613a718ca09 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Tue, 4 Jun 2024 12:32:47 +0200 Subject: [PATCH 17/22] Put model to default device before initializing optimizer (#76) --- benchmarks/flair/train.py | 4 +++- benchmarks/image_classification/pytorch/train.py | 4 +++- benchmarks/llm/train.py | 3 +++ benchmarks/lm/pytorch/train.py | 3 +++ benchmarks/poetry.lock | 11 ++++++----- benchmarks/pyproject.toml | 3 +++ 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/benchmarks/flair/train.py b/benchmarks/flair/train.py index e4580ba..ce935ba 100644 --- a/benchmarks/flair/train.py +++ b/benchmarks/flair/train.py @@ -33,7 +33,7 @@ WandbCallback, ) from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams -from pfl.internal.ops.pytorch_ops import to_tensor +from pfl.internal.ops.pytorch_ops import get_default_device, to_tensor from pfl.model.pytorch import PyTorchModel from .argument_parsing import add_flair_training_arguments @@ -96,6 +96,8 @@ def main(): arguments.num_classes = num_classes pytorch_model = get_model_pytorch(arguments) + # Put on GPU if available. + pytorch_model = pytorch_model.to(get_default_device()) variables = [p for p in pytorch_model.parameters() if p.requires_grad] if arguments.central_optimizer == 'adam': diff --git a/benchmarks/image_classification/pytorch/train.py b/benchmarks/image_classification/pytorch/train.py index 0397da0..62bc667 100644 --- a/benchmarks/image_classification/pytorch/train.py +++ b/benchmarks/image_classification/pytorch/train.py @@ -31,7 +31,7 @@ WandbCallback, ) from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams -from pfl.internal.ops.pytorch_ops import to_tensor +from pfl.internal.ops.pytorch_ops import get_default_device, to_tensor from pfl.model.pytorch import PyTorchModel from pfl.privacy import CentrallyAppliedPrivacyMechanism @@ -81,6 +81,8 @@ def main(): _) = get_datasets(arguments) pytorch_model = get_model_pytorch(arguments) + # Put on GPU if available. + pytorch_model = pytorch_model.to(get_default_device()) params = [p for p in pytorch_model.parameters() if p.requires_grad] diff --git a/benchmarks/llm/train.py b/benchmarks/llm/train.py index 1e205fe..1321623 100644 --- a/benchmarks/llm/train.py +++ b/benchmarks/llm/train.py @@ -25,6 +25,7 @@ from pfl.aggregate.simulate import SimulatedBackend from pfl.callback import AggregateMetricsToDisk, CentralEvaluationCallback, StopwatchCallback, WandbCallback from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams +from pfl.internal.ops import pytorch_ops from pfl.model.pytorch import PyTorchModel @@ -88,6 +89,8 @@ def main(): peft_config = parse_peft_config(arguments) hf_model = wrap_hugging_face_model(hf_model, peft_config, causal_lm_metrics_fn) + # Put on GPU if available. + hf_model = hf_model.to(pytorch_ops.get_default_device()) params = [p for p in hf_model.parameters() if p.requires_grad] if arguments.central_optimizer == 'adam': diff --git a/benchmarks/lm/pytorch/train.py b/benchmarks/lm/pytorch/train.py index e175e96..72ae60b 100644 --- a/benchmarks/lm/pytorch/train.py +++ b/benchmarks/lm/pytorch/train.py @@ -31,6 +31,7 @@ WandbCallback, ) from pfl.hyperparam import NNEvalHyperParams, NNTrainHyperParams +from pfl.internal.ops.pytorch_ops import get_default_device from pfl.model.pytorch import PyTorchModel from pfl.privacy import CentrallyAppliedPrivacyMechanism @@ -88,6 +89,8 @@ def main(): arguments.max_sequence_length = metadata['max_sequence_length'] pytorch_model = get_model_pytorch(arguments) + # Put on GPU if available. + pytorch_model = pytorch_model.to(get_default_device()) params = [p for p in pytorch_model.parameters() if p.requires_grad] if arguments.central_optimizer == 'adam': diff --git a/benchmarks/poetry.lock b/benchmarks/poetry.lock index 2a27500..986c974 100644 --- a/benchmarks/poetry.lock +++ b/benchmarks/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1082,7 +1082,7 @@ testing = ["coverage", "pyyaml"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, @@ -1601,6 +1601,7 @@ torch = [ {version = ">=2.0.1+cu118,<3.0.0", optional = true, markers = "sys_platform == \"linux\""}, {version = ">=2.0.1,<3.0.0", optional = true, markers = "sys_platform == \"darwin\""}, ] +werkzeug = ">=3.0.3" wheel = "^0.41.2" [package.extras] @@ -2257,7 +2258,7 @@ requests = ">=2.21.0,<3" setuptools = ">=41.0.0" six = ">1.9" tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=3.0.3" +werkzeug = ">=1.0.1" [[package]] name = "tensorboard-data-server" @@ -2776,7 +2777,7 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess name = "werkzeug" version = "3.0.3" description = "The comprehensive WSGI web application library." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, @@ -3144,4 +3145,4 @@ tf = ["pfl", "pfl", "tensorflow", "tensorflow_addons"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "2c4a3e255f8fd453e4048dc5aaf861c95169cc74ad99f2b3e501f43cbdab1833" +content-hash = "0443f4cbcdf7aa09adf1ff78133d37ee5735b4ffc33f43e88b54ff7bd28df52e" diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml index 1ed3bc9..6ec985a 100644 --- a/benchmarks/pyproject.toml +++ b/benchmarks/pyproject.toml @@ -30,6 +30,9 @@ pfl = [ ] pillow = ">=10.2.0" datasets = "^2.19.1" +# Fixes vulnerability https://github.com/advisories/GHSA-2g68-c3qc-8985 +werkzeug = ">=3.0.3" + [tool.poetry.extras] tf = ["pfl", "tensorflow_addons", "tensorflow"] From cdc991a10e0f600dfdb99abc5b6acd1062374955 Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Tue, 4 Jun 2024 13:17:47 +0200 Subject: [PATCH 18/22] Improved greedy worker scheduling (#73) --- benchmarks/dataset/argument_parsing.py | 16 +++++++ benchmarks/dataset/flair/__init__.py | 13 +++--- benchmarks/dataset/flair/numpy.py | 33 ++++++++++++-- benchmarks/flair/configs/baseline.yaml | 1 + pfl/data/federated_dataset.py | 63 +++++++++++++++++--------- tests/data/test_federated_dataset.py | 2 +- 6 files changed, 97 insertions(+), 31 deletions(-) diff --git a/benchmarks/dataset/argument_parsing.py b/benchmarks/dataset/argument_parsing.py index d72c1fb..3304284 100644 --- a/benchmarks/dataset/argument_parsing.py +++ b/benchmarks/dataset/argument_parsing.py @@ -132,6 +132,20 @@ def add_dataset_arguments( default=100, help='Maximum number of images per user') + parser.add_argument( + '--scheduling_base_weight_multiplier', + type=float, + default=1.0, + help=('Figure 3b in pfl-research paper ' + 'https://arxiv.org/abs/2404.06430 show adding a' + 'base value for each user\'s weight for scheduling ' + 'in distributed simulations speeds up training. ' + 'This parameter adds a ' + 'multiplicative factor of the median user weight ' + 'as base value. 0.0 means no base value added and ' + '~1.0 is the optimal value for the FLAIR benchmark ' + 'according to Figure 3b (can be different for ' + 'other setups).')) elif known_args.dataset == 'alpaca': parser = add_artificial_fed_dataset_arguments(parser) @@ -289,6 +303,8 @@ def get_datasets( data_path=args.data_path, use_fine_grained_labels=args.use_fine_grained_labels, max_num_user_images=args.max_num_user_images, + scheduling_base_weight_multiplier=args. + scheduling_base_weight_multiplier, numpy_to_tensor=numpy_to_tensor) elif args.dataset == 'flair_pytorch': from .flair import make_flair_pytorch_datasets diff --git a/benchmarks/dataset/flair/__init__.py b/benchmarks/dataset/flair/__init__.py index 8f5cb82..b2fa017 100644 --- a/benchmarks/dataset/flair/__init__.py +++ b/benchmarks/dataset/flair/__init__.py @@ -24,7 +24,9 @@ def get_central_data_and_metadata(data_path: str, def make_flair_datasets(data_path: str, use_fine_grained_labels: bool, - max_num_user_images: int, numpy_to_tensor: Callable): + max_num_user_images: int, + scheduling_base_weight_multiplier: float, + numpy_to_tensor: Callable): """ Create a train and val ``FederatedDataset`` as well as a central dataset from the FLAIR dataset. @@ -33,11 +35,10 @@ def make_flair_datasets(data_path: str, use_fine_grained_labels: bool, training_federated_dataset = make_federated_dataset( data_path, 'train', use_fine_grained_labels, max_num_user_images, - numpy_to_tensor) - val_federated_dataset = make_federated_dataset(data_path, 'val', - use_fine_grained_labels, - max_num_user_images, - numpy_to_tensor) + scheduling_base_weight_multiplier, numpy_to_tensor) + val_federated_dataset = make_federated_dataset( + data_path, 'val', use_fine_grained_labels, max_num_user_images, + scheduling_base_weight_multiplier, numpy_to_tensor) central_data, metadata = get_central_data_and_metadata( data_path, use_fine_grained_labels) diff --git a/benchmarks/dataset/flair/numpy.py b/benchmarks/dataset/flair/numpy.py index 7920634..05d65e6 100644 --- a/benchmarks/dataset/flair/numpy.py +++ b/benchmarks/dataset/flair/numpy.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import logging from typing import Callable import h5py @@ -11,12 +12,15 @@ from .common import get_label_mapping, get_multi_hot_targets, get_user_num_images +logger = logging.getLogger(name=__name__) + def make_federated_dataset( hdf5_path: str, partition: str, use_fine_grained_labels: bool, max_num_user_images: int, + scheduling_base_weight_multiplier: float = 1., numpy_to_tensor: Callable = lambda x: x) -> FederatedDataset: """ Create federated dataset from the flair dataset, to use in simulations. @@ -31,14 +35,37 @@ def make_federated_dataset( Whether to use fine-grained label taxonomy. :param max_num_user_images: Maximum number of images each user can have. + :param scheduling_base_weight_multiplier: + The number of datapoints per user is used as a weight to greedily + schedule users in distributed simulations. Figure 3b in pfl-research + paper https://arxiv.org/abs/2404.06430 show adding a base value for + each user\'s weight for scheduling in distributed simulations + speeds up training. + This parameter adds a multiplicative factor of the median user weight + as base value. 0.0 means no base value added and ~1.0 is the optimal + value for the FLAIR benchmark according to Figure 3b (can be different + for other setups. :param numpy_to_tensor: Function that convert numpy array to ML framework tensor. :return: Federated dataset from the HDF5 data file. """ num_classes = len(get_label_mapping(hdf5_path, use_fine_grained_labels)) - user_num_images = get_user_num_images(hdf5_path, partition) - user_ids = sorted(user_num_images.keys()) + user_id_to_weight = { + k: min(v, max_num_user_images) + for k, v in get_user_num_images(hdf5_path, partition).items() + } + median_datapoints = np.median(list(user_id_to_weight.values())) + base_value = scheduling_base_weight_multiplier * median_datapoints + logger.info( + f'User mean datapoints: {np.mean(list(user_id_to_weight.values()))}, median datapoints: {median_datapoints}, base_value: {base_value}' + ) + user_id_to_weight = { + k: v + base_value + for k, v in user_id_to_weight.items() + } + + user_ids = sorted(user_id_to_weight.keys()) sampler = get_user_sampler('random', user_ids) def make_dataset_fn(user_id): @@ -59,7 +86,7 @@ def make_dataset_fn(user_id): return FederatedDataset(make_dataset_fn, sampler, - user_id_to_weight=user_num_images) + user_id_to_weight=user_id_to_weight) def make_artificial_federated_dataset( diff --git a/benchmarks/flair/configs/baseline.yaml b/benchmarks/flair/configs/baseline.yaml index dbf46b4..7a91c9a 100644 --- a/benchmarks/flair/configs/baseline.yaml +++ b/benchmarks/flair/configs/baseline.yaml @@ -25,6 +25,7 @@ evaluation_frequency: 20 central_eval_batch_size: 512 pretrained: true #save_model_path: './checkpoints' +scheduling_base_weight_multiplier: 1.0 # Simulate I.I.D. data by ignoring user ID # - dataset: flair_iid diff --git a/pfl/data/federated_dataset.py b/pfl/data/federated_dataset.py index 15dd394..c7cb82e 100644 --- a/pfl/data/federated_dataset.py +++ b/pfl/data/federated_dataset.py @@ -49,8 +49,9 @@ def _distributed_sampler(sampler, get_seed, rank, world_size): yield users_all_workers[rank], seeds_all_workers[rank] -def _sorted_cohort_subprocess(q_send, q_request, sampler, seed_sampler, rank, - world_size, user_id_to_weight): +def _sorted_cohort_subprocess(q_send, q_request, q_send_num, sampler, + seed_sampler, rank, world_size, + user_id_to_weight): cache = [] cohort_size = None cache_size = _INITIAL_USER_SAMPLES_CACHE_SIZE @@ -68,15 +69,26 @@ def _sorted_cohort_subprocess(q_send, q_request, sampler, seed_sampler, rank, if cohort_size is not None and len(cache) > cohort_size: # Perform request cohort_samples = cache[:cohort_size] - if user_id_to_weight is not None: - cohort_samples = sorted( - cohort_samples, - key=lambda tup: user_id_to_weight[tup[0]], - reverse=True) - q_send.put(cohort_samples) - + cohort_samples = sorted(cohort_samples, + key=lambda tup: user_id_to_weight[tup[0]], + reverse=True) + + # This optimization procedure should be deterministic + # across all worker processes. + worker_samples = [[] for _ in range(world_size)] + worker_total_weight = np.zeros(world_size) + for sample in cohort_samples: + # Add user to worker with the minimum load. + min_worker_index = np.argmin(worker_total_weight) + worker_samples[min_worker_index].append(sample) + worker_total_weight[min_worker_index] += user_id_to_weight[ + sample[0]] + + q_send_num.put(len(worker_samples[rank])) + q_send.put(worker_samples[rank]) cache = cache[cohort_size:] cohort_size = None + # Check if the queue is full if len(cache) < cache_size: # Generate multiple samples at a time @@ -119,21 +131,22 @@ def __init__(self, sampler, seed_sampler, rank, world_size, self._samples_q = mp.Queue() self._cohort_request_q = mp.Queue() + self._cohort_num_response_q = mp.Queue() self._sample_process = mp.Process( target=_sorted_cohort_subprocess, - args=(self._samples_q, self._cohort_request_q, sampler, - seed_sampler, rank, world_size, user_id_to_weight)) + args=(self._samples_q, self._cohort_request_q, + self._cohort_num_response_q, sampler, seed_sampler, rank, + world_size, user_id_to_weight)) self._sample_process.start() atexit.register(self.__del__) def __iter__(self): while True: - for i, sample in enumerate(self._samples_q.get()): - if i % self._world_size == self._rank: - yield sample + yield from self._samples_q.get() def set_cohort_size(self, cohort_size): self._cohort_request_q.put(cohort_size) + return self._cohort_num_response_q.get() def __del__(self): self._samples_q.close() @@ -353,10 +366,12 @@ def __init__(self, def _try_set_cohort_size(self, cohort_size: int): # Doesn't need a cohort to be set, can continue. + worker_cohort_size = None with contextlib.suppress(AttributeError): # pytype: disable=attribute-error - self._sample_fn.set_cohort_size(cohort_size) + worker_cohort_size = self._sample_fn.set_cohort_size(cohort_size) # pytype: enable=attribute-error + return worker_cohort_size def __next__(self) -> Tuple[AbstractDataset, int]: # Each worker will make a dataset with its own sampled user. @@ -429,12 +444,18 @@ def from_slices_with_dirichlet_class_distribution( def get_cohort(self, cohort_size: int) -> Iterable[Tuple[AbstractDataset, int]]: - self._try_set_cohort_size(cohort_size) - for i in range(cohort_size): - if (i % get_ops().distributed.world_size - ) == get_ops().distributed.global_rank: - user_ids, seed = next(self.sampler) - yield self.make_dataset_fn(user_ids), seed + # Set next cohort size for sampler if possible + worker_cohort_size = self._try_set_cohort_size(cohort_size) + if worker_cohort_size is None: + for i in range(cohort_size): + if (i % get_ops().distributed.world_size + ) == get_ops().distributed.global_rank: + user_id, seed = next(self.sampler) + yield self.make_dataset_fn(user_id), seed + else: + for _ in range(worker_cohort_size): + user_id, seed = next(self.sampler) + yield self.make_dataset_fn(user_id), seed class FederatedDatasetMixture(FederatedDatasetBase): diff --git a/tests/data/test_federated_dataset.py b/tests/data/test_federated_dataset.py index 20331e7..0818bb1 100644 --- a/tests/data/test_federated_dataset.py +++ b/tests/data/test_federated_dataset.py @@ -266,7 +266,7 @@ def test_3_workers(self, mock_get_ops, federated_dataset_mixture, @pytest.fixture def user_id_to_weight(request): if hasattr(request, 'param') and request.param: - return {i: i for i in range(100)} + return {i: i + 1 for i in range(100)} else: return None From 2173c3b381e70ee40a6b8f2d07352c2c7f5c9975 Mon Sep 17 00:00:00 2001 From: ac554 <47990575+ac554@users.noreply.github.com> Date: Tue, 4 Jun 2024 12:18:33 +0100 Subject: [PATCH 19/22] Code to accompany MDM paper (#75) --- publications/mdm/README.md | 26 + publications/mdm/mdm/__init__.py | 4 + publications/mdm/mdm/algorithm.py | 290 ++++++++++ publications/mdm/mdm/bridge/base.py | 46 ++ publications/mdm/mdm/bridge/factory.py | 104 ++++ .../mdm/mdm/bridge/pytorch/__init__.py | 0 .../mdm/mdm/bridge/pytorch/polya_mixture.py | 102 ++++ publications/mdm/mdm/init_algorithm.py | 266 +++++++++ publications/mdm/mdm/model.py | 137 +++++ publications/mdm/mdm_paper/README.md | 3 + publications/mdm/mdm_paper/__init__.py | 0 .../notebooks/MSE_cifar10_experiments.ipynb | 501 +++++++++++++++++ .../notebooks/cifar10_visualisations.ipynb | 241 ++++++++ .../notebooks/femnist_visualisations.ipynb | 516 ++++++++++++++++++ .../mdm/mdm_paper/training/__init__.py | 0 publications/mdm/mdm_paper/training/mle.py | 102 ++++ publications/mdm/mdm_paper/training/train.py | 149 +++++ .../mdm/mdm_paper/training/train_femnist.py | 101 ++++ .../training/train_femnist_rebuttal.py | 117 ++++ publications/mdm/mdm_utils/__init__.py | 0 .../mdm/mdm_utils/datasets/__init__.py | 3 + .../mdm/mdm_utils/datasets/cifar10_dataset.py | 158 ++++++ .../mdm/mdm_utils/datasets/femnist_dataset.py | 324 +++++++++++ .../mdm/mdm_utils/datasets/mixture_dataset.py | 154 ++++++ .../mdm/mdm_utils/datasets/sampler.py | 123 +++++ publications/mdm/mdm_utils/models/__init__.py | 1 + .../mdm/mdm_utils/models/argument_parsing.py | 188 +++++++ .../mdm/mdm_utils/models/pytorch/__init__.py | 6 + .../mdm/mdm_utils/models/pytorch/cnn.py | 226 ++++++++ .../mdm/mdm_utils/models/pytorch/dnn.py | 51 ++ .../mdm/mdm_utils/models/pytorch/layer.py | 88 +++ .../mdm/mdm_utils/models/pytorch/metrics.py | 54 ++ .../models/pytorch/module_modification.py | 131 +++++ .../mdm/mdm_utils/models/pytorch_model.py | 105 ++++ publications/mdm/mdm_utils/utils/__init__.py | 5 + .../mdm/mdm_utils/utils/argument_parsing.py | 154 ++++++ publications/mdm/mdm_utils/utils/tools.py | 55 ++ .../mdm/mdm_utils/utils/visualize_results.py | 71 +++ .../run_cifar10_mse_alpha_phi_experiments.sh | 70 +++ publications/mdm/run_femnist.sh | 34 ++ 40 files changed, 4706 insertions(+) create mode 100644 publications/mdm/README.md create mode 100644 publications/mdm/mdm/__init__.py create mode 100644 publications/mdm/mdm/algorithm.py create mode 100644 publications/mdm/mdm/bridge/base.py create mode 100644 publications/mdm/mdm/bridge/factory.py create mode 100644 publications/mdm/mdm/bridge/pytorch/__init__.py create mode 100644 publications/mdm/mdm/bridge/pytorch/polya_mixture.py create mode 100644 publications/mdm/mdm/init_algorithm.py create mode 100644 publications/mdm/mdm/model.py create mode 100644 publications/mdm/mdm_paper/README.md create mode 100644 publications/mdm/mdm_paper/__init__.py create mode 100644 publications/mdm/mdm_paper/notebooks/MSE_cifar10_experiments.ipynb create mode 100644 publications/mdm/mdm_paper/notebooks/cifar10_visualisations.ipynb create mode 100644 publications/mdm/mdm_paper/notebooks/femnist_visualisations.ipynb create mode 100644 publications/mdm/mdm_paper/training/__init__.py create mode 100644 publications/mdm/mdm_paper/training/mle.py create mode 100644 publications/mdm/mdm_paper/training/train.py create mode 100644 publications/mdm/mdm_paper/training/train_femnist.py create mode 100644 publications/mdm/mdm_paper/training/train_femnist_rebuttal.py create mode 100644 publications/mdm/mdm_utils/__init__.py create mode 100644 publications/mdm/mdm_utils/datasets/__init__.py create mode 100644 publications/mdm/mdm_utils/datasets/cifar10_dataset.py create mode 100644 publications/mdm/mdm_utils/datasets/femnist_dataset.py create mode 100644 publications/mdm/mdm_utils/datasets/mixture_dataset.py create mode 100644 publications/mdm/mdm_utils/datasets/sampler.py create mode 100644 publications/mdm/mdm_utils/models/__init__.py create mode 100644 publications/mdm/mdm_utils/models/argument_parsing.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/__init__.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/cnn.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/dnn.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/layer.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/metrics.py create mode 100644 publications/mdm/mdm_utils/models/pytorch/module_modification.py create mode 100644 publications/mdm/mdm_utils/models/pytorch_model.py create mode 100644 publications/mdm/mdm_utils/utils/__init__.py create mode 100644 publications/mdm/mdm_utils/utils/argument_parsing.py create mode 100644 publications/mdm/mdm_utils/utils/tools.py create mode 100644 publications/mdm/mdm_utils/utils/visualize_results.py create mode 100755 publications/mdm/run_cifar10_mse_alpha_phi_experiments.sh create mode 100755 publications/mdm/run_femnist.sh diff --git a/publications/mdm/README.md b/publications/mdm/README.md new file mode 100644 index 0000000..38e2a26 --- /dev/null +++ b/publications/mdm/README.md @@ -0,0 +1,26 @@ +# Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials (MDMs) + +This software project accompanies the research paper, "Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials". + +Mixture-of-Dirichlet-Multinomial (MDM) models allow one to model heterogeneous federated datasets, and such MDM models can be trained with privacy preserving federated learning. + +## Documentation + +This repo contains the code to run all experiments in the paper "Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials", and to process the results to produce the plots shown in the paper are available in the `mdm-paper` directory on this fork of the `pfl-research` framework, for running simulations using Private Federated Learning. + +The structure of the `mdm` repo is: +- `mdm/`: This directory contains the algorithmic code, implementing the MDM model algorithm in the pfl-research framework. +- `mdm_paper/`: This directory contains subdirectories: `training/` contains the python training scripts to run inference of MDM parameters on the CIFAR-10 and FEMNIST datasets; `notebooks/` contains Jupyter notebooks used to visualise results and create plots shown in the paper. +- `mdm_utils/`: This directory contains utilities to help with training setup, e.g. argument parsers, dataset helper functions, etc. + + +## Setup for experiments + +It is assumed you first follow the default setup for benchmarks in the pfl-research framework. The details to follow for this default setup are available [here](https://github.com/apple/pfl-research/blob/develop/benchmarks/README.md). + +It is next assumed that you have the FEMNIST and CIFAR-10 datasets downloaded locally in directories data/femnist/ and data/cifar10/ respectively. To download the data and ensure it is preprocessed correctly, please follow the pfl-research instructions for data setup [here](https://github.com/apple/pfl-research/tree/develop/benchmarks/image_classification). + +## Running experiments in paper +To run MDM parameter inference on CIFAR-10: ` bash publications/mdm/run_cifar10_mse_alpha_phi_experiments.sh`. + +To run MDM parameter inference on FEMNIST, where the users are split between the server-side dataset and the live dataset: `bash publications/mdm/run_femnist.sh`. diff --git a/publications/mdm/mdm/__init__.py b/publications/mdm/mdm/__init__.py new file mode 100644 index 0000000..7ae9963 --- /dev/null +++ b/publications/mdm/mdm/__init__.py @@ -0,0 +1,4 @@ +from .model import (MDMModelHyperParams, MDMModel) +from .init_algorithm import (MDMInitializationAlgorithmParams, + MDMInitializationAlgorithm) +from .algorithm import (MDMAlgorithmParams, MDMAlgorithm) diff --git a/publications/mdm/mdm/algorithm.py b/publications/mdm/mdm/algorithm.py new file mode 100644 index 0000000..0f3be49 --- /dev/null +++ b/publications/mdm/mdm/algorithm.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- + +from dataclasses import dataclass +from typing import Tuple, Optional, TypeVar, Callable, Union + +import numpy as np +import torch + +from pfl.common_types import Population +from pfl.data.dataset import AbstractDataset +from pfl.hyperparam import get_param_value +from pfl.metrics import Metrics +from pfl.context import CentralContext +from pfl.stats import MappedVectorStatistics +from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams +from pfl.data.dataset import AbstractDatasetType + +from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType +from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges + + +@dataclass(frozen=True) +class MDMAlgorithmParams(AlgorithmHyperParams): + """ + Parameters for initialization algorithm of Polya Mixture. + + :param central_num_iterations: + Number of iterations of training + :param extract_categories_fn: + Function to extract categories from user dataset. By default return + labels of user dataset. + """ + cohort_size: int + num_samples_mixture_bins: np.ndarray + central_num_iterations: int = 1 + extract_categories_fn: Callable[[AbstractDataset], Union[ + np.ndarray, + torch.Tensor]] = lambda user_dataset: user_dataset.raw_data[1] + + +MDMAlgorithmParamsType = TypeVar('MDMAlgorithmParamsType', + bound=MDMAlgorithmParams) + + +class MDMAlgorithm(FederatedAlgorithm[MDMAlgorithmParamsType, + MDMModelHyperParamsType, MDMModelType, + MappedVectorStatistics, AbstractDatasetType]): + """ + Federated algorithm class for learning mixture of Polya + (Dirichlet-Multinomial) distribution using MLE algorithm. + """ + + def get_next_central_contexts( + self, + model: MDMModelType, + iteration: int, + algorithm_params: MDMAlgorithmParamsType, + model_train_params: MDMModelHyperParamsType, + model_eval_params: Optional[MDMModelHyperParamsType] = None, + ) -> Tuple[Optional[Tuple[CentralContext[MDMAlgorithmParamsType, + MDMModelHyperParamsType], ...]], + MDMModelType, Metrics]: + + if (model.alphas <= 0).any(): + raise AssertionError( + f'Cannot have zero elements in alphas: {model.alphas}') + if (model.num_samples_distribution <= 0).any(): + raise AssertionError( + f'Cannot have zero elements in num_samples_distribution: {model.num_samples_distribution}' + ) + if (model.phi <= 0).any(): + raise AssertionError( + f'Cannot have zero elements in phi: {model.phi}') + + if iteration == algorithm_params.central_num_iterations: + return None, model, Metrics() + + configs = [ + CentralContext( + current_central_iteration=iteration, + do_evaluation=False, + cohort_size=get_param_value(algorithm_params.cohort_size), + population=Population.TRAIN, + model_train_params=model_train_params.static_clone(), + model_eval_params=None, + algorithm_params=algorithm_params.static_clone(), + seed=self._get_seed()) + ] + return tuple(configs), model, Metrics() + + def simulate_one_user( + self, + model: MDMModelType, + user_dataset: AbstractDataset, + central_context: CentralContext[MDMAlgorithmParamsType, + MDMModelHyperParamsType], + ) -> Tuple[Optional[MappedVectorStatistics], Metrics]: + """ + Encode user's dataset into statistics with a `MDMModel`. + + central_context.algorithm_params.extract_categories_fn is a callable + used to extract the categories tracked with the Polya-Mixture model. + """ + + if (model.alphas <= 0).any(): + raise AssertionError( + '> 0 alpha params have zero values, which would cause algorithm to fail. Cannot proceed.' + ) + + if (model.phi <= 0).any(): + raise AssertionError( + '> 0 phi params have zero values, which would cause algorithm to fail. Cannot proceed.' + ) + + if (model.num_samples_distribution <= 0).any(): + raise AssertionError( + '> 0 num_samples_distribution have zero values, which would cause algorithm to fail. Cannot proceed.' + ) + + # Get counts + categories = central_context.algorithm_params.extract_categories_fn( + user_dataset) + #num_user_samples = len(categories) + category_counts = bridges.polya_mixture_bridge( + ).category_counts_polya_mixture( + categories, central_context.model_train_params.num_categories) + + e = torch.zeros( + (central_context.model_train_params.num_components, + central_context.algorithm_params.num_samples_mixture_bins.shape[1] + )) + + num_samples = len(categories) + + # TODO make more flexible so it supports if each mixture has different bin edges. + # TODO here I just assume that all components have the same bin edges, which is why I index 0 in central_context.algorithm_params.num_samples_mixture_bins[0] + selected_bin = -1 + for i, bin_edge in enumerate( + central_context.algorithm_params.num_samples_mixture_bins[0]): + bin_edge = int(bin_edge) + if num_samples <= bin_edge: + selected_bin = i + break + + user_num_samples_distribution = model.num_samples_distribution[:, + selected_bin] + + # E Step - compute posterior probability of each component + posterior_probabilities = bridges.polya_mixture_bridge( + ).expectation_step(model.phi, model.alphas, + user_num_samples_distribution, category_counts) + + # M Step - compute client update to alphas for fixed point update + # which will be applied by the model in process_aggregated_statistics. + # Note the numerator and denominator are both weighted by w (the + # probability vector giving the client belonging to each component). + (numerator, + denominator) = bridges.polya_mixture_bridge().maximization_step( + posterior_probabilities, category_counts, model.alphas) + + e[:, selected_bin] = posterior_probabilities.view(-1) + + statistics = MappedVectorStatistics() + statistics['posterior_probabilities'] = posterior_probabilities.to('cpu') + statistics['numerator'] = numerator.to('cpu') + statistics['denominator'] = denominator.to('cpu') + statistics['num_samples_distribution'] = e.to('cpu') + + return statistics, Metrics() + + def process_aggregated_statistics( + self, central_context: CentralContext[MDMAlgorithmParamsType, + MDMModelHyperParamsType], + aggregate_metrics: Metrics, model: MDMModelType, + statistics: MappedVectorStatistics + ) -> Tuple[MDMModelType, Metrics]: + + # The new weight of a mixture component is the mean client weight of + # that component + + # TODO prevent any <= 0 values in posterior_probabilities, numerator and denominator and num_samples_distribution + posterior_probabilities = statistics['posterior_probabilities'] + numerator = statistics['numerator'] + denominator = statistics['denominator'] + num_samples_distribution = statistics['num_samples_distribution'] + + posterior_probabilities = torch.clamp(posterior_probabilities, min=0) + numerator = torch.clamp(numerator, min=0) + denominator = torch.clamp(denominator, min=0) + num_samples_distribution = torch.clamp(num_samples_distribution, min=0) + + print('\n\nProcess aggregated statistics') + print('numerator', numerator.shape, numerator) + print('denominator', denominator.shape, denominator) + print('num_samples_distribution', num_samples_distribution.shape, + num_samples_distribution) + + def prevent_zero(tensor, min_val=0, mass_reallocation_percentage=0.01): + if not (tensor == 0).any(): + print('no zeros in tensor') + return tensor + num_zero = torch.sum(tensor <= min_val, dim=1, keepdim=True) + total_mass = torch.sum(tensor, dim=1, keepdim=True) + if total_mass > 0: + extra_mass = total_mass * mass_reallocation_percentage / num_zero + tensor = torch.where(tensor > min_val, tensor, + extra_mass.expand_as(tensor)) + tensor = tensor / torch.sum(tensor, dim=1, + keepdim=True) * total_mass + return tensor + # no probability mass - instead force set all zero elements to 0.02 + return torch.ones_like(tensor) * 0.02 + + if (posterior_probabilities == 0).any(): + num_zero = torch.sum(posterior_probabilities == 0) + total_mass = torch.sum(posterior_probabilities) + extra_mass = total_mass * 0.01 / num_zero + posterior_probabilities = torch.where(posterior_probabilities > 0, + posterior_probabilities, + extra_mass) + posterior_probabilities = posterior_probabilities / torch.sum( + posterior_probabilities) * total_mass + assert total_mass == torch.sum(posterior_probabilities) + + if (numerator == 0).any(): + numerator = prevent_zero(numerator, min_val=1) + if (numerator == 0).any(): + raise AssertionError('prevent zero did not work on numerator') + + if (denominator == 0).any(): + denominator = prevent_zero(denominator) + if (denominator == 0).any(): + raise AssertionError( + 'prevent zero did not work on denominator') + + if (num_samples_distribution == 0).any(): + modified_num_samples_distribution = prevent_zero( + num_samples_distribution, min_val=0.001) + mass_reallocation_percentage = 0.01 + while (modified_num_samples_distribution == 0).any(): + mass_reallocation_percentage *= 2 + if mass_reallocation_percentage >= 1: + raise AssertionError( + f'prevent zero did not work on num_samples_distribution: {num_samples_distribution}' + ) + modified_num_samples_distribution = prevent_zero( + num_samples_distribution, + min_val=0.001, + mass_reallocation_percentage=mass_reallocation_percentage) + num_samples_distribution = modified_num_samples_distribution + + phi = posterior_probabilities / central_context.algorithm_params.cohort_size + + # Each alpha is updated using the fixed point update, note that the + # numerator and denominator are weighted by each client before being + # aggregated, so this is a weighted update. + + if (model.alphas == 0).any(): + raise AssertionError('model.alphas had zeros before update') + alphas = bridges.polya_mixture_bridge().update_alpha( + model.alphas, numerator, denominator) + if (alphas == 0).any(): + raise AssertionError('alphas has zeros after update') + + num_samples_distribution = num_samples_distribution / posterior_probabilities.reshape( + -1, 1).expand_as(num_samples_distribution) + + # renormalise num_samples_distribution again on server, since DP might mean that statistics don't sum to 1 per mixture component + num_samples_distribution = num_samples_distribution / num_samples_distribution.sum( + dim=1, keepdim=True) + if (num_samples_distribution == 0).any(): + num_samples_distribution = prevent_zero(num_samples_distribution, + min_val=0.001) + + # renormalise phi + phi = phi / phi.sum() + + model, metrics = model.apply_model_update( + MappedVectorStatistics({ + 'alphas': + alphas, + 'phi': + phi, + 'num_samples_distribution': + num_samples_distribution + })) + metrics['alphas'] = alphas + metrics['phi'] = phi + metrics['num_samples_distribution'] = num_samples_distribution + return model, metrics diff --git a/publications/mdm/mdm/bridge/base.py b/publications/mdm/mdm/bridge/base.py new file mode 100644 index 0000000..f63e763 --- /dev/null +++ b/publications/mdm/mdm/bridge/base.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +from typing import Any, Dict, Protocol, TypeVar, Tuple + + +Tensor = TypeVar('Tensor') + + +class PolyaMixtureFrameworkBridge(Protocol[Tensor]): + """ + Interface for Polya-Mixture algorithm for a particular Deep Learning + framework. + """ + + @staticmethod + def category_counts_polya_mixture(categories: Tensor, + num_categories: int) -> Tensor: + """ + """ + pass + + @staticmethod + def category_probabilities_polya_mixture_initialization( + num_components, num_categories, component, categories) -> Tensor: + """ + """ + pass + + @staticmethod + def expectation_step(phi, alphas, category_counts) -> Tensor: + """ + """ + pass + + @staticmethod + def maximization_step(posterior_probabilities, category_counts, + alphas) -> Tensor: + """ + """ + pass + + @staticmethod + def update_alpha(alphas, numerator, denominator) -> Tensor: + """ + """ + pass diff --git a/publications/mdm/mdm/bridge/factory.py b/publications/mdm/mdm/bridge/factory.py new file mode 100644 index 0000000..c395e26 --- /dev/null +++ b/publications/mdm/mdm/bridge/factory.py @@ -0,0 +1,104 @@ +# Copyright © 2023-2024 Apple Inc. +from pfl.internal.bridge.base import ( + CommonFrameworkBridge, + FedProxFrameworkBridge, + FTRLFrameworkBridge, + SCAFFOLDFrameworkBridge, + SGDFrameworkBridge, +) +from pfl.internal.ops.framework_types import MLFramework +from pfl.internal.ops.selector import get_framework_module + +from publications.mdm.mdm.bridge.base import PolyaMixtureFrameworkBridge + + +class FrameworkBridgeFactory: + """ + A collection of bridges to deep learning specific + implementations for several algorithms. + The bridge returned depends on the Deep Learning + framework in use. + This way, we can inject framework-specific code + into an algorithm, and only have one implementation + of each algorithm in the public interface, e.g. one + public FedAvg class instead of one for each of TF, + PyTorch, etc. + + Each method returns a class with utility functions + for a particular algorithm. + """ + + @staticmethod + def common_bridge() -> CommonFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import common as common_pt + return common_pt.PyTorchCommonBridge + elif framework == MLFramework.TENSORFLOW: + from .tensorflow import common as common_tf + return common_tf.TFCommonBridge + elif framework == MLFramework.NUMPY: + from .numpy import common as common_np + return common_np.NumpyCommonBridge + else: + raise NotImplementedError("Common bridge not available " + f"for framework {framework}") + + @staticmethod + def sgd_bridge() -> SGDFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import sgd as sgd_pt + return sgd_pt.PyTorchSGDBridge + elif framework == MLFramework.TENSORFLOW: + from .tensorflow import sgd as sgd_tf + return sgd_tf.TFSGDBridge + else: + raise NotImplementedError("SGD bridge not available " + f"for framework {framework}") + + @staticmethod + def fedprox_bridge() -> FedProxFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import proximal as proximal_pt + return proximal_pt.PyTorchFedProxBridge + elif framework == MLFramework.TENSORFLOW: + from .tensorflow import proximal as proximal_tf + return proximal_tf.TFFedProxBridge + else: + raise NotImplementedError("FedProx bridge not available " + f"for framework {framework}") + + @staticmethod + def scaffold_bridge() -> SCAFFOLDFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import scaffold as scaffold_pt + return scaffold_pt.PyTorchSCAFFOLDBridge + else: + raise NotImplementedError("SCAFFOLD bridge not available " + f"for framework {framework}") + + @staticmethod + def ftrl_bridge() -> FTRLFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import ftrl as ftrl_pt + return ftrl_pt.PyTorchFTRLBridge + elif framework == MLFramework.TENSORFLOW: + from .tensorflow import ftrl as ftrl_tf + return ftrl_tf.TFFTRLBridge + else: + raise NotImplementedError("FTRL bridge not available " + f"for framework {framework}") + + @staticmethod + def polya_mixture_bridge() -> PolyaMixtureFrameworkBridge: + framework = get_framework_module().FRAMEWORK_TYPE + if framework == MLFramework.PYTORCH: + from .pytorch import polya_mixture as polya_mixture_pt + return polya_mixture_pt.PyTorchPolyaMixtureBridge + else: + raise NotImplementedError("PolyaMixture bridge not available " + f"for framework {framework}") diff --git a/publications/mdm/mdm/bridge/pytorch/__init__.py b/publications/mdm/mdm/bridge/pytorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm/bridge/pytorch/polya_mixture.py b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py new file mode 100644 index 0000000..a43ed9c --- /dev/null +++ b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple +import torch + +from ..base import PolyaMixtureFrameworkBridge + + +class PyTorchPolyaMixtureBridge(PolyaMixtureFrameworkBridge[torch.Tensor]): + """ + Interface for Polya-Mixture algorithm in PyTorch. + """ + + @staticmethod + def category_counts_polya_mixture(categories: torch.Tensor, + num_categories: int) -> torch.Tensor: + categories = categories.to('cpu') + uniques, counts = torch.unique(categories, return_counts=True) + category_counts = torch.zeros(num_categories).int() + category_counts[uniques.int().to(torch.int64)] = counts.int() + return category_counts + + @staticmethod + def category_probabilities_polya_mixture_initialization( + num_components, num_categories, component, + categories) -> torch.Tensor: + # Compute counts of each class and normalize to probability vectors + uniques, counts = torch.unique(categories, return_counts=True) + counts = counts.to('cpu') + uniques = uniques.to('cpu') + p = torch.zeros((num_components, num_categories)) + p = p.to('cpu') + p[component][uniques.int().to(torch.int64)] = counts.float( + ) # User contributes only to estimate for their mixture component + p[component] /= torch.sum(p[component]) + + return p + + @staticmethod + def expectation_step(phi, alphas, num_samples_distribution, + category_counts) -> torch.Tensor: + if (num_samples_distribution == 0).any(): + raise AssertionError('num_samples_distribution contains zero values, which cannot work with expectation step on clients') + + # E Step - compute posterior probability of each component + # Compute log prior + log likelihood + # TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False) + phi = torch.Tensor(phi).to('cpu') + alphas = torch.Tensor(alphas).to('cpu') + category_counts = category_counts.to('cpu') + num_samples_distribution = num_samples_distribution.to('cpu') + log_v = torch.log(phi) + ( + torch.lgamma(torch.sum(alphas, dim=1, keepdim=False)) - + torch.lgamma( + torch.sum(category_counts + alphas, dim=1, keepdim=False)) + + torch.sum( + torch.lgamma(category_counts + alphas) - torch.lgamma(alphas), + dim=1, + keepdim=False)) + torch.log(num_samples_distribution) + + # TODO Ignore this as log(0) => NaN + # TODO fix this equation so that it works with num_samples_distribution = 0 + # + torch.log(num_samples_distribution[:, num_user_samples]) + + # Compute log probability of the data, computed like this for numerical stability + # computes sum(v) + log_normalization_constant = log_v[0] + torch.log( + torch.sum(torch.exp(log_v - log_v[0]))) + #print('log_normalization_constant', log_normalization_constant) + + # Compute posterior probability + w = torch.exp(log_v - log_normalization_constant) + + return w + + @staticmethod + def maximization_step(posterior_probabilities, category_counts, + alphas) -> torch.Tensor: + # M Step - compute client update to alphas for fixed point update + # which will be applied by the model in process_aggregated_statistics. + # Note the numerator and denominator are both weighted by w (the + # probability vector giving the client belonging to each component). + posterior_probabilities = torch.Tensor(posterior_probabilities).to('cpu') + category_counts = torch.Tensor(category_counts).to('cpu') + alphas = torch.Tensor(alphas).to('cpu') + numerator = posterior_probabilities.reshape( + (-1, 1)) * (torch.digamma(category_counts + alphas) - + torch.digamma(alphas)) + # Paper currently says something different, where alphas should all be summed first. + denominator = posterior_probabilities.reshape( + (-1, 1)) * (torch.digamma( + torch.sum(category_counts + alphas, dim=1, keepdim=True)) - + torch.digamma(torch.sum(alphas, dim=1, keepdim=True))) + + return numerator, denominator + + @staticmethod + def update_alpha(alphas, numerator, denominator) -> torch.Tensor: + alphas = alphas.to('cpu') + numerator = numerator.to('cpu') + denominator = denominator.to('cpu') + return torch.Tensor(alphas) * numerator / denominator diff --git a/publications/mdm/mdm/init_algorithm.py b/publications/mdm/mdm/init_algorithm.py new file mode 100644 index 0000000..9a24c1b --- /dev/null +++ b/publications/mdm/mdm/init_algorithm.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- + +from dataclasses import dataclass +from typing import Tuple, Optional, TypeVar, Callable, Union +from collections import defaultdict + +import numpy as np +import torch + +from pfl.common_types import Population +from pfl.data.dataset import AbstractDataset +from pfl.hyperparam import get_param_value +from pfl.metrics import Metrics +from pfl.context import CentralContext +from pfl.stats import MappedVectorStatistics +from pfl.internal.ops import get_ops +from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams +from pfl.data.dataset import AbstractDatasetType + +from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType +from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges + + +@dataclass(frozen=True) +class MDMInitializationAlgorithmParams(AlgorithmHyperParams): + """ + Parameters for initialization algorithm of Polya Mixture. + + :param strategy: + Strategy for a user to decide which component to contribute + to for initialization. Currently only 'random' implemented. + :param central_num_iterations: + Number of iterations to perform in algorithm. + :param cohort_size: + Number of users over which to aggregate statistics in each + iteration. + """ + cohort_size: int + num_samples_mixture_bins: np.ndarray + strategy: str = 'random' + central_num_iterations: int = 1 + extract_categories_fn: Callable[[AbstractDataset], Union[ + np.ndarray, + torch.Tensor]] = lambda user_dataset: user_dataset.raw_data[1] + + def __post_init__(self): + assert self.strategy in [ + 'random' + ], (f'strategy {self.strategy} is not supported') + assert self.central_num_iterations >= 1 + assert self.cohort_size > 0 + assert np.all(self.num_samples_mixture_bins > 0) + + +MDMInitializationAlgorithmParamsType = TypeVar( + 'MDMInitializationAlgorithmParamsType', + bound=MDMInitializationAlgorithmParams) + + +class MDMInitializationAlgorithm( + FederatedAlgorithm[MDMInitializationAlgorithmParamsType, + MDMModelHyperParamsType, MDMModelType, + MappedVectorStatistics, AbstractDatasetType]): + """ + Federated algorithm class for learning initialization of mixture of Polya + (Dirichlet-Multinomial) distribution. + """ + + def __init__(self, statistics_dir: Optional[str] = None): + super().__init__() + self._running_sums = defaultdict(int) + + def get_next_central_contexts( + self, + model: MDMModelType, + iteration: int, + algorithm_params: MDMInitializationAlgorithmParamsType, + model_train_params: MDMModelHyperParamsType, + model_eval_params: Optional[MDMModelHyperParamsType] = None, + ) -> Tuple[Optional[Tuple[CentralContext[ + MDMInitializationAlgorithmParamsType, MDMModelHyperParamsType], + ...]], MDMModelType, Metrics]: + + if iteration == algorithm_params.central_num_iterations: + return None, model, Metrics() + + configs = [ + CentralContext( + current_central_iteration=iteration, + do_evaluation=False, + cohort_size=get_param_value(algorithm_params.cohort_size), + population=Population.TRAIN, + model_train_params=model_train_params.static_clone(), + model_eval_params=None, + algorithm_params=algorithm_params.static_clone(), + seed=self._get_seed()) + ] + return tuple(configs), model, Metrics() + + def simulate_one_user( + self, + model: MDMModelType, + user_dataset: AbstractDataset, + central_context: CentralContext[MDMInitializationAlgorithmParamsType, + MDMModelHyperParamsType], + ) -> Tuple[Optional[MappedVectorStatistics], Metrics]: + """ + Encode user's dataset into statistics with a `MDMModel`. + """ + algorithm_params = central_context.algorithm_params + + if algorithm_params.strategy == 'random': + # Randomly assign user to a mixture component + # TODO this approach might not work when num mixture > 1 and the + # cohort size is large, as the p and q values will likely be very + # similar for all clusters and this symmetry could hurt convergence. + component = np.random.choice( + range(central_context.model_train_params.num_components)) + else: + raise ValueError( + f'Strategy {algorithm_params.strategy} not recognized.' + 'Only "random" strategy is implemented.') + + # Compute counts of each class and normalize to probability vectors + # User contributes only to estimate for their mixture component + categories = central_context.algorithm_params.extract_categories_fn( + user_dataset) + p = bridges.polya_mixture_bridge( + ).category_probabilities_polya_mixture_initialization( + central_context.model_train_params.num_components, + central_context.model_train_params.num_categories, component, + categories) + q = p**2 + + # Record user mixture component + # component sizes are needed for initialization computation + e = torch.zeros( + (central_context.model_train_params.num_components, + central_context.algorithm_params.num_samples_mixture_bins.shape[1] + )) + + num_samples = len(categories) + + selected_bin = -1 + for i, bin_edge in enumerate(central_context.algorithm_params. + num_samples_mixture_bins[component]): + bin_edge = int(bin_edge) + if num_samples <= bin_edge: + selected_bin = i + break + + e[component, selected_bin] = 1 + + statistics = MappedVectorStatistics() + statistics['p'] = p.to('cpu') + statistics['q'] = q.to('cpu') + statistics['e'] = e.to('cpu') + return statistics, Metrics() + + def process_aggregated_statistics( + self, central_context: CentralContext[ + MDMInitializationAlgorithmParamsType, + MDMModelHyperParamsType], aggregate_metrics: Metrics, + model: MDMModelType, statistics: MappedVectorStatistics + ) -> Tuple[MDMModelType, Metrics]: + + # Directly aggregate running sum of statistics + # only relevant if num_central_iterations > 1 + for key, val in statistics.items(): + self._running_sums[key] += val + + if (central_context.current_central_iteration == + central_context.algorithm_params.central_num_iterations - 1): + + num_components = central_context.model_train_params.num_components + + p = self._running_sums['p'] + q = self._running_sums['q'] + e = self._running_sums['e'] + + print('init aggregated statistics') + print('p', p) + print('q', q) + print('e', e) + + # Set all elements < 0 to 0 + # This may arise due to differential privacy noise + p = torch.clamp(p, min=0) + q = torch.clamp(q, min=0) + e = torch.clamp(e, min=0) + + # Need to prevent p, q, and subsequently alpha, and num_samples_distribution from having no non-zero values. + # Fix issue of alpha = 0 by assigning some small prob to all categories + # Cannot have alpha = 0 for any of the categories. + # In practice, we might get alpha = 0 if the cohort size used for the initialisation step was too small, + # such that we did not see any instances of this category occuring in the population. + # In practice, one should go over the entire population to find the probability of each category. + # Fix this issue by apportioning a small amount of the probability to that value of alpha. + + # Check if any element is equal to zero + if (p == 0).any(): + num_categories_leq_zero = torch.sum(p <= 0, dim=1) + extra_mass = torch.sum(p, + dim=1) * 0.01 / num_categories_leq_zero + + p = torch.where(p > 0, p, extra_mass.unsqueeze(1).expand_as(p)) + # Note fix zero issue in p and q separately because q >= p^2 + + # TODO Consider approximating p as p/(cohort size/num_components), + # since users are randomly assigned to components + # Similarly can approximate q as q/(cohort_size/num_components). + # This might make the results more accurate, since DP noise will + # be added to e, component_sums will be noisy. + #component_sums = torch.sum(e, dim=1, keepdim=True) + num_users_component = central_context.algorithm_params.cohort_size / num_components + p = torch.divide(p, num_users_component) + + if (q == 0).any(): + q[q == 0] = torch.pow(p[q == 0], 2) * 1.1 + q = torch.divide(q, num_users_component) + + if (e == 0).any(): + num_zero = torch.sum(e == 0, dim=1, keepdim=True) + extra_mass = torch.sum(e, dim=1, + keepdim=True) * 0.01 / num_zero + e = torch.where(e > 0, e, extra_mass.expand_as(e)) + + num_samples_distribution = torch.divide(e, num_users_component) + + # Compute alpha that matches the first two moments + # of the empirical distribution + + # Use either category with max probability or else average/median over all categories + init_coefficient = (p - q) / (q - torch.pow(p, 2)) + + k_pmax = torch.argmax(p, dim=1) + k_pmax_coefficient = init_coefficient[ + torch.arange(init_coefficient.size(0)), k_pmax] + + moment_matching_alpha_k_pmax = p * k_pmax_coefficient.reshape( + -1, 1) + + alphas = moment_matching_alpha_k_pmax + + # if only two categories and one component + #alpha_0 = 0.5 * (1 - p[0,1]) / (1 - p[0,0] - p[0,1]) + #alpha_1 = 0.5 * (1 - p[0,0]) / (1 - p[0,0] - p[0,1]) + #alphas = [torch.Tensor([alpha_0, alpha_1])] + + phi = 1 / num_components * torch.ones(num_components) + + model, metrics = model.apply_model_update( + MappedVectorStatistics({ + 'alphas': + alphas, + 'phi': + phi, + 'num_samples_distribution': + num_samples_distribution + })) + metrics['alphas'] = alphas + metrics['phi'] = phi + metrics['num_samples_distribution'] = num_samples_distribution + return model, metrics + else: + return model, Metrics() diff --git a/publications/mdm/mdm/model.py b/publications/mdm/mdm/model.py new file mode 100644 index 0000000..3372fa3 --- /dev/null +++ b/publications/mdm/mdm/model.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- + +from typing import TypeVar, Generic, Tuple, List, Union, Optional +from dataclasses import dataclass +import os +import joblib + +import numpy as np + +from pfl.hyperparam.base import ModelHyperParams +from pfl.model.base import Model +from pfl.metrics import Metrics +from pfl.stats import MappedVectorStatistics +from pfl.exception import CheckpointNotFoundError +from pfl.internal.ops.selector import set_framework_module +from pfl.internal.ops import pytorch_ops +from pfl.internal.ops.selector import get_default_framework_module as get_ops + +Tensor = TypeVar('Tensor') +FrameworkModelType = TypeVar('FrameworkModelType') + + +@dataclass(frozen=True) +class MDMModelHyperParams(ModelHyperParams): + """ + Parameters for Polya-Mixture model. + """ + num_components: int + num_categories: int + + def __post_init__(self): + if self.num_components is not None: + assert self.num_components >= 1, ( + 'Must have >= 1 component in Polya-Mixture model') + if self.num_categories is not None: + assert self.num_categories >= 2, ( + 'Must have >= 2 categories being modelled with Polya-Mixture') + + +MDMModelHyperParamsType = TypeVar('MDMModelHyperParamsType', + bound=MDMModelHyperParams) + + +class MDMModel(Model, Generic[MDMModelHyperParamsType, Tensor]): + """ + Polya Mixture model. + + Used Fixed-Point solver. + + Model that applies a weighted version of the fixed point Polya update from + https://tminka.github.io/papers/dirichlet/minka-dirichlet.pdf to the alpha + of each mixture component. + :param phi: + np.ndarray of shape (number_mixture_components,) giving the weight of + each component, sums to 1. + :param alphas: + np.ndarray of shape (number_mixture_components, num_categories), stores + the alpha parameter for the Dirichlet of each mixture component. + """ + set_framework_module(pytorch_ops) + _MODEL_CKPT_NAME = "polya-mixture.joblib" + + def __init__(self, + alphas: Optional[Union[List, np.ndarray]] = None, + phi: Optional[Union[List, np.ndarray]] = None, + num_samples_distribution: Optional[Union[List, + np.ndarray]] = None): + + self._alphas = get_ops().to_tensor( + alphas) if alphas is not None else None + self._phi = get_ops().to_tensor(phi) if phi is not None else None + self._num_samples_distribution = get_ops().to_tensor( + num_samples_distribution + ) if num_samples_distribution is not None else None + + def _to_dict(self): + return { + 'alphas': self._alphas, + 'phi': self._phi, + 'num_samples_distribution': self._num_samples_distribution + } + + def save(self, dir_path: str) -> None: + """ + Save a Polya-Mixture model to disk. + + :param dir_path: + Path to which to save Polya-Mixture model will be saved. + """ + if not os.path.isdir(dir_path): + os.makedirs(dir_path) + + save_path = os.path.join(dir_path, self._MODEL_CKPT_NAME) + joblib.dump(self._to_dict(), save_path) + + def load(self, dir_path: str) -> None: + save_path = os.path.join(dir_path, self._MODEL_CKPT_NAME) + if not os.path.exists(save_path): + raise CheckpointNotFoundError(save_path) + + parameters = joblib.load(save_path) + try: + self._alphas = parameters['alphas'] + self._phi = parameters['phi'] + self._num_samples_distribution = parameters[ + 'num_samples_distribution'] + except KeyError as e: + raise KeyError( + 'Polya-Mixture model checkpoint does not ' + 'contain required keys: "alpha", "phi", ' + f'"num_components", "num_categories", "num_samples_distribution": {e}' + ) + + @property + def phi(self) -> Union[np.ndarray, List[float]]: + return self._phi + + @property + def alphas(self) -> Union[np.ndarray, List[float]]: + return self._alphas + + @property + def num_samples_distribution(self) -> Union[np.ndarray, List[float]]: + return self._num_samples_distribution + + def apply_model_update( + self, + statistics: MappedVectorStatistics) -> Tuple['MDMModel', Metrics]: + + self._alphas = statistics['alphas'] + self._phi = statistics['phi'] + self._num_samples_distribution = statistics['num_samples_distribution'] + + return self, Metrics() + + +MDMModelType = TypeVar('MDMModelType', bound=MDMModel) diff --git a/publications/mdm/mdm_paper/README.md b/publications/mdm/mdm_paper/README.md new file mode 100644 index 0000000..0a7e824 --- /dev/null +++ b/publications/mdm/mdm_paper/README.md @@ -0,0 +1,3 @@ +# Code for paper "Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials" + +The code to run all experiments in the paper, and to process the results to produce the plots shown in the paper are available here. diff --git a/publications/mdm/mdm_paper/__init__.py b/publications/mdm/mdm_paper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm_paper/notebooks/MSE_cifar10_experiments.ipynb b/publications/mdm/mdm_paper/notebooks/MSE_cifar10_experiments.ipynb new file mode 100644 index 0000000..5d3ce36 --- /dev/null +++ b/publications/mdm/mdm_paper/notebooks/MSE_cifar10_experiments.ipynb @@ -0,0 +1,501 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "id": "d168630b-085e-4ad9-82fe-413abc04c6b0", + "metadata": {}, + "outputs": [], + "source": [ + "import joblib\n", + "import os\n", + "from collections import defaultdict\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 315, + "id": "0eb31238-e448-4ba9-a15a-e64cbc2459a4", + "metadata": {}, + "outputs": [], + "source": [ + "results = defaultdict(dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 348, + "id": "49c3fd57-e79f-4f01-95a1-d100794f3dd3", + "metadata": {}, + "outputs": [], + "source": [ + "def mse(y, t):\n", + " return np.linalg.norm(y-t, ord=2) / np.linalg.norm(t, ord=2) #np.size(y)\n", + "\n", + "def get_results(dirpath, targets):\n", + " mse_alphas = []\n", + " mse_phi = []\n", + " \n", + " i = 0\n", + " while True:\n", + " path = f'{dirpath}/{str(i)}/polya-mixture.joblib'\n", + " if os.path.exists(path):\n", + " d = joblib.load(f'{dirpath}/{str(i)}/polya-mixture.joblib')\n", + " y_a = d['alphas'].numpy() \n", + " y_p = d['phi'].numpy()\n", + " t_a = np.array(targets['alphas']).reshape(y_a.shape)\n", + " t_p = np.array(targets['phi']).reshape(y_p.shape)\n", + "\n", + " if False: # instead of this, just reorder targets\n", + " # Order according to increasing phi value\n", + " y_sorting_indices = np.argsort(y_p)\n", + " t_sorting_indices = np.argsort(t_p)\n", + " y_a = y_a[y_sorting_indices, :] if y_a.ndim == 2 else y_a[y_sorting_indices]\n", + " t_a = t_a[t_sorting_indices, :] if t_a.ndim == 2 else t_a[t_sorting_indices]\n", + " y_p = y_p[y_sorting_indices]\n", + " t_p = t_p[t_sorting_indices]\n", + " \n", + " mse_alphas.append(mse(y_a, t_a))\n", + " mse_phi.append(mse(y_p, t_p))\n", + " i += 1\n", + " else:\n", + " break\n", + " print('y_a', y_a)\n", + " print('t_a', t_a)\n", + " print('y_p', y_p)\n", + " print('t_p', t_p)\n", + "\n", + " return mse_alphas, mse_phi" + ] + }, + { + "cell_type": "code", + "execution_count": 349, + "id": "34f11c71-e68e-4a6b-a548-684632ac2e45", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.9830888 0.9836113 0.949971 1.0266845 0.98187065 0.93955207\n", + " 1.0021306 1.0542753 0.98674315 0.9831437 ]]\n", + "t_a [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", + "y_p [1.]\n", + "t_p [1.]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_1_mixture_easy_iteration_models'\n", + "targets = {'phi': 1.0, 'alphas': [1.0]*10}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Low heterogeneity - 1 mixture component'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 350, + "id": "a44366a7-d41e-43cf-a2a4-0d2104a4c77d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.8949972 0.8292597 0.8842902 0.8516318 0.85689884 0.8867707\n", + " 0.8828852 0.8397239 0.8045578 0.89611095]\n", + " [0.9422966 0.91131413 0.9135793 0.90758824 0.94545484 0.94940096\n", + " 0.9347347 0.91238105 0.90499514 0.96489817]]\n", + "t_a [[0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8]\n", + " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ]]\n", + "y_p [0.38324493 0.6167551 ]\n", + "t_p [0.5 0.5]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_2_mixture_easy_iteration_models'\n", + "targets = {'phi': [0.5, 0.5], 'alphas': [[0.8]*10, [1.0]*10]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Low heterogeneity - 2 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 351, + "id": "090350bd-bbc9-4b57-b893-5350a2f8a389", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[1.066593 1.0640484 0.99796337 0.99338555 1.0719604 1.0735439\n", + " 1.0606164 0.99636525 1.0607691 1.0325032 ]\n", + " [0.8113324 0.7412155 0.7421719 0.76742476 0.7978137 0.80386424\n", + " 0.8376431 0.7696238 0.78532445 0.78983146]\n", + " [0.88768315 0.8750244 0.8565851 0.8092215 0.9191889 0.90774536\n", + " 0.9050219 0.8186563 0.82398546 0.86230224]]\n", + "t_a [[1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1]\n", + " [0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8]\n", + " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ]]\n", + "y_p [0.6212769 0.13670582 0.2420173 ]\n", + "t_p [0.334 0.333 0.333]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_3_mixture_easy_iteration_models'\n", + "targets = {'phi': [0.334, 0.333, 0.333], 'alphas': [[1.1]*10, [0.8]*10, [1.0]*10]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Low heterogeneity - 3 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 352, + "id": "3af0041b-ef85-4d82-a8e0-d24c3113e9b0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.09486745 0.18039015 0.6038319 1.0086943 1.9933786 0.09489949\n", + " 0.99804974 2.0103388 0.5047902 0.501056 ]]\n", + "t_a [[0.1 0.2 0.6 1. 2. 0.1 1. 2. 0.5 0.5]]\n", + "y_p [1.]\n", + "t_p [1.]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_1_mixture_medium_iteration_models'\n", + "targets = {'phi': 1.0, 'alphas': [0.1, 0.2, 0.6, 1.0, 2.0, 0.1, 1.0, 2.0, 0.5, 0.5]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Medium heterogeneity - 1 mixture component'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 353, + "id": "c23bfe0d-d870-4410-8fca-fe4b08040ec3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.09601912 0.203612 0.576978 0.9885703 1.942201 0.09881581\n", + " 0.9899525 2.032183 0.52628845 0.511908 ]\n", + " [1.1092798 0.1123076 0.6980247 0.90534943 1.0147401 0.18994084\n", + " 0.5059109 0.9570854 0.79806316 1.5191599 ]]\n", + "t_a [[0.1 0.2 0.6 1. 2. 0.1 1. 2. 0.5 0.5]\n", + " [1.1 0.1 0.7 0.9 1. 0.2 0.5 1. 0.8 1.5]]\n", + "y_p [0.37693217 0.6230678 ]\n", + "t_p [0.4 0.6]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_2_mixture_medium_iteration_models'\n", + "targets = {'phi': [0.4, 0.6], 'alphas': [[0.1, 0.2, 0.6, 1.0, 2.0, 0.1, 1.0, 2.0, 0.5, 0.5], [1.1, 0.1, 0.7, 0.9, 1.0, 0.2, 0.5, 1.0, 0.8, 1.5]]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Medium heterogeneity - 2 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 354, + "id": "2e1b0397-184b-4c0c-81fb-979899e35e11", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.51219416 0.51829547 0.47929743 0.53371894 0.50652635 0.499473\n", + " 0.47228953 0.4900735 0.4821596 0.48754203]\n", + " [0.13210084 0.17273213 0.5936526 0.997307 1.9424596 0.10008831\n", + " 1.0737354 2.0053742 0.52007437 0.49917984]\n", + " [1.039937 0.09993631 0.75087065 0.89299226 0.9463883 0.17576067\n", + " 0.5173786 1.0259137 0.78837794 1.5182108 ]]\n", + "t_a [[0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]\n", + " [0.1 0.2 0.6 1. 2. 0.1 1. 2. 0.5 0.5]\n", + " [1.1 0.1 0.7 0.9 1. 0.2 0.5 1. 0.8 1.5]]\n", + "y_p [0.20531842 0.3174111 0.47727048]\n", + "t_p [0.2 0.3 0.5]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_3_mixture_medium_iteration_models'\n", + "targets = {'phi': [0.2, 0.3, 0.5], 'alphas': [[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.1, 0.2, 0.6, 1.0, 2.0, 0.1, 1.0, 2.0, 0.5, 0.5], [1.1, 0.1, 0.7, 0.9, 1.0, 0.2, 0.5, 1.0, 0.8, 1.5]]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['Medium heterogeneity - 3 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 355, + "id": "64a00cc7-e66b-494f-adea-df47c093b755", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.10798693 0.08479945 0.09277119 0.10356809 0.0958684 0.09630357\n", + " 0.08943666 0.10059562 0.10332205 0.09725104]]\n", + "t_a [[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]\n", + "y_p [1.]\n", + "t_p [1.]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_1_mixture_hard_iteration_models'\n", + "targets = {'phi': 1.0, 'alphas': [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['High heterogeneity - 1 mixture component'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 356, + "id": "4514fb77-c061-4d90-863c-f8901b1baf90", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[0.1807397 0.17358837 0.1946859 0.18628934 0.19083592 0.17438039\n", + " 0.18713334 0.18131252 0.18108001 0.19267383]\n", + " [0.34003973 0.34449705 0.35007524 0.35807556 0.36345354 0.3291148\n", + " 0.33449852 0.3500933 0.3563593 0.3211949 ]]\n", + "t_a [[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n", + " [0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3]]\n", + "y_p [0.37103534 0.6289647 ]\n", + "t_p [0.1 0.9]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_2_mixture_hard_iteration_models'\n", + "targets = {'phi': [0.1, 0.9], 'alphas': [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['High heterogeneity - 2 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 357, + "id": "44415ead-b5ee-4a28-b42b-14ad43aff582", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_a [[1.1966388 0.08676551 0.72974104 0.9376729 1.0623652 0.20623356\n", + " 0.51617557 1.0524946 0.8251987 1.5541486 ]\n", + " [0.24933581 0.2668144 0.27757534 0.25659457 0.2190239 0.25226066\n", + " 0.25375283 0.26445332 0.2606129 0.21766934]\n", + " [1.0591303 0.08694185 0.626189 0.8334799 0.9366921 0.16620831\n", + " 0.47646916 0.944582 0.7469978 1.4314592 ]]\n", + "t_a [[1.1 0.1 0.7 0.9 1. 0.2 0.5 1. 0.8 1.5]\n", + " [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n", + " [0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3]]\n", + "y_p [0.5383507 0.19458508 0.2670642 ]\n", + "t_p [0.8 0.05 0.15]\n" + ] + } + ], + "source": [ + "dirpath = '../mle_params/cifar10_3_mixture_hard_iteration_models'\n", + "targets = {'phi': [0.8, 0.05, 0.15], 'alphas': [[1.1, 0.1, 0.7, 0.9, 1.0, 0.2, 0.5, 1.0, 0.8, 1.5], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]]}\n", + "\n", + "a, p = get_results(dirpath, targets)\n", + "results['High heterogeneity - 3 mixture components'] = {'alphas': a, 'phi': p}" + ] + }, + { + "cell_type": "code", + "execution_count": 358, + "id": "5bdb12e7-6422-4636-b5a3-2b224e60ab7e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Normalised MSE of alpha parameters during training\\non artificially federated CIFAR-10\\nunder different target heterogeneity conditions\\nand using different numbers of mixture components')" + ] + }, + "execution_count": 358, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for k in results.keys():\n", + " if 'Low' in k:\n", + " line = '--+'\n", + " elif 'Medium' in k:\n", + " line = '-o'\n", + " else:\n", + " line = '--'\n", + "\n", + " if '1' in k:\n", + " colour = 'red'\n", + " elif '2' in k:\n", + " colour = 'green'\n", + " elif '3' in k:\n", + " colour = 'blue'\n", + " plt.plot(results[k]['alphas'], line, color=colour, label=k+' - MSE alphas')\n", + "\n", + "\n", + " \n", + "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + "plt.xlabel('Training iterations')\n", + "plt.ylabel('Normalised MSE')\n", + "plt.title('Normalised MSE of alpha parameters during training\\non artificially federated CIFAR-10\\nunder different target heterogeneity conditions\\nand using different numbers of mixture components')" + ] + }, + { + "cell_type": "code", + "execution_count": 359, + "id": "e3be4caa-0215-491f-aee9-6dda4aa1f08d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Normalised MSE of phi parameters during training\\non artificially federated CIFAR-10\\nunder different target heterogeneity conditions\\nand using different numbers of mixture components')" + ] + }, + "execution_count": 359, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for k in results.keys():\n", + " if 'Low' in k:\n", + " line = '--+'\n", + " elif 'Medium' in k:\n", + " line = '-o'\n", + " else:\n", + " line = '--'\n", + "\n", + " if '1' in k:\n", + " colour = 'red'\n", + " print(results[k]['phi'])\n", + " elif '2' in k:\n", + " colour = 'green'\n", + " elif '3' in k:\n", + " colour = 'blue'\n", + " plt.plot(results[k]['phi'], line, color=colour, label=k+' - MSE phi')\n", + " \n", + "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + "plt.xlabel('Training iterations')\n", + "plt.ylabel('Normalised MSE')\n", + "plt.title('Normalised MSE of phi parameters during training\\non artificially federated CIFAR-10\\nunder different target heterogeneity conditions\\nand using different numbers of mixture components')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1db253c-f175-41f8-852e-988acd0f1445", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29131b53-2dba-4a5f-a6e9-254314bf4cf8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afc1151f-476a-4669-b215-35ae64ff87cd", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e76b4c-675e-4a46-b600-36ece1a819e5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:v310] *", + "language": "python", + "name": "conda-env-v310-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/publications/mdm/mdm_paper/notebooks/cifar10_visualisations.ipynb b/publications/mdm/mdm_paper/notebooks/cifar10_visualisations.ipynb new file mode 100644 index 0000000..fb04e25 --- /dev/null +++ b/publications/mdm/mdm_paper/notebooks/cifar10_visualisations.ipynb @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "47be4a41-58e1-44e2-bd52-a366cf901fb9", + "metadata": {}, + "source": [ + "This notebook visualizes CIFAR10 users using various methods of generating simulation users." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "77d25f5d-0b22-48fe-999f-74e1f0bc6ddf", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from sklearn.manifold import TSNE\n", + "import joblib\n", + "\n", + "from ramsay.data.sampling import get_data_sampler\n", + "from polya_mixture.datasets.cifar10_dataset import load_and_preprocess\n", + "from polya_mixture.datasets.sampler import DirichletDataSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "25e2c77a-1a4b-469f-abaf-8ad552f71747", + "metadata": {}, + "outputs": [], + "source": [ + "# Functions to generate users\n", + "\n", + "def generate_mixture_users(num_users, alphas, phi, num_samples_dists, all_labels):\n", + " users = []\n", + " ks = []\n", + " for _ in range(num_users):\n", + " k = np.random.choice(range(len(alphas)), p=phi)\n", + " d = num_samples_dists[k]\n", + " n = np.random.choice(range(len(d)), p=d)\n", + "\n", + " sampler = DirichletDataSampler(alphas[k], all_labels)\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)\n", + "\n", + "def generate_single_dirichlet_users(num_users, alpha, user_len_sampler, all_labels):\n", + " sampler = DirichletDataSampler(alpha, all_labels)\n", + " users = []\n", + " for _ in range(num_users):\n", + " n = user_len_sampler()\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)\n", + "\n", + "def generate_uniform_users(num_users, user_len_sampler, all_labels):\n", + " sampler = get_data_sampler('random', len(all_labels))\n", + " users = []\n", + " for _ in range(num_users):\n", + " n = user_len_sampler()\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d838856d-85f1-4f14-b098-85b158b5d462", + "metadata": {}, + "outputs": [], + "source": [ + "# Load Cifar10 and get all labels\n", + "data_dir = 'data/cifar10'\n", + "_, all_labels, _, _ = load_and_preprocess(os.path.join(data_dir, 'cifar10_train.p'))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8034bc9f-c053-48e8-a91f-fe1f8a2b8d90", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['alphas', 'phi'])\n", + "alpha_mixture [[0.2790543 0.19542773 1.7441087 0.9729872 0.18570527 0.09938517\n", + " 0.09327706 0.10400821 3.144703 1.0748252 ]\n", + " [0.26441148 0.17979808 0.11472267 0.99709547 1.2816628 1.3844566\n", + " 0.98972297 0.4767252 1.1671178 0.0879145 ]]\n", + "phi_mixture [0.9021526 0.09784738]\n", + "num_samples_distribution_mixture (2, 500)\n", + "(2,) (2, 10) (2, 500)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Load single dirichlet saved params\n", + "mixture = True\n", + "\n", + "if not mixture:\n", + " # load learned parameters\n", + " params = joblib.load(os.path.join('mle_params', 'cifar10_1_mixture', 'polya-mixture.joblib'))\n", + " print(params.keys())\n", + " alpha_one_component = np.array(params['alphas'][0])\n", + " \n", + " num_samples_distribution_one_component = joblib.load(os.path.join('num_samples_distribution', 'cifar10_1_mixture.joblib'))\n", + " num_samples_distribution_one_component = num_samples_distribution_one_component.numpy().reshape(-1,)\n", + " \n", + " len_sampler = lambda: np.random.choice(range(len(num_samples_distribution_one_component)), p=num_samples_distribution_one_component)\n", + "\n", + " # true distribution\n", + " true_alpha = [0.3, 0.2, 0.1, 1, 1.2, 1.3, 0.8, 0.5, 1.2, 0.1]\n", + " true_len_sampler = lambda: 20\n", + "\n", + " # create distributions\n", + " num_users = 500\n", + " true_users = generate_single_dirichlet_users(num_users, true_alpha, true_len_sampler, all_labels)\n", + " simulated_single_dirichlet_users = generate_single_dirichlet_users(num_users, alpha_one_component, len_sampler, all_labels)\n", + " simulated_uniform_users = generate_uniform_users(num_users, len_sampler, all_labels)\n", + "\n", + " # Run TSNE on the label counts of all users\n", + " tsne2 = TSNE(n_components=2)\n", + " X = np.vstack([true_users, simulated_single_dirichlet_users, simulated_uniform_users])\n", + " X = X / X.sum(axis=1, keepdims=True) # This will normalize the counts to proportions to ignore num sample per user effects\n", + " X_2dim = tsne2.fit_transform(X)\n", + "\n", + " N = num_users\n", + " plt.figure(figsize=(8, 5))\n", + " plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + " plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + " plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single dirichlet simulation')\n", + " plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + " plt.legend(fontsize=10)\n", + "\n", + "\n", + "else:\n", + " # Load two component mixture dirichlet saved params\n", + " params = joblib.load(os.path.join('mle_params', 'cifar10_2_mixture', 'polya-mixture.joblib'))\n", + " print(params.keys())\n", + " alpha_mixture = np.array(params['alphas'])\n", + " print('alpha_mixture', alpha_mixture)\n", + " phi_mixture = np.array(params['phi'])\n", + " phi_mixture /= sum(phi_mixture)\n", + " print('phi_mixture', phi_mixture)\n", + " \n", + " num_samples_distribution_mixture = joblib.load(os.path.join('num_samples_distribution', 'cifar10_2_mixture.joblib'))\n", + " num_samples_distribution_mixture = num_samples_distribution_mixture.numpy()\n", + " print('num_samples_distribution_mixture', num_samples_distribution_mixture.shape)\n", + " print(phi_mixture.shape, alpha_mixture.shape, num_samples_distribution_mixture.shape)\n", + "\n", + " # true distribution\n", + " #true_alpha = [[1.8]*10, [0.2]*10]\n", + " #true_phi = [0.3, 0.7]\n", + " true_alpha = [[0.3, 0.2, 0.1, 1, 1.2, 1.3, 0.8, 0.5, 1.2, 0.1], [0.3, 0.2, 1.8, 1, 0.2, 0.1, 0.1, 0.1, 3.2, 1.1]]\n", + " true_phi = [0.1, 0.9]\n", + "\n", + " # generate users\n", + " num_users = 500\n", + " # TODO change num_samples_distribution_mixture so it's not just tied to learned parameters, but is the actual true distribution\n", + " true_users = generate_mixture_users(num_users, true_alpha, true_phi, num_samples_distribution_mixture, all_labels)\n", + " simulated_dirichlet_mixture_users = generate_mixture_users(num_users, alpha_mixture, phi_mixture, num_samples_distribution_mixture, all_labels)\n", + " len_sampler = lambda: 20\n", + " simulated_uniform_users = generate_uniform_users(num_users, len_sampler, all_labels)\n", + "\n", + " # Run TSNE on the label counts of all users\n", + " tsne2 = TSNE(n_components=2)\n", + " X = np.vstack([true_users, simulated_dirichlet_mixture_users, simulated_uniform_users]) \n", + " X = X / X.sum(axis=1, keepdims=True) # This will normalize the counts to proportions to ignore num sample per user effects\n", + " X_2dim = tsne2.fit_transform(X)\n", + " \n", + " # plot\n", + " N = num_users\n", + " plt.figure(figsize=(8, 5))\n", + " plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + " plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + " plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single mixture dirichlet simulation')\n", + " plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + " plt.legend(fontsize=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05dc6205-7096-412a-9188-f1086cadbda0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/publications/mdm/mdm_paper/notebooks/femnist_visualisations.ipynb b/publications/mdm/mdm_paper/notebooks/femnist_visualisations.ipynb new file mode 100644 index 0000000..5bd9046 --- /dev/null +++ b/publications/mdm/mdm_paper/notebooks/femnist_visualisations.ipynb @@ -0,0 +1,516 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "47be4a41-58e1-44e2-bd52-a366cf901fb9", + "metadata": {}, + "source": [ + "This notebook visualizes CIFAR10 users using various methods of generating simulation users." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "77d25f5d-0b22-48fe-999f-74e1f0bc6ddf", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from sklearn.manifold import TSNE\n", + "import joblib\n", + "\n", + "from ramsay.data.sampling import get_data_sampler\n", + "from polya_mixture.datasets.femnist_dataset import _load_h5_into_dict\n", + "from polya_mixture.datasets.sampler import DirichletDataSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "59530d17-5a4c-43f2-b2f8-b6e91ab4e2c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Helper functions\n", + "\n", + "def load_femnist_as_dict(data_dir):\n", + " h5_file_path = os.path.join(data_dir, 'fed_emnist_train.h5')\n", + " digits_only = False\n", + " numpy_to_tensor = lambda x: x\n", + " user_id_to_data = _load_h5_into_dict(h5_file_path, digits_only,\n", + " numpy_to_tensor)\n", + " return user_id_to_data\n", + "\n", + "def get_all_labels(user_id_to_data):\n", + " all_labels = []\n", + " for _, y in user_id_to_data.values():\n", + " all_labels.append(y)\n", + " return np.hstack(all_labels)\n", + "\n", + "\n", + "def get_femnist_len_sampler(user_id_to_data):\n", + " histo = np.zeros(500)\n", + " for _, y in user_id_to_data.values():\n", + " num_samples = len(y)\n", + " histo[num_samples] += 1\n", + "\n", + " p = histo / sum(histo) \n", + " len_sampler = lambda p=p: np.random.choice(range(len(p)), p=p)\n", + " return len_sampler" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "25e2c77a-1a4b-469f-abaf-8ad552f71747", + "metadata": {}, + "outputs": [], + "source": [ + "# Functions to generate users\n", + "\n", + "def generate_true_users(user_id_to_data):\n", + " users = []\n", + " for _, y in user_id_to_data.values():\n", + " labels, counts = np.unique(y, return_counts=True)\n", + " full_y = np.zeros(62)\n", + " full_y[labels] = counts\n", + " users.append(full_y)\n", + " return np.array(users)\n", + "\n", + "\n", + "def generate_mixture_users(num_users, alphas, phi, num_samples_dists, all_labels):\n", + " users = []\n", + " ks = []\n", + " for _ in range(num_users):\n", + " k = np.random.choice(range(len(alphas)), p=phi)\n", + " d = num_samples_dists[k]\n", + " n = np.random.choice(range(len(d)), p=d)\n", + "\n", + " sampler = DirichletDataSampler(alphas[k], all_labels)\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)\n", + "\n", + "def generate_single_dirichlet_users(num_users, alpha, user_len_sampler, all_labels):\n", + " sampler = DirichletDataSampler(alpha, all_labels)\n", + " users = []\n", + " for _ in range(num_users):\n", + " n = user_len_sampler()\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)\n", + "\n", + "def generate_uniform_users(num_users, user_len_sampler, all_labels):\n", + " sampler = get_data_sampler('random', len(all_labels))\n", + " users = []\n", + " for _ in range(num_users):\n", + " n = user_len_sampler()\n", + " idxs = sampler(n)\n", + " y = np.array(all_labels)[idxs]\n", + " full_y = np.zeros(62)\n", + " vals, count = np.unique(y, return_counts=True)\n", + " full_y[vals] = count\n", + " users.append(full_y)\n", + " return np.array(users)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d838856d-85f1-4f14-b098-85b158b5d462", + "metadata": {}, + "outputs": [], + "source": [ + "# Load femnist and get all labels\n", + "data_dir = 'data/femnist'\n", + "user_id_to_data = load_femnist_as_dict(data_dir)\n", + "all_labels = get_all_labels(user_id_to_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "37a3dee8-e3fd-40ce-9577-9f84a544c818", + "metadata": {}, + "outputs": [], + "source": [ + "num_users = len(user_id_to_data.keys())" + ] + }, + { + "cell_type": "markdown", + "id": "92e05e84-3154-4460-afb0-2cba1a665aa0", + "metadata": {}, + "source": [ + "# One mixture component" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8034bc9f-c053-48e8-a91f-fe1f8a2b8d90", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['alphas', 'phi'])\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Load single dirichlet saved params\n", + "\n", + "# load learned parameters\n", + "params = joblib.load(os.path.join('mle_params', 'femnist_learn_dirichlet_1_mixture', 'polya-mixture.joblib'))\n", + "print(params.keys())\n", + "alpha_one_component = np.array(params['alphas'][0])\n", + "\n", + "num_samples_distribution_one_component = joblib.load(os.path.join('num_samples_distribution', 'femnist_learn_dirichlet_1_mixture.joblib'))\n", + "num_samples_distribution_one_component = num_samples_distribution_one_component.numpy().reshape(-1,)\n", + "\n", + "len_sampler = lambda: np.random.choice(range(len(num_samples_distribution_one_component)), p=num_samples_distribution_one_component)\n", + "\n", + "# create distributions\n", + "true_users = generate_true_users(user_id_to_data)\n", + "simulated_single_dirichlet_users = generate_single_dirichlet_users(len(true_users), alpha_one_component, len_sampler, all_labels)\n", + "simulated_uniform_users = generate_uniform_users(len(true_users), len_sampler, all_labels)\n", + "\n", + "# Run TSNE on the label counts of all users\n", + "tsne2 = TSNE(n_components=2)\n", + "X = np.vstack([true_users, simulated_single_dirichlet_users])#, simulated_uniform_users])\n", + "X = X / X.sum(axis=1, keepdims=True) # This will normalize the counts to proportions to ignore num sample per user effects\n", + "X_2dim = tsne2.fit_transform(X)\n", + "\n", + "N = num_users\n", + "plt.figure(figsize=(8, 5))\n", + "plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + "plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + "plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single dirichlet simulation')\n", + "#plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + "plt.legend(fontsize=10)" + ] + }, + { + "cell_type": "markdown", + "id": "69f2e996-8985-45f7-adb0-1ad56ded1d8c", + "metadata": {}, + "source": [ + "# Two mixture components" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "eeae50e6-b4e4-409e-b8b1-f45d27f2ed70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['alphas', 'phi'])\n", + "alpha_mixture [[12.486713 13.534921 11.7629795 12.550145 11.480532 10.293059\n", + " 11.921352 12.587899 11.840125 11.941666 3.9614604 2.2822833\n", + " 7.3454185 2.6096861 2.1497195 6.5782847 1.219504 1.6149502\n", + " 8.219717 2.2317927 1.2118896 2.750498 6.845294 5.16876\n", + " 20.386126 6.246426 1.2428675 2.2957945 16.824781 6.5312443\n", + " 9.8076315 2.787484 3.030136 1.3144614 2.9685311 1.3412713\n", + " 6.926577 3.528872 1.3654627 7.718317 19.724745 1.1862212\n", + " 2.2440865 6.1563945 1.3286887 1.0327483 1.2686301 9.836655\n", + " 1.2492453 7.9961534 1.320585 1.1531479 1.781072 10.346889\n", + " 1.3283604 14.271842 1.3694632 1.4000323 1.3433177 1.3627586\n", + " 1.2021077 1.3795127]\n", + " [44.51527 50.686996 45.006035 45.453217 44.310375 41.812286\n", + " 45.320396 47.128685 44.48823 44.521595 3.9697187 3.7719092\n", + " 4.3950214 3.7158391 3.4485605 3.7217412 3.4760392 3.436453\n", + " 5.725511 3.7413826 3.4088106 4.3864307 3.7450933 4.0867777\n", + " 4.2136483 4.049439 3.738219 3.8356507 3.954183 3.9112475\n", + " 3.938368 4.162987 3.984091 3.9271889 3.6723745 4.0041394\n", + " 3.9384997 3.9038906 3.8839767 3.7883527 3.8080533 3.7992022\n", + " 3.401485 4.051268 3.5428376 2.285073 3.3482904 5.7865825\n", + " 3.7956715 3.8241901 3.8699336 3.475093 2.9511976 3.8318043\n", + " 3.9740224 3.9577084 3.8802054 4.14223 3.9508624 4.002048\n", + " 3.295877 3.841521 ]]\n", + "phi_mixture [0.39602438 0.60397565]\n", + "num_samples_distribution_mixture (2, 500)\n", + "(2,) (2, 62) (2, 500)\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Load two component mixture dirichlet saved params\n", + "params = joblib.load(os.path.join('mle_params', 'femnist_2_mixture', 'polya-mixture.joblib'))\n", + "print(params.keys())\n", + "alpha_mixture = np.array(params['alphas'])\n", + "print('alpha_mixture', alpha_mixture)\n", + "phi_mixture = np.array(params['phi'])\n", + "phi_mixture /= sum(phi_mixture)\n", + "print('phi_mixture', phi_mixture)\n", + "\n", + "num_samples_distribution_mixture = joblib.load(os.path.join('num_samples_distribution', 'femnist_2_mixture.joblib'))\n", + "num_samples_distribution_mixture = num_samples_distribution_mixture.numpy()\n", + "print('num_samples_distribution_mixture', num_samples_distribution_mixture.shape)\n", + "print(phi_mixture.shape, alpha_mixture.shape, num_samples_distribution_mixture.shape)\n", + "\n", + "# generate users\n", + "num_users = 500\n", + "true_users = generate_true_users(user_id_to_data)\n", + "simulated_dirichlet_mixture_users = generate_mixture_users(len(true_users), alpha_mixture, phi_mixture, num_samples_distribution_mixture, all_labels)\n", + "len_sampler = get_femnist_len_sampler(user_id_to_data)\n", + "simulated_uniform_users = generate_uniform_users(len(true_users), len_sampler, all_labels)\n", + "\n", + "# Run TSNE on the label counts of all users\n", + "tsne2 = TSNE(n_components=2)\n", + "X = np.vstack([true_users, simulated_dirichlet_mixture_users, simulated_uniform_users]) \n", + "X = X / X.sum(axis=1, keepdims=True) # This will normalize the counts to proportions to ignore num sample per user effects\n", + "X_2dim = tsne2.fit_transform(X)\n", + "\n", + "# plot\n", + "N = num_users\n", + "plt.figure(figsize=(8, 5))\n", + "plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + "plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + "plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single mixture dirichlet simulation')\n", + "plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + "plt.legend(fontsize=10)" + ] + }, + { + "cell_type": "markdown", + "id": "54127757-97f1-4619-912a-ddc91f085e32", + "metadata": {}, + "source": [ + "# Three mixture components" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "05dc6205-7096-412a-9188-f1086cadbda0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['alphas', 'phi'])\n", + "alpha_mixture [[ 7.331553 7.954057 6.9863634 7.491554 6.840809 6.176658\n", + " 7.037211 7.3866696 7.0418634 7.0186877 1.7180551 1.3936192\n", + " 4.336823 1.3972884 1.021399 3.9414258 0.8377679 0.8865488\n", + " 4.666944 1.3927824 0.82493305 1.5903513 4.075701 2.323618\n", + " 11.663883 3.662332 0.8683789 0.9965437 9.724308 3.1967835\n", + " 5.7382493 1.7378958 1.9411799 0.8642131 1.9035438 0.8644596\n", + " 4.9368877 2.3315918 0.90976584 5.0233307 12.5443945 0.7650201\n", + " 1.4112251 4.09973 0.91791886 0.70593554 0.8813811 5.6781783\n", + " 0.82670665 5.477468 0.91935456 0.7856333 1.1689324 7.026988\n", + " 0.83059084 9.19348 0.8955782 0.8933723 0.8802291 0.93215483\n", + " 0.79196453 0.9271174 ]\n", + " [18.83051 21.391619 19.192364 19.395418 18.862967 18.080807\n", + " 19.338434 20.053724 18.768417 19.05818 1.8996079 1.8278965\n", + " 2.0523577 1.8387978 1.6229041 1.7828559 1.7393044 1.6851878\n", + " 2.5363393 1.7556136 1.6119165 1.9750109 1.85119 1.9215419\n", + " 1.9163609 1.8660287 1.758909 1.8497199 1.8336132 1.8163389\n", + " 1.8787143 1.8776171 1.9155294 1.9301225 1.710874 1.9480428\n", + " 1.8938534 1.7851719 1.9563513 1.8582665 1.8689779 1.7700071\n", + " 1.6418803 1.9925227 1.6899055 1.1193209 1.7580124 2.5709739\n", + " 1.841688 1.8507614 1.861241 1.6949676 1.3854067 1.8211398\n", + " 1.7923293 1.8661227 1.8890306 1.8883762 1.8973581 1.8782278\n", + " 1.5289938 1.8693763 ]\n", + " [ 7.890613 8.639257 7.737538 8.249985 7.3764997 6.5286493\n", + " 7.720129 8.291615 7.935243 7.8369765 6.848669 2.595738\n", + " 4.581804 3.8827293 4.726481 4.3145986 0.87768173 2.2444005\n", + " 4.7095323 1.4775896 0.93684626 3.5841136 4.011161 7.8069186\n", + " 12.993805 4.214646 0.95914626 3.931263 9.673255 8.524542\n", + " 6.182734 1.9064779 1.9609798 1.0343878 2.0815775 1.0086254\n", + " 1.1224002 1.3733677 1.0175576 2.079446 2.4763615 1.0114193\n", + " 1.0507283 1.8728429 0.81286395 0.9262054 1.1011454 3.7723496\n", + " 0.9499362 1.327955 0.9742082 0.79845166 1.0648375 1.2909755\n", + " 1.0237517 2.5326338 1.0434972 1.0078688 0.9999532 1.0237565\n", + " 0.9694846 1.0330129 ]]\n", + "phi_mixture [0.32083896 0.60536647 0.07379463]\n", + "num_samples_distribution_mixture (3, 500)\n", + "(3,) (3, 62) (3, 500)\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Load three component mixture dirichlet saved params\n", + "num_mixture_components = 3\n", + "\n", + "params = joblib.load(os.path.join('mle_params', f'femnist_{num_mixture_components}_mixture', 'polya-mixture.joblib'))\n", + "print(params.keys())\n", + "alpha_mixture = np.array(params['alphas'])\n", + "print('alpha_mixture', alpha_mixture)\n", + "phi_mixture = np.array(params['phi'])\n", + "phi_mixture /= sum(phi_mixture)\n", + "print('phi_mixture', phi_mixture)\n", + "\n", + "num_samples_distribution_mixture = joblib.load(os.path.join('num_samples_distribution', f'femnist_{num_mixture_components}_mixture.joblib'))\n", + "num_samples_distribution_mixture = num_samples_distribution_mixture.numpy()\n", + "print('num_samples_distribution_mixture', num_samples_distribution_mixture.shape)\n", + "print(phi_mixture.shape, alpha_mixture.shape, num_samples_distribution_mixture.shape)\n", + "\n", + "# generate users\n", + "num_users = 500\n", + "true_users = generate_true_users(user_id_to_data)\n", + "simulated_dirichlet_mixture_users = generate_mixture_users(len(true_users), alpha_mixture, phi_mixture, num_samples_distribution_mixture, all_labels)\n", + "len_sampler = get_femnist_len_sampler(user_id_to_data)\n", + "simulated_uniform_users = generate_uniform_users(len(true_users), len_sampler, all_labels)\n", + "\n", + "# Run TSNE on the label counts of all users\n", + "tsne2 = TSNE(n_components=2)\n", + "X = np.vstack([true_users, simulated_dirichlet_mixture_users, simulated_uniform_users]) \n", + "X = X / X.sum(axis=1, keepdims=True) # This will normalize the counts to proportions to ignore num sample per user effects\n", + "X_2dim = tsne2.fit_transform(X)\n", + "\n", + "# plot\n", + "N = num_users\n", + "plt.figure(figsize=(8, 5))\n", + "plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + "plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + "plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single mixture dirichlet simulation')\n", + "plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + "plt.legend(fontsize=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "80b5780a-71b5-4e73-8f85-fa7445700f5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "N = num_users\n", + "plt.figure(figsize=(8, 5))\n", + "plt.title('TSNE visualisation of user label distributions in CIFAR10 dataset')\n", + "plt.scatter(X_2dim[:N, 0], X_2dim[:N, 1], s=1, color='red', label='true users')\n", + "plt.scatter(X_2dim[N:2*N, 0], X_2dim[N:2*N, 1], s=5, color='green', label='single mixture dirichlet simulation')\n", + "#plt.scatter(X_2dim[2*N:3*N, 0], X_2dim[2*N:3*N, 1], s=1, color='blue', label='Simulated users: partitioned uniformly randomly')\n", + "plt.legend(fontsize=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863cdccf-3222-4139-b4fc-9b60b2925b62", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/publications/mdm/mdm_paper/training/__init__.py b/publications/mdm/mdm_paper/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm_paper/training/mle.py b/publications/mdm/mdm_paper/training/mle.py new file mode 100644 index 0000000..3376740 --- /dev/null +++ b/publications/mdm/mdm_paper/training/mle.py @@ -0,0 +1,102 @@ +import numpy as np + +from pfl.aggregate.simulate import SimulatedBackend +from pfl.callback import ModelCheckpointingCallback +from pfl.privacy import (CentrallyAppliedPrivacyMechanism, PLDPrivacyAccountant, GaussianMechanism) + +from publications.mdm.mdm_utils.utils.tools import ModelCheckpointingIterationCallback +from publications.mdm.mdm import (MDMModel, MDMModelHyperParams, + MDMAlgorithm, MDMAlgorithmParams, + MDMInitializationAlgorithm, + MDMInitializationAlgorithmParams) + + +def solve_polya_mixture_mle( + arguments, + training_federated_dataset, + val_federated_dataset, + num_components, + num_categories, + save_path, + save_path_histogram, + add_DP=False, + extract_labels_fn=lambda user_dataset: user_dataset.raw_data[1]): + """ + Solve polya-mixture MLE + """ + + # model + model = MDMModel() + model_params = MDMModelHyperParams(num_components, num_categories) + print(f'before init algo - model alphas: {model.alphas}, phi: {model.phi}') + + if add_DP: + num_iterations = arguments.central_num_iterations_init_algorithm + arguments.central_num_iterations_algorithm + + accountant = PLDPrivacyAccountant( + num_compositions=num_iterations, + sampling_probability=0.001, + mechanism='gaussian', + epsilon=2, + delta=1e-7, + noise_scale=1.0) + mechanism = GaussianMechanism.from_privacy_accountant( + accountant=accountant, clipping_bound=0.5) + + postprocessors = [CentrallyAppliedPrivacyMechanism(mechanism)] + else: + postprocessors = [] + backend = SimulatedBackend(training_data=training_federated_dataset, + val_data=val_federated_dataset, + postprocessors=postprocessors) + + bin_edges = np.linspace( + 0, arguments.max_num_samples_mixture_component_init_algorithm, 11)[1:] + num_samples_mixture_bins = np.vstack([bin_edges] * 6) + print('\nnum_samples_mixture_bins', num_samples_mixture_bins) + + # init algorithm + init_algorithm = MDMInitializationAlgorithm() + init_algorithm_params = MDMInitializationAlgorithmParams( + cohort_size=arguments.cohort_size_init_algorithm, + num_samples_mixture_bins=num_samples_mixture_bins, + strategy=arguments.strategy, + central_num_iterations=arguments.central_num_iterations_init_algorithm, + extract_categories_fn=extract_labels_fn) + init_algorithm.run( + algorithm_params=init_algorithm_params, + backend=backend, + model=model, + model_train_params=model_params, + model_eval_params=None, + callbacks=[ModelCheckpointingCallback(model_checkpoint_dir=save_path)]) + print(f'after init algo - model alphas: {model.alphas}, phi: {model.phi}') + + # TODO do I need to reset phi and alpha in model? + # Require model phi = (1 / num_mixture_components) * + # np.ones(num_mixture_components) + + # algorithm + algorithm = MDMAlgorithm() + algorithm_params = MDMAlgorithmParams( + cohort_size=arguments.cohort_size_algorithm, + num_samples_mixture_bins=num_samples_mixture_bins, + central_num_iterations=arguments.central_num_iterations_algorithm, + extract_categories_fn=extract_labels_fn) + algorithm.run( + algorithm_params=algorithm_params, + backend=backend, + model=model, + model_train_params=model_params, + model_eval_params=None, + callbacks=[ + ModelCheckpointingCallback(model_checkpoint_dir=save_path), + ModelCheckpointingIterationCallback( + model_checkpoint_dir=save_path + '_iteration_models', + checkpoint_frequency=1) + ]) + print( + f'after algo - model alphas: {model.alphas}, phi: {model.phi}, num_samples_distribution: {model.num_samples_distribution}' + ) + + return model.phi, model.alphas, model.num_samples_distribution diff --git a/publications/mdm/mdm_paper/training/train.py b/publications/mdm/mdm_paper/training/train.py new file mode 100644 index 0000000..9431ccd --- /dev/null +++ b/publications/mdm/mdm_paper/training/train.py @@ -0,0 +1,149 @@ +import os +import argparse +import datetime + +import numpy as np +import torch +import joblib + +from pfl.internal.ops import pytorch_ops +from pfl.internal.ops.selector import get_default_framework_module as get_ops +from pfl.internal.ops.selector import set_framework_module +from pfl.internal.platform.selector import get_platform + +from publications.mdm.mdm_utils.datasets import make_cifar10_datasets +from publications.mdm.mdm_utils.utils import (add_dataset_args, add_experiment_args, + add_mle_args, add_init_algorithm_args, + add_algorithm_args, + add_histogram_algorithm_args, + add_user_visualisation_args) + +from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle + + +def get_arguments(): + parser = argparse.ArgumentParser() + add_experiment_args(parser) + add_dataset_args(parser) + add_mle_args(parser) + add_init_algorithm_args(parser) + add_algorithm_args(parser) + add_histogram_algorithm_args(parser) + add_user_visualisation_args(parser) + return parser.parse_args() + + +set_framework_module(pytorch_ops) +arguments = get_arguments() +np.random.seed(arguments.seed) +torch.random.manual_seed(arguments.seed) + +# Solve MLE using only CPU +os.environ['RAMSAY_PYTORCH_DEVICE'] = 'cpu' + +# Create data of live users +num_classes = 10 +input_shape = (32, 32, 3) + +# check arguments for mixture components +print('arguments.num_mixture_components', arguments.num_mixture_components, + type(arguments.num_mixture_components)) +print('arguments.component_mean_user_dataset_length', + arguments.component_mean_user_dataset_length) +print('arguments.component_phi', arguments.component_phi, + type(arguments.component_phi)) + +assert arguments.num_mixture_components == len( + arguments.component_mean_user_dataset_length) == len( + arguments.component_phi) +if len(arguments.component_alphas) == arguments.num_mixture_components: + # one alpha for all classes for each mixture component + alphas = np.array(arguments.component_alphas).reshape( + -1, 1) * np.ones(num_classes) + print('alphas', alphas) +else: + # must have length num mixture components * num_classes + print('len(arguments.component_alphas)', len(arguments.component_alphas)) + print('arguments.num_mixture_components * num_classes', + arguments.num_mixture_components * num_classes) + assert len(arguments.component_alphas + ) == arguments.num_mixture_components * num_classes + # assumes alphas are ordered each class, each mixture component + print('arguments.component_alphas', arguments.component_alphas) + alphas = np.array(arguments.component_alphas).reshape( + arguments.num_mixture_components, num_classes) + print('alphas', alphas.shape, alphas) + +print('arguments.component_phi', arguments.component_phi) +phi = np.array(arguments.component_phi) +# default values for samplers are tuple +# arguments.component_mean_user_dataset_length (50, 30) +# but these default values can be overwritten at run time of lambda fn. +samplers = [ + lambda x=x: x for x in arguments.component_mean_user_dataset_length +] +print('samplers', [s() for s in samplers]) +print('true phi', phi) +print('true alphas', alphas) + +# option to create artificial_federated_dataset or federated_dataset +live_training_data, live_val_data, central_val_data = make_cifar10_datasets( + dataset_type='artificial_federated_dataset', + data_dir=arguments.data_dir, + user_dataset_len_samplers=samplers, + numpy_to_tensor=get_ops().to_tensor, + phi=phi, + alphas=alphas) + +# TODO support modelling federated dataset using mixture-polya +# or uniform distribution + +# If running simulations then compute phi, alphas and num_samples_histos +# either by solving the polya MLE or just computing the histogram for +# uniform simulations +print('simulated_dirichlet_mixture experiment') +if arguments.precomputed_parameter_filepath is None: + print('learn simulated_dirichlet_mixture parameters') + dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + current_time = datetime.datetime.now() + timestamp = current_time.strftime("%Y-%m-%d_%H-%M") + save_dir = ( + #f'cifar10_{arguments.num_mixture_components}_mixture_{timestamp}') + f'cifar10_{arguments.num_mixture_components}_mixture_{arguments.dirname}' + ) + save_path = os.path.join(dir_path, save_dir) + + dir_path_histogram = get_platform().create_checkpoint_directories( + ['num_samples_distribution'])[0] + save_path_histogram = os.path.join(dir_path_histogram, + save_dir + '.joblib') + + # Solve polya-mixture MLE + # TODO use mle arguments for cohort size, num iterations, etc. + # TODO what to do about num_samples_histos? + phi, alphas, num_samples_distributions = solve_polya_mixture_mle( + arguments=arguments, + training_federated_dataset=live_training_data, + val_federated_dataset=None, + num_components=arguments.num_mixture_components, + num_categories=num_classes, + save_path=save_path, + save_path_histogram=save_path_histogram) + + phi = phi.numpy() if isinstance(phi, torch.Tensor) else phi + alphas = alphas.numpy() if isinstance(alphas, torch.Tensor) else alphas + num_samples_distributions = num_samples_distributions.numpy() + print('phi', phi) + print('alphas', alphas) + print('num_samples_distributions', num_samples_distributions) + +else: + params = joblib.load(arguments.precomputed_parameter_filepath) + + phi = np.array(params['phi']) + phi /= phi.sum() + alphas = np.array(params['alphas']) + num_samples_distributions = np.array(params['num_samples_distributions']) + num_samples_distributions /= np.sum(num_samples_distributions, + axis=1, + keepdims=True) diff --git a/publications/mdm/mdm_paper/training/train_femnist.py b/publications/mdm/mdm_paper/training/train_femnist.py new file mode 100644 index 0000000..dd6c90e --- /dev/null +++ b/publications/mdm/mdm_paper/training/train_femnist.py @@ -0,0 +1,101 @@ +import os +import argparse + +import joblib +import numpy as np +import torch + +from pfl.internal.ops import pytorch_ops +from pfl.internal.ops.selector import get_default_framework_module as get_ops +from pfl.internal.ops.selector import set_framework_module +from pfl.internal.platform.selector import get_platform + +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, + add_init_algorithm_args, add_algorithm_args, + add_histogram_algorithm_args, + add_user_visualisation_args) + +from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle + + +def get_arguments(): + parser = argparse.ArgumentParser() + add_experiment_args(parser) + add_mle_args(parser) + add_init_algorithm_args(parser) + add_algorithm_args(parser) + add_histogram_algorithm_args(parser) + add_user_visualisation_args(parser) + return parser.parse_args() + + +set_framework_module(pytorch_ops) +arguments = get_arguments() +np.random.seed(arguments.seed) +torch.random.manual_seed(arguments.seed) + +# Solve MLE using only CPU +os.environ['RAMSAY_PYTORCH_DEVICE'] = 'cpu' + +# Create data of live users +num_classes = 62 +input_shape = (28, 28, 1) + +live_training_data, live_val_data, central_val_data = make_femnist_datasets( + arguments.data_dir, + digits_only=False, + numpy_to_tensor=get_ops().to_tensor, + dataset_type=arguments.dataset_type) + +add_DP = True + +# If running simulations then compute phi, alphas and num_samples_histos +# either by solving the polya MLE or just computing the histogram for +# uniform simulations +print('simulated_dirichlet_mixture experiment') +if arguments.precomputed_parameter_filepath is None: + print('learn simulated_dirichlet_mixture parameters') + dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + save_dir = ( + f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture' + ) + if add_DP: + save_dir += '_DP' + save_path = os.path.join(dir_path, save_dir) + + dir_path_histogram = get_platform().create_checkpoint_directories( + ['num_samples_distribution'])[0] + save_path_histogram = os.path.join(dir_path_histogram, + save_dir + '.joblib') + + # Solve polya-mixture MLE + # TODO use mle arguments for cohort size, num iterations, etc. + # TODO what to do about num_samples_histos? + phi, alphas, num_samples_distributions = solve_polya_mixture_mle( + arguments=arguments, + training_federated_dataset=live_training_data, + val_federated_dataset=None, + num_components=arguments.num_mixture_components, + num_categories=num_classes, + save_path=save_path, + save_path_histogram=save_path_histogram, + add_DP=add_DP) + + phi = phi.numpy() if isinstance(phi, torch.Tensor) else phi + alphas = alphas.numpy() if isinstance(alphas, torch.Tensor) else alphas + num_samples_distributions = num_samples_distributions.numpy() + print('phi', phi) + print('alphas', alphas) + print('num_samples_distributions', num_samples_distributions) + +else: + params = joblib.load(arguments.precomputed_parameter_filepath) + + phi = np.array(params['phi']) + phi /= phi.sum() + alphas = np.array(params['alphas']) + num_samples_distributions = np.array(params['num_samples_distributions']) + num_samples_distributions /= np.sum(num_samples_distributions, + axis=1, + keepdims=True) diff --git a/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py new file mode 100644 index 0000000..096c7ec --- /dev/null +++ b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py @@ -0,0 +1,117 @@ +import os +import argparse + +import joblib +import numpy as np +import torch + +from pfl.internal.ops import pytorch_ops +from pfl.internal.ops.selector import get_default_framework_module as get_ops +from pfl.internal.ops.selector import set_framework_module +from pfl.internal.platform.selector import get_platform + +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, + add_init_algorithm_args, add_algorithm_args, + add_histogram_algorithm_args, + add_user_visualisation_args, + add_dataset_preprocessing_args) + +from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle + + +def get_arguments(): + parser = argparse.ArgumentParser() + add_experiment_args(parser) + add_mle_args(parser) + add_init_algorithm_args(parser) + add_algorithm_args(parser) + add_histogram_algorithm_args(parser) + add_user_visualisation_args(parser) + add_dataset_preprocessing_args(parser) + return parser.parse_args() + + +set_framework_module(pytorch_ops) +arguments = get_arguments() +np.random.seed(arguments.seed) +torch.random.manual_seed(arguments.seed) + +# Solve MLE using only CPU +os.environ['RAMSAY_PYTORCH_DEVICE'] = 'cpu' + +# Create data of live users +num_classes = 62 +input_shape = (28, 28, 1) + +live_training_data, live_val_data, central_val_data = make_femnist_datasets( + arguments.data_dir, + digits_only=False, + numpy_to_tensor=get_ops().to_tensor, + dataset_type=arguments.dataset_type, + filter_method=arguments.filter_method, + sample_fraction=arguments.sample_fraction, + start_idx=arguments.start_idx, + end_idx=arguments.end_idx, + include_sampled=arguments.include_sampled) + +add_DP = False + +# If running simulations then compute phi, alphas and num_samples_histos +# either by solving the polya MLE or just computing the histogram for +# uniform simulations +print('simulated_dirichlet_mixture experiment') +if arguments.precomputed_parameter_filepath is None: + print('learn simulated_dirichlet_mixture parameters') + dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + save_dir = ( + f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture_{arguments.filter_method}_filter_method' + ) + if arguments.filter_method is not None: + if arguments.filter_method == 'index': + save_dir += f'_{arguments.start_idx}_start_idx_{arguments.end_idx}_end_idx' + elif arguments.filter_method == 'sample': + save_dir += f'_{arguments.sample_fraction}_sample_fraction_{arguments.include_sampled}_include_sampled' + else: + raise ValueError( + f'Invalid value for filter_method: {arguments.filter_method}') + + if add_DP: + save_dir += '_DP' + save_path = os.path.join(dir_path, save_dir) + + dir_path_histogram = get_platform().create_checkpoint_directories( + ['num_samples_distribution'])[0] + save_path_histogram = os.path.join(dir_path_histogram, + save_dir + '.joblib') + + # Solve polya-mixture MLE + # TODO use mle arguments for cohort size, num iterations, etc. + # TODO what to do about num_samples_histos? + phi, alphas, num_samples_distributions = solve_polya_mixture_mle( + arguments=arguments, + training_federated_dataset=live_training_data, + val_federated_dataset=None, + num_components=arguments.num_mixture_components, + num_categories=num_classes, + save_path=save_path, + save_path_histogram=save_path_histogram, + add_DP=add_DP) + + phi = phi.numpy() if isinstance(phi, torch.Tensor) else phi + alphas = alphas.numpy() if isinstance(alphas, torch.Tensor) else alphas + num_samples_distributions = num_samples_distributions.numpy() + print('phi', phi) + print('alphas', alphas) + print('num_samples_distributions', num_samples_distributions) + +else: + params = joblib.load(arguments.precomputed_parameter_filepath) + + phi = np.array(params['phi']) + phi /= phi.sum() + alphas = np.array(params['alphas']) + num_samples_distributions = np.array(params['num_samples_distributions']) + num_samples_distributions /= np.sum(num_samples_distributions, + axis=1, + keepdims=True) diff --git a/publications/mdm/mdm_utils/__init__.py b/publications/mdm/mdm_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm_utils/datasets/__init__.py b/publications/mdm/mdm_utils/datasets/__init__.py new file mode 100644 index 0000000..e189449 --- /dev/null +++ b/publications/mdm/mdm_utils/datasets/__init__.py @@ -0,0 +1,3 @@ +from .mixture_dataset import get_user_counts +from .cifar10_dataset import make_cifar10_datasets +from .femnist_dataset import make_femnist_datasets diff --git a/publications/mdm/mdm_utils/datasets/cifar10_dataset.py b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py new file mode 100644 index 0000000..3d06410 --- /dev/null +++ b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +import os +import pickle +from typing import Callable, List, Optional, Tuple + +import numpy as np + +from pfl.data import (ArtificialFederatedDataset, FederatedDataset, + FederatedDatasetBase) +from pfl.data.sampling import get_user_sampler, get_data_sampler +from pfl.data.dataset import Dataset + +from .mixture_dataset import (ArtificialFederatedDatasetMixture, + partition_by_dirichlet_mixture_class_distribution + ) +from .sampler import DirichletDataSampler + + +def load_and_preprocess(pickle_file_path: str, + channel_means: Optional[np.ndarray] = None, + channel_stddevs: Optional[np.ndarray] = None, + exclude_classes=None): + images, labels = pickle.load(open(pickle_file_path, 'rb')) + images = images.astype(np.float32) + + # Normalize per-channel. + if channel_means is None: + channel_means = images.mean(axis=(0, 1, 2), dtype='float64') + if channel_stddevs is None: + channel_stddevs = images.std(axis=(0, 1, 2), dtype='float64') + images = (images - channel_means) / channel_stddevs + + if exclude_classes is not None: + for exclude_class in exclude_classes: + mask = (labels != exclude_class).reshape(-1) + labels = labels[mask] + images = images[mask] + + return images, labels, channel_means, channel_stddevs + + +def make_artificial_federated_dataset( + images: np.ndarray, + labels: np.ndarray, + user_dataset_len_samplers: List[Callable], + phi: np.ndarray = None, + alphas: np.ndarray = None, + numpy_to_tensor: Callable = lambda x: x) -> FederatedDatasetBase: + """ + Create an artificial federated dataset from the CIFAR10 dataset + """ + data = [numpy_to_tensor(images), numpy_to_tensor(labels)] + if alphas is not None: + data_samplers = [ + DirichletDataSampler(alpha, labels) for alpha in alphas + ] + return ArtificialFederatedDatasetMixture.from_slices( + phi, data, data_samplers, user_dataset_len_samplers) + else: + data_sampler = get_data_sampler('random', len(labels)) + return ArtificialFederatedDataset.from_slices( + data, data_sampler, user_dataset_len_samplers[0]) + + +def make_federated_dataset( + images: np.ndarray, + labels: np.ndarray, + user_dataset_len_samplers: List[Callable], + phi: np.ndarray = None, + alphas: np.ndarray = None, + numpy_to_tensor: Callable = lambda x: x) -> FederatedDataset: + """ + Create a federated dataset from the CIFAR10 dataset. + """ + + if alphas is not None: + user_idxs = partition_by_dirichlet_mixture_class_distribution( + labels, phi, alphas, user_dataset_len_samplers) + else: + all_idxs = np.arange(len(labels)).astype(int) + np.random.shuffle(all_idxs) + user_idxs = [] + while True: + n = user_dataset_len_samplers[0]() + if len(all_idxs) >= n: + user_idxs.append(all_idxs[:n]) + all_idxs = all_idxs[n:] + else: + user_idxs.append(all_idxs) + break + + user_sampler = get_user_sampler('random', list(range(len(user_idxs)))) + images = numpy_to_tensor(images) + labels = numpy_to_tensor(labels) + + data = dict() + for user_id in range(len(user_idxs)): + data[user_id] = [ + images[user_idxs[user_id]], labels[user_idxs[user_id]] + ] + + return FederatedDataset.from_slices(data, user_sampler) + + +def make_central_dataset(images: np.ndarray, labels: np.ndarray) -> Dataset: + """ + Create central dataset (represented as a ``Dataset``) from CIFAR10. + This ``Dataset`` can be used for central evaluation with + ``CentralEvaluationCallback``. + """ + return Dataset(raw_data=[images, labels]) + + +def make_cifar10_datasets( + dataset_type: str, + data_dir: str, + user_dataset_len_samplers: List[Callable], + numpy_to_tensor: Callable, + phi: np.ndarray = None, + alphas: np.ndarray = None +) -> Tuple[FederatedDataset, FederatedDataset, Dataset]: + """ + Create a train and val ``ArtificialFederatedDataset`` as well as a + central dataset from the CIFAR10 dataset. + + The data files can be found at ``s3://pfl/data/cifar10/``. + """ + train_images, train_labels, channel_means, channel_stddevs = ( + load_and_preprocess(os.path.join(data_dir, 'cifar10_train.p'))) + + val_images, val_labels, _, _ = load_and_preprocess( + os.path.join(data_dir, 'cifar10_test.p'), channel_means, + channel_stddevs) + + # supports artificial federated dataset and federated dataset + fed_dataset_fn = make_artificial_federated_dataset \ + if dataset_type == 'artificial_federated_dataset' \ + else make_federated_dataset + + # create federated training and val datasets + # from central training and val data. + training_federated_dataset = fed_dataset_fn( + train_images, + train_labels, + user_dataset_len_samplers, + phi=phi, + alphas=alphas, + numpy_to_tensor=numpy_to_tensor) + val_federated_dataset = fed_dataset_fn(val_images, + val_labels, + user_dataset_len_samplers, + phi=phi, + alphas=alphas, + numpy_to_tensor=numpy_to_tensor) + central_val_data = make_central_dataset(val_images, val_labels) + + return training_federated_dataset, val_federated_dataset, central_val_data diff --git a/publications/mdm/mdm_utils/datasets/femnist_dataset.py b/publications/mdm/mdm_utils/datasets/femnist_dataset.py new file mode 100644 index 0000000..c921e4f --- /dev/null +++ b/publications/mdm/mdm_utils/datasets/femnist_dataset.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Callable, Dict, Tuple, List, Optional + +import h5py +import numpy as np +import torch + +from pfl.data import ArtificialFederatedDataset, FederatedDataset +from pfl.data.sampling import get_user_sampler, get_data_sampler +from pfl.data.dataset import Dataset + +from .mixture_dataset import (ArtificialFederatedDatasetMixture, + partition_by_dirichlet_mixture_class_distribution + ) +from .sampler import DirichletDataSampler + + +def _sample_users(user_id_to_data: Dict[str, List[np.ndarray]], + filter_method: Optional[str] = None, + sample_fraction: float = None, + start_idx: int = None, + end_idx: int = None, + include_sampled: bool = True): + + user_ids = list(user_id_to_data.keys()) + + if filter_method is None: + print('\nKEEP ALL USERS') + # no change. Use all users + return user_id_to_data + + elif filter_method == 'index': + assert start_idx is not None + assert end_idx is not None + print(f'\nKEEP ALL USERS WITH IDS IN RANGE {start_idx}-{end_idx}') + + selected_user_ids = user_ids[start_idx:end_idx] + + elif filter_method == 'sample': + assert sample_fraction >= 0 and sample_fraction <= 1 + + sample_number = int(sample_fraction * len(user_ids)) + + original_state = np.random.get_state() + np.random.seed(0) + sampled_ids = np.random.choice(len(user_ids), + sample_number, + replace=False) + np.random.set_state(original_state) + + if not include_sampled: + sampled_ids = np.setdiff1d(np.arange(len(user_ids)), sampled_ids) + print( + f'\nKEEP {len(sampled_ids)} SAMPLED USERS WITH IDS: {sampled_ids}') + selected_user_ids = [user_ids[i] for i in sampled_ids] + + else: + raise ValueError(f'filter_method {filter_method} is not valid') + + return {user_id: user_id_to_data[user_id] for user_id in selected_user_ids} + + +def _load_h5_into_dict( + h5_file_path: str, + digits_only: bool, + numpy_to_tensor: Callable = lambda x: x +) -> Dict[str, List[np.ndarray]]: + """ + Load data into memory and create a mapping from user ids to that + user's data. + + :returns: + A dictionary mapping user ids to data. The data is a tuple + `(pixels,labels)`, where `pixels` is a `BxWxHx1` `np.ndarray` + (stacked images of a user) and `labels` is a `Bx1` vector of + categorical labels. + """ + user_id_to_data = {} + with h5py.File(h5_file_path, "r") as f: + for user, h5_group in f['examples'].items(): + images = np.expand_dims(h5_group['pixels'][()], axis=-1) + labels = h5_group['label'][()] + if digits_only: + images = images[labels < 10] + labels = labels[labels < 10] + user_id_to_data[user] = [ + numpy_to_tensor(images), + numpy_to_tensor(labels) + ] + + return user_id_to_data + + +def make_federated_dataset(user_id_to_data: Dict[str, List[np.ndarray]], + use_existing_partition: bool = True, + phi=None, + alphas=None, + user_dataset_len_samplers=None) -> FederatedDataset: + """ + Create federated dataset from a FEMNIST data file. + """ + + user_ids = list(user_id_to_data.keys()) + + if not use_existing_partition: + images = torch.cat([data[0] for data in user_id_to_data.values()]) + labels = torch.cat([data[1] for data in user_id_to_data.values()]) + + if alphas is not None: + user_idxs = partition_by_dirichlet_mixture_class_distribution( + labels.cpu().numpy(), phi, alphas, user_dataset_len_samplers) + else: + all_idxs = np.arange(len(labels)).astype(int) + np.random.shuffle(all_idxs) + user_idxs = [] + while True: + n = user_dataset_len_samplers[0]() + if len(all_idxs) >= n: + user_idxs.append(all_idxs[:n]) + all_idxs = all_idxs[n:] + else: + user_idxs.append(all_idxs) + break + + user_ids = [str(i) for i in range(len(user_idxs))] + user_id_to_data = dict( + zip(user_ids, [(images[idx], labels[idx]) for idx in user_idxs])) + + sampler = get_user_sampler('random', user_ids) + federated_dataset = FederatedDataset.from_slices(user_id_to_data, sampler) + + return federated_dataset + + +def make_special_federated_dataset( + user_id_to_data: Dict[str, List[np.ndarray]]) -> FederatedDataset: + """ + Create federated dataset from a FEMNIST data file. + Keep same label distribution per user, but mix up the datapoints. + """ + + user_ids = list(user_id_to_data.keys()) + images = torch.cat([data[0] for data in user_id_to_data.values()]) + + labels = torch.cat([data[1] + for data in user_id_to_data.values()]).cpu().numpy() + unique_labels = np.unique(labels) + indices_per_class = { + i: np.random.permutation(np.nonzero(labels == i)[0]) + for i in unique_labels + } + #from collections import Counter + #print('labels count',Counter(labels)) + #for k,v in indices_per_class.items(): + # print('indices_per_class', k, len(v)) + + new_user_id_to_data = dict() + start_id_per_class = {i: 0 for i in unique_labels} + #print('start_id_per_class', start_id_per_class) + for user_id, data in user_id_to_data.items(): + #print('user_id', user_id) + user_labels = data[1].cpu().numpy() + #print('labels', labels) + + # sample images based off labels. + sampled_data_idx = [] + for label in user_labels: + #print('label', label) + #print('start_id_per_class[label]', start_id_per_class[label]) + #print(' indices_per_class[label]', type(indices_per_class[label]), len(indices_per_class[label])) + sampled_data_idx.append( + indices_per_class[label][start_id_per_class[label]]) + start_id_per_class[label] += 1 + + # TODO might need to ensure labels in not on cpu any more. + new_user_id_to_data[user_id] = [images[sampled_data_idx], data[1]] + + #new_labels = torch.cat([data[1] for data in new_user_id_to_data.values()]).cpu().numpy() + #print('new label counts', Counter(new_labels)) + + sampler = get_user_sampler('random', user_ids) + federated_dataset = FederatedDataset.from_slices(new_user_id_to_data, + sampler) + + return federated_dataset + + +def make_artificial_federated_dataset( + user_id_to_data: Dict[str, List[np.ndarray]], + user_dataset_len_samplers: List[Callable], + phi: np.ndarray = None, + alphas: np.ndarray = None) -> Tuple[ArtificialFederatedDataset, dict]: + """ + Create artificial federated dataset from a FEMNIST data file. + """ + images = torch.cat([data[0] for data in user_id_to_data.values()]) + labels = torch.cat([data[1] for data in user_id_to_data.values()]) + + data = [images, labels] + + if alphas is not None: + data_samplers = [ + DirichletDataSampler(alpha, + labels.cpu().numpy()) for alpha in alphas + ] + return ArtificialFederatedDatasetMixture.from_slices( + phi, data, data_samplers, user_dataset_len_samplers) + else: + data_sampler = get_data_sampler('random', len(labels)) + return ArtificialFederatedDataset.from_slices( + data, data_sampler, user_dataset_len_samplers[0]) + + +def make_central_dataset( + user_id_to_data: Dict[str, List[np.ndarray]]) -> Dataset: + """ + Create central dataset from a FEMNIST data file. + """ + images = np.concatenate([data[0].cpu() for data in user_id_to_data.values()], + axis=0) + labels = np.concatenate([data[1].cpu() for data in user_id_to_data.values()], + axis=0) + + return Dataset(raw_data=[images, labels]) + + +def make_femnist_datasets( + data_dir: str, + digits_only: bool = False, + numpy_to_tensor: Callable = lambda x: x, + dataset_type: str = 'original', + phi=None, + alphas=None, + user_dataset_len_samplers=None, + filter_method: Optional[str] = None, + sample_fraction: float = None, + start_idx: int = None, + end_idx: int = None, + include_sampled: bool = True +) -> Tuple[FederatedDataset, FederatedDataset, Dataset]: + """ + Create a train and val ``FederatedDataset`` as well as a central dataset + from the FEMNIST data. + """ + + train_h5_file_path = os.path.join(data_dir, 'fed_emnist_train.h5') + val_h5_file_path = os.path.join(data_dir, 'fed_emnist_test.h5') + + train_user_id_to_data = _load_h5_into_dict(train_h5_file_path, digits_only, + numpy_to_tensor) + train_user_id_to_data = _sample_users(train_user_id_to_data, filter_method, + sample_fraction, start_idx, end_idx, + include_sampled) + + val_user_id_to_data = _load_h5_into_dict(val_h5_file_path, digits_only, + numpy_to_tensor) + val_user_id_to_data = _sample_users(val_user_id_to_data, filter_method, + sample_fraction, start_idx, end_idx, + include_sampled) + + # create federated training and val datasets from central training and val + # data + if dataset_type == 'original': + training_federated_dataset = make_federated_dataset( + train_user_id_to_data) + val_federated_dataset = make_federated_dataset(val_user_id_to_data) + + elif dataset_type == 'original_labels_uniform_datapoints': + training_federated_dataset = make_special_federated_dataset( + train_user_id_to_data) + val_federated_dataset = make_special_federated_dataset( + val_user_id_to_data) + + elif dataset_type == 'polya_mixture_federated': + training_federated_dataset = make_federated_dataset( + user_id_to_data=train_user_id_to_data, + use_existing_partition=False, + phi=phi, + alphas=alphas, + user_dataset_len_samplers=user_dataset_len_samplers) + val_federated_dataset = make_federated_dataset( + user_id_to_data=val_user_id_to_data, + use_existing_partition=False, + phi=phi, + alphas=alphas, + user_dataset_len_samplers=user_dataset_len_samplers) + + elif dataset_type == 'polya_mixture_artificial_federated': + training_federated_dataset = make_artificial_federated_dataset( + train_user_id_to_data, + user_dataset_len_samplers, + phi=phi, + alphas=alphas) + val_federated_dataset = make_artificial_federated_dataset( + val_user_id_to_data, + user_dataset_len_samplers, + phi=phi, + alphas=alphas) + + elif dataset_type == 'uniform_federated': + training_federated_dataset = make_federated_dataset( + train_user_id_to_data, + use_existing_partition=False, + user_dataset_len_samplers=user_dataset_len_samplers) + val_federated_dataset = make_federated_dataset( + val_user_id_to_data, + use_existing_partition=False, + user_dataset_len_samplers=user_dataset_len_samplers) + + elif dataset_type == 'uniform_artificial_federated': + training_federated_dataset = make_artificial_federated_dataset( + train_user_id_to_data, user_dataset_len_samplers) + val_federated_dataset = make_artificial_federated_dataset( + val_user_id_to_data, user_dataset_len_samplers) + + else: + raise NotImplementedError( + f'Dataset type {dataset_type} not recognized.') + + central_data = make_central_dataset(val_user_id_to_data) + + return training_federated_dataset, val_federated_dataset, central_data diff --git a/publications/mdm/mdm_utils/datasets/mixture_dataset.py b/publications/mdm/mdm_utils/datasets/mixture_dataset.py new file mode 100644 index 0000000..495d4b1 --- /dev/null +++ b/publications/mdm/mdm_utils/datasets/mixture_dataset.py @@ -0,0 +1,154 @@ +from collections import defaultdict +from typing import Callable, Iterable, List, Tuple + +import numpy as np +import joblib + +from pfl.data import ArtificialFederatedDataset, FederatedDatasetBase +from pfl.data.dataset import AbstractDataset +from pfl.internal.ops.selector import (get_default_framework_module as get_ops) + + +class ArtificialFederatedDatasetMixture(FederatedDatasetBase): + """ + A type of federated dataset that is a mixture of multiple + ArtificialFederatedDataset. To sample a new users we randomly + sample a component with corresponding probability vector phi, + we then sample a user from the corresponding + ArtificialFederatedDataset. + :param phi: + A np.ndarray of shape len(mixture_component_datasets), + probability vector that gives the probability of + each mixture component + :param mixture_component_datasets: + List of type ArtificialFederatedDataset + """ + + def __init__(self, phi, mixture_component_datasets): + super().__init__() + self.phi = phi + self.mixture_component_datasets = mixture_component_datasets + + def __next__(self): + mixture_component = np.random.choice(range(len(self.phi)), p=self.phi) + return next(self.mixture_component_datasets[mixture_component]) + + def get_cohort(self, + cohort_size: int) -> Iterable[Tuple[AbstractDataset, int]]: + for i in range(cohort_size): + if (i % get_ops().distributed.world_size + ) == get_ops().distributed.global_rank: + yield next(self) + + @classmethod + def from_slices(cls, phi, data, data_samplers, dataset_len_samplers): + """ + Construct a mixture of simulated federated datasets from a single + regular dataset where there is no such thing as a user identifier. + Each mixture samples a user from the same data but using its + own data_sampler and dataset_len_sampler. + :param phi: + A np.ndarray probability vector + :param data: + A list of ``np.ndarray``, i.e. the same format as a ``Dataset`` + accepts. + :param data_samplers: + List of callables of length len(phi), each callable is a data + sampler for an ArtificialFederatedDataset + :param dataset_len_samplers: + List of callables of length len(phi), each callable is a data + length sampler for an ArtificialFederatedDataset + :returns: + An instance of `ArtificialFederatedDatasetMixture`. + """ + mixture_component_datasets = [] + for data_sampler, dataset_len_sampler in zip(data_samplers, + dataset_len_samplers): + mixture_component_datasets.append( + ArtificialFederatedDataset.from_slices(data, data_sampler, + dataset_len_sampler)) + return cls(phi, mixture_component_datasets) + + +def partition_by_dirichlet_mixture_class_distribution( + labels: np.ndarray, + phi: np.ndarray, + alphas: np.ndarray, + user_dataset_len_samplers: List[Callable], + spread_distribution_after_num_fails: int = 20, + spread_distribution_after_fails_percentage: float = 0.02 +) -> List[List[int]]: + """ + Partitions central data using a mixture of dirichlet distributions. Works + the same as partition_by_dirichlet_class_distribution except that it first + randomly samples a mixture component using probability vector phi, and then + selects the corresponding alpha and user_dataset_len_sampler. + """ + num_components = len(phi) + num_classes = len(np.unique(labels)) + indices_per_class = [ + list(np.where(labels == i)[0]) for i in range(num_classes) + ] + users_to_indices = defaultdict(list) + + user_id = 0 + while True: + component = np.random.choice(num_components, p=phi) + alpha = alphas[component] + user_dataset_len_sampler = user_dataset_len_samplers[component] + class_priors = np.random.dirichlet(alpha=alpha) + class_prior_cdf = np.cumsum(class_priors) + user_num_datapoints = user_dataset_len_sampler() + if user_num_datapoints > sum( + [len(cidxs) for cidxs in indices_per_class]): + # Not enough datapoints left. + break + + i = 1 + while True: + if len(users_to_indices[user_id]) >= user_num_datapoints: + user_id += 1 + break + # Sample class from user's class distribution (Dirichlet) + sampled_class = np.argmax(np.random.uniform() <= class_prior_cdf) + if len(indices_per_class[sampled_class]): + # Add datapoint to user if there are still datapoints + # available of that class. + users_to_indices[user_id].append( + indices_per_class[sampled_class].pop()) + if i % (user_num_datapoints * + spread_distribution_after_num_fails) == 0: + # Every this number of failed samples, + # even out the class distribution a tiny bit (at least 2% + # chance for every class) such that + # sampling classes with datapoints remaining to be allocated + # are more probable. This will typically only be an issue for + # the final few 1-5 users. + class_priors += spread_distribution_after_fails_percentage + class_priors /= sum(class_priors) + class_prior_cdf = np.cumsum(class_priors) + i += 1 + return list(users_to_indices.values()) + + +def get_user_counts(training_federated_dataset, num_classes, + num_central_iterations, cohort_size, save_path): + """ + Helper function to check the label counts of a cohort of users. + Can be used to visualize the users generated from different experiments + over a number of central iterations in train.py. + """ + print('get_user_counts') + all_counts = dict() + for r in range(num_central_iterations): + all_counts[r + 1] = [] + l = list(training_federated_dataset.get_cohort(cohort_size)) + for d, _ in l: + _, y = d.raw_data + y = y.cpu().numpy() + uniques, counts = np.unique(y, return_counts=True) + full_y = np.zeros(num_classes) + full_y[uniques.astype(int)] = counts + all_counts[r + 1].append(full_y.tolist()) + + joblib.dump(all_counts, save_path) diff --git a/publications/mdm/mdm_utils/datasets/sampler.py b/publications/mdm/mdm_utils/datasets/sampler.py new file mode 100644 index 0000000..070783b --- /dev/null +++ b/publications/mdm/mdm_utils/datasets/sampler.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +import itertools + +import numpy as np + + +class DirichletDataSampler: + """ + Data sampling mechanism that samples user class proportions from a + Dirichlet distribution with a given alpha parameter. + Sampling is done by first drawing a vector of class proportions + p ~ Dir(alpha), then sampling a class from a categorical + distribution with parameter p and uniformly at random choosing + (with replacement) an index with the corresponding class. + + :param alpha: + Parameter of the Dirichlet distribution. Must be array_like and + have length equal to the number of unique classes present in labels. + :param labels: + A one-dimensional array of all labels (integers). This should have + length equal to the size of the corresponding dataset. + """ + + def __init__(self, alpha: np.ndarray, labels: np.ndarray): + self.unique_labels = np.unique(labels) + assert len(alpha) == len( + self.unique_labels + ), "Number of classes doesn't equal dirichlet parameter dimension." + self.indices_per_class = { + i: np.nonzero(labels == i)[0] + for i in self.unique_labels + } + self.alpha = alpha + + def __call__(self, n: int): + """ + Sample a list of datapoint indices. + :param n: + Number of samples to draw. + :returns: + Sampled indices in range '[0, len(labels)]' + """ + class_priors = np.random.dirichlet(alpha=self.alpha) + sampled_class_counts = np.random.multinomial(n, pvals=class_priors) + sampled_indices = [ + list( + np.random.choice(self.indices_per_class[i], + size=class_count, + replace=True)) + for i, class_count in zip(self.unique_labels, sampled_class_counts) + ] + return sum(sampled_indices, []) + + +class DirichletMixtureDataSampler: + + def __init__(self, phi: np.ndarray, alphas: np.ndarray, + labels: np.ndarray): + self.phi = phi + self.dirichlet_samplers = [ + DirichletDataSampler(alpha, labels) for alpha in alphas + ] + + def __call__(self, n: int): + j = np.random.choice(range(len(self.phi)), p=self.phi) + return self.dirichlet_samplers[j](n) + + +class MinimizeReuseDataSampler: + """ + Data sampling mechanism that maximises the time between instances of reuse. + This is done by simply iterating through the sample space in linear fashion + and starting over once the end is reached. + Every data sampling mechanism returns a list of indices when called. + The indices can be used to construct an artificial user dataset. + :param max_bound: + Maximum bound for sampling space. + Will sample in the range `[0, max_bound)`. + """ + + def __init__(self, max_bound): + self._index_iter = itertools.cycle(range(max_bound)) + + def __call__(self, n): + """ + Sample a list of data point indices. + :param n: + Number of samples to draw. + :returns: + Sampled indices in range `[0, max_bound)`. + """ + return list(itertools.islice(self._index_iter, n)) + + +def get_data_sampler(sample_type, + max_bound=None, + phi=None, + alphas=None, + alpha=None, + labels=None): + """ + Factory for data sampling mechanisms. + These samplers can be used when sampling data points for an artificial + user dataset in `ArtificialFederatedDataset`, by providing it as the + `sampler` argument. + Implemented samplers: + * random - Randomly sample from the range `[0, max_bound)`. + * minimize_reuse - Sample while minimizing the number of times a number is + sampled again. + * dirichlet - Sample class proportions from a Dirichlet with given alpha + parameter and sample classes according to these proportions. + """ + if sample_type == 'random': + return lambda n: np.random.randint(0, max_bound, size=n) + elif sample_type == 'minimize_reuse': + return MinimizeReuseDataSampler(max_bound) + elif sample_type == 'dirichlet': + return DirichletDataSampler(alpha, labels) + elif sample_type == 'dirichlet_mixture': + return DirichletMixtureDataSampler(phi, alphas, labels) + else: + raise NotImplementedError diff --git a/publications/mdm/mdm_utils/models/__init__.py b/publications/mdm/mdm_utils/models/__init__.py new file mode 100644 index 0000000..92c2260 --- /dev/null +++ b/publications/mdm/mdm_utils/models/__init__.py @@ -0,0 +1 @@ +from .pytorch_model import simple_cnn diff --git a/publications/mdm/mdm_utils/models/argument_parsing.py b/publications/mdm/mdm_utils/models/argument_parsing.py new file mode 100644 index 0000000..3d70d70 --- /dev/null +++ b/publications/mdm/mdm_utils/models/argument_parsing.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- + +import argparse +from typing import Optional, Tuple + + +def add_model_arguments( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """ + Add `model_name` argument to parser and add + model-specific arguments depending on the model specified in + `model_name` argument. + """ + + parser.add_argument( + '--model_name', + choices=[ + 'simple_cnn', 'simple_dnn', 'resnet18', 'lm_lstm', + 'lm_transformer', 'multi_label_cnn' + ], + default='simple_cnn', + help='Which model to train. See models.py for definitions.') + + # Get the value of `model_name` argument and dynamically add + # arguments depending on which model is chosen. + known_args, _ = parser.parse_known_args() + + if known_args.model_name in {'lm_lstm', 'lm_transformer'}: + parser.add_argument("--embedding_size", + type=int, + required=True, + help='Number of dimensions in embedding layer.') + + if known_args.model_name == 'lm_lstm': + parser.add_argument("--num_cell_states", + type=int, + required=True, + help='Number of cell states in each LSTM.') + + parser.add_argument("--num_lstm_layers", + type=int, + required=True, + help='Number of stacked LSTM layers.') + + if known_args.model_name == 'lm_transformer': + parser.add_argument("--hidden_size", + type=int, + required=True, + help='Number of hidde states in each Transformer.') + + parser.add_argument( + "--num_heads", + type=int, + required=True, + help='Number of heads in multi-head attention layers.') + + parser.add_argument( + '--feedforward_size', + type=int, + required=True, + help='Number of feed forward hidden states in each Transformer.') + + parser.add_argument("--num_transformer_layers", + type=int, + required=True, + help='Number of stacked Transformer layers.') + + parser.add_argument( + '--dropout_rate', + type=float, + default=0.1, + help='Dropout rate applied in the Transformer model.') + + if known_args.model_name == 'multi_label_cnn': + _torchvision_architectures = [ + 'alexnet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2', 'vgg11', 'vgg11_bn', + 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', + 'squeezenet1_0', 'squeezenet1_1', 'inception_v3', 'densenet121', + 'densenet169', 'densenet201', 'densenet161', 'googlenet', + 'mobilenet_v2', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', + 'mnasnet1_3', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', + 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' + ] + + parser.add_argument('--model_type', + choices=_torchvision_architectures, + help='Model architecture.') + + return parser + + +def _get_model_dims_for_dataset( + dataset_name: str) -> Tuple[Optional[Tuple[int, ...]], Optional[int]]: + """ + Get the correct input shape and number of outputs for the + specified dataset. + """ + if dataset_name == 'femnist': + input_shape = (28, 28, 1) + num_outputs = 62 + elif dataset_name == 'femnist_digits': + input_shape = (28, 28, 1) + num_outputs = 10 + elif dataset_name in ['cifar10', 'cifar10_iid']: + input_shape = (32, 32, 3) + num_outputs = 10 + else: + input_shape = None + num_outputs = None + + return input_shape, num_outputs + + +def get_model_tf2(args: argparse.Namespace): + """ + Initialize the TensorFlow v2 model specified by ``args.model_name`` with + other required arguments also available in ``args``. + Use ``add_model_arguments`` to dynamically add arguments required by + the selected model. + """ + assert 'model_name' in vars(args) + from . import tf2 + + input_shape, num_outputs = _get_model_dims_for_dataset(args.dataset) + + model_name = args.model_name.lower() + if model_name == 'dnn': + model = tf2.dnn(input_shape, args.hidden_dims, num_outputs) + elif model_name == 'simple_dnn': + model = tf2.simple_dnn(input_shape, num_outputs) + elif model_name == 'simple_cnn': + model = tf2.simple_cnn(input_shape, num_outputs) + elif model_name == 'resnet18': + model = tf2.resnet18(input_shape, num_outputs) + elif model_name == 'lm_lstm': + model = tf2.lm_lstm(args.embedding_size, args.num_cell_states, + args.num_lstm_layers, args.vocab_size) + elif model_name == 'lm_transformer': + model = tf2.lm_transformer(args.embedding_size, args.hidden_size, + args.num_heads, args.feedforward_size, + args.num_transformer_layers, + args.vocab_size, args.max_sequence_length, + args.dropout_rate) + else: + raise TypeError(f'Model {model_name} not implemented for TF2.') + + return model + + +def get_model_pytorch(args: argparse.Namespace): + """ + Initialize the PyTorch model specified by ``args.model_name`` with + other required arguments also available in ``args``. + Use ``add_model_arguments`` to dynamically add arguments required by + the selected model. + """ + assert 'model_name' in vars(args) + from . import pytorch + + input_shape, num_outputs = _get_model_dims_for_dataset(args.dataset) + + model_name = args.model_name.lower() + + if model_name == 'dnn': + model = pytorch.dnn(input_shape, args.hidden_dims, num_outputs) + elif model_name == 'simple_dnn': + model = pytorch.simple_dnn(input_shape, num_outputs) + elif model_name == 'simple_cnn': + model = pytorch.simple_cnn(input_shape, num_outputs) + elif model_name == 'lm_lstm': + model = pytorch.lm_lstm(args.embedding_size, args.num_cell_states, + args.num_lstm_layers, args.vocab_size, + args.pad_symbol, args.unk_symbol) + elif model_name == 'lm_transformer': + model = pytorch.lm_transformer( + args.embedding_size, args.hidden_size, args.num_heads, + args.feedforward_size, args.num_transformer_layers, + args.vocab_size, args.max_sequence_length, args.pad_symbol, + args.unk_symbol, args.dropout_rate) + elif model_name == 'multi_label_cnn': + model = pytorch.multi_label_cnn(args.model_type, args.num_classes, + args.channel_mean, + args.channel_stddevs, args.pretrained) + else: + raise TypeError(f'Model {model_name} not implemented for PyTorch.') + return model diff --git a/publications/mdm/mdm_utils/models/pytorch/__init__.py b/publications/mdm/mdm_utils/models/pytorch/__init__.py new file mode 100644 index 0000000..664abd6 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .cnn import simple_cnn, multi_label_cnn +from .dnn import dnn, simple_dnn +from .lstm import lm_lstm +from .transformer import lm_transformer diff --git a/publications/mdm/mdm_utils/models/pytorch/cnn.py b/publications/mdm/mdm_utils/models/pytorch/cnn.py new file mode 100644 index 0000000..ef85a40 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/cnn.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- + +import types +from typing import Tuple, List + +import numpy as np +import torch # type: ignore +import torch.nn as nn +import torch.nn.functional as F +from pfl.metrics import Weighted + +from .layer import Transpose2D +from .metrics import image_classification_metrics, image_classification_loss +from ..numpy.metrics import AveragedPrecision, MacroWeighted + + +def multi_label_cnn( + model_type: str, + num_outputs: int, + channel_mean: List[float], + channel_stddevs: List[float], + pretrained: bool, +): + """ + A CNN used for multi-label classification task. + + :param model_type: + The architecture of the model. + :param num_outputs: + Size of the output multi-label classification layer. + :param channel_mean: + Means for input image RGB channels. + :param channel_stddevs: + Standard deviations for input image RGB channels. + :param pretrained: + Whether to use ImageNet pretrained model. + + :return: + A Pytorch CNN module for multi-label classification. + """ + + import torchvision.models # type: ignore + import torchvision.transforms as transforms # type: ignore + from .module_modification import (validate_no_batchnorm, + freeze_batchnorm_modules, + convert_batchnorm_modules) + + torchvision_models = torchvision.models.__dict__ + + class MultiLabelCNN(nn.Module): + """ + Wrapper of torchvision.models used for Ramsay PFL training on multi-label + classification task, e.g. on FLAIR dataset. + """ + + def __init__( + self, + torchvision_model_type: str, + num_outputs: int, + channel_mean: List[float], + channel_stddevs: List[float], + pretrained: bool, + ): + super().__init__() + self._num_outputs = num_outputs + + # input image transformation, same as standard ImageNet training + self.train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.Normalize(channel_mean, channel_stddevs) + ]) + self.eval_transform = transforms.Compose([ + transforms.Resize(224), + transforms.Normalize(channel_mean, channel_stddevs) + ]) + + # per-class binary cross-entropy for multi-label classification + # learning objective + self.loss_fct = nn.BCEWithLogitsLoss() + + # https://github.com/pytorch/examples/blob/master/imagenet/main.py + base_model = torchvision_models[torchvision_model_type]( + num_classes=self._num_outputs) + if pretrained: + pretrained_model = torchvision_models[torchvision_model_type]( + pretrained=True) + pretrained_state = pretrained_model.state_dict() + # Pretrained models typically use Batch Normalization. Since we do + # not want to collect channel statistics in private learning, we + # freeze the trained statistics in all batch norm modules, i.e., + # the statistics will be from pretrained dataset (ImageNet) instead + # of private data on device. + base_model = freeze_batchnorm_modules(base_model) + base_state = base_model.state_dict() + state_to_load = {} + # skip loading the final classifier layer's weight and bias + for k, v in list(pretrained_state.items())[:-2]: + assert k in base_state and v.size() == base_state[k].size() + state_to_load[k] = v + base_model.load_state_dict(state_to_load, strict=False) + self.base_model = base_model + else: + # convert all batch norm module to group norm if not using + # pretrained models + self.base_model = convert_batchnorm_modules(base_model) + + # assert there is no batch norm module in current model + validate_no_batchnorm(self) + + def transform(self, images: torch.Tensor): + images = (images.float() / 255.0).permute(0, 3, 1, 2) + if self.training: + return self.train_transform(images) + else: + return self.eval_transform(images) + + def forward(self, images): + x = self.transform(images) + return self.base_model(x) + + def loss(self, inputs, targets, eval=False): + self.eval() if eval else self.train() + return self.loss_fct(self(inputs), targets) + + @torch.no_grad() + def metrics(self, inputs, targets, eval=True): + self.eval() if eval else self.train() + logits = self(inputs) + + num_data = len(inputs) + num_predictions = np.ones(self._num_outputs) * num_data + summed_loss = F.binary_cross_entropy_with_logits( + logits, targets, reduction='none').sum(dim=0) + + scores = torch.sigmoid(logits) + predictions = torch.round(scores) + correct = torch.sum(torch.eq(targets, predictions), dim=0) + + # evaluate precision and recall + predictions = predictions.bool() + targets = targets.bool() + # true positives, positive label predicted as positive + tps = torch.sum(predictions & targets, dim=0) + # false positives, negative label predicted as positive + fps = torch.sum(predictions & ~targets, dim=0) + # false negatives, positive label predicted as negative + fns = torch.sum(~predictions & targets, dim=0) + # micro true positives, false positives, false negatives + tps_sum = tps.sum().item() + fps_sum = fps.sum().item() + fns_sum = fns.sum().item() + + return { + # Micro metrics: averaged over all predictions + "micro loss": + Weighted(summed_loss.sum().item(), num_predictions.sum()), + "micro accuracy": + Weighted(correct.sum().item(), num_predictions.sum()), + "micro precision": + Weighted(tps_sum, tps_sum + fps_sum), + "micro recall": + Weighted(tps_sum, tps_sum + fns_sum), + "micro AP": + AveragedPrecision(y_true=targets.cpu().numpy(), + y_pred=scores.cpu().numpy(), + multi_label=False), + # Macro metrics: averaged over classes + "macro loss": + MacroWeighted(summed_loss.cpu().numpy(), num_predictions), + "macro accuracy": + MacroWeighted(correct.cpu().numpy(), num_predictions), + "macro precision": + MacroWeighted(tps.cpu().numpy(), + tps.cpu().numpy() + fps.cpu().numpy()), + "macro recall": + MacroWeighted(tps.cpu().numpy(), + tps.cpu().numpy() + fns.cpu().numpy()), + "macro AP": + AveragedPrecision(y_true=targets.cpu().numpy(), + y_pred=scores.cpu().numpy(), + multi_label=True), + } + + return MultiLabelCNN(model_type, num_outputs, channel_mean, + channel_stddevs, pretrained) + + +def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: + """ + A simple CNN with 2 convolutional layers and one dense hidden layer. + + :param input_shape: + The shape of the input images, e.g. (32,32,3). + :param num_outputs: + Size of output softmax layer. + :return: + A PyTorch CNN model. + """ + in_channels = input_shape[-1] + maxpool_output_size = (input_shape[0] - 4) // 2 + flatten_size = maxpool_output_size * maxpool_output_size * 64 + + model = nn.Sequential(*[ + Transpose2D(), + nn.Conv2d(in_channels, 32, kernel_size=(3, 3)), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=(3, 3)), + nn.ReLU(), + nn.MaxPool2d((2, 2)), + nn.Dropout(0.25), + nn.Flatten(), + nn.Linear(flatten_size, 128), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(128, num_outputs), + ]) + + # Apply Glorot (Xavier) uniform initialization to match TF2 model. + for m in model.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + + model.loss = types.MethodType(image_classification_loss, model) + model.metrics = types.MethodType(image_classification_metrics, model) + return model diff --git a/publications/mdm/mdm_utils/models/pytorch/dnn.py b/publications/mdm/mdm_utils/models/pytorch/dnn.py new file mode 100644 index 0000000..a07e152 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/dnn.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple +import functools +import types + +import torch.nn as nn +import numpy as np + +from .metrics import image_classification_metrics, image_classification_loss + + +def dnn(input_shape: Tuple[int, ...], hidden_dims: Tuple[int, ...], + num_outputs: int) -> nn.Module: + """ + Define a feed-forward neural network in PyTorch. + + :param input_shape: + The shape of the input data (excluding batch size). E.g. if the + input is an image of dimensions (12,12,3), then it will be flattened + into a 432-dimensional vector before propagated through the network. + :param hidden_dims: + A tuple describing the size of each hidden layer. + :param num_outputs: + Size of output softmax layer. + :return: + A PyTorch DNN model. + """ + + in_features = int(np.prod(input_shape)) + layers = [nn.Flatten()] + for dim in hidden_dims: + layers.extend([nn.Linear(in_features, dim), nn.ReLU()]) + in_features = dim + layers.append(nn.Linear(in_features, num_outputs)) + model = nn.Sequential(*layers) + model.loss = types.MethodType(image_classification_loss, model) + model.metrics = types.MethodType(image_classification_metrics, model) + return model + + +def simple_dnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: + """ + Define a feed-forward neural network with 2 hidden layers of size 200. + This is the same architecture as used in + McMahan et al. 2017 https://arxiv.org/pdf/1602.05629.pdf. + See ``dnn`` for description about parameters. + """ + return functools.partial(dnn, hidden_dims=[200, + 200])(input_shape, + num_outputs=num_outputs) diff --git a/publications/mdm/mdm_utils/models/pytorch/layer.py b/publications/mdm/mdm_utils/models/pytorch/layer.py new file mode 100644 index 0000000..b21d35d --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/layer.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _NormBase +from ..numpy.layer import positional_encoding + + +class _FrozenBatchNorm(_NormBase, ABC): + """ + A special batch normalization module that will freeze the statistics + during training and only update the affine parameters. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + self._check_input_dim(input) + + # turn of training so no batchnorm statistics is collected + # and use pretrained statistics in training as well + self.training = False + + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + bn_training = (self.running_mean is None) and (self.running_var + is None) + + return F.batch_norm( + input, + # If buffers are not tracked, ensure that they won't be updated + self.running_mean + if not self.training or self.track_running_stats else None, + self.running_var + if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps) + + +class FrozenBatchNorm1D(_FrozenBatchNorm): + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)'.format( + input.dim())) + + +class FrozenBatchNorm2D(_FrozenBatchNorm): + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)'.format( + input.dim())) + + +class FrozenBatchNorm3D(_FrozenBatchNorm): + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)'.format( + input.dim())) + + +class Transpose2D(nn.Module): + """ + Transpose Tensorflow style image to PyTorch compatible + """ + + def forward(self, inputs: torch.Tensor): + return inputs.permute((0, 3, 1, 2)) + + +class PositionalEncoding(nn.Module): + + def __init__(self, embedding_size: int, max_sequence_length: int): + super().__init__() + pe = positional_encoding(max_sequence_length, embedding_size) + self.register_buffer('pe', torch.as_tensor(pe)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + torch.unsqueeze(self.pe[:x.size(1)], 0) diff --git a/publications/mdm/mdm_utils/models/pytorch/metrics.py b/publications/mdm/mdm_utils/models/pytorch/metrics.py new file mode 100644 index 0000000..2f9f096 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/metrics.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +from typing import Dict + +import torch +import torch.nn as nn + +from pfl.metrics import Weighted + + +def cross_entropy(logits: torch.Tensor, targets: torch.Tensor, + reduction: str) -> torch.Tensor: + """ PyTorch cross entropy loss """ + # TODO: support logits with more than 2 dimensions + assert logits.ndim == 2, f"expect 2D tensor, get {logits.ndim}D" + loss_fct = nn.CrossEntropyLoss(reduction=reduction) + if targets.dim() > 1: + targets = targets.squeeze() + return loss_fct(logits, targets.long()) + + +@torch.no_grad() +def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ PyTorch classification accuracy """ + # TODO: support logits with more than 2 dimensions + assert logits.ndim == 2, f"expect 2D tensor, get {logits.ndim}D" + correct = logits.argmax(-1) == targets.squeeze().long() + return correct.float().sum() + + +def image_classification_loss(self: nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + eval: bool = False) -> torch.Tensor: + """ Loss function to be attached to `PyTorchModel` for classification """ + self.eval() if eval else self.train() + return cross_entropy(self(inputs), targets, "mean") + + +@torch.no_grad() +def image_classification_metrics(self: nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + eval: bool = True) -> Dict[str, Weighted]: + """ Metrics function to be attached to `PyTorchModel` for classification """ + self.eval() if eval else self.train() + logits = self(inputs) + num_samples = len(inputs) + correct = accuracy(logits, targets).item() + loss = cross_entropy(logits, targets, "sum").item() + return { + "loss": Weighted(loss, num_samples), + "accuracy": Weighted(correct, num_samples) + } diff --git a/publications/mdm/mdm_utils/models/pytorch/module_modification.py b/publications/mdm/mdm_utils/models/pytorch/module_modification.py new file mode 100644 index 0000000..39d9047 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch/module_modification.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +from typing import Callable, Type + +from torch import nn + +from .layer import FrozenBatchNorm1D, FrozenBatchNorm2D, FrozenBatchNorm3D + + +def _replace_child(root: nn.Module, child_name: str, + converter: Callable[[nn.Module], nn.Module]) -> None: + """ + Converts a sub-module to a new module given a helper + function, the root module and a string representing + the name of the submodule to be replaced. + """ + # find the immediate parent + parent = root + nameList = child_name.split(".") + for name in nameList[:-1]: + parent = parent._modules[name] + # set to identity + parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]]) + + +def replace_all_modules( + root: nn.Module, + target_class: Type[nn.Module], + converter: Callable[[nn.Module], nn.Module], +) -> nn.Module: + """ + Converts all the submodules (of root) that have the same + type as target_class, given a converter, a module root, + and a target class type. + """ + # base case + if isinstance(root, target_class): + return converter(root) + + for name, obj in root.named_modules(): + if isinstance(obj, target_class): + _replace_child(root, name, converter) + return root + + +def _batchnorm_to_groupnorm( + module: nn.modules.batchnorm._BatchNorm) -> nn.Module: + """ + Converts a BatchNorm ``module`` to GroupNorm module. + This is a helper function. + + Notes: + A default value of 32 is chosen for the number of groups based on the + paper *Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour* + https://arxiv.org/pdf/1706.02677.pdf + """ + return nn.GroupNorm(min(32, module.num_features), + module.num_features, + affine=True) + + +def _batchnorm_to_freeze_batchnorm( + module: nn.modules.batchnorm._BatchNorm) -> nn.Module: + """ + Converts a BatchNorm module to the corresponding FrozenBatchNorm module. + This is useful for private finetuning models with BatchNorm module since + we do not want to collect training data statistics for updating BatchNorm + parameters. Instead, the statistics of FrozenBatchNorm is never updated. + """ + + def match_dim(): + if isinstance(module, nn.BatchNorm1d): + return FrozenBatchNorm1D + elif isinstance(module, nn.BatchNorm2d): + return FrozenBatchNorm2D + elif isinstance(module, nn.BatchNorm3d): + return FrozenBatchNorm3D + + return match_dim()(num_features=module.num_features, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats) + + +def validate_no_batchnorm(module: nn.Module): + """ + Assert no regular batch normalization in model architecture. + """ + ans = not isinstance(module, nn.modules.batchnorm._BatchNorm) + for child in module.children(): + ans = ans and validate_no_batchnorm(child) + assert ans + return ans + + +def convert_batchnorm_modules( + model: nn.Module, + converter: Callable[[nn.modules.batchnorm._BatchNorm], + nn.Module] = _batchnorm_to_groupnorm, +) -> nn.Module: + """ + Converts all BatchNorm modules to another module + (defaults to GroupNorm) that is privacy compliant. + + :param model + Module instance, potentially with sub-modules + :param converter + Function or a lambda that converts an instance of a + Batchnorm to another nn.Module. + + :return + Model with all the BatchNorm types replaced by another operation + by using the provided converter, defaulting to GroupNorm if one + isn't provided. + """ + return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, + converter) + + +def freeze_batchnorm_modules(model: nn.Module): + """ + Convert all BatchNorm modules to FrozenBatchNorm modules where the stats + are frozen and not updated during training. + + :param model + Module instance, potentially with sub-modules + :return + A `FrozenBatchNorm` module with the same dimension as input module. + """ + return convert_batchnorm_modules(model, _batchnorm_to_freeze_batchnorm) diff --git a/publications/mdm/mdm_utils/models/pytorch_model.py b/publications/mdm/mdm_utils/models/pytorch_model.py new file mode 100644 index 0000000..cadc368 --- /dev/null +++ b/publications/mdm/mdm_utils/models/pytorch_model.py @@ -0,0 +1,105 @@ +import types +from typing import Tuple, Dict + +import torch +from torch import nn + +from pfl.metrics import Weighted + +# Taken directly from the Ramsay Examples Repo, all methods needed to define the +# pytorch model used for training. + + +def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: + """ + A simple CNN with 2 convolutional layers and one dense hidden layer. + + :param input_shape: + The shape of the input images, e.g. (32,32,3). + :param num_outputs: + Size of output softmax layer. + :return: + A PyTorch CNN model. + """ + in_channels = input_shape[-1] + maxpool_output_size = (input_shape[0] - 4) // 2 + flatten_size = maxpool_output_size * maxpool_output_size * 64 + + model = nn.Sequential(*[ + Transpose2D(), + nn.Conv2d(in_channels, 32, kernel_size=(3, 3)), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=(3, 3)), + nn.ReLU(), + nn.MaxPool2d((2, 2)), + nn.Dropout(0.25), + nn.Flatten(), + nn.Linear(flatten_size, 128), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(128, num_outputs), + ]) + + # Apply Glorot (Xavier) uniform initialization to match TF2 model. + for m in model.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + + model.loss = types.MethodType(image_classification_loss, model) + model.metrics = types.MethodType(image_classification_metrics, model) + return model + + +class Transpose2D(nn.Module): + """ + Transpose Tensorflow style image to PyTorch compatible + """ + + def forward(self, inputs: torch.Tensor): + return inputs.permute((0, 3, 1, 2)) + + +def image_classification_loss(self: nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + eval: bool = False) -> torch.Tensor: + """ Loss function to be attached to `PyTorchModel` for classification """ + self.eval() if eval else self.train() + return cross_entropy(self(inputs), targets, "mean") + + +def cross_entropy(logits: torch.Tensor, targets: torch.Tensor, + reduction: str) -> torch.Tensor: + """ PyTorch cross entropy loss """ + # TODO: support logits with more than 2 dimensions + assert logits.ndim == 2, f"expect 2D tensor, get {logits.ndim}D" + loss_fct = nn.CrossEntropyLoss(reduction=reduction) + if targets.dim() > 1: + targets = targets.squeeze() + return loss_fct(logits, targets.long()) + + +@torch.no_grad() +def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ PyTorch classification accuracy """ + # TODO: support logits with more than 2 dimensions + assert logits.ndim == 2, f"expect 2D tensor, get {logits.ndim}D" + correct = logits.argmax(-1) == targets.squeeze().long() + return correct.float().sum() + + +@torch.no_grad() +def image_classification_metrics(self: nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + eval: bool = True) -> Dict[str, Weighted]: + """ Metrics function to be attached to `PyTorchModel` for classification """ + self.eval() if eval else self.train() + logits = self(inputs) + num_samples = len(inputs) + correct = accuracy(logits, targets).item() + loss = cross_entropy(logits, targets, "sum").item() + return { + "loss": Weighted(loss, num_samples), + "accuracy": Weighted(correct, num_samples) + } diff --git a/publications/mdm/mdm_utils/utils/__init__.py b/publications/mdm/mdm_utils/utils/__init__.py new file mode 100644 index 0000000..b3fccc1 --- /dev/null +++ b/publications/mdm/mdm_utils/utils/__init__.py @@ -0,0 +1,5 @@ +from .argument_parsing import (add_dataset_args, add_experiment_args, + add_init_algorithm_args, add_algorithm_args, + add_mle_args, add_histogram_algorithm_args, + add_user_visualisation_args, + add_dataset_preprocessing_args) diff --git a/publications/mdm/mdm_utils/utils/argument_parsing.py b/publications/mdm/mdm_utils/utils/argument_parsing.py new file mode 100644 index 0000000..2650e66 --- /dev/null +++ b/publications/mdm/mdm_utils/utils/argument_parsing.py @@ -0,0 +1,154 @@ +import argparse + + +class store_bool(argparse.Action): + + def __init__(self, option_strings, dest, **kwargs): + argparse.Action.__init__(self, option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + false_values = set(['false', 'no']) + true_values = set(['true', 'yes']) + + values = values.lower() + + if not values in (false_values | true_values): + raise argparse.ArgumentError( + self, 'Value must be either "true" or "false"') + value = (values in true_values) + + setattr(namespace, self.dest, value) + + +def add_experiment_args(parser): + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--data_dir', type=str) + parser.add_argument('--dirname', type=str) + parser.add_argument('--mle_param_dirname', type=str, default='publications/mdm/mle_params') + parser.add_argument( + '--precomputed_parameter_filepath', + type=str, + default=None, + # default='saved_mle_params/femnist_2_mixture.pkl', + help='If given then the inferred dirichlet mixture' + 'params will be loaded from here, do not specify file extension') + parser.add_argument('--dataset_type', + type=str, + default='original', + choices=[ + 'original', + 'original_labels_uniform_datapoints', + 'polya_mixture_federated', + 'polya_mixture_artificial_federated', + 'uniform_federated', + 'uniform_artificial_federated', + ]) + return parser + + +def add_dataset_preprocessing_args(parser): + parser.add_argument('--filter_method', + type=str, + default=None, + choices=['index', 'sample']) + parser.add_argument('--sample_fraction', type=float, default=1.0) + parser.add_argument('--start_idx', type=int, default=0) + parser.add_argument('--end_idx', type=int, default=-1) + parser.add_argument('--include_sampled', action=store_bool, default=True) + return parser + + +def float_list(arg): + try: + float_values = [float(val) for val in arg.split()] + return float_values + except ValueError: + raise argparse.ArgumentTypeError("Invalid float values in the list") + + +def int_list(arg): + try: + int_values = [int(val) for val in arg.split()] + return int_values + except ValueError: + raise argparse.ArgumentTypeError("Invalid int values in the list") + + +def add_dataset_args(parser): + # Args that are used to create a dirichlet mixture partition of cifar10 + parser.add_argument( + '--component_mean_user_dataset_length', + type=int_list, + default=50, + help="Mean number of samples per user in each component") + parser.add_argument('--component_phi', + type=float_list, + default=1.0, + help='True mixture component weights') + parser.add_argument( + '--component_alphas', + type=float_list, + default=0.1, + help='The alpha value for each mixture component ' + '(all entries of vector/all categories take same value)') + return parser + + +def add_init_algorithm_args(parser): + parser.add_argument('--cohort_size_init_algorithm', type=int) + parser.add_argument('--max_num_samples_mixture_component_init_algorithm', + type=int) + parser.add_argument('--strategy', type=str, default='random') + parser.add_argument('--central_num_iterations_init_algorithm', + type=int, + default=1) + return parser + + +def add_algorithm_args(parser): + parser.add_argument('--cohort_size_algorithm', type=int) + parser.add_argument('--max_num_samples_mixture_component_algorithm', + type=int) + parser.add_argument('--central_num_iterations_algorithm', + type=int, + default=10) + return parser + + +def add_histogram_algorithm_args(parser): + parser.add_argument('--cohort_size_histogram_algorithm', type=int) + parser.add_argument('--central_num_iterations_histogram_algorithm', + type=int, + default=1) + parser.add_argument('--num_bins_histogram', type=int, default=500) + return parser + + +def add_mle_args(parser): + # Add the arguments related to the solving of the Polya Mixture MLE + parser.add_argument( + '--num_mixture_components', + type=int, + default=1, + help='Number of Polya mixture components to try to infer') + return parser + + +def add_user_visualisation_args(parser): + parser.add_argument('--cohort_size_visualization', type=int, default=100) + parser.add_argument('--num_iterations_visualization', type=int, default=20) + return parser + + +def add_flair_visualisation_args(parser): + parser.add_argument('--use_fine_grained_labels', + action=store_bool, + default=False, + help='Whether to use fine-grained label taxonomy.') + + parser.add_argument('--max_num_user_images', + type=int, + default=100, + help='Maximum number of images per user') + + return parser diff --git a/publications/mdm/mdm_utils/utils/tools.py b/publications/mdm/mdm_utils/utils/tools.py new file mode 100644 index 0000000..6280ee1 --- /dev/null +++ b/publications/mdm/mdm_utils/utils/tools.py @@ -0,0 +1,55 @@ +import os +from typing import Tuple + +from pfl.callback import TrainingProcessCallback +from pfl.internal.ops.selector import get_default_framework_module as get_ops + +from pfl.metrics import Metrics +from pfl.model.base import StatefulModel + + +class ModelCheckpointingIterationCallback(TrainingProcessCallback): + """ + Callback to save model checkpoints. + + :param model_checkpoint_dir: + A path to disk for saving the trained model. + :param checkpoint_frequency: + The number of central iterations after which to save a model. + When zero (the default), the model is saved once after + training is complete. + """ + + def __init__(self, + model_checkpoint_dir: str, + *, + checkpoint_frequency: int = 0): + if get_ops().distributed.local_rank == 0: + self.checkpoint_frequency = checkpoint_frequency + from pfl.internal.platform.selector import get_platform + self.model_checkpoint_dir = get_platform( + ).create_checkpoint_directories([model_checkpoint_dir])[0] + + def _should_checkpoint_now(self, central_iteration: int) -> bool: + """ + Return true when the number of `central_iteration`s that have + completed is a non-zero multiple of `self.checkpoint_frequency`. + """ + return (self.checkpoint_frequency > 0 + and central_iteration % self.checkpoint_frequency + == (self.checkpoint_frequency - 1)) + + def after_central_iteration( + self, aggregate_metrics: Metrics, model: StatefulModel, *, + central_iteration: int) -> Tuple[bool, Metrics]: + if get_ops().distributed.local_rank == 0: + if self._should_checkpoint_now(central_iteration): + model.save( + os.path.join(self.model_checkpoint_dir, + f'{central_iteration}')) + return False, Metrics() + + def on_train_end(self, *, model: StatefulModel) -> None: + if get_ops().distributed.local_rank == 0: + if self.checkpoint_frequency == 0: + model.save(self.model_checkpoint_dir + '_end') diff --git a/publications/mdm/mdm_utils/utils/visualize_results.py b/publications/mdm/mdm_utils/utils/visualize_results.py new file mode 100644 index 0000000..281f39d --- /dev/null +++ b/publications/mdm/mdm_utils/utils/visualize_results.py @@ -0,0 +1,71 @@ +import os +from itertools import product + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +PARENT_DIR = os.path.dirname(os.getcwd()) + + +def plot_cifar10_results(): + filename = os.path.join(PARENT_DIR, 'results', 'cifar10_hp_search.csv') + df = pd.read_csv(filename) + experiments = np.unique(df['experiment'].values).tolist() + + dfs = dict() + for experiment in experiments: + dfs[experiment] = (df.loc[df['experiment'] == experiment]) + + column_names = [ + 'cohort_size', 'local_batch_size', 'local_learning_rate', + 'local_num_epochs' + ] + unique_vals = dict() + for column_name in column_names: + unique_vals[column_name] = np.unique(dfs['live'][column_name]).tolist() + + accs = dict() + for name, df in dfs.items(): + accs[name] = dict() + for tup in product(*unique_vals.values()): + filter_dic = dict(zip(column_names, tup)) + a = df.loc[(df[list(filter_dic)] == pd.Series(filter_dic)).all( + axis=1)]['Central val | accuracy (avg)'].values.mean() + accs[name][tup] = a + + x = np.array(list(accs['live'].values())) + permutation = np.argsort(-x) + mask = np.array(list(accs['live'].values()))[permutation] >= 0.6 + + dic = dict() + c = dict(zip(accs.keys(), ['blue', 'red', 'green'])) + + plt.rcParams.update({'font.size': 13}) + for name, d in accs.items(): + x = np.array(list(d.values()))[permutation][mask] + dic[name] = x + + plt.plot(x, label=name, c=c[name]) + + plt.xlabel('Random hyperparameter setting') + plt.ylabel('Classification accuracy (%)') + plt.legend() + plt.show() + + print( + np.mean( + np.abs( + np.array(dic['live']) - np.array(dic['simulated_dirichlet'])))) + print( + np.mean( + np.abs(np.array(dic['live']) - + np.array(dic['simulated_uniform'])))) + + +def main(): + plot_cifar10_results() + + +if __name__ == '__main__': + main() diff --git a/publications/mdm/run_cifar10_mse_alpha_phi_experiments.sh b/publications/mdm/run_cifar10_mse_alpha_phi_experiments.sh new file mode 100755 index 0000000..23e6be9 --- /dev/null +++ b/publications/mdm/run_cifar10_mse_alpha_phi_experiments.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +set -oeux + +export PYTHONPATH=. + +num_components='1' +type='easy' + +if [ "$num_components" -eq 1 ] && [ "$type" == "easy" ]; then + + alphas='1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0' + phi='1.0' + user_dataset_length='20' + +elif [ "$num_components" -eq 2 ] && [ "$type" == "easy" ]; then + + alphas='1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8' + phi='0.5 0.5' + user_dataset_length='20 20' + +elif [ "$num_components" -eq 3 ] && [ "$type" == "easy" ]; then + + alphas='1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1 1.1' + phi='0.334 0.333 0.333' + user_dataset_length='20 20 20' + +elif [ "$num_components" -eq 1 ] && [ "$type" == "medium" ]; then + + alphas='0.1 0.2 0.6 1.0 2.0 0.1 1.0 2.0 0.5 0.5' + phi='1.0' + user_dataset_length='20' + +elif [ "$num_components" -eq 2 ] && [ "$type" == "medium" ]; then + + alphas='0.1 0.2 0.6 1.0 2.0 0.1 1.0 2.0 0.5 0.5 1.1 0.1 0.7 0.9 1.0 0.2 0.5 1.0 0.8 1.5' + phi='0.4 0.6' + user_dataset_length='20 20' + +elif [ "$num_components" -eq 3 ] && [ "$type" == "medium" ]; then + + alphas='0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.1 0.2 0.6 1.0 2.0 0.1 1.0 2.0 0.5 0.5 1.1 0.1 0.7 0.9 1.0 0.2 0.5 1.0 0.8 1.5' + phi='0.2 0.3 0.5' + user_dataset_length='20 20 20' + +elif [ "$num_components" -eq 1 ] && [ "$type" == "hard" ]; then + + alphas=0.1 + phi='1.0' + user_dataset_length='20' + +elif [ "$num_components" -eq 2 ] && [ "$type" == "hard" ]; then + + alphas='0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3' + phi='0.1 0.9' + user_dataset_length='20 20' + +elif [ "$num_components" -eq 3 ] && [ "$type" == "hard" ]; then + + alphas='0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3 1.1 0.1 0.7 0.9 1.0 0.2 0.5 1.0 0.8 1.5' + phi='0.05 0.15 0.8' + user_dataset_length='20 20 20' + +else + + echo 'fail' + +fi + +python3 publications/mdm/mdm_paper/training/train.py --num_mixture_components "$num_components" --component_mean_user_dataset_length "$user_dataset_length" --component_alphas "$alphas" --cohort_size_init_algorithm 1000 --central_num_iterations_init_algorithm 1 --max_num_samples_mixture_component_init_algorithm 40 --cohort_size_algorithm 1000 --central_num_iterations_algorithm 60 --component_phi "$phi" --data_dir data/cifar10 --dirname "$type" diff --git a/publications/mdm/run_femnist.sh b/publications/mdm/run_femnist.sh new file mode 100755 index 0000000..419d978 --- /dev/null +++ b/publications/mdm/run_femnist.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -oeux + +export PYTHONPATH=. + +for num_components in 1 2 3; do + max_num_samples_mixture_component=450 + filter_method='sample' # 'sample' 'index' + sample_fraction=0.5 + include_sampled='True' + start_idx=0 + end_idx=1302 + + if [ "$filter_method" == "sample" ]; then + cohort_size=$(echo "scale=0; ($sample_fraction * 3400)" | bc | cut -d'.' -f1) + else + cohort_size=$(echo "$end_idx - $start_idx" | bc) + fi + + cohort_size_init_algorithm=$cohort_size + central_num_iterations_init_algorithm=1 + + cohort_size_algorithm=$cohort_size + central_num_iterations_algorithm=50 + + if [ "$filter_method" == "index" ]; then + python3 publications/mdm/mdm_paper/training/train_femnist_rebuttal.py --num_mixture_components "$num_components" --cohort_size_init_algorithm "$cohort_size_init_algorithm" --max_num_samples_mixture_component_init_algorithm "$max_num_samples_mixture_component" --central_num_iterations_init_algorithm "$central_num_iterations_init_algorithm" --cohort_size_algorithm "$cohort_size_algorithm" --max_num_samples_mixture_component_algorithm "$max_num_samples_mixture_component" --central_num_iterations_algorithm "$central_num_iterations_algorithm" --data_dir data/femnist --filter_method "$filter_method" --start_idx "$start_idx" --end_idx "$end_idx" + + elif [ "$filter_method" == "sample" ]; then + python3 publications/mdm/mdm_paper/training/train_femnist_rebuttal.py --num_mixture_components "$num_components" --cohort_size_init_algorithm "$cohort_size_init_algorithm" --max_num_samples_mixture_component_init_algorithm "$max_num_samples_mixture_component" --central_num_iterations_init_algorithm "$central_num_iterations_init_algorithm" --cohort_size_algorithm "$cohort_size_algorithm" --max_num_samples_mixture_component_algorithm "$max_num_samples_mixture_component" --central_num_iterations_algorithm "$central_num_iterations_algorithm" --data_dir data/femnist --filter_method "$filter_method" --sample_fraction "$sample_fraction" --include_sampled "$include_sampled" + + fi +done From d10e441d41663fedbafca86ab5db9d9423164eba Mon Sep 17 00:00:00 2001 From: fgranqvist Date: Tue, 4 Jun 2024 13:27:55 +0200 Subject: [PATCH 20/22] Prepare 0.2 release: changelog + bump version (#74) --- CHANGELOG.md | 23 +++++++++++++++++++++++ VERSION | 2 +- pfl/version.py | 2 +- pyproject.toml | 2 +- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10438ac..c49ba7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,29 @@ * +## v0.2.0 + +### Breaking change! + +* `EMMGMMHyperParams` is renamed to `EMGMMHyperParams` (#55) + +### New features + +* Return local metadata from model training to algorithm (#71). + +### Tasks completed + +* Update FLAIR preprocessing script to download dataset from HuggingFace, available at https://huggingface.co/datasets/apple/flair (#72). +* Update LLM Benchmark Configs (#63). +* New improved worker scheduling in distributed simulations. Speeds up FLAIR benchmark by 19% (#73). +* Don't pin PyTorch version to 2.0.1 (#69). +* Move `--noise_cohort_size` to `add_mechanism_arguments` (#70). + +### Bug fixes + +* + + ## v0.1.0 2024-03-01 diff --git a/VERSION b/VERSION index 6e8bf73..0ea3a94 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0 +0.2.0 diff --git a/pfl/version.py b/pfl/version.py index 3dc1f76..d3ec452 100644 --- a/pfl/version.py +++ b/pfl/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/pyproject.toml b/pyproject.toml index bb3b6dc..7aba581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,7 +205,7 @@ markers = [ # Run with command: # poetry run bump-my-version bump [tool.bumpversion] -current_version = "0.1.0" +current_version = "0.2.0" commit = "True" tag = "False" From 19ef43765b07fe514d83c7891fac2bf8181d2617 Mon Sep 17 00:00:00 2001 From: Filip Granqvist Date: Tue, 4 Jun 2024 13:49:34 +0200 Subject: [PATCH 21/22] same ruff ignores in publications as in benchmarks --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7aba581..242ebaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,7 @@ ignore = [ "tests/integration/*" = ["S607"] "pfl/*" = ["S101", "S607"] "benchmarks/*" = ["S101", "SIM115", "A002", "B007", "E741", "S310"] +"publications/*" = ["S101", "SIM115", "A002", "B007", "E741", "S310"] [tool.coverage.report] skip_empty = true From 2065d11e735e9b6d32e90a863fd625fa3b51f081 Mon Sep 17 00:00:00 2001 From: ac554 <47990575+ac554@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:07:43 +0100 Subject: [PATCH 22/22] ruff and mypy on publications/mdm (#78) --- .pre-commit-config.yaml | 1 + publications/__init__.py | 0 publications/mdm/__init__.py | 0 publications/mdm/mdm/__init__.py | 7 ++-- publications/mdm/mdm/algorithm.py | 20 ++++++------ publications/mdm/mdm/bridge/__init__.py | 0 publications/mdm/mdm/bridge/base.py | 7 ++-- publications/mdm/mdm/bridge/factory.py | 1 - .../mdm/mdm/bridge/pytorch/polya_mixture.py | 16 ++++++---- publications/mdm/mdm/init_algorithm.py | 18 ++++------- publications/mdm/mdm/model.py | 18 +++++------ .../mdm/mdm_paper/notebooks/__init__.py | 0 publications/mdm/mdm_paper/training/mle.py | 28 ++++++++-------- publications/mdm/mdm_paper/training/train.py | 25 ++++++++------- .../mdm/mdm_paper/training/train_femnist.py | 21 ++++++------ .../training/train_femnist_rebuttal.py | 23 +++++++------ .../mdm/mdm_utils/datasets/__init__.py | 2 +- .../mdm/mdm_utils/datasets/cifar10_dataset.py | 13 +++----- .../mdm/mdm_utils/datasets/femnist_dataset.py | 32 ++++++++----------- .../mdm/mdm_utils/datasets/mixture_dataset.py | 6 ++-- .../mdm/mdm_utils/datasets/sampler.py | 2 -- .../mdm/mdm_utils/models/argument_parsing.py | 2 -- .../mdm/mdm_utils/models/pytorch/__init__.py | 4 +-- .../mdm/mdm_utils/models/pytorch/cnn.py | 16 ++++------ .../mdm/mdm_utils/models/pytorch/dnn.py | 8 ++--- .../mdm/mdm_utils/models/pytorch/layer.py | 18 ++++------- .../mdm/mdm_utils/models/pytorch/metrics.py | 2 -- .../models/pytorch/module_modification.py | 2 -- .../mdm/mdm_utils/models/pytorch_model.py | 4 +-- publications/mdm/mdm_utils/utils/__init__.py | 15 ++++++--- .../mdm/mdm_utils/utils/argument_parsing.py | 16 ++++++---- publications/mdm/mdm_utils/utils/tools.py | 1 - .../mdm/mdm_utils/utils/visualize_results.py | 10 +++--- 33 files changed, 159 insertions(+), 179 deletions(-) create mode 100644 publications/__init__.py create mode 100644 publications/mdm/__init__.py create mode 100644 publications/mdm/mdm/bridge/__init__.py create mode 100644 publications/mdm/mdm_paper/notebooks/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da4f768..e91a0b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,7 @@ repos: rev: v1.5.0 hooks: - id: mypy + exclude: ^publications/ # TODO: license header hook diff --git a/publications/__init__.py b/publications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/__init__.py b/publications/mdm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm/__init__.py b/publications/mdm/mdm/__init__.py index 7ae9963..89c1ea9 100644 --- a/publications/mdm/mdm/__init__.py +++ b/publications/mdm/mdm/__init__.py @@ -1,4 +1,3 @@ -from .model import (MDMModelHyperParams, MDMModel) -from .init_algorithm import (MDMInitializationAlgorithmParams, - MDMInitializationAlgorithm) -from .algorithm import (MDMAlgorithmParams, MDMAlgorithm) +from .algorithm import MDMAlgorithm, MDMAlgorithmParams +from .init_algorithm import MDMInitializationAlgorithm, MDMInitializationAlgorithmParams +from .model import MDMModel, MDMModelHyperParams diff --git a/publications/mdm/mdm/algorithm.py b/publications/mdm/mdm/algorithm.py index 0f3be49..0f360c6 100644 --- a/publications/mdm/mdm/algorithm.py +++ b/publications/mdm/mdm/algorithm.py @@ -1,22 +1,18 @@ -# -*- coding: utf-8 -*- - from dataclasses import dataclass -from typing import Tuple, Optional, TypeVar, Callable, Union +from typing import Callable, Optional, Tuple, TypeVar, Union import numpy as np import torch +from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm from pfl.common_types import Population -from pfl.data.dataset import AbstractDataset +from pfl.context import CentralContext +from pfl.data.dataset import AbstractDataset, AbstractDatasetType from pfl.hyperparam import get_param_value from pfl.metrics import Metrics -from pfl.context import CentralContext from pfl.stats import MappedVectorStatistics -from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams -from pfl.data.dataset import AbstractDatasetType - -from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges +from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType @dataclass(frozen=True) @@ -44,7 +40,8 @@ class MDMAlgorithmParams(AlgorithmHyperParams): class MDMAlgorithm(FederatedAlgorithm[MDMAlgorithmParamsType, MDMModelHyperParamsType, MDMModelType, - MappedVectorStatistics, AbstractDatasetType]): + MappedVectorStatistics, + AbstractDatasetType]): """ Federated algorithm class for learning mixture of Polya (Dirichlet-Multinomial) distribution using MLE algorithm. @@ -161,7 +158,8 @@ def simulate_one_user( e[:, selected_bin] = posterior_probabilities.view(-1) statistics = MappedVectorStatistics() - statistics['posterior_probabilities'] = posterior_probabilities.to('cpu') + statistics['posterior_probabilities'] = posterior_probabilities.to( + 'cpu') statistics['numerator'] = numerator.to('cpu') statistics['denominator'] = denominator.to('cpu') statistics['num_samples_distribution'] = e.to('cpu') diff --git a/publications/mdm/mdm/bridge/__init__.py b/publications/mdm/mdm/bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm/bridge/base.py b/publications/mdm/mdm/bridge/base.py index f63e763..a9900ef 100644 --- a/publications/mdm/mdm/bridge/base.py +++ b/publications/mdm/mdm/bridge/base.py @@ -1,7 +1,4 @@ -# -*- coding: utf-8 -*- - -from typing import Any, Dict, Protocol, TypeVar, Tuple - +from typing import Any, Dict, Protocol, Tuple, TypeVar Tensor = TypeVar('Tensor') @@ -9,7 +6,7 @@ class PolyaMixtureFrameworkBridge(Protocol[Tensor]): """ Interface for Polya-Mixture algorithm for a particular Deep Learning - framework. + framework. """ @staticmethod diff --git a/publications/mdm/mdm/bridge/factory.py b/publications/mdm/mdm/bridge/factory.py index c395e26..4c3e449 100644 --- a/publications/mdm/mdm/bridge/factory.py +++ b/publications/mdm/mdm/bridge/factory.py @@ -8,7 +8,6 @@ ) from pfl.internal.ops.framework_types import MLFramework from pfl.internal.ops.selector import get_framework_module - from publications.mdm.mdm.bridge.base import PolyaMixtureFrameworkBridge diff --git a/publications/mdm/mdm/bridge/pytorch/polya_mixture.py b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py index a43ed9c..463e42a 100644 --- a/publications/mdm/mdm/bridge/pytorch/polya_mixture.py +++ b/publications/mdm/mdm/bridge/pytorch/polya_mixture.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- - from typing import Tuple + import torch from ..base import PolyaMixtureFrameworkBridge @@ -40,11 +39,13 @@ def category_probabilities_polya_mixture_initialization( def expectation_step(phi, alphas, num_samples_distribution, category_counts) -> torch.Tensor: if (num_samples_distribution == 0).any(): - raise AssertionError('num_samples_distribution contains zero values, which cannot work with expectation step on clients') + raise AssertionError( + 'num_samples_distribution contains zero values, which cannot work with expectation step on clients' + ) # E Step - compute posterior probability of each component # Compute log prior + log likelihood - # TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False) + # TODO log_v might be missing + torch.lgamma(torch.sum(counts)+1) - torch.sum(torch.lgamma(category_counts+1), dim=1, keepdim=False) phi = torch.Tensor(phi).to('cpu') alphas = torch.Tensor(alphas).to('cpu') category_counts = category_counts.to('cpu') @@ -56,7 +57,7 @@ def expectation_step(phi, alphas, num_samples_distribution, torch.sum( torch.lgamma(category_counts + alphas) - torch.lgamma(alphas), dim=1, - keepdim=False)) + torch.log(num_samples_distribution) + keepdim=False)) + torch.log(num_samples_distribution) # TODO Ignore this as log(0) => NaN # TODO fix this equation so that it works with num_samples_distribution = 0 @@ -75,12 +76,13 @@ def expectation_step(phi, alphas, num_samples_distribution, @staticmethod def maximization_step(posterior_probabilities, category_counts, - alphas) -> torch.Tensor: + alphas) -> torch.Tensor: # M Step - compute client update to alphas for fixed point update # which will be applied by the model in process_aggregated_statistics. # Note the numerator and denominator are both weighted by w (the # probability vector giving the client belonging to each component). - posterior_probabilities = torch.Tensor(posterior_probabilities).to('cpu') + posterior_probabilities = torch.Tensor(posterior_probabilities).to( + 'cpu') category_counts = torch.Tensor(category_counts).to('cpu') alphas = torch.Tensor(alphas).to('cpu') numerator = posterior_probabilities.reshape( diff --git a/publications/mdm/mdm/init_algorithm.py b/publications/mdm/mdm/init_algorithm.py index 9a24c1b..9acaaee 100644 --- a/publications/mdm/mdm/init_algorithm.py +++ b/publications/mdm/mdm/init_algorithm.py @@ -1,24 +1,20 @@ -# -*- coding: utf-8 -*- - -from dataclasses import dataclass -from typing import Tuple, Optional, TypeVar, Callable, Union from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, TypeVar, Union import numpy as np import torch +from pfl.algorithm.base import AlgorithmHyperParams, FederatedAlgorithm from pfl.common_types import Population -from pfl.data.dataset import AbstractDataset +from pfl.context import CentralContext +from pfl.data.dataset import AbstractDataset, AbstractDatasetType from pfl.hyperparam import get_param_value +from pfl.internal.ops import get_ops from pfl.metrics import Metrics -from pfl.context import CentralContext from pfl.stats import MappedVectorStatistics -from pfl.internal.ops import get_ops -from pfl.algorithm.base import FederatedAlgorithm, AlgorithmHyperParams -from pfl.data.dataset import AbstractDatasetType - -from publications.mdm.mdm.model import MDMModelType, MDMModelHyperParamsType from publications.mdm.mdm.bridge.factory import FrameworkBridgeFactory as bridges +from publications.mdm.mdm.model import MDMModelHyperParamsType, MDMModelType @dataclass(frozen=True) diff --git a/publications/mdm/mdm/model.py b/publications/mdm/mdm/model.py index 3372fa3..65813f0 100644 --- a/publications/mdm/mdm/model.py +++ b/publications/mdm/mdm/model.py @@ -1,20 +1,18 @@ -# -*- coding: utf-8 -*- - -from typing import TypeVar, Generic, Tuple, List, Union, Optional -from dataclasses import dataclass import os -import joblib +from dataclasses import dataclass +from typing import Generic, List, Optional, Tuple, TypeVar, Union +import joblib import numpy as np -from pfl.hyperparam.base import ModelHyperParams -from pfl.model.base import Model -from pfl.metrics import Metrics -from pfl.stats import MappedVectorStatistics from pfl.exception import CheckpointNotFoundError -from pfl.internal.ops.selector import set_framework_module +from pfl.hyperparam.base import ModelHyperParams from pfl.internal.ops import pytorch_ops from pfl.internal.ops.selector import get_default_framework_module as get_ops +from pfl.internal.ops.selector import set_framework_module +from pfl.metrics import Metrics +from pfl.model.base import Model +from pfl.stats import MappedVectorStatistics Tensor = TypeVar('Tensor') FrameworkModelType = TypeVar('FrameworkModelType') diff --git a/publications/mdm/mdm_paper/notebooks/__init__.py b/publications/mdm/mdm_paper/notebooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/publications/mdm/mdm_paper/training/mle.py b/publications/mdm/mdm_paper/training/mle.py index 3376740..88d9394 100644 --- a/publications/mdm/mdm_paper/training/mle.py +++ b/publications/mdm/mdm_paper/training/mle.py @@ -2,13 +2,16 @@ from pfl.aggregate.simulate import SimulatedBackend from pfl.callback import ModelCheckpointingCallback -from pfl.privacy import (CentrallyAppliedPrivacyMechanism, PLDPrivacyAccountant, GaussianMechanism) - +from pfl.privacy import CentrallyAppliedPrivacyMechanism, GaussianMechanism, PLDPrivacyAccountant +from publications.mdm.mdm import ( + MDMAlgorithm, + MDMAlgorithmParams, + MDMInitializationAlgorithm, + MDMInitializationAlgorithmParams, + MDMModel, + MDMModelHyperParams, +) from publications.mdm.mdm_utils.utils.tools import ModelCheckpointingIterationCallback -from publications.mdm.mdm import (MDMModel, MDMModelHyperParams, - MDMAlgorithm, MDMAlgorithmParams, - MDMInitializationAlgorithm, - MDMInitializationAlgorithmParams) def solve_polya_mixture_mle( @@ -33,13 +36,12 @@ def solve_polya_mixture_mle( if add_DP: num_iterations = arguments.central_num_iterations_init_algorithm + arguments.central_num_iterations_algorithm - accountant = PLDPrivacyAccountant( - num_compositions=num_iterations, - sampling_probability=0.001, - mechanism='gaussian', - epsilon=2, - delta=1e-7, - noise_scale=1.0) + accountant = PLDPrivacyAccountant(num_compositions=num_iterations, + sampling_probability=0.001, + mechanism='gaussian', + epsilon=2, + delta=1e-7, + noise_scale=1.0) mechanism = GaussianMechanism.from_privacy_accountant( accountant=accountant, clipping_bound=0.5) diff --git a/publications/mdm/mdm_paper/training/train.py b/publications/mdm/mdm_paper/training/train.py index 9431ccd..811cefe 100644 --- a/publications/mdm/mdm_paper/training/train.py +++ b/publications/mdm/mdm_paper/training/train.py @@ -1,24 +1,26 @@ -import os import argparse import datetime +import os +import joblib import numpy as np import torch -import joblib from pfl.internal.ops import pytorch_ops from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_cifar10_datasets -from publications.mdm.mdm_utils.utils import (add_dataset_args, add_experiment_args, - add_mle_args, add_init_algorithm_args, - add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_cifar10_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_dataset_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -104,7 +106,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] current_time = datetime.datetime.now() timestamp = current_time.strftime("%Y-%m-%d_%H-%M") save_dir = ( diff --git a/publications/mdm/mdm_paper/training/train_femnist.py b/publications/mdm/mdm_paper/training/train_femnist.py index dd6c90e..9b118b3 100644 --- a/publications/mdm/mdm_paper/training/train_femnist.py +++ b/publications/mdm/mdm_paper/training/train_femnist.py @@ -1,5 +1,5 @@ -import os import argparse +import os import joblib import numpy as np @@ -9,14 +9,16 @@ from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_femnist_datasets -from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, - add_init_algorithm_args, add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -56,7 +58,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] save_dir = ( f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture' ) diff --git a/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py index 096c7ec..9848838 100644 --- a/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py +++ b/publications/mdm/mdm_paper/training/train_femnist_rebuttal.py @@ -1,5 +1,5 @@ -import os import argparse +import os import joblib import numpy as np @@ -9,15 +9,17 @@ from pfl.internal.ops.selector import get_default_framework_module as get_ops from pfl.internal.ops.selector import set_framework_module from pfl.internal.platform.selector import get_platform - -from publications.mdm.mdm_utils.datasets import make_femnist_datasets -from publications.mdm.mdm_utils.utils import (add_experiment_args, add_mle_args, - add_init_algorithm_args, add_algorithm_args, - add_histogram_algorithm_args, - add_user_visualisation_args, - add_dataset_preprocessing_args) - from publications.mdm.mdm_paper.training.mle import solve_polya_mixture_mle +from publications.mdm.mdm_utils.datasets import make_femnist_datasets +from publications.mdm.mdm_utils.utils import ( + add_algorithm_args, + add_dataset_preprocessing_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) def get_arguments(): @@ -63,7 +65,8 @@ def get_arguments(): print('simulated_dirichlet_mixture experiment') if arguments.precomputed_parameter_filepath is None: print('learn simulated_dirichlet_mixture parameters') - dir_path = get_platform().create_checkpoint_directories([arguments.mle_param_dirname])[0] + dir_path = get_platform().create_checkpoint_directories( + [arguments.mle_param_dirname])[0] save_dir = ( f'femnist_{arguments.dataset_type}_{arguments.num_mixture_components}_mixture_{arguments.filter_method}_filter_method' ) diff --git a/publications/mdm/mdm_utils/datasets/__init__.py b/publications/mdm/mdm_utils/datasets/__init__.py index e189449..533151f 100644 --- a/publications/mdm/mdm_utils/datasets/__init__.py +++ b/publications/mdm/mdm_utils/datasets/__init__.py @@ -1,3 +1,3 @@ -from .mixture_dataset import get_user_counts from .cifar10_dataset import make_cifar10_datasets from .femnist_dataset import make_femnist_datasets +from .mixture_dataset import get_user_counts diff --git a/publications/mdm/mdm_utils/datasets/cifar10_dataset.py b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py index 3d06410..bd27c0e 100644 --- a/publications/mdm/mdm_utils/datasets/cifar10_dataset.py +++ b/publications/mdm/mdm_utils/datasets/cifar10_dataset.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- - import os import pickle from typing import Callable, List, Optional, Tuple import numpy as np -from pfl.data import (ArtificialFederatedDataset, FederatedDataset, - FederatedDatasetBase) -from pfl.data.sampling import get_user_sampler, get_data_sampler +from pfl.data import ArtificialFederatedDataset, FederatedDataset, FederatedDatasetBase from pfl.data.dataset import Dataset +from pfl.data.sampling import get_data_sampler, get_user_sampler -from .mixture_dataset import (ArtificialFederatedDatasetMixture, - partition_by_dirichlet_mixture_class_distribution - ) +from .mixture_dataset import ArtificialFederatedDatasetMixture, partition_by_dirichlet_mixture_class_distribution from .sampler import DirichletDataSampler @@ -94,7 +89,7 @@ def make_federated_dataset( images = numpy_to_tensor(images) labels = numpy_to_tensor(labels) - data = dict() + data = {} for user_id in range(len(user_idxs)): data[user_id] = [ images[user_idxs[user_id]], labels[user_idxs[user_id]] diff --git a/publications/mdm/mdm_utils/datasets/femnist_dataset.py b/publications/mdm/mdm_utils/datasets/femnist_dataset.py index c921e4f..21880a8 100644 --- a/publications/mdm/mdm_utils/datasets/femnist_dataset.py +++ b/publications/mdm/mdm_utils/datasets/femnist_dataset.py @@ -1,27 +1,23 @@ -# -*- coding: utf-8 -*- - import os -from typing import Callable, Dict, Tuple, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import h5py import numpy as np import torch from pfl.data import ArtificialFederatedDataset, FederatedDataset -from pfl.data.sampling import get_user_sampler, get_data_sampler from pfl.data.dataset import Dataset +from pfl.data.sampling import get_data_sampler, get_user_sampler -from .mixture_dataset import (ArtificialFederatedDatasetMixture, - partition_by_dirichlet_mixture_class_distribution - ) +from .mixture_dataset import ArtificialFederatedDatasetMixture, partition_by_dirichlet_mixture_class_distribution from .sampler import DirichletDataSampler def _sample_users(user_id_to_data: Dict[str, List[np.ndarray]], filter_method: Optional[str] = None, - sample_fraction: float = None, - start_idx: int = None, - end_idx: int = None, + sample_fraction: Optional[float] = None, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, include_sampled: bool = True): user_ids = list(user_id_to_data.keys()) @@ -156,7 +152,7 @@ def make_special_federated_dataset( #for k,v in indices_per_class.items(): # print('indices_per_class', k, len(v)) - new_user_id_to_data = dict() + new_user_id_to_data = {} start_id_per_class = {i: 0 for i in unique_labels} #print('start_id_per_class', start_id_per_class) for user_id, data in user_id_to_data.items(): @@ -218,10 +214,10 @@ def make_central_dataset( """ Create central dataset from a FEMNIST data file. """ - images = np.concatenate([data[0].cpu() for data in user_id_to_data.values()], - axis=0) - labels = np.concatenate([data[1].cpu() for data in user_id_to_data.values()], - axis=0) + images = np.concatenate( + [data[0].cpu() for data in user_id_to_data.values()], axis=0) + labels = np.concatenate( + [data[1].cpu() for data in user_id_to_data.values()], axis=0) return Dataset(raw_data=[images, labels]) @@ -235,9 +231,9 @@ def make_femnist_datasets( alphas=None, user_dataset_len_samplers=None, filter_method: Optional[str] = None, - sample_fraction: float = None, - start_idx: int = None, - end_idx: int = None, + sample_fraction: Optional[float] = None, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, include_sampled: bool = True ) -> Tuple[FederatedDataset, FederatedDataset, Dataset]: """ diff --git a/publications/mdm/mdm_utils/datasets/mixture_dataset.py b/publications/mdm/mdm_utils/datasets/mixture_dataset.py index 495d4b1..2de7789 100644 --- a/publications/mdm/mdm_utils/datasets/mixture_dataset.py +++ b/publications/mdm/mdm_utils/datasets/mixture_dataset.py @@ -1,12 +1,12 @@ from collections import defaultdict from typing import Callable, Iterable, List, Tuple -import numpy as np import joblib +import numpy as np from pfl.data import ArtificialFederatedDataset, FederatedDatasetBase from pfl.data.dataset import AbstractDataset -from pfl.internal.ops.selector import (get_default_framework_module as get_ops) +from pfl.internal.ops.selector import get_default_framework_module as get_ops class ArtificialFederatedDatasetMixture(FederatedDatasetBase): @@ -139,7 +139,7 @@ def get_user_counts(training_federated_dataset, num_classes, over a number of central iterations in train.py. """ print('get_user_counts') - all_counts = dict() + all_counts = {} for r in range(num_central_iterations): all_counts[r + 1] = [] l = list(training_federated_dataset.get_cohort(cohort_size)) diff --git a/publications/mdm/mdm_utils/datasets/sampler.py b/publications/mdm/mdm_utils/datasets/sampler.py index 070783b..c419d53 100644 --- a/publications/mdm/mdm_utils/datasets/sampler.py +++ b/publications/mdm/mdm_utils/datasets/sampler.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import itertools import numpy as np diff --git a/publications/mdm/mdm_utils/models/argument_parsing.py b/publications/mdm/mdm_utils/models/argument_parsing.py index 3d70d70..3f109a8 100644 --- a/publications/mdm/mdm_utils/models/argument_parsing.py +++ b/publications/mdm/mdm_utils/models/argument_parsing.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import argparse from typing import Optional, Tuple diff --git a/publications/mdm/mdm_utils/models/pytorch/__init__.py b/publications/mdm/mdm_utils/models/pytorch/__init__.py index 664abd6..ee4accf 100644 --- a/publications/mdm/mdm_utils/models/pytorch/__init__.py +++ b/publications/mdm/mdm_utils/models/pytorch/__init__.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -from .cnn import simple_cnn, multi_label_cnn +from .cnn import multi_label_cnn, simple_cnn from .dnn import dnn, simple_dnn from .lstm import lm_lstm from .transformer import lm_transformer diff --git a/publications/mdm/mdm_utils/models/pytorch/cnn.py b/publications/mdm/mdm_utils/models/pytorch/cnn.py index ef85a40..f8d00ba 100644 --- a/publications/mdm/mdm_utils/models/pytorch/cnn.py +++ b/publications/mdm/mdm_utils/models/pytorch/cnn.py @@ -1,17 +1,16 @@ -# -*- coding: utf-8 -*- - import types -from typing import Tuple, List +from typing import List, Tuple import numpy as np import torch # type: ignore import torch.nn as nn import torch.nn.functional as F + from pfl.metrics import Weighted -from .layer import Transpose2D -from .metrics import image_classification_metrics, image_classification_loss from ..numpy.metrics import AveragedPrecision, MacroWeighted +from .layer import Transpose2D +from .metrics import image_classification_loss, image_classification_metrics def multi_label_cnn( @@ -41,9 +40,8 @@ def multi_label_cnn( import torchvision.models # type: ignore import torchvision.transforms as transforms # type: ignore - from .module_modification import (validate_no_batchnorm, - freeze_batchnorm_modules, - convert_batchnorm_modules) + + from .module_modification import convert_batchnorm_modules, freeze_batchnorm_modules, validate_no_batchnorm torchvision_models = torchvision.models.__dict__ @@ -218,7 +216,7 @@ def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: # Apply Glorot (Xavier) uniform initialization to match TF2 model. for m in model.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) model.loss = types.MethodType(image_classification_loss, model) diff --git a/publications/mdm/mdm_utils/models/pytorch/dnn.py b/publications/mdm/mdm_utils/models/pytorch/dnn.py index a07e152..b6ac8bf 100644 --- a/publications/mdm/mdm_utils/models/pytorch/dnn.py +++ b/publications/mdm/mdm_utils/models/pytorch/dnn.py @@ -1,13 +1,11 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple import functools import types +from typing import Tuple -import torch.nn as nn import numpy as np +import torch.nn as nn -from .metrics import image_classification_metrics, image_classification_loss +from .metrics import image_classification_loss, image_classification_metrics def dnn(input_shape: Tuple[int, ...], hidden_dims: Tuple[int, ...], diff --git a/publications/mdm/mdm_utils/models/pytorch/layer.py b/publications/mdm/mdm_utils/models/pytorch/layer.py index b21d35d..9c62736 100644 --- a/publications/mdm/mdm_utils/models/pytorch/layer.py +++ b/publications/mdm/mdm_utils/models/pytorch/layer.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- - from abc import ABC import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.batchnorm import _NormBase + from ..numpy.layer import positional_encoding @@ -22,10 +21,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # and use pretrained statistics in training as well self.training = False - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum + exponential_average_factor = 0.0 if self.momentum is None else self.momentum bn_training = (self.running_mean is None) and (self.running_var is None) @@ -48,24 +44,22 @@ class FrozenBatchNorm1D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)'.format( - input.dim())) + raise ValueError( + f'expected 2D or 3D input (got {input.dim()}D input)') class FrozenBatchNorm2D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)'.format( - input.dim())) + raise ValueError(f'expected 4D input (got {input.dim()}D input)') class FrozenBatchNorm3D(_FrozenBatchNorm): def _check_input_dim(self, input): if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)'.format( - input.dim())) + raise ValueError(f'expected 5D input (got {input.dim()}D input)') class Transpose2D(nn.Module): diff --git a/publications/mdm/mdm_utils/models/pytorch/metrics.py b/publications/mdm/mdm_utils/models/pytorch/metrics.py index 2f9f096..92c1523 100644 --- a/publications/mdm/mdm_utils/models/pytorch/metrics.py +++ b/publications/mdm/mdm_utils/models/pytorch/metrics.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from typing import Dict import torch diff --git a/publications/mdm/mdm_utils/models/pytorch/module_modification.py b/publications/mdm/mdm_utils/models/pytorch/module_modification.py index 39d9047..345fe64 100644 --- a/publications/mdm/mdm_utils/models/pytorch/module_modification.py +++ b/publications/mdm/mdm_utils/models/pytorch/module_modification.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from typing import Callable, Type from torch import nn diff --git a/publications/mdm/mdm_utils/models/pytorch_model.py b/publications/mdm/mdm_utils/models/pytorch_model.py index cadc368..5ab603e 100644 --- a/publications/mdm/mdm_utils/models/pytorch_model.py +++ b/publications/mdm/mdm_utils/models/pytorch_model.py @@ -1,5 +1,5 @@ import types -from typing import Tuple, Dict +from typing import Dict, Tuple import torch from torch import nn @@ -42,7 +42,7 @@ def simple_cnn(input_shape: Tuple[int, ...], num_outputs: int) -> nn.Module: # Apply Glorot (Xavier) uniform initialization to match TF2 model. for m in model.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) model.loss = types.MethodType(image_classification_loss, model) diff --git a/publications/mdm/mdm_utils/utils/__init__.py b/publications/mdm/mdm_utils/utils/__init__.py index b3fccc1..37bf735 100644 --- a/publications/mdm/mdm_utils/utils/__init__.py +++ b/publications/mdm/mdm_utils/utils/__init__.py @@ -1,5 +1,10 @@ -from .argument_parsing import (add_dataset_args, add_experiment_args, - add_init_algorithm_args, add_algorithm_args, - add_mle_args, add_histogram_algorithm_args, - add_user_visualisation_args, - add_dataset_preprocessing_args) +from .argument_parsing import ( + add_algorithm_args, + add_dataset_args, + add_dataset_preprocessing_args, + add_experiment_args, + add_histogram_algorithm_args, + add_init_algorithm_args, + add_mle_args, + add_user_visualisation_args, +) diff --git a/publications/mdm/mdm_utils/utils/argument_parsing.py b/publications/mdm/mdm_utils/utils/argument_parsing.py index 2650e66..e936c8b 100644 --- a/publications/mdm/mdm_utils/utils/argument_parsing.py +++ b/publications/mdm/mdm_utils/utils/argument_parsing.py @@ -7,12 +7,12 @@ def __init__(self, option_strings, dest, **kwargs): argparse.Action.__init__(self, option_strings, dest, **kwargs) def __call__(self, parser, namespace, values, option_string=None): - false_values = set(['false', 'no']) - true_values = set(['true', 'yes']) + false_values = {'false', 'no'} + true_values = {'true', 'yes'} values = values.lower() - if not values in (false_values | true_values): + if values not in (false_values | true_values): raise argparse.ArgumentError( self, 'Value must be either "true" or "false"') value = (values in true_values) @@ -24,7 +24,9 @@ def add_experiment_args(parser): parser.add_argument('--seed', type=int, default=0) parser.add_argument('--data_dir', type=str) parser.add_argument('--dirname', type=str) - parser.add_argument('--mle_param_dirname', type=str, default='publications/mdm/mle_params') + parser.add_argument('--mle_param_dirname', + type=str, + default='publications/mdm/mle_params') parser.add_argument( '--precomputed_parameter_filepath', type=str, @@ -61,17 +63,19 @@ def add_dataset_preprocessing_args(parser): def float_list(arg): try: float_values = [float(val) for val in arg.split()] - return float_values except ValueError: raise argparse.ArgumentTypeError("Invalid float values in the list") + else: + return float_values def int_list(arg): try: int_values = [int(val) for val in arg.split()] - return int_values except ValueError: raise argparse.ArgumentTypeError("Invalid int values in the list") + else: + return int_values def add_dataset_args(parser): diff --git a/publications/mdm/mdm_utils/utils/tools.py b/publications/mdm/mdm_utils/utils/tools.py index 6280ee1..44deab7 100644 --- a/publications/mdm/mdm_utils/utils/tools.py +++ b/publications/mdm/mdm_utils/utils/tools.py @@ -3,7 +3,6 @@ from pfl.callback import TrainingProcessCallback from pfl.internal.ops.selector import get_default_framework_module as get_ops - from pfl.metrics import Metrics from pfl.model.base import StatefulModel diff --git a/publications/mdm/mdm_utils/utils/visualize_results.py b/publications/mdm/mdm_utils/utils/visualize_results.py index 281f39d..67ddc80 100644 --- a/publications/mdm/mdm_utils/utils/visualize_results.py +++ b/publications/mdm/mdm_utils/utils/visualize_results.py @@ -13,7 +13,7 @@ def plot_cifar10_results(): df = pd.read_csv(filename) experiments = np.unique(df['experiment'].values).tolist() - dfs = dict() + dfs = {} for experiment in experiments: dfs[experiment] = (df.loc[df['experiment'] == experiment]) @@ -21,13 +21,13 @@ def plot_cifar10_results(): 'cohort_size', 'local_batch_size', 'local_learning_rate', 'local_num_epochs' ] - unique_vals = dict() + unique_vals = {} for column_name in column_names: unique_vals[column_name] = np.unique(dfs['live'][column_name]).tolist() - accs = dict() + accs = {} for name, df in dfs.items(): - accs[name] = dict() + accs[name] = {} for tup in product(*unique_vals.values()): filter_dic = dict(zip(column_names, tup)) a = df.loc[(df[list(filter_dic)] == pd.Series(filter_dic)).all( @@ -38,7 +38,7 @@ def plot_cifar10_results(): permutation = np.argsort(-x) mask = np.array(list(accs['live'].values()))[permutation] >= 0.6 - dic = dict() + dic = {} c = dict(zip(accs.keys(), ['blue', 'red', 'green'])) plt.rcParams.update({'font.size': 13})