Skip to content

Commit

Permalink
Merge pull request #73 from dataforgoodfr/rs/last-one
Browse files Browse the repository at this point in the history
Many fixes when trying to use Llama parse on multi pages tables
  • Loading branch information
RonanMorgan authored Apr 22, 2024
2 parents 252053c + 76ba652 commit e3daf98
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 76 deletions.
2 changes: 1 addition & 1 deletion app/pages/0_Import_File.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
modelfile: random_forest_model_low_false_positive.joblib
table_extraction:
- type: LLamaParse
- type: LlamaParse
- type: Unstructured
params:
hi_res_model_name: "yolox"
Expand Down
36 changes: 18 additions & 18 deletions app/pages/2_Merge_Tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@


def merge_table(table_extractor: str) -> None:
for asset in st.session_state["assets"]["table_extractors"]:
if asset["type"] == table_extractor:
first_df_columns = asset["tables"][0].columns
first_df_columns = pd.Series([])
table_list = []
for key, table in st.session_state["tables"].items():
if table_extractor in key:
if first_df_columns.empty:
first_df_columns = table.columns
# Replace column names for all DataFrames in the list
for df in asset["tables"]:
df.columns = first_df_columns
table.columns = first_df_columns
table_list.append(table)

st.session_state["new_tables"] = pd.concat(
asset["tables"], ignore_index=True, sort=False
)
st.session_state["new_tables"] = pd.concat(
table_list, ignore_index=True, sort=False
)


def save_merge(table_extractor: str) -> None:
Expand Down Expand Up @@ -83,16 +86,13 @@ def save_merge(table_extractor: str) -> None:
)

if table_extractor is not None:
for asset in st.session_state["assets"]["table_extractors"]:
i = 0
if asset["type"] == table_extractor:
for table in asset["tables"]:
st.markdown("Table shape :" + str(table.shape))
st.markdown("Table index : _" + str(i))
i += 1
st.dataframe(
table,
)
for key, table in st.session_state["tables"].items():
if table_extractor in key:
st.markdown("Table shape :" + str(table.shape))
st.markdown("Table name : " + key)
st.dataframe(
table,
)

