Skip to content

Commit

Permalink
Merge pull request #44 from seareport/fix-3D
Browse files Browse the repository at this point in the history
fix: fix bug for 3D data inversion
  • Loading branch information
tomsail authored Oct 16, 2024
2 parents efc85f0 + eca55c9 commit 6abc326
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xarray-selafin"
version = "0.1.8"
version = "0.1.9"
description = ""
authors = ["tomsail <[email protected]>", "lucduron <[email protected]>"]
readme = "README.md"
Expand All @@ -26,3 +26,7 @@ xarray = {version = "*", extras = ["io"]}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.black]
target-version = ['py37']
line-length = 107
27 changes: 13 additions & 14 deletions xarray_selafin/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def write_serafin(fout, ds):
first_date_str = first_time.values.astype(str) # "1900-01-01T00:00:00.000000000"
first_date_str = first_date_str.rstrip("0") + "0" # "1900-01-01T00:00:00.0"
try:
date = datetime.strptime(first_date_str, '%Y-%m-%dT%H:%M:%S.%f')
date = datetime.strptime(first_date_str, "%Y-%m-%dT%H:%M:%S.%f")
slf_header.date = attrgetter("year", "month", "day", "hour", "minute", "second")(date)
except ValueError:
slf_header.date = DEFAULT_DATE_START
Expand Down Expand Up @@ -133,7 +133,7 @@ def write_serafin(fout, ds):
if "plan" in ds.dims: # 3D
slf_header.nb_planes = len(ds.plan)
slf_header.is_2d = False
shape = (slf_header.nb_var, slf_header.nb_nodes_2d, slf_header.nb_planes)
shape = (slf_header.nb_var, slf_header.nb_planes, slf_header.nb_nodes_2d)
else: # 2D (converted if required)
# if ds.attrs["type"] == "3D":
# slf_header.is_2d = False # to enable conversion from 3D
Expand Down Expand Up @@ -165,6 +165,8 @@ def write_serafin(fout, ds):
temp[iv] = ds[var]
else:
temp[iv] = ds.isel(time=it)[var]
if slf_header.nb_planes > 1:
temp[iv] = np.reshape(np.ravel(temp[iv]), (slf_header.nb_planes, slf_header.nb_nodes_2d))
resout.write_entire_frame(
slf_header,
t_,
Expand Down Expand Up @@ -233,15 +235,14 @@ def _raw_indexing_method(self, key):
for it, t in enumerate(time_indices):
temp = self.slf_reader.read_var_in_frame(t, self.var) # shape = (nb_nodes,)
temp = np.reshape(temp, self.shape[1:]) # shape = (nb_nodes_2d, nb_planes)
if node_key == slice(None) and plan_key == slice(
None
): # speedup if not selection
if node_key == slice(None) and plan_key == slice(None): # speedup if not selection
data[it] = temp
else:
if plan_key is None:
data[it] = temp[node_indices]
else:
data[it] = temp[node_indices][:, plan_indices]
values = temp[node_indices][:, plan_indices]
data[it] = np.reshape(values, (len(plan_indices), len(node_indices))).T

# Remove dimension if key was an integer
if isinstance(node_key, int):
Expand Down Expand Up @@ -290,8 +291,8 @@ def open_dataset(
shape = (len(times), npoin2)
dims = ["time", "node"]
else:
shape = (len(times), npoin2, nplan)
dims = ["time", "node", "plan"]
shape = (len(times), nplan, npoin2)
dims = ["time", "plan", "node"]

for var in vars:
if lazy_loading:
Expand All @@ -305,14 +306,15 @@ def open_dataset(
if is_2d:
data[time_index, :] = values
else:
data[time_index, :, :] = values.reshape(npoin2, nplan)
data[time_index, :, :] = np.reshape(values, (nplan, npoin2))
data_vars[var] = xr.Variable(dims=dims, data=data)

coords = {
"x": ("node", x[:npoin2]),
"y": ("node", y[:npoin2]),
"time": times,
# Consider how to include IPOBO (with node and plan dimensions?) if it's essential for your analysis
# Consider how to include IPOBO (with node and plan dimensions?)
# if it's essential for your analysis
}

ds = xr.Dataset(data_vars=data_vars, coords=coords)
Expand All @@ -327,10 +329,7 @@ def open_dataset(
if not is_2d:
ds.attrs["ikle3"] = np.reshape(slf.header.ikle, (slf.header.nb_elements, ndp3))
ds.attrs["variables"] = {
var_ID: (
name.decode(Serafin.SLF_EIT).rstrip(),
unit.decode(Serafin.SLF_EIT).rstrip()
)
var_ID: (name.decode(Serafin.SLF_EIT).rstrip(), unit.decode(Serafin.SLF_EIT).rstrip())
for var_ID, name, unit in slf.header.iter_on_all_variables()
}
ds.attrs["date_start"] = slf.header.date
Expand Down

0 comments on commit 6abc326

Please sign in to comment.