Adapting the MOSAIKS example notebook by selecting different bands #26
-
We are trying to adapt the mosaiks example notebook by selecting different bands than 'visual' (which has bands B04, B03, B02). Lets say that we want to use bands B04, B07, and B08. How would we adapt the code in such a way that the images would be stacked for the featurization process? The root of the problem lies in the fact that we would like to switch from Sentinel to Landsat 8 which only has individual bands on the data catalog. We watched the PyData Global talk which showed how to use It appears that the Below is a boiled down version of the example notebook for reference. !pip install -q git+https://github.com/geopandas/dask-geopandas
import warnings
import time
import os
RASTERIO_BEST_PRACTICES = dict( # See https://github.com/pangeo-data/cog-best-practices
CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
AWS_NO_SIGN_REQUEST="YES",
GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
GDAL_SWATH_SIZE="200000000",
VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import rasterio
import rasterio.warp
import rasterio.mask
import shapely.geometry
import geopandas
import dask_geopandas
from dask.distributed import Client
import pystac_client
import planetary_computer as pc
def featurize(input_img, model, device):
"""Helper method for running an image patch through the model.
Args:
input_img (np.ndarray): Image in (C x H x W) format with a dtype of uint8.
model (torch.nn.Module): Feature extractor network
"""
assert len(input_img.shape) == 3
input_img = torch.from_numpy(input_img / 255.0).float()
input_img = input_img.to(device)
with torch.no_grad():
feats = model(input_img.unsqueeze(0)).cpu().numpy()
return feats
class RCF(nn.Module):
"""A model for extracting Random Convolution Features (RCF) from input imagery."""
def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):
super(RCF, self).__init__()
# We create `num_features / 2` filters so require `num_features` to be divisible by 2
assert num_features % 2 == 0
self.conv1 = nn.Conv2d(
num_input_channels,
num_features // 2,
kernel_size=kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
)
nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
nn.init.constant_(self.conv1.bias, -1.0)
def forward(self, x):
x1a = F.relu(self.conv1(x), inplace=True)
x1b = F.relu(-self.conv1(x), inplace=True)
x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
if len(x1a.shape) == 1: # case where we passed a single input
return torch.cat((x1a, x1b), dim=0)
elif len(x1a.shape) == 2: # case where we passed a batch of > 1 inputs
return torch.cat((x1a, x1b), dim=1)
num_features = 1024
device = torch.device("cuda")
model = RCF(num_features).eval().to(device)
df = pd.read_csv(
"https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download", # noqa: E501
index_col=0,
na_values=[-999],
).dropna()
points = df[["lon", "lat"]]
population = df["population"]
gdf = geopandas.GeoDataFrame(df, geometry=geopandas.points_from_xy(df.lon, df.lat))
gdf
population_log = np.log10(population + 1)
NPARTITIONS = 250
ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)
hd = ddf.hilbert_distance().compute()
gdf["hd"] = hd
gdf = gdf.sort_values("hd")
dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)
def query(points):
"""
Find a STAC item for points in the `points` DataFrame
Parameters
----------
points : geopandas.GeoDataFrame
A GeoDataFrame
Returns
-------
geopandas.GeoDataFrame
A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC
item that covers each point.
"""
intersects = shapely.geometry.mapping(points.unary_union.convex_hull)
search_start = "2018-01-01"
search_end = "2019-12-31"
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1"
)
# The time frame in which we search for non-cloudy imagery
search = catalog.search(
collections=["sentinel-2-l2a"],
intersects=intersects,
datetime=[search_start, search_end],
query={"eo:cloud_cover": {"lt": 10}},
limit=500,
)
ic = search.get_all_items_as_dict()
features = ic["features"]
features_d = {item["id"]: item for item in features}
data = {
"eo:cloud_cover": [],
"geometry": [],
}
index = []
for item in features:
data["eo:cloud_cover"].append(item["properties"]["eo:cloud_cover"])
data["geometry"].append(shapely.geometry.shape(item["geometry"]))
index.append(item["id"])
items = geopandas.GeoDataFrame(data, index=index, geometry="geometry").sort_values(
"eo:cloud_cover"
)
point_list = points.geometry.tolist()
point_items = []
for point in point_list:
covered_by = items[items.covers(point)]
if len(covered_by):
point_items.append(features_d[covered_by.index[0]])
else:
# There weren't any scenes matching our conditions for this point (too cloudy)
point_items.append(None)
return points.assign(stac_item=point_items)
with Client(n_workers=16) as client:
print(client.dashboard_link)
meta = dgdf._meta.assign(stac_item=[])
df2 = dgdf.map_partitions(query, meta=meta).compute()
df3 = df2.dropna(subset=["stac_item"])
matching_urls = [
pc.sign(item["assets"]["visual"]["href"]) for item in df3.stac_item.tolist()
]
points = df3[["lon", "lat"]].to_numpy()
population_log = np.log10(df3["population"].to_numpy() + 1)
class CustomDataset(Dataset):
def __init__(self, points, fns, buffer=500):
self.points = points
self.fns = fns
self.buffer = buffer
def __len__(self):
return self.points.shape[0]
def __getitem__(self, idx):
lon, lat = self.points[idx]
fn = self.fns[idx]
if fn is None:
return None
else:
point_geom = shapely.geometry.mapping(shapely.geometry.Point(lon, lat))
with rasterio.Env():
with rasterio.open(fn, "r") as f:
point_geom = rasterio.warp.transform_geom(
"epsg:4326", f.crs.to_string(), point_geom
)
point_shape = shapely.geometry.shape(point_geom)
mask_shape = point_shape.buffer(self.buffer).envelope
mask_geom = shapely.geometry.mapping(mask_shape)
try:
out_image, out_transform = rasterio.mask.mask(
f, [mask_geom], crop=True
)
except ValueError as e:
if "Input shapes do not overlap raster." in str(e):
return None
out_image = out_image / 255.0
out_image = torch.from_numpy(out_image).float()
return out_image
dataset = CustomDataset(points, matching_urls)
dataloader = DataLoader(
dataset,
batch_size=8,
shuffle=False,
num_workers=os.cpu_count() * 2,
collate_fn=lambda x: x,
pin_memory=False,
)
x_all = np.zeros((points.shape[0], num_features), dtype=float)
tic = time.time()
i = 0
for images in dataloader:
for image in images:
if image is not None:
# A full image should be ~101x101 pixels (i.e. ~1km^2 at a 10m/px spatial
# resolution), however we can receive smaller images if an input point
# happens to be at the edge of a S2 scene (a literal edge case). To deal
# with these (edge) cases we crudely drop all images where the spatial
# dimensions aren't both greater than 20 pixels.
if image.shape[1] >= 20 and image.shape[2] >= 20:
image = image.to(device)
with torch.no_grad():
feats = model(image.unsqueeze(0)).cpu().numpy()
x_all[i] = feats
else:
# this happens if the point is close to the edge of a scene
# (one or both of the spatial dimensions of the image are very small)
pass
else:
pass # this happens if we do not find a S2 scene for some point
if i % 1000 == 0:
print(
f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
+ f" -- {time.time()-tic:0.2f} seconds"
)
tic = time.time()
i += 1 Would it be possible to define a new band in the Or would we want to search individual bands for matching urls and then use Any help, hints, or guidance would be greatly appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
Hey @cullen-molitor, Here's an example of reading different band sets with stackstac -- https://gist.github.com/calebrob6/09b6a0210d63f96e80317a5796623d9b. To implement this in the above code I think you could replace: |
Beta Was this translation helpful? Give feedback.
Hey @cullen-molitor,
Here's an example of reading different band sets with stackstac -- https://gist.github.com/calebrob6/09b6a0210d63f96e80317a5796623d9b. To implement this in the above code I think you could replace:
matching_urls = [pc.sign(item["assets"]["visual"]["href"]) for item in df3.stac_item.tolist()]
with something that signed the whole item, then use stackstac instead of rasterio.open in the CustomDataset.