Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Euler Ancestral Discrete Scheduler #308

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions swift/StableDiffusion/pipeline/EulerAncestralDiscreteScheduler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import Accelerate
import CoreML

@available(iOS 16.2, macOS 13.1, *)
public final class EulerAncestralDiscreteScheduler: Scheduler {
public let trainStepCount: Int
public let inferenceStepCount: Int
public let betas: [Float]
public let timeSteps: [Int]
public let alphas: [Float]
public let alphasCumProd: [Float]
public let sigmas: [Float]
public let initNoiseSigma: Float
private(set) var randomSource: RandomSource
public private(set) var modelOutputs: [MLShapedArray<Float32>] = []

public init(
randomSource: RandomSource,
stepCount: Int = 50,
trainStepCount: Int = 1000,
betaStart: Float = 0.0001,
betaEnd: Float = 0.02,
betaSchedule: BetaSchedule = .scaledLinear,
timestepSpacing: TimestepSpacing = .leading,
stepsOffset: Int = 1
) {
self.randomSource = randomSource
self.trainStepCount = trainStepCount
inferenceStepCount = stepCount

switch betaSchedule {
case .linear:
betas = linspace(betaStart, betaEnd, trainStepCount)
case .scaledLinear:
betas = linspace(sqrt(betaStart), sqrt(betaEnd), trainStepCount).map { $0 * $0 }
}

alphas = betas.map { 1.0 - $0 }
var alphasCumProd = alphas
for i in 1 ..< alphasCumProd.count {
alphasCumProd[i] *= alphasCumProd[i - 1]
}
self.alphasCumProd = alphasCumProd

var sigmas = vForce.sqrt(
vDSP.divide(
vDSP.subtract(
[Float](repeating: 1, count: alphasCumProd.count),
alphasCumProd
),
alphasCumProd
)
)

var timeSteps = [Float](repeating: 0.0, count: stepCount)
switch timestepSpacing {
case .linspace:
timeSteps = linspace(0, Float(trainStepCount - 1), stepCount)
case .leading:
let stepRatio = trainStepCount / stepCount
timeSteps = (0 ..< stepCount).map { Float($0 * stepRatio + stepsOffset) }
case .trailing:
let stepRatio = trainStepCount / stepCount
timeSteps = (1 ... stepCount).map { Float($0 * stepRatio - 1) }
}
timeSteps.reverse()
var sigmasInt = [Float](repeating: 0.0, count: timeSteps.count)
vDSP_vlint(&sigmas, &timeSteps, vDSP_Stride(1), &sigmasInt, vDSP_Stride(1), vDSP_Length(timeSteps.count), vDSP_Length(sigmas.count))
sigmasInt.append(0.0)
initNoiseSigma = sigmasInt.max()!
self.timeSteps = timeSteps.map { Int($0) }
self.sigmas = sigmasInt
}

public func step(
output: MLShapedArray<Float32>,
timeStep: Int,
sample: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
let stepIndex = timeSteps.firstIndex(of: timeStep) ?? timeSteps.count - 1
let sigma = sigmas[stepIndex]
let predOriginalSample = weightedSum([1.0, Double(-1.0 * sigma)], [sample, output])

let sigmaFrom = sigmas[stepIndex]
let sigmaTo = sigmas[stepIndex + 1]
let sigmaUp = sqrt(pow(sigmaTo, 2.0) * (pow(sigmaFrom, 2.0) - pow(sigmaTo, 2.0)) / pow(sigmaFrom, 2.0))
let sigmaDown = sqrt(pow(sigmaTo, 2.0) - pow(sigmaUp, 2.0))

// Convert to an ODE derivative:
let derivative = weightedSum([Double(1.0 / sigma), Double(-1.0 / sigma)], [sample, predOriginalSample])
let dt = sigmaDown - sigma
let prevSample = weightedSum([1.0, Double(dt)], [sample, derivative])
let noise = MLShapedArray<Float32>(
converting: randomSource.normalShapedArray(
output.shape,
mean: 0.0,
stdev: 1.0
)
)
return weightedSum([1.0, Double(sigmaUp)], [prevSample, noise])
}

public func scaleModelInput(
sample: MLShapedArray<Float32>,
timeStep: Int
) -> MLShapedArray<Float32> {
let stepIndex = timeSteps.firstIndex(of: timeStep) ?? timeSteps.count - 1
let sigma = sigmas[stepIndex]
let scale = sqrt(pow(sigma, 2.0) + 1.0)
let scalarCount = sample.scalarCount
return MLShapedArray(unsafeUninitializedShape: sample.shape) { scalars, _ in
assert(scalars.count == scalarCount)
sample.withUnsafeShapedBufferPointer { sample, _, _ in
for i in 0 ..< scalarCount {
scalars.initializeElement(at: i, to: sample[i] / scale)
}
}
}
}
}
19 changes: 19 additions & 0 deletions swift/StableDiffusion/pipeline/Scheduler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public protocol Scheduler {
timeStep t: Int,
sample s: MLShapedArray<Float32>
) -> MLShapedArray<Float32>