with col2:
st.markdown(
Expand Down
4 changes: 1 addition & 3 deletions app/pages/3_Clean_Headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def set_headers(algorithm_name: str) -> None:

st.markdown("# Current extraction")
st.markdown("The extracted table is displaye below")
df = st.data_editor(
st.dataframe(
st.session_state.tables[st.session_state["algorithm_name"]],
num_rows="dynamic",
width=900,
height=900,
disabled=True,
)
58 changes: 30 additions & 28 deletions app/pages/4_Clean_Tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import streamlit as st
from utils import set_algorithm_name, get_pdf_iframe, to_csv_file, update_df_csv_to_save
from utils import set_algorithm_name, get_pdf_iframe, to_csv_file
from menu import display_pages_menu
from country_by_country.utils.constants import JURIDICTIONS
from Levenshtein import distance
Expand Down Expand Up @@ -55,7 +55,7 @@ def convert_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
return dataframe


special_characters = "#&()[]@"
special_characters = "#&()[]@©"


def style_symbol(v, props=""):
Expand All @@ -82,6 +82,8 @@ def update_min(string, min_distance, most_similar, input_string=input_string):
else:
return min_distance, most_similar

if input_string == None:
return "None"
min_distance = float("inf")
most_similar = None
for string in JURIDICTIONS.keys():
Expand All @@ -98,6 +100,22 @@ def update_min(string, min_distance, most_similar, input_string=input_string):
return most_similar


def validate(data: pd.DataFrame) -> None:
st.session_state.tables[st.session_state["algorithm_name"]] = data


def update_df_csv_to_save() -> None:
for idx, change in st.session_state.changes["edited_rows"].items():
for label, value in change.items():
st.session_state.tables[st.session_state["algorithm_name"]].loc[
idx, label
] = value

st.session_state["df_csv_to_save"] = to_csv_file(
st.session_state.tables[st.session_state["algorithm_name"]],
)


st.set_page_config(layout="wide", page_title="Tables customization") # page_icon="📈"
st.title("Country by Country Tax Reporting analysis : Tables")
st.subheader(
Expand All @@ -110,10 +128,6 @@ def update_min(string, min_distance, most_similar, input_string=input_string):
and "pdf_after_page_validation" in st.session_state
):

st.session_state.tables[st.session_state["algorithm_name"]] = convert_dataframe(
st.session_state.tables[st.session_state["algorithm_name"]]
)

col3, col4 = st.columns(2)
with col3:
st.markdown(
Expand Down Expand Up @@ -156,10 +170,11 @@ def update_min(string, min_distance, most_similar, input_string=input_string):
),
)

st.session_state.tables[st.session_state["algorithm_name"]] = st.data_editor(
st.data_editor(
st.session_state.tables[st.session_state["algorithm_name"]],
num_rows="dynamic",
on_change=update_df_csv_to_save,
key="changes",
width=800,
height=900,
)
Expand All @@ -184,6 +199,11 @@ def update_min(string, min_distance, most_similar, input_string=input_string):

dataframe = st.session_state.tables[st.session_state["algorithm_name"]].copy()

if country:
dataframe.iloc[:-2, 0] = dataframe.iloc[:-2, 0].apply(
lambda x: most_similar_string(x)
)

if remove_symbols:
pattern = "\(.*?\)" + "|[" + re.escape(special_characters) + "]"
for column in dataframe.columns:
Expand All @@ -198,11 +218,6 @@ def update_min(string, min_distance, most_similar, input_string=input_string):
new_row.iloc[0] = "Total Calculated"
dataframe.loc[-1] = new_row.transpose()

if country:
dataframe.iloc[:-2, 0] = dataframe.iloc[:-2, 0].apply(
lambda x: most_similar_string(x)
)

dataframe_styler = dataframe.style

if total:
Expand Down Expand Up @@ -241,21 +256,8 @@ def update_min(string, min_distance, most_similar, input_string=input_string):

st.dataframe(dataframe_styler, use_container_width=True, height=1000)

validated = st.button(
st.button(
"Save the table above",
on_click=validate,
args=(dataframe_styler.data,),
)
if validated:
st.session_state.tables[
st.session_state["algorithm_name"]
] = dataframe_styler.data
# This does not work
# Update the csv file to download as well
# print("clicked")
# st.session_state["df_csv_to_save"] = to_csv_file(
# st.session_state.tables[st.session_state["algorithm_name"]]
# )
# We rather rerun , which reloads the page and updates the data
# to be downloaded
# Otherwise, if you click the download button, you get the previous data
# the first time and then the right data on the second click
st.rerun()
6 changes: 0 additions & 6 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,3 @@ def set_algorithm_name(my_key: str) -> None:
@st.cache_data
def to_csv_file(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")


def update_df_csv_to_save() -> None:
st.session_state["df_csv_to_save"] = to_csv_file(
st.session_state.tables[st.session_state["algorithm_name"]],
)
7 changes: 7 additions & 0 deletions country_by_country/table_extraction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
# SOFTWARE.

# Local imports
import logging
import sys

from .camelot_extractor import Camelot
from .from_csv import FromCSV
from .llama_parse_extractor import LlamaParseExtractor
from .unstructured import Unstructured
from .unstructured_api import UnstructuredAPI

logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")


def from_config(config: dict) -> Camelot:
extractor_type = config["type"]
Expand All @@ -52,3 +57,5 @@ def from_config(config: dict) -> Camelot:
from .extract_table_api import ExtractTableAPI

return ExtractTableAPI(**extractor_params)
else:
logging.info(f"There are no extractors of the type : {extractor_type}")
10 changes: 10 additions & 0 deletions country_by_country/table_extraction/llama_parse_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __call__(self, pdf_filepath: str) -> dict:
for page in json_objs[0]["pages"]:
for item in page["items"]:
if item["type"] == "table":
# If the number of columns in the header row is greater than the data rows
header_length = len(item["rows"][0])

for i in range(1, len(item["rows"])):
while len(item["rows"][i]) < header_length:
item["rows"][i].append("No Extract ")
while len(item["rows"][i]) > header_length:
item["rows"][0].append("No Extract ")
header_length = len(item["rows"][0])

df = pd.DataFrame(item["rows"][1:], columns=item["rows"][0])
tables_list.append(df)

Expand Down
33 changes: 13 additions & 20 deletions country_by_country/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,21 @@ def gather_tables(
tables_by_name = {}
for asset in assets["table_extractors"]:
tables = asset["tables"]
if len(tables) == 1:
for column in tables[0].columns:
for i in range(len(tables)):
for label, _content in tables[i].items():
if isinstance(tables[i][label], pd.DataFrame):
tables[i].columns = [
"No Extract " + str(i + 1) for i in range(tables[i].shape[1])
]
break
for label, content in tables[i].items():
if (
tables[0][column].dtype == "object"
content.dtype == "object"
): # Check if the column contains string data
tables[0][column] = tables[0][column].replace("", None)
tables[0][column] = tables[0][column].str.replace(
",",
".",
) # else we wont be able to convert to float
tables[0][column] = tables[0][column].str.replace(".", "")
tables_by_name[asset["type"]] = tables[0]
elif len(tables) > 1:
for i in range(len(tables)):
for column in tables[i].columns:
if (
tables[i][column].dtype == "object"
): # Check if the column contains string data
tables[i][column] = tables[i][column].replace("", None)
tables[i][column] = tables[i][column].str.replace(",", ".")
tables[i][column] = tables[i][column].str.replace(".", "")
tables_by_name[asset["type"] + "_" + str(i)] = tables[i]
tables[i][label] = tables[i][label].replace("", None)
tables[i][label] = tables[i][label].str.replace(".", "")
tables[i][label] = tables[i][label].str.replace(",", ".")
tables_by_name[asset["type"] + "_" + str(i)] = tables[i]

return tables_by_name

Expand Down

0 comments on commit e3daf98

Please sign in to comment.