Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests #24

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions iatidata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import defaultdict
from io import StringIO
from textwrap import dedent
from typing import Any, Iterator

import iatikit
import requests
Expand Down Expand Up @@ -212,7 +213,7 @@ def save_converted_xml_to_csv(dataset_etree, csv_file, prefix=None, filename=Non
)
)

schama_dict = get_sorted_schema_dict()
schema_dict = get_sorted_schema_dict()

for activity in dataset_etree.findall("iati-activity"):
version = dataset_etree.get("version", "1.01")
Expand All @@ -223,7 +224,7 @@ def save_converted_xml_to_csv(dataset_etree, csv_file, prefix=None, filename=Non
if version.startswith("1"):
activities = transform(activities).getroot()

sort_iati_element(activities.getchildren()[0], schama_dict)
sort_iati_element(activities.getchildren()[0], schema_dict)

activity, error = xmlschema.to_dict(
activities, schema=schema, validation="lax", decimal_type=float
Expand Down Expand Up @@ -299,9 +300,10 @@ def save_all(parts=5, sample=None, refresh=False):
if sample and num > sample:
break

print("Loading registry data into database")
with concurrent.futures.ProcessPoolExecutor() as executor:
for job in executor.map(save_part, buckets.items()):
print("DONE {job}")
print(f"DONE {job}")
continue


Expand Down Expand Up @@ -368,7 +370,9 @@ def flatten_object(obj, current_path="", no_index_path=tuple()):


@functools.lru_cache(1000)
def path_info(full_path, no_index_path):
def path_info(
full_path: tuple[str | int, ...], no_index_path: tuple[str, ...]
) -> tuple[str, list[str], list[str], str, tuple[dict[str, str], ...]]:
all_paths = []
for num, part in enumerate(full_path):
if isinstance(part, int):
Expand All @@ -390,7 +394,9 @@ def path_info(full_path, no_index_path):
return object_key, parent_keys_list, parent_keys_no_index, object_type, parent_keys


def traverse_object(obj, emit_object, full_path=tuple(), no_index_path=tuple()):
def traverse_object(
obj: dict[str, Any], emit_object: bool, full_path=tuple(), no_index_path=tuple()
) -> Iterator[Any]:
for original_key, value in list(obj.items()):
key = original_key.replace("-", "")

Expand Down Expand Up @@ -446,7 +452,7 @@ def create_rows(result):
# get activity dates before traversal remove them
activity_dates = result.activity.get("activity-date", []) or []

for object, full_path, no_index_path in traverse_object(result.activity, 1):
for object, full_path, no_index_path in traverse_object(result.activity, True):
(
object_key,
parent_keys_list,
Expand Down Expand Up @@ -533,7 +539,7 @@ def activity_objects():


def schema_analysis():
print("doing schema analysis")
print("Creating tables '_fields' and '_tables'")
create_table(
"_object_type_aggregate",
f"""SELECT
Expand Down
160 changes: 159 additions & 1 deletion iatidata/tests/test_iatidata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from lxml import etree

from iatidata import sort_iati_element
from iatidata import path_info, sort_iati_element, traverse_object


def test_sort_iati_element():
Expand Down Expand Up @@ -43,3 +43,161 @@ def test_sort_iati_element():
b"</iati-activity>"
)
assert etree.tostring(element) == expected_xml


def test_traverse_object_strings():
activity_object = {
"iati-identifier": "AA-BBB-000000000-CCCC",
}
result = list(traverse_object(activity_object, True))
expected_result = [
(
{
"iatiidentifier": "AA-BBB-000000000-CCCC",
},
(),
(),
),
]
assert result == expected_result


def test_traverse_object_dicts():
activity_object = {
"activity-status": {"@code": "2"},
}
result = list(traverse_object(activity_object, True))
expected_result = [
(
{
"activitystatus": {"@code": "2"},
},
(),
(),
),
]
assert result == expected_result


def test_traverse_object_narratives_plain():
activity_object = {
"title": {"narrative": ["Title narrative"]},
}
result = list(traverse_object(activity_object, True))
expected_result = [
(
{
"title": {"narrative": "Title narrative"},
},
(),
(),
),
]
assert result == expected_result


def test_traverse_object_narratives_with_language():
activity_object = {
"title": {
"narrative": [
{
"$": "Title narrative",
"@{http://www.w3.org/XML/1998/namespace}lang": "en",
}
]
},
}
result = list(traverse_object(activity_object, True))
expected_result = [
(
{
"title": {"narrative": "EN: Title narrative"},
},
(),
(),
),
]
assert result == expected_result


def test_traverse_object_lists_of_dicts():
activity_object = {
"transaction": [
{
"value": {
"$": 2000000.0,
"@currency": "GBP",
"@value-date": "2024-01-30",
},
"description": {"narrative": ["Transaction 0 description narrative"]},
},
{
"value": {
"$": 600000.0,
"@currency": "USD",
"@value-date": "2024-01-31",
},
"description": {"narrative": ["Transaction 1 description narrative"]},
},
]
}
result = list(traverse_object(activity_object, True))
expected_result = [
(
{
"value": {
"$": 2000000.0,
"@currency": "GBP",
"@value-date": "2024-01-30",
},
"description": {"narrative": "Transaction 0 description narrative"},
},
("transaction", 0),
("transaction",),
),
(
{
"value": {
"$": 600000.0,
"@currency": "USD",
"@value-date": "2024-01-31",
},
"description": {"narrative": "Transaction 1 description narrative"},
},
("transaction", 1),
("transaction",),
),
]
assert result == expected_result


def test_path_info():
full_path = ("result", 12, "indicator", 3, "period", 0, "actual", 0)
no_index_path = ("result", "indicator", "period", "actual")
(
object_key,
parent_keys_list,
parent_keys_no_index,
object_type,
parent_keys,
) = path_info(full_path, no_index_path)

assert object_key == "result.12.indicator.3.period.0.actual.0"
assert parent_keys_list == [
"result.12",
"result.12.indicator.3",
"result.12.indicator.3.period.0",
]
assert parent_keys_no_index == [
"result",
"result_indicator",
"result_indicator_period",
]
assert object_type == "result_indicator_period_actual"
assert parent_keys == (
{
"result": "result.12",
"result_indicator": "result.12.indicator.3",
"result_indicator_period": "result.12.indicator.3.period.0",
},
)