Skip to content

Commit

Permalink
Doc sand types (#304)
Browse files Browse the repository at this point in the history
* Clean docstrings and add types
  • Loading branch information
benjello authored Oct 21, 2024
1 parent ddf3bad commit 2c1a336
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 36 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

### 2.2.3 [#304](https://github.com/openfisca/openfisca-survey-manager/pull/304)

* Minor change.
- Add docstrings and types to some tests.

### 2.2.2 [#297](https://github.com/openfisca/openfisca-survey-manager/pull/297)

* Minor change.
Expand Down
9 changes: 4 additions & 5 deletions openfisca_survey_manager/surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def find_tables(self, variable, tables = None, rename_ident = True):
return container_tables

def get_columns(self, table, rename_ident = True):
"""
Get columns of a table.
"""
"""Get columns of a table."""
assert table is not None
if self.hdf5_file_path is not None:
store = pandas.HDFStore(self.hdf5_file_path, "r")
Expand Down Expand Up @@ -352,8 +350,8 @@ def get_values(self, variables = None, table = None, lowercase = False, ignoreca
return df

def insert_table(self, label = None, name = None, **kwargs):
"""
Inserts a table in the Survey object
"""Insert a table in the Survey object.
If a pandas dataframe is provided, it is saved in the store file
"""
parquet_file = kwargs.pop('parquet_file', None)
Expand Down Expand Up @@ -387,6 +385,7 @@ def insert_table(self, label = None, name = None, **kwargs):
self.tables[name][key] = val

def to_json(self):
"""Convert the survey to a JSON object."""
self_json = collections.OrderedDict((
))
self_json['hdf5_file_path'] = self.hdf5_file_path
Expand Down
3 changes: 3 additions & 0 deletions openfisca_survey_manager/tests/test_matching.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Tests for the matching functionality in OpenFisca Survey Manager."""

import pandas as pd

Expand All @@ -13,6 +14,7 @@


def test_reproduction():
"""Test the reproduction of examples from the StatMatch documentation."""
if rpy2 is None:
return

Expand Down Expand Up @@ -88,6 +90,7 @@ def test_reproduction():


def test_nnd_hotdeck_using_rpy2():
"""Test the nnd_hotdeck_using_rpy2 function with iris dataset."""
if rpy2 is None:
print('rpy2 is absent: skipping test') # noqa analysis:ignore
return
Expand Down
16 changes: 9 additions & 7 deletions openfisca_survey_manager/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

@pytest.mark.order(after="test_write_parquet.py::TestWriteParquet::test_write_parquet_one_file_per_entity")
class TestParquet(TestCase):
"""Tests for Parquet file operations."""

data_dir = os.path.join(
openfisca_survey_manager_location,
"openfisca_survey_manager",
Expand All @@ -32,6 +34,7 @@ class TestParquet(TestCase):
survey_name = "test_parquet_survey"

def test_add_survey_to_collection_parquet(self):
"""Test adding a parquet survey to a collection."""
survey_collection = SurveyCollection(name=self.collection_name)
survey_file_path = os.path.join(self.data_dir, self.collection_name)
add_survey_to_collection(
Expand All @@ -43,6 +46,7 @@ def test_add_survey_to_collection_parquet(self):
assert self.survey_name in list(ordered_dict["surveys"].keys())

def test_build_collection(self):
"""Test building a survey collection from parquet files."""
collection_name = self.collection_name
json_file = os.path.join(
self.data_dir,
Expand Down Expand Up @@ -73,9 +77,7 @@ def test_build_collection(self):

@pytest.mark.order(after="test_build_collection")
def test_load_single_parquet_monolithic(self):
"""
Test the loading all the data from parquet files in memory.
"""
"""Test loading all the data from parquet files in memory."""
# Create survey scenario
survey_scenario = AbstractSurveyScenario()
survey_name = self.collection_name + "_2020"
Expand Down Expand Up @@ -126,9 +128,7 @@ def test_load_single_parquet_monolithic(self):
assert (results['income_tax'] == [195.00001525878906, 3.0, 510.0000305175781, 600.0, 750.0])

def test_load_multiple_parquet_monolithic(self):
"""
Test the loading all the data from parquet files in memory.
"""
"""Test loading all data from parquet files in memory."""
collection_name = 'test_multiple_parquet_collection'
data_dir = os.path.join(self.data_dir, collection_name)
# Create survey scenario
Expand Down Expand Up @@ -182,7 +182,9 @@ def test_load_multiple_parquet_monolithic(self):

def test_load_parquet_batch(self):
"""
Test the batch loading of data from parquet files. This allow loading larger than memory datasets.
Test the batch loading of data from parquet files.
This allow loading larger than memory datasets.
"""
df = pd.read_parquet(
os.path.join(self.data_dir, self.collection_name, "household.parquet")
Expand Down
174 changes: 151 additions & 23 deletions openfisca_survey_manager/tests/test_scenario.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for the survey scenario functionality in OpenFisca Survey Manager."""

import shutil
import logging
import os
import pytest
from typing import Dict, Any, List, Optional, Callable


from openfisca_core import periods
Expand All @@ -22,9 +25,30 @@
log = logging.getLogger(__name__)


def create_randomly_initialized_survey_scenario(nb_persons = 10, nb_groups = 5, salary_max_value = 50000,
rent_max_value = 1000, collection = "test_random_generator", use_marginal_tax_rate = False, reform = None):

def create_randomly_initialized_survey_scenario(
nb_persons: int = 10,
nb_groups: int = 5,
salary_max_value: float = 50000,
rent_max_value: float = 1000,
collection: Optional[str] = "test_random_generator",
use_marginal_tax_rate: bool = False,
reform: Optional[Callable] = None
) -> AbstractSurveyScenario:
"""
Create a randomly initialized survey scenario.
Args:
nb_persons (int): Number of persons
nb_groups (int): Number of groups
salary_max_value (float): Maximum salary value
rent_max_value (float): Maximum rent value
collection (Optional[str]): Collection name
use_marginal_tax_rate (bool): Use marginal tax rate
reform (Optional[Callable]): Reform function
Returns:
AbstractSurveyScenario: Initialized survey scenario
"""
if collection is not None:
return create_randomly_initialized_survey_scenario_from_table(
nb_persons, nb_groups, salary_max_value, rent_max_value, collection, use_marginal_tax_rate, reform = reform)
Expand All @@ -33,7 +57,30 @@ def create_randomly_initialized_survey_scenario(nb_persons = 10, nb_groups = 5,
nb_persons, nb_groups, salary_max_value, rent_max_value, use_marginal_tax_rate, reform = reform)


def create_randomly_initialized_survey_scenario_from_table(nb_persons, nb_groups, salary_max_value, rent_max_value, collection, use_marginal_tax_rate, reform = None):
def create_randomly_initialized_survey_scenario_from_table(
nb_persons: int,
nb_groups: int,
salary_max_value: float,
rent_max_value: float,
collection: str,
use_marginal_tax_rate: bool,
reform: Optional[Callable] = None
) -> AbstractSurveyScenario:
"""
Create a randomly initialized survey scenario from a table.
Args:
nb_persons (int): Number of persons
nb_groups (int): Number of groups
salary_max_value (float): Maximum salary value
rent_max_value (float): Maximum rent value
collection (str): Collection name
use_marginal_tax_rate (bool): Use marginal tax rate
reform (Optional[Callable]): Reform function
Returns:
AbstractSurveyScenario: Initialized survey scenario
"""
variable_generators_by_period = {
periods.period('2017-01'): [
{
Expand Down Expand Up @@ -87,7 +134,28 @@ def create_randomly_initialized_survey_scenario_from_table(nb_persons, nb_groups
return survey_scenario


def create_randomly_initialized_survey_scenario_from_data_frame(nb_persons, nb_groups, salary_max_value, rent_max_value, use_marginal_tax_rate = False, reform = None):
def create_randomly_initialized_survey_scenario_from_data_frame(
nb_persons: int,
nb_groups: int,
salary_max_value: float,
rent_max_value: float,
use_marginal_tax_rate: bool = False,
reform: Optional[Callable] = None
) -> AbstractSurveyScenario:
"""
Create a randomly initialized survey scenario from a data frame.
Args:
nb_persons (int): Number of persons
nb_groups (int): Number of groups
salary_max_value (float): Maximum salary value
rent_max_value (float): Maximum rent value
use_marginal_tax_rate (bool): Use marginal tax rate
reform (Optional[Callable]): Reform function
Returns:
AbstractSurveyScenario: Initialized survey scenario
"""
input_data_frame_by_entity = generate_input_input_dataframe_by_entity(
nb_persons, nb_groups, salary_max_value, rent_max_value)
survey_scenario = AbstractSurveyScenario()
Expand Down Expand Up @@ -121,7 +189,24 @@ def create_randomly_initialized_survey_scenario_from_data_frame(nb_persons, nb_g
return survey_scenario


def generate_input_input_dataframe_by_entity(nb_persons, nb_groups, salary_max_value, rent_max_value):
def generate_input_input_dataframe_by_entity(
nb_persons: int,
nb_groups: int,
salary_max_value: float,
rent_max_value: float
) -> Dict[str, Any]:
"""
Generate input dataframe by entity with randomly initialized variables.
Args:
nb_persons (int): Number of persons
nb_groups (int): Number of groups
salary_max_value (float): Maximum salary value
rent_max_value (float): Maximum rent value
Returns:
Dict[str, Any]: Input dataframe by entity
"""
input_dataframe_by_entity = make_input_dataframe_by_entity(tax_benefit_system, nb_persons, nb_groups)
randomly_init_variable(
tax_benefit_system,
Expand All @@ -145,8 +230,13 @@ def generate_input_input_dataframe_by_entity(nb_persons, nb_groups, salary_max_v
return input_dataframe_by_entity


def test_input_dataframe_generator(nb_persons = 10, nb_groups = 5, salary_max_value = 50000,
rent_max_value = 1000):
def test_input_dataframe_generator(
nb_persons: int = 10,
nb_groups: int = 5,
salary_max_value: float = 50000,
rent_max_value: float = 1000
) -> None:
"""Test the input dataframe generator function."""
input_dataframe_by_entity = generate_input_input_dataframe_by_entity(
nb_persons, nb_groups, salary_max_value, rent_max_value)
assert (input_dataframe_by_entity['person']['household_role'] == "first_parent").sum() == 5
Expand All @@ -171,9 +261,21 @@ def test_input_dataframe_generator(nb_persons = 10, nb_groups = 5, salary_max_va
# On vérifie que l'attribut `used_as_input_variables` correspond à la liste des variables
# qui sont employées dans le calcul des simulations, les autres variables n'étant pas utilisées dans le calcul,
# étant dans la base en entrée mais pas dans la base en sortie (la base de la simulation)
def test_init_from_data(nb_persons = 10, nb_groups = 5, salary_max_value = 50000,
rent_max_value = 1000):

def test_init_from_data(
nb_persons: int = 10,
nb_groups: int = 5,
salary_max_value: float = 50000,
rent_max_value: float = 1000,
) -> None:
"""
Test the initialization of data in the survey scenario.
Args:
nb_persons: Number of persons to generate in the test data.
nb_groups: Number of household groups to generate.
salary_max_value: Maximum value for randomly generated salaries.
rent_max_value: Maximum value for randomly generated rents.
"""
# Set up test : the minimum necessary data to perform an `init_from_data`
survey_scenario = AbstractSurveyScenario()
assert survey_scenario.simulations is None
Expand Down Expand Up @@ -225,9 +327,21 @@ def test_init_from_data(nb_persons = 10, nb_groups = 5, salary_max_value = 50000
assert data_out['household']['rent'].equals(table_men['rent'])


def test_survey_scenario_input_dataframe_import(nb_persons = 10, nb_groups = 5, salary_max_value = 50000,
rent_max_value = 1000):

def test_survey_scenario_input_dataframe_import(
nb_persons: int = 10,
nb_groups: int = 5,
salary_max_value: float = 50000,
rent_max_value: float = 1000,
) -> None:
"""
Test the import of input dataframes into a survey scenario.
Args:
nb_persons: Number of persons to generate.
nb_groups: Number of household groups.
salary_max_value: Maximum salary value.
rent_max_value: Maximum rent value.
"""
input_data_frame_by_entity = generate_input_input_dataframe_by_entity(
nb_persons, nb_groups, salary_max_value, rent_max_value)
survey_scenario = AbstractSurveyScenario()
Expand All @@ -252,11 +366,21 @@ def test_survey_scenario_input_dataframe_import(nb_persons = 10, nb_groups = 5,
).all()


def test_survey_scenario_input_dataframe_import_scrambled_ids(nb_persons = 10, nb_groups = 5, salary_max_value = 50000,
rent_max_value = 1000):
'''
On teste que .init_from_data fait
'''
def test_survey_scenario_input_dataframe_import_scrambled_ids(
nb_persons: int = 10,
nb_groups: int = 5,
salary_max_value: float = 50000,
rent_max_value: float = 1000
) -> None:
"""
Test survey scenario input dataframe import with scrambled IDs.
Args:
nb_persons: Number of persons to generate.
nb_groups: Number of household groups.
salary_max_value: Maximum salary value.
rent_max_value: Maximum rent value.
"""
input_data_frame_by_entity = generate_input_input_dataframe_by_entity(
nb_persons, nb_groups, salary_max_value, rent_max_value) # Un dataframe d'exemple que l'on injecte
input_data_frame_by_entity['person']['household_id'] = 4 - input_data_frame_by_entity['person']['household_id']
Expand All @@ -282,7 +406,8 @@ def test_survey_scenario_input_dataframe_import_scrambled_ids(nb_persons = 10, n
).all()


def test_dump_survey_scenario():
def test_dump_survey_scenario() -> None:
"""Test the dump and restore functionality of survey scenarios."""
survey_scenario = create_randomly_initialized_survey_scenario()
directory = os.path.join(
openfisca_survey_manager_location,
Expand Down Expand Up @@ -314,7 +439,8 @@ def test_dump_survey_scenario():


@pytest.mark.order(before="test_add_survey_to_collection.py::test_add_survey_to_collection")
def test_inflate():
def test_inflate() -> None:
"""Test the inflate method of the survey scenario."""
survey_scenario = create_randomly_initialized_survey_scenario(collection = None)
period = "2017-01"
inflator = 2.42
Expand Down Expand Up @@ -354,7 +480,8 @@ def test_inflate():


@pytest.mark.order(before="test_add_survey_to_collection.py::test_add_survey_to_collection")
def test_compute_pivot_table():
def test_compute_pivot_table() -> None:
"""Test the compute_pivot_table method of the survey scenario."""
survey_scenario = create_randomly_initialized_survey_scenario(collection = None)
period = "2017-01"
pivot_table = survey_scenario.compute_pivot_table(columns = ['age'], values = ["salary"], period = period, simulation = "baseline")
Expand All @@ -369,7 +496,8 @@ def test_compute_pivot_table():
assert pivot_table.values.round() == 13570.


def test_compute_quantile():
def test_compute_quantile() -> List[float]:
"""Test the compute_quantiles method of the survey scenario."""
survey_scenario = create_randomly_initialized_survey_scenario()
period = "2017-01"
quintiles = survey_scenario.compute_quantiles(variable = "salary", nquantiles = 5, period = period, weighted = False, simulation = "baseline")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

setup(
name = 'OpenFisca-Survey-Manager',
version = '2.2.3',
version = '2.2.4',
author = 'OpenFisca Team',
author_email = '[email protected]',
classifiers = [classifier for classifier in classifiers.split('\n') if classifier],
Expand Down

0 comments on commit 2c1a336

Please sign in to comment.