From 3cf9295515ce477ed9547d59bed053df30e24ccb Mon Sep 17 00:00:00 2001 From: Sharan Thakur Date: Thu, 25 Jul 2024 16:24:34 +0530 Subject: [PATCH 1/2] fixes sd3-version argument needed for torch2coreml --- .../mixed_bit_compression_pre_analysis.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py b/python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py index 505895c..e1b6e07 100644 --- a/python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py +++ b/python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py @@ -578,6 +578,11 @@ def main(args): "If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. " "No precision override is applied when using a custom VAE." )) + # needed since this calls `torch2coreml` and that would throw an error + parser.add_argument( + "--sd3-version", + action="store_true", + help=("If specified, the pre-trained model will be treated as an SD3 model.")) args = parser.parse_args() main(args) From a31c2af578ef9b16240d33dcd30e6048da230a50 Mon Sep 17 00:00:00 2001 From: Sharan Thakur Date: Thu, 25 Jul 2024 17:16:07 +0530 Subject: [PATCH 2/2] Fix for sd3-version flag --- python_coreml_stable_diffusion/mixed_bit_compression_apply.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python_coreml_stable_diffusion/mixed_bit_compression_apply.py b/python_coreml_stable_diffusion/mixed_bit_compression_apply.py index b65f07c..8e2f8b1 100644 --- a/python_coreml_stable_diffusion/mixed_bit_compression_apply.py +++ b/python_coreml_stable_diffusion/mixed_bit_compression_apply.py @@ -120,6 +120,10 @@ def get_tensor_hash(tensor): "If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. " "No precision override is applied when using a custom VAE." )) + parser.add_argument( + "--sd3-version", + action="store_true", + help=("If specified, the pre-trained model will be treated as an SD3 model.")) args = parser.parse_args()