diff --git a/iatidata/__init__.py b/iatidata/__init__.py index 4d1fdf3..b2d6677 100644 --- a/iatidata/__init__.py +++ b/iatidata/__init__.py @@ -15,6 +15,7 @@ from collections import defaultdict from io import StringIO from textwrap import dedent +from typing import Any, Iterator import iatikit import requests @@ -212,7 +213,7 @@ def save_converted_xml_to_csv(dataset_etree, csv_file, prefix=None, filename=Non ) ) - schama_dict = get_sorted_schema_dict() + schema_dict = get_sorted_schema_dict() for activity in dataset_etree.findall("iati-activity"): version = dataset_etree.get("version", "1.01") @@ -223,7 +224,7 @@ def save_converted_xml_to_csv(dataset_etree, csv_file, prefix=None, filename=Non if version.startswith("1"): activities = transform(activities).getroot() - sort_iati_element(activities.getchildren()[0], schama_dict) + sort_iati_element(activities.getchildren()[0], schema_dict) activity, error = xmlschema.to_dict( activities, schema=schema, validation="lax", decimal_type=float @@ -299,9 +300,10 @@ def save_all(parts=5, sample=None, refresh=False): if sample and num > sample: break + print("Loading registry data into database") with concurrent.futures.ProcessPoolExecutor() as executor: for job in executor.map(save_part, buckets.items()): - print("DONE {job}") + print(f"DONE {job}") continue @@ -368,7 +370,9 @@ def flatten_object(obj, current_path="", no_index_path=tuple()): @functools.lru_cache(1000) -def path_info(full_path, no_index_path): +def path_info( + full_path: tuple[str | int, ...], no_index_path: tuple[str, ...] +) -> tuple[str, list[str], list[str], str, tuple[dict[str, str], ...]]: all_paths = [] for num, part in enumerate(full_path): if isinstance(part, int): @@ -390,7 +394,9 @@ def path_info(full_path, no_index_path): return object_key, parent_keys_list, parent_keys_no_index, object_type, parent_keys -def traverse_object(obj, emit_object, full_path=tuple(), no_index_path=tuple()): +def traverse_object( + obj: dict[str, Any], emit_object: bool, full_path=tuple(), no_index_path=tuple() +) -> Iterator[Any]: for original_key, value in list(obj.items()): key = original_key.replace("-", "") @@ -446,7 +452,7 @@ def create_rows(result): # get activity dates before traversal remove them activity_dates = result.activity.get("activity-date", []) or [] - for object, full_path, no_index_path in traverse_object(result.activity, 1): + for object, full_path, no_index_path in traverse_object(result.activity, True): ( object_key, parent_keys_list, @@ -533,7 +539,7 @@ def activity_objects(): def schema_analysis(): - print("doing schema analysis") + print("Creating tables '_fields' and '_tables'") create_table( "_object_type_aggregate", f"""SELECT diff --git a/iatidata/tests/test_iatidata.py b/iatidata/tests/test_iatidata.py index 2374016..b4146a2 100644 --- a/iatidata/tests/test_iatidata.py +++ b/iatidata/tests/test_iatidata.py @@ -2,7 +2,7 @@ from lxml import etree -from iatidata import sort_iati_element +from iatidata import path_info, sort_iati_element, traverse_object def test_sort_iati_element(): @@ -43,3 +43,161 @@ def test_sort_iati_element(): b"" ) assert etree.tostring(element) == expected_xml + + +def test_traverse_object_strings(): + activity_object = { + "iati-identifier": "AA-BBB-000000000-CCCC", + } + result = list(traverse_object(activity_object, True)) + expected_result = [ + ( + { + "iatiidentifier": "AA-BBB-000000000-CCCC", + }, + (), + (), + ), + ] + assert result == expected_result + + +def test_traverse_object_dicts(): + activity_object = { + "activity-status": {"@code": "2"}, + } + result = list(traverse_object(activity_object, True)) + expected_result = [ + ( + { + "activitystatus": {"@code": "2"}, + }, + (), + (), + ), + ] + assert result == expected_result + + +def test_traverse_object_narratives_plain(): + activity_object = { + "title": {"narrative": ["Title narrative"]}, + } + result = list(traverse_object(activity_object, True)) + expected_result = [ + ( + { + "title": {"narrative": "Title narrative"}, + }, + (), + (), + ), + ] + assert result == expected_result + + +def test_traverse_object_narratives_with_language(): + activity_object = { + "title": { + "narrative": [ + { + "$": "Title narrative", + "@{http://www.w3.org/XML/1998/namespace}lang": "en", + } + ] + }, + } + result = list(traverse_object(activity_object, True)) + expected_result = [ + ( + { + "title": {"narrative": "EN: Title narrative"}, + }, + (), + (), + ), + ] + assert result == expected_result + + +def test_traverse_object_lists_of_dicts(): + activity_object = { + "transaction": [ + { + "value": { + "$": 2000000.0, + "@currency": "GBP", + "@value-date": "2024-01-30", + }, + "description": {"narrative": ["Transaction 0 description narrative"]}, + }, + { + "value": { + "$": 600000.0, + "@currency": "USD", + "@value-date": "2024-01-31", + }, + "description": {"narrative": ["Transaction 1 description narrative"]}, + }, + ] + } + result = list(traverse_object(activity_object, True)) + expected_result = [ + ( + { + "value": { + "$": 2000000.0, + "@currency": "GBP", + "@value-date": "2024-01-30", + }, + "description": {"narrative": "Transaction 0 description narrative"}, + }, + ("transaction", 0), + ("transaction",), + ), + ( + { + "value": { + "$": 600000.0, + "@currency": "USD", + "@value-date": "2024-01-31", + }, + "description": {"narrative": "Transaction 1 description narrative"}, + }, + ("transaction", 1), + ("transaction",), + ), + ] + assert result == expected_result + + +def test_path_info(): + full_path = ("result", 12, "indicator", 3, "period", 0, "actual", 0) + no_index_path = ("result", "indicator", "period", "actual") + ( + object_key, + parent_keys_list, + parent_keys_no_index, + object_type, + parent_keys, + ) = path_info(full_path, no_index_path) + + assert object_key == "result.12.indicator.3.period.0.actual.0" + assert parent_keys_list == [ + "result.12", + "result.12.indicator.3", + "result.12.indicator.3.period.0", + ] + assert parent_keys_no_index == [ + "result", + "result_indicator", + "result_indicator_period", + ] + assert object_type == "result_indicator_period_actual" + assert parent_keys == ( + { + "result": "result.12", + "result_indicator": "result.12.indicator.3", + "result_indicator_period": "result.12.indicator.3.period.0", + }, + )