Skip to content

Commit

Permalink
add hypodd to japan
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Sep 16, 2024
1 parent 4062b31 commit acf2fd4
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/japan/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local/
23 changes: 23 additions & 0 deletions examples/japan/plot_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
with open(f"{root_path}/{region}/config.json", "r") as f:
config = json.load(f)
print(json.dumps(config, indent=4, sort_keys=True))
if "mindepth" not in config:
config["mindepth"] = 0
if "maxdepth" not in config:
config["maxdepth"] = 30
xlim = [config["minlongitude"], config["maxlongitude"]]
ylim = [config["minlatitude"], config["maxlatitude"]]

Expand Down Expand Up @@ -935,6 +939,25 @@ def plot3d(x, y, z, config, fig_name):
config_plot3d,
f"{root_path}/{figure_path}/earthquake_location_adloc.html",
)

if adloc_dt_exist and len(adloc_dt_catalog) > 0:
plot3d(
adloc_dt_catalog["longitude"],
adloc_dt_catalog["latitude"],
adloc_dt_catalog["depth_km"],
config_plot3d,
f"{root_path}/{figure_path}/earthquake_location_adloc_dt.html",
)

if adloc_dtcc_exist and len(adloc_dtcc_catalog) > 0:
plot3d(
adloc_dtcc_catalog["longitude"],
adloc_dtcc_catalog["latitude"],
adloc_dtcc_catalog["depth_km"],
config_plot3d,
f"{root_path}/{figure_path}/earthquake_location_adloc_dtcc.html",
)

