Skip to content

Commit

Permalink
sd3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 21, 2024
1 parent 8af46e5 commit 07d9952
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 81 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def fix_dynamic_axes(
input_shapes = {}
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)
dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs)

onnx_inputs = {}
for name, value in dummy_inputs.items():
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,10 @@ def onnx_export_from_model(
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

tokenizer_3 = getattr(model, "tokenizer_3", None)
if tokenizer_3 is not None:
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))

model.save_config(output)

if float_dtype == "bf16":
Expand Down
81 changes: 71 additions & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,22 +1015,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}

if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
Expand Down Expand Up @@ -1160,6 +1151,76 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = "pooled_projections"

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.pooled_projection_dim = normalized_config.config.pooled_projection_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
[self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype
)


class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
shape = [self.batch_size]
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
return super().generate(input_name, framework, int_dtype, float_dtype)


class SD3TransformerOnnxConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
(DummyTransformerTimestpsInputGenerator,)
+ UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ (PooledProjectionsDummyInputGenerator,)
)
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="joint_attention_dim",
vocab_size="attention_head_dim",
allow_new=True,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["pooled_projections"] = {0: "batch_size"}
return common_inputs

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
hidden_states = inputs.pop("sample", None)
if hidden_states is not None:
inputs["hidden_states"] = hidden_states
return inputs


class T5EncoderOnnxConfig(CLIPTextOnnxConfig):
@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class GroupViTOnnxConfig(CLIPOnnxConfig):
pass

Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ class TasksManager:
}

_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
"t5-encoder": supported_tasks_mapping(
"feature-extraction",
onnx="T5EncoderOnnxConfig",
),
"clip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextOnnxConfig",
Expand All @@ -347,6 +351,10 @@ class TasksManager:
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
"sd3-transformer": supported_tasks_mapping(
"semantic-segmentation",
onnx="SD3TransformerOnnxConfig",
),
"vae-encoder": supported_tasks_mapping(
"semantic-segmentation",
onnx="VaeEncoderOnnxConfig",
Expand Down
161 changes: 111 additions & 50 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)

if check_if_diffusers_greater("0.30.0"):
from diffusers import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3Pipeline,
)

from diffusers.models.attention_processor import (
Attention,
AttnAddedKVProcessor,
Expand Down Expand Up @@ -87,56 +95,95 @@ def _get_submodels_for_export_diffusion(
Returns the components of a Stable Diffusion model.
"""

models_for_export = {}

is_stable_diffusion_xl = isinstance(
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
)
if is_stable_diffusion_xl:
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim
is_stable_diffusion_3 = isinstance(
pipeline, (StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline)
)

models_for_export = {}
is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
if is_stable_diffusion_xl:
if is_stable_diffusion_xl or is_stable_diffusion_3:
text_encoder.config.output_hidden_states = True
text_encoder.text_model.config.output_hidden_states = True

if is_stable_diffusion_3:
text_encoder.config.export_model_type = "clip-text-with-projection"
else:
text_encoder.config.export_model_type = "clip-text-model"

models_for_export["text_encoder"] = text_encoder

# U-NET
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
if not is_torch_greater_or_equal_than_2_1:
pipeline.unet.set_attn_processor(AttnProcessor())
# Text encoder 2
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
text_encoder_2.config.export_model_type = "clip-text-with-projection"

pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
models_for_export["unet"] = pipeline.unet
models_for_export["text_encoder_2"] = text_encoder_2

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
# Text encoder 3
text_encoder_3 = getattr(pipeline, "text_encoder_3", None)
if text_encoder_3 is not None:
text_encoder_3.config.export_model_type = "t5-encoder"
models_for_export["text_encoder_3"] = text_encoder_3

# U-NET
unet = getattr(pipeline, "unet", None)
if unet is not None:
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
unet.set_attn_processor(AttnProcessor())

# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
unet.config.time_cond_proj_dim = getattr(pipeline.unet.config, "time_cond_proj_dim", None)
unet.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
unet.config.export_model_type = "unet"
models_for_export["unet"] = unet

# Transformer
transformer = getattr(pipeline, "transformer", None)
if transformer is not None:
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
transformer.set_attn_processor(AttnProcessor())

transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None)
transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
transformer.config.export_model_type = "sd3-transformer"
models_for_export["transformer"] = transformer

# VAE Encoder
vae_encoder = copy.deepcopy(pipeline.vae)

# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)

# we return the distribution parameters to be able to recreate it in the decoder
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
# VAE Decoder
vae_decoder = copy.deepcopy(pipeline.vae)

# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder)

vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export


Expand Down Expand Up @@ -294,31 +341,58 @@ def get_diffusion_models_for_export(
`Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `ExportConfig`]: A Dict containing the model and
export configs for the different components of the model.
"""

models_for_export = _get_submodels_for_export_diffusion(pipeline)

# Text encoder
if "text_encoder" in models_for_export:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model=pipeline.text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config)

# Text encoder 2
if "text_encoder_2" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
export_config = export_config_constructor(
pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config)

# Text encoder 3
if "text_encoder_3" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
export_config = export_config_constructor(
pipeline.text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_3"] = (models_for_export["text_encoder_3"], export_config)

# U-NET
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="unet",
)
unet_export_config = export_config_constructor(pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["unet"] = (models_for_export["unet"], unet_export_config)
if "unet" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation"
)
unet_export_config = export_config_constructor(
pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["unet"] = (models_for_export["unet"], unet_export_config)

# Transformer
if "transformer" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation"
)
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = models_for_export["vae_encoder"]
Expand All @@ -344,19 +418,6 @@ def get_diffusion_models_for_export(
vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["vae_decoder"] = (vae_decoder, vae_export_config)

if "text_encoder_2" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="clip-text-with-projection",
)
export_config = export_config_constructor(
pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config)

return models_for_export


Expand Down
Loading

0 comments on commit 07d9952

Please sign in to comment.