diff --git a/apps/pretrained_compound/ChemRL/GEM-2/README.md b/apps/pretrained_compound/ChemRL/GEM-2/README.md index fa6bdfca..effb9e3d 100644 --- a/apps/pretrained_compound/ChemRL/GEM-2/README.md +++ b/apps/pretrained_compound/ChemRL/GEM-2/README.md @@ -95,6 +95,16 @@ To reproduce the result from the ogb leaderboard, just run the inference command sh scripts/inference.sh +## Run inference +To reproduce the result from the ogb leaderboard, you can download the checkponit from [here](https://baidu-nlp.bj.bcebos.com/PaddleHelix/models/molecular_modeling/gem2_l12_c256.pdparams). +Then put it under the local `./model` folder and run the inference command: + + sh scripts/inference.sh +We also provide a checkpoint with smaller embedding size(128), you can download it from [here](https://baidu-nlp.bj.bcebos.com/PaddleHelix/models/molecular_modeling/gem2_l12_c128.pdparams) . + +Then change the `encoder_config` to `opt3d_l12_c128.json`, `init_model` to `gem2_l12_c128.pdparams` in `inference.sh`. + +Now you can run the inference command with the new checkpoint. ## Citing this work diff --git a/apps/pretrained_compound/ChemRL/GEM-2/configs/model_configs/opt3d_l12_c128.json b/apps/pretrained_compound/ChemRL/GEM-2/configs/model_configs/opt3d_l12_c128.json new file mode 100644 index 00000000..aa9e659b --- /dev/null +++ b/apps/pretrained_compound/ChemRL/GEM-2/configs/model_configs/opt3d_l12_c128.json @@ -0,0 +1,5 @@ +{ + "node_channel": 128, + "pair_channel": 128, + "triple_channel": 128 +} diff --git a/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py b/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py index fb9e878f..ad82d697 100644 --- a/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py +++ b/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py @@ -274,6 +274,7 @@ def _read_json(path): ### build model model = MolRegressionModel(model_config, encoder_config) + single_model = model print("parameter size:", calc_parameter_size(model.parameters())) if args.distributed: model = paddle.DataParallel(model) @@ -308,7 +309,7 @@ def _read_json(path): ema.register() if epoch_id == 69: - model.encoder.reduce_dropout() + single_model.encoder.reduce_dropout() ## train s_time = time.time()