diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs b/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs
new file mode 100644
index 0000000000..0b1d9923ab
--- /dev/null
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs
@@ -0,0 +1,81 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http.Headers;
+using System.Reactive;
+using System.Text;
+using System.Threading.Tasks;
+using DryIoc.ImTools;
+using Google.OrTools.LinearSolver;
+using Nncase.IR;
+using Nncase.IR.CPU;
+using Nncase.IR.Math;
+using Nncase.IR.NN;
+using Nncase.IR.Tensors;
+using Nncase.Passes.Rules.Neutral;
+using Nncase.PatternMatch;
+using Nncase.Targets;
+using static Nncase.PatternMatch.F.Math;
+using static Nncase.PatternMatch.Utility;
+using static Nncase.Utilities.ReplaceUtility;
+
+namespace Nncase.Passes.Rules;
+
+///
+/// Unet Merger for all.
+///
+public sealed class FusionMerger : ExprCloner
+{
+ private readonly IReadOnlyDictionary _multiVarMap;
+
+ public FusionMerger(IReadOnlyDictionary multiVarMap)
+ {
+ _multiVarMap = multiVarMap;
+ }
+
+ protected override Expr VisitCall(Call expr, Unit context)
+ {
+ if (_multiVarMap.TryGetValue(expr, out var newVar))
+ {
+ return newVar;
+ }
+
+ return base.VisitCall(expr, context);
+ }
+
+ protected override Expr VisitLeafCall(Call expr, Unit context)
+ {
+ var target = Clone(expr.Target, context);
+ var arguments = CloneArray(expr.Arguments, context);
+ if (target is Binary)
+ {
+ arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
+ }
+
+ if (target is Conv2D conv)
+ {
+ var bias = (TensorConst)arguments[2];
+ var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray();
+ var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity });
+ var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 }));
+ var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]);
+ return newClamp;
+ }
+
+ return expr.With(target: target, arguments: arguments);
+ }
+
+ protected override Expr VisitLeafVar(Var expr, Unit context)
+ {
+ if (_multiVarMap.TryGetValue(expr, out var newVar))
+ {
+ return newVar;
+ }
+
+ throw new InvalidOperationException();
+ }
+}
+
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs
index 74ca6d7929..cd1d6040a7 100644
--- a/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs
@@ -23,51 +23,6 @@
namespace Nncase.Passes.Rules;
-///
-/// MHA Merger for all.
-///
-public sealed class MHAMerger : ExprCloner
-{
- private readonly IReadOnlyDictionary _multiVarMap;
-
- public MHAMerger(IReadOnlyDictionary multiVarMap)
- {
- _multiVarMap = multiVarMap;
- }
-
- protected override Expr VisitCall(Call expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- return base.VisitCall(expr, context);
- }
-
- protected override Expr VisitLeafCall(Call expr, Unit context)
- {
- var target = Clone(expr.Target, context);
- var arguments = CloneArray(expr.Arguments, context);
- if (target is Binary)
- {
- arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
- }
-
- return expr.With(target: target, arguments: arguments);
- }
-
- protected override Expr VisitLeafVar(Var expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- throw new InvalidOperationException();
- }
-}
-
// pattern from BERT base
[RuleGenerator]
public sealed partial class FuseMHA1 : FusionMaker
@@ -248,7 +203,7 @@ private static Pattern CreatePattern()
{ position_ids, (Var)newInputs[1] },
{ attn_mask, (Var)newInputs[2] },
};
- var merger = new MHAMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("MHALLaMA65B", $"{nameof(FuseMHA2)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), hidden_in, position_ids, attn_mask);
@@ -308,7 +263,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
- var merger = new MHAMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("SDTextEncoderMHA", $"{nameof(FuseMHA3)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);
@@ -345,7 +300,7 @@ private static Pattern CreatePattern()
{
{ input, newInputs[0] },
};
- var merger = new MHAMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("SDTextEncoderHeader", $"{nameof(FuseSDTextEncoderHeader)}_{Count++}", ModuleKind, clonedRoot, newInputs.ToArray()), input);
@@ -389,7 +344,7 @@ private static Pattern CreatePattern()
{ input, newInputs[1] },
};
- var merger = new MHAMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("SDTextEncoderTail", $"{nameof(FuseSDTextEncoderTail)}_{Count++}", ModuleKind, clonedRoot, newInputs), input_ids, input);
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs
index ae382f985b..ae30765b80 100644
--- a/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs
@@ -27,61 +27,6 @@
namespace Nncase.Passes.Rules;
-///
-/// Unet Merger for all.
-///
-public sealed class UnetMerger : ExprCloner
-{
- private readonly IReadOnlyDictionary _multiVarMap;
-
- public UnetMerger(IReadOnlyDictionary multiVarMap)
- {
- _multiVarMap = multiVarMap;
- }
-
- protected override Expr VisitCall(Call expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- return base.VisitCall(expr, context);
- }
-
- protected override Expr VisitLeafCall(Call expr, Unit context)
- {
- var target = Clone(expr.Target, context);
- var arguments = CloneArray(expr.Arguments, context);
- if (target is Binary)
- {
- arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
- }
-
- if (target is Conv2D conv)
- {
- var bias = (TensorConst)arguments[2];
- var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray();
- var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity });
- var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 }));
- var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]);
- return newClamp;
- }
-
- return expr.With(target: target, arguments: arguments);
- }
-
- protected override Expr VisitLeafVar(Var expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- throw new InvalidOperationException();
- }
-}
-
///
/// stable-disffusion Unet spatial transformer.
///
@@ -194,7 +139,7 @@ private static Pattern CreatePattern()
{ input, (Var)newInputs[0] },
{ encoderHiddenStates, (Var)newInputs[1] },
};
- var merger = new UnetMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("UnetSpatialTransformer", $"{nameof(FuseUnetSpatialTransformer)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input, encoderHiddenStates);
@@ -282,7 +227,7 @@ private static Pattern CreatePattern()
multiVarMap.Add(oldInputs[i], (Var)newInputs[i]);
}
- var merger = new UnetMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("UnetResBlock", $"{nameof(FuseUnetResBlock)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), oldInputs.ToArray());
@@ -336,7 +281,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
- var merger = new UnetMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("UnetTimeEmb", $"{nameof(FuseUnetTimeEmb)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs
index b43271a7b6..59e22d2587 100644
--- a/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs
@@ -24,51 +24,6 @@
namespace Nncase.Passes.Rules;
-///
-/// VAE Merger for all.
-///
-public sealed class VAEMerger : ExprCloner
-{
- private readonly IReadOnlyDictionary _multiVarMap;
-
- public VAEMerger(IReadOnlyDictionary multiVarMap)
- {
- _multiVarMap = multiVarMap;
- }
-
- protected override Expr VisitCall(Call expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- return base.VisitCall(expr, context);
- }
-
- protected override Expr VisitLeafCall(Call expr, Unit context)
- {
- var target = Clone(expr.Target, context);
- var arguments = CloneArray(expr.Arguments, context);
- if (target is Binary)
- {
- arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
- }
-
- return expr.With(target: target, arguments: arguments);
- }
-
- protected override Expr VisitLeafVar(Var expr, Unit context)
- {
- if (_multiVarMap.TryGetValue(expr, out var newVar))
- {
- return newVar;
- }
-
- throw new InvalidOperationException();
- }
-}
-
///
/// stable-disffusion VAE Decoder res-block.
///
@@ -120,7 +75,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
- var merger = new VAEMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("VAEDecRes", $"{nameof(FuseVAEDecRes)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);
@@ -163,7 +118,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
- var merger = new VAEMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("VAEDecHead", $"{nameof(FuseVAEDecHead)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);
@@ -222,7 +177,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
- var merger = new VAEMerger(multiVarMap);
+ var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);
var callFusion = new Call(new Fusion("VAEDecMHA", $"{nameof(FuseVAEDecMHA)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);