Skip to content

Commit

Permalink
Better handling of optional virtual files (e.g., shading in Figure.gr…
Browse files Browse the repository at this point in the history
…dimage) (#2493)

Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 authored Aug 5, 2023
1 parent d580dff commit 109f209
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 79 deletions.
40 changes: 23 additions & 17 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Uses ctypes to wrap most of the core functions from the C API.
"""
import ctypes as ctp
import pathlib
import sys
from contextlib import contextmanager, nullcontext

Expand Down Expand Up @@ -1474,6 +1475,7 @@ def virtualfile_from_data(
z=None,
extra_arrays=None,
required_z=False,
required_data=True,
):
"""
Store any data inside a virtual file.
Expand All @@ -1484,7 +1486,7 @@ def virtualfile_from_data(
Parameters
----------
check_kind : str
check_kind : str or None
Used to validate the type of data that can be passed in. Choose
from 'raster', 'vector', or None. Default is None (no validation).
data : str or pathlib.Path or xarray.DataArray or {table-like} or None
Expand All @@ -1498,6 +1500,9 @@ def virtualfile_from_data(
All of these arrays must be of the same size as the x/y/z arrays.
required_z : bool
State whether the 'z' column is required.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Returns
-------
Expand Down Expand Up @@ -1528,21 +1533,25 @@ def virtualfile_from_data(
...
<vector memory>: N = 3 <7/9> <4/6> <1/3>
"""
kind = data_kind(data, x, y, z, required_z=required_z)

if check_kind == "raster" and kind not in ("file", "grid"):
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(data)}")
if check_kind == "vector" and kind not in (
"file",
"matrix",
"vectors",
"geojson",
):
raise GMTInvalidInput(f"Unrecognized data type for vector: {type(data)}")
kind = data_kind(
data, x, y, z, required_z=required_z, required_data=required_data
)

if check_kind:
valid_kinds = ("file", "arg") if required_data is False else ("file",)
if check_kind == "raster":
valid_kinds += ("grid",)
elif check_kind == "vector":
valid_kinds += ("matrix", "vectors", "geojson")
if kind not in valid_kinds:
raise GMTInvalidInput(
f"Unrecognized data type for {check_kind}: {type(data)}"
)

# Decide which virtualfile_from_ function to use
_virtualfile_from = {
"file": nullcontext,
"arg": nullcontext,
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
# Note: virtualfile_from_matrix is not used because a matrix can be
Expand All @@ -1553,11 +1562,8 @@ def virtualfile_from_data(
}[kind]

# Ensure the data is an iterable (Python list or tuple)
if kind in ("geojson", "grid"):
_data = (data,)
elif kind == "file":
# Useful to handle `pathlib.Path` and string file path alike
_data = (str(data),)
if kind in ("geojson", "grid", "file", "arg"):
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
elif kind == "vectors":
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
Expand Down
75 changes: 62 additions & 13 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, kind=None
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
):
"""
Check if the combination of data/x/y/z is valid.
Expand All @@ -25,6 +25,7 @@ def _validate_data_input(
>>> _validate_data_input(data="infile")
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> _validate_data_input(data=None, required_data=False)
>>> _validate_data_input()
Traceback (most recent call last):
...
Expand All @@ -41,6 +42,30 @@ def _validate_data_input(
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
Expand All @@ -61,11 +86,11 @@ def _validate_data_input(
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
raise GMTInvalidInput("No input data provided.")
if x is None or y is None: # either x or y is None
if required_data: # data is not optional
raise GMTInvalidInput("No input data provided.")
elif x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
# both x and y are not None, now check z
if required_z and z is None:
if required_z and z is None: # both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
else: # data is not None
if x is not None or y is not None or z is not None:
Expand All @@ -81,38 +106,43 @@ def _validate_data_input(
raise GMTInvalidInput("data must provide x, y, and z columns.")


def data_kind(data=None, x=None, y=None, z=None, required_z=False):
def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
"""
Check what kind of data is provided to a module.
Possible types:
* a file name provided as 'data'
* a pathlib.Path provided as 'data'
* an xarray.DataArray provided as 'data'
* a matrix provided as 'data'
* a pathlib.PurePath object provided as 'data'
* an xarray.DataArray object provided as 'data'
* a 2-D matrix provided as 'data'
* 1-D arrays x and y (and z, optionally)
* an optional argument (None, bool, int or float) provided as 'data'
Arguments should be ``None`` if not used. If doesn't fit any of these
categories (or fits more than one), will raise an exception.
Parameters
----------
data : str or pathlib.Path or xarray.DataArray or {table-like} or None
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes}.
{table-classes} or an option argument.
x/y : 1-D arrays or None
x and y columns as numpy arrays.
z : 1-D array or None
z column as numpy array. To be used optionally when x and y are given.
required_z : bool
State whether the 'z' column is required.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Returns
-------
kind : str
One of: ``'file'``, ``'grid'``, ``'matrix'``, ``'vectors'``.
One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``,
or ``'vectors'``.
Examples
--------
Expand All @@ -128,20 +158,39 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False):
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
'file'
>>> data_kind(data=None, x=None, y=None, required_data=False)
'arg'
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
'arg'
>>> data_kind(data=True, x=None, y=None, required_data=False)
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
"""
# determine the data kind
if isinstance(data, (str, pathlib.PurePath)):
kind = "file"
elif isinstance(data, (bool, int, float)) or (data is None and not required_data):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "grid"
elif hasattr(data, "__geo_interface__"):
# geo-like Python object that implements ``__geo_interface__``
# (geopandas.GeoDataFrame or shapely.geometry)
kind = "geojson"
elif data is not None:
kind = "matrix"
else:
kind = "vectors"
_validate_data_input(data=data, x=x, y=y, z=z, required_z=required_z, kind=kind)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)
return kind


Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/dimfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def dimfilter(grid, **kwargs):

with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind=None, data=grid)
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if (outgrid := kwargs.get("G")) is None:
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
Expand Down
26 changes: 7 additions & 19 deletions pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
"""
grdimage - Plot grids or images.
"""
import contextlib

from pygmt.clib import Session
from pygmt.helpers import (
build_arg_string,
data_kind,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias

__doctest_skip__ = ["grdimage"]

Expand Down Expand Up @@ -180,16 +172,12 @@ def grdimage(self, grid, **kwargs):
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with contextlib.ExitStack() as stack:
# shading using an xr.DataArray
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
shading_context = lib.virtualfile_from_data(
check_kind="raster", data=kwargs["I"]
)
kwargs["I"] = stack.enter_context(shading_context)

fname = stack.enter_context(file_context)
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as fname, lib.virtualfile_from_data(
check_kind="raster", data=kwargs.get("I"), required_data=False
) as shadegrid:
kwargs["I"] = shadegrid
lib.call_module(
module="grdimage", args=build_arg_string(kwargs, infile=fname)
)
34 changes: 7 additions & 27 deletions pygmt/src/grdview.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
"""
grdview - Create a three-dimensional plot from a grid.
"""
import contextlib

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias

__doctest_skip__ = ["grdview"]

Expand Down Expand Up @@ -155,23 +146,12 @@ def grdview(self, grid, **kwargs):
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

with contextlib.ExitStack() as stack:
if kwargs.get("G") is not None:
# deal with kwargs["G"] if drapegrid is xr.DataArray
drapegrid = kwargs["G"]
if data_kind(drapegrid) in ("file", "grid"):
if data_kind(drapegrid) == "grid":
drape_context = lib.virtualfile_from_data(
check_kind="raster", data=drapegrid
)
kwargs["G"] = stack.enter_context(drape_context)
else:
raise GMTInvalidInput(
f"Unrecognized data type for drapegrid: {type(drapegrid)}"
)
fname = stack.enter_context(file_context)
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as fname, lib.virtualfile_from_data(
check_kind="raster", data=kwargs.get("G"), required_data=False
) as drapegrid:
kwargs["G"] = drapegrid
lib.call_module(
module="grdview", args=build_arg_string(kwargs, infile=fname)
)
8 changes: 6 additions & 2 deletions pygmt/tests/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,9 @@ def test_virtualfile_from_data_required_z_matrix(array_func, kind):
)
data = array_func(dataframe)
with clib.Session() as lib:
with lib.virtualfile_from_data(data=data, required_z=True) as vfile:
with lib.virtualfile_from_data(
data=data, required_z=True, check_kind="vector"
) as vfile:
with GMTTempFile() as outfile:
lib.call_module("info", f"{vfile} ->{outfile.name}")
output = outfile.read(keep_tabs=True)
Expand All @@ -461,7 +463,9 @@ def test_virtualfile_from_data_required_z_matrix_missing():
data = np.ones((5, 2))
with clib.Session() as lib:
with pytest.raises(GMTInvalidInput):
with lib.virtualfile_from_data(data=data, required_z=True):
with lib.virtualfile_from_data(
data=data, required_z=True, check_kind="vector"
):
pass


Expand Down

0 comments on commit 109f209

Please sign in to comment.