func scaleModelInput(
sample: MLShapedArray<Float32>,
timeStep: Int
) -> MLShapedArray<Float32>
}

@available(iOS 16.2, macOS 13.1, *)
Expand Down Expand Up @@ -100,6 +105,13 @@ public extension Scheduler {

return noisySamples
}

func scaleModelInput(
sample: MLShapedArray<Float32>,
timeStep: Int
) -> MLShapedArray<Float32> {
return sample
}
}

// MARK: - Timesteps
Expand All @@ -125,6 +137,13 @@ public enum BetaSchedule {
case scaledLinear
}

@available(iOS 16.2, macOS 13.1, *)
public enum TimestepSpacing {
case linspace
case leading
case trailing
}

// MARK: - PNDMScheduler

/// A scheduler used to compute a de-noised image
Expand Down
13 changes: 12 additions & 1 deletion swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public enum StableDiffusionScheduler {
case pndmScheduler
/// Scheduler that uses a second order DPM-Solver++ algorithm
case dpmSolverMultistepScheduler

case eulerAncestralDiscreteScheduler
}

/// RNG compatible with StableDiffusionPipeline
Expand Down Expand Up @@ -229,6 +231,9 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
switch config.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount, timeStepSpacing: config.schedulerTimestepSpacing)
case .eulerAncestralDiscreteScheduler: return EulerAncestralDiscreteScheduler(
randomSource: randomSource(from: config.rngType, seed: config.seed),
stepCount: config.stepCount)
}
}

Expand Down Expand Up @@ -258,10 +263,16 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {

// Expand the latents for classifier-free guidance
// and input to the Unet noise prediction model
let latentUnetInput = latents.map {
var latentUnetInput = latents.map {
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
}

for i in 0..<config.imageCount {
latentUnetInput[i] = scheduler[i].scaleModelInput(
sample: latentUnetInput[i],
timeStep: t)
}

// Before Unet, execute controlNet and add the output into Unet inputs
let additionalResiduals = try controlNet?.execute(
latents: latentUnetInput,
Expand Down
12 changes: 11 additions & 1 deletion swift/StableDiffusion/pipeline/StableDiffusionXLPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ public struct StableDiffusionXLPipeline: StableDiffusionPipelineProtocol {
switch config.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount, timeStepSpacing: config.schedulerTimestepSpacing)
case .eulerAncestralDiscreteScheduler: return EulerAncestralDiscreteScheduler(
randomSource: randomSource(from: config.rngType, seed: config.seed),
stepCount: config.stepCount)
}
}

Expand Down Expand Up @@ -208,10 +211,17 @@ public struct StableDiffusionXLPipeline: StableDiffusionPipelineProtocol {
for (step,t) in timeSteps.enumerated() {
// Expand the latents for classifier-free guidance
// and input to the Unet noise prediction model
let latentUnetInput = latents.map {
var latentUnetInput = latents.map {
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
}

for i in 0..<config.imageCount {
latentUnetInput[i] = scheduler[i].scaleModelInput(
sample: latentUnetInput[i],
timeStep: t
)
}

// Switch to refiner if specified
if let refiner = unetRefiner, step == refinerStartStep {
unet.unloadResources()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.

import XCTest
import CoreML
@testable import StableDiffusion

@available(iOS 16.2, macOS 14.0, *)
final class EulerAncestralDiscreteSchedulerTests: XCTestCase {
func testScal() throws {
let scheduler = EulerAncestralDiscreteScheduler(randomSource: TorchRandomSource(seed: 0), stepCount: 25)
let scaledSample = scheduler.scaleModelInput(
sample: MLShapedArray(
scalars: [0.1, 0.2, 0.3, 0.4, 0.5],
shape: [5]
),
timeStep: 960)
XCTAssertTrue(true)
}

func testLinspaceStep() throws {
let scheduler = EulerAncestralDiscreteScheduler(randomSource: TorchRandomSource(seed: 0), stepCount: 25)
let preSample = scheduler.step(
output: MLShapedArray<Float32>(repeating: 0.5, shape: [5]),
timeStep: 960,
sample: MLShapedArray<Float32>(repeating: 0.5, shape: [5])
)
XCTAssertEqual(
preSample,
MLShapedArray(
scalars: [18.519714, -15.129302, -49.71263, 0.6798735, -29.640398],
shape: [5]
)
)
}
}