forked from nasaharvest/crop-maml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble.py
66 lines (44 loc) · 1.79 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from argparse import ArgumentParser
from pathlib import Path
import sys
sys.path.append("..")
from src.maml.predict import predict
def prefix_from_name(model_name: str) -> str:
return model_name[:-15]
def landcover_mapper():
parser = ArgumentParser()
# figure out which model to use
parser.add_argument("--version", type=int, default=0)
parser.add_argument("--query", type=str, default=None)
parser.add_argument("--test_folder_name", type=str, default="earth_engine_region_busia")
args = parser.parse_args()
version_folder = Path(f"../data/maml_models/version_{args.version}")
# hardcoded for now
test_folder = Path(f"../data/raw/{args.test_folder_name}")
test_files = test_folder.glob("*.tif")
all_models = list(version_folder.glob(args.query))
print(f"Using the following models: ")
print(all_models)
print(f"Using model {version_folder}")
save_dirname = test_folder.name
save_dir = version_folder / save_dirname
save_dir.mkdir(exist_ok=True)
for test_path in test_files:
num_outfiles = 0
output = save_dir / f"preds_{test_path.name}"
if output.exists():
print(f"{test_path.name} already run! skipping")
continue
print(f"Running for {test_path}")
for model in all_models:
if num_outfiles == 0:
out = predict(version_folder, test_path, prefix=prefix_from_name(model.name))
else:
out["prediction_0"] += predict(
version_folder, test_path, prefix=prefix_from_name(model.name)
)["prediction_0"]
num_outfiles += 1
out["prediction_0"] /= num_outfiles
out.to_netcdf(save_dir / f"preds_{test_path.name}")
if __name__ == "__main__":
landcover_mapper()