-
Notifications
You must be signed in to change notification settings - Fork 20
/
train.py
40 lines (33 loc) · 1.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Train one model for each task
from behavioural_cloning import behavioural_cloning_train
def main():
print("===Training FindCave model===")
behavioural_cloning_train(
data_dir="data/MineRLBasaltFindCave-v0",
in_model="data/VPT-models/foundation-model-1x.model",
in_weights="data/VPT-models/foundation-model-1x.weights",
out_weights="train/MineRLBasaltFindCave.weights"
)
print("===Training MakeWaterfall model===")
behavioural_cloning_train(
data_dir="data/MineRLBasaltMakeWaterfall-v0",
in_model="data/VPT-models/foundation-model-1x.model",
in_weights="data/VPT-models/foundation-model-1x.weights",
out_weights="train/MineRLBasaltMakeWaterfall.weights"
)
print("===Training CreateVillageAnimalPen model===")
behavioural_cloning_train(
data_dir="data/MineRLBasaltCreateVillageAnimalPen-v0",
in_model="data/VPT-models/foundation-model-1x.model",
in_weights="data/VPT-models/foundation-model-1x.weights",
out_weights="train/MineRLBasaltCreateVillageAnimalPen.weights"
)
print("===Training BuildVillageHouse model===")
behavioural_cloning_train(
data_dir="data/MineRLBasaltBuildVillageHouse-v0",
in_model="data/VPT-models/foundation-model-1x.model",
in_weights="data/VPT-models/foundation-model-1x.weights",
out_weights="train/MineRLBasaltBuildVillageHouse.weights"
)
if __name__ == "__main__":
main()