if hypodd_ct_exist and len(catalog_ct_hypodd) > 0:
plot3d(
catalog_ct_hypodd["LON"],
Expand Down
330 changes: 330 additions & 0 deletions examples/japan/run_adloc_cc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
# %%
import json
import os
import pickle
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.optim as optim
from adloc.adloc import TravelTimeDD
from adloc.data import PhaseDatasetDTCC
from adloc.eikonal2d import init_eikonal2d
from adloc.inversion import optimize_dd

# from args import parse_args
from matplotlib import pyplot as plt
from plotting import plotting_dd
from pyproj import Proj
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

torch.manual_seed(0)
np.random.seed(0)


# %%
if __name__ == "__main__":

# %%
# args = parse_args()
# root_path = args.root_path
# region = args.region
root_path = "local"
region = "hinet"
data_path = f"{root_path}/{region}/cctorch"
result_path = f"{root_path}/{region}/adloc_dd"
figure_path = f"{root_path}/{region}/adloc_dd/figures"
if not os.path.exists(data_path):
os.makedirs(data_path)
if not os.path.exists(result_path):
os.makedirs(result_path)

# %%
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
if ddp:
init_process_group(backend="gloo")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
print(f"DDP rank {ddp_rank}, local rank {ddp_local_rank}, world size {ddp_world_size}")
else:
# vanilla, non-DDP run
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
print("Non-DDP run")

# %% reading from the generated files
events = pd.read_csv(os.path.join(data_path, "cctorch_events.csv"))
events["event_time"] = pd.to_datetime(events["event_time"])
events.rename(columns={"event_time": "time"}, inplace=True)
stations = pd.read_csv(os.path.join(data_path, "cctorch_stations.csv"))
picks = pd.read_csv(os.path.join(data_path, "cctorch_picks.csv"), parse_dates=["phase_time"])
pairs = pd.read_csv(os.path.join(data_path, "dtcc.csv"))

with open(f"{root_path}/{region}/config.json", "r") as fp:
config = json.load(fp)
print(json.dumps(config, indent=4))
config["use_amplitude"] = True

# ## Eikonal for 1D velocity model
zz = [0.0, 5.5, 16.0, 32.0]
vp = [5.5, 5.5, 6.7, 7.8]
vp_vs_ratio = 1.73
vs = [v / vp_vs_ratio for v in vp]
h = 0.3

# %%
if not os.path.exists(result_path):
os.makedirs(result_path)
if not os.path.exists(figure_path):
os.makedirs(figure_path)

# %%
## Automatic region; you can also specify a region
# lon0 = stations["longitude"].median()
# lat0 = stations["latitude"].median()
lat0 = (config["minlatitude"] + config["maxlatitude"]) / 2
lon0 = (config["minlongitude"] + config["maxlongitude"]) / 2
proj = Proj(f"+proj=sterea +lon_0={lon0} +lat_0={lat0} +units=km")

stations["x_km"], stations["y_km"] = proj(stations["longitude"], stations["latitude"])
stations["z_km"] = stations["depth_km"]
events["time"] = pd.to_datetime(events["time"])
events["x_km"], events["y_km"] = proj(events["longitude"], events["latitude"])
events["z_km"] = events["depth_km"]

events_init = events.copy()

## set up the config; you can also specify the region manually
if ("xlim_km" not in config) or ("ylim_km" not in config) or ("zlim_km" not in config):
xmin, ymin = proj(config["minlongitude"], config["minlatitude"])
xmax, ymax = proj(config["maxlongitude"], config["maxlatitude"])
# zmin, zmax = config["mindepth"], config["maxdepth"]
zmin = config["mindepth"] if "mindepth" in config else 0.0
zmax = config["maxdepth"] if "maxdepth" in config else 30.0
config["xlim_km"] = (xmin, xmax)
config["ylim_km"] = (ymin, ymax)
config["zlim_km"] = (zmin, zmax)

mapping_phase_type_int = {"P": 0, "S": 1}
config["vel"] = {"P": 6.0, "S": 6.0 / 1.73}
config["vel"] = {mapping_phase_type_int[k]: v for k, v in config["vel"].items()}

# %%
config["eikonal"] = None
## Eikonal for 1D velocity model
# zz = [0.0, 5.5, 16.0, 32.0]
# vp = [5.5, 5.5, 6.7, 7.8]
# vp_vs_ratio = 1.73
# vs = [v / vp_vs_ratio for v in vp]
# zz = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 30.0]
# vp = [4.746, 4.793, 4.799, 5.045, 5.721, 5.879, 6.504, 6.708, 6.725, 7.800]
# vs = [2.469, 2.470, 2.929, 2.930, 3.402, 3.403, 3.848, 3.907, 3.963, 4.500]
# h = 0.3
vel = {"Z": zz, "P": vp, "S": vs}
config["eikonal"] = {
"vel": vel,
"h": h,
"xlim_km": config["xlim_km"],
"ylim_km": config["ylim_km"],
"zlim_km": config["zlim_km"],
}
config["eikonal"] = init_eikonal2d(config["eikonal"])

# %%
config["bfgs_bounds"] = (
(config["xlim_km"][0] - 1, config["xlim_km"][1] + 1), # x
(config["ylim_km"][0] - 1, config["ylim_km"][1] + 1), # y
(0, config["zlim_km"][1] + 1),
(None, None), # t
)

# %%
phase_dataset = PhaseDatasetDTCC(pairs, picks, events, stations, rank=ddp_local_rank, world_size=ddp_world_size)
data_loader = DataLoader(phase_dataset, batch_size=None, shuffle=False, num_workers=0, drop_last=False)

# %%
num_event = len(events)
num_station = len(stations)
station_loc = stations[["x_km", "y_km", "z_km"]].values
event_loc = events[["x_km", "y_km", "z_km"]].values # + np.random.randn(num_event, 3) * 2.0
travel_time = TravelTimeDD(
num_event,
num_station,
station_loc=station_loc,
event_loc=event_loc,
# event_time=event_time,
eikonal=config["eikonal"],
)
if ddp:
travel_time = DDP(travel_time)
raw_travel_time = travel_time.module if ddp else travel_time

if ddp_local_rank == 0:
print(f"Dataset: {len(events)} events, {len(stations)} stations, {len(data_loader)} batches")

## invert loss
######################################################################################################
optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1)
valid_index = np.ones(len(pairs), dtype=bool)
EPOCHS = 100
for epoch in range(EPOCHS):
loss = 0
optimizer.zero_grad()
# for meta in tqdm(phase_dataset, desc=f"Epoch {i}"):
for meta in data_loader:
out = travel_time(
meta["idx_sta"],
meta["idx_eve"],
meta["phase_type"],
meta["phase_time"],
meta["phase_weight"],
)
pred_, loss_ = out["phase_time"], out["loss"]

loss_.backward()

if ddp:
dist.all_reduce(loss_, op=dist.ReduceOp.SUM)
# loss_ /= ddp_world_size

loss += loss_

# torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0)
optimizer.step()
with torch.no_grad():
raw_travel_time.event_loc.weight.data[:, 2].clamp_(
min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
)
raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0
if ddp_local_rank == 0:
print(f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}")

### filtering
pred_time = []
phase_dataset.valid_index = np.ones(len(pairs), dtype=bool)
for meta in phase_dataset:
meta = travel_time(
meta["idx_sta"],
meta["idx_eve"],
meta["phase_type"],
meta["phase_time"],
meta["phase_weight"],
)
pred_time.append(meta["phase_time"].detach().numpy())

pred_time = np.concatenate(pred_time)
valid_index = (
np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * 3.0
) # * (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std

pairs_df = pd.DataFrame(
{
"event_index1": pairs["idx_eve1"],
"event_index2": pairs["idx_eve2"],
"station_index": pairs["idx_sta"],
}
)
pairs_df = pairs_df[valid_index]
config["MIN_OBS"] = 8
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(
lambda x: len(x) >= config["MIN_OBS"]
)
valid_index = np.zeros(len(pairs), dtype=bool)
valid_index[pairs_df.index] = True

phase_dataset.valid_index = valid_index

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()
valid_event_index = np.unique(pairs["idx_eve1"][valid_index])
valid_event_index = np.concatenate(
[np.unique(pairs["idx_eve1"][valid_index]), np.unique(pairs["idx_eve2"][valid_index])]
)
valid_event_index = np.sort(np.unique(valid_event_index))

if ddp_local_rank == 0 and (epoch % 10 == 0):
events = events_init.copy()
events["time"] = events["time"] + pd.to_timedelta(np.squeeze(invert_event_time), unit="s")
events["x_km"] = invert_event_loc[:, 0]
events["y_km"] = invert_event_loc[:, 1]
events["z_km"] = invert_event_loc[:, 2]
events[["longitude", "latitude"]] = events.apply(
lambda x: pd.Series(proj(x["x_km"], x["y_km"], inverse=True)), axis=1
)
events["depth_km"] = events["z_km"]
events = events.iloc[valid_event_index]
events.to_csv(
f"{result_path}/adloc_dtcc_events_{epoch//10}.csv",
index=False,
float_format="%.5f",
date_format="%Y-%m-%dT%H:%M:%S.%f",
)
plotting_dd(events, stations, config, figure_path, events_init, suffix=f"_ddcc_{epoch//10}")

# ######################################################################################################
# optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=10, line_search_fn="strong_wolfe")

# def closure():
# optimizer.zero_grad()
# loss = 0
# # for meta in tqdm(phase_dataset, desc=f"BFGS"):
# if ddp_local_rank == 0:
# print(f"BFGS: ", end="")
# for meta in phase_dataset:
# if ddp_local_rank == 0:
# print(".", end="")

# loss_ = travel_time(
# meta["idx_sta"],
# meta["idx_eve"],
# meta["phase_type"],
# meta["phase_time"],
# meta["phase_weight"],
# )["loss"]
# loss_.backward()

# if ddp:
# dist.all_reduce(loss_, op=dist.ReduceOp.SUM)
# # loss_ /= ddp_world_size

# loss += loss_

# if ddp_local_rank == 0:
# print(f"Loss: {loss}")
# raw_travel_time.event_loc.weight.data[:, 2].clamp_(min=config["zlim_km"][0], max=config["zlim_km"][1])
# return loss

# optimizer.step(closure)
# ######################################################################################################

# %%
if ddp_local_rank == 0:

plotting_dd(events, stations, config, figure_path, events_init, suffix="_dtcc")

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()

events = events_init.copy()
events["time"] = events["time"] + pd.to_timedelta(np.squeeze(invert_event_time), unit="s")
events["x_km"] = invert_event_loc[:, 0]
events["y_km"] = invert_event_loc[:, 1]
events["z_km"] = invert_event_loc[:, 2]
events[["longitude", "latitude"]] = events.apply(
lambda x: pd.Series(proj(x["x_km"], x["y_km"], inverse=True)), axis=1
)
events["depth_km"] = events["z_km"]
events = events.iloc[valid_event_index]
events.to_csv(
f"{result_path}/adloc_dtcc_events.csv", index=False, float_format="%.5f", date_format="%Y-%m-%dT%H:%M:%S.%f"
)
Loading

0 comments on commit acf2fd4

Please sign in to comment.