Skip to content

Commit

Permalink
Feat: Flux support (#450)
Browse files Browse the repository at this point in the history
* feat: support for flux

* feat: reduce batching on flux

* normalize baseline names

* chore: linting

* fix: license

* tests: more images to fix the SDK
  • Loading branch information
db0 authored Sep 11, 2024
1 parent 4c0b12e commit bdb099e
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.42.0

* Adds support for the Flux family of models

# 4.41.0

* Adds support for extra backends behind LLM bridges, and for knowing which are validated.
Expand Down
5 changes: 5 additions & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def validate(self):
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("flux_1") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Flux currently.", rc="ControlNetMismatch")
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("HiRes Fix does not work with Flux currently.", rc="HiResMismatch")
if "loras" in self.params:
if len(self.params["loras"]) > 5:
raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras")
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def get_gen_kudos(self):
if self.wp.params.get("hires_fix", False):
return self.wp.kudos * 7
return self.wp.kudos * 4
if model_reference.get_model_baseline(self.model) in ["flux_1"]:
# Flux is double the size of SDXL and much slower, so it gives double the rewards from it.
return self.wp.kudos * 8
return self.wp.kudos

def log_aborted_generation(self):
Expand Down
24 changes: 15 additions & 9 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from horde.classes.base.waiting_prompt import WaitingPrompt
from horde.classes.stable.kudos import KudosModel
from horde.consts import (
BASELINE_BATCHING_MULTIPLIERS,
HEAVY_POST_PROCESSORS,
KNOWN_LCM_LORA_IDS,
KNOWN_LCM_LORA_VERSIONS,
Expand Down Expand Up @@ -364,18 +365,16 @@ def require_upfront_kudos(self, counted_totals, total_threads):
max_res = 768
# We allow everyone to use SDXL up to 1024
if max_res < 1024 and any(
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade", "flux_1"] for mn in model_names
):
max_res = 1024
if max_res > 1024:
max_res = 1024
# Using more than 10 steps with LCM requires upfront kudos
if self.is_using_lcm() and self.get_accurate_steps() > 10:
return (True, max_res, False)
# Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse.
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30:
return (True, max_res, False)
if self.get_accurate_steps() > 50:
# Some models don't require a lot of steps, so we check their requirements. The max steps we allow without upfront kudos is 40
if any(model_reference.get_model_requirements(mn).get("max_steps", 40) > self.get_accurate_steps() for mn in model_names):
return (True, max_res, False)
if self.width * self.height > max_res * max_res:
return (True, max_res, False)
Expand All @@ -400,9 +399,7 @@ def downgrade(self, max_resolution):
# Break, just in case we went too low
if self.width * self.height < 512 * 512:
break
max_steps = 50
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in self.get_model_names()):
max_steps = 30
max_steps = min(model_reference.get_model_requirements(mn).get("max_steps", 30) for mn in self.get_model_names())
if self.params.get("control_type"):
max_steps = 20
if self.is_using_lcm():
Expand Down Expand Up @@ -435,7 +432,7 @@ def get_accurate_steps(self):
if self.params.get("sampler_name", "k_euler_a") in ["k_dpm_adaptive"]:
# This sampler chooses the steps amount automatically
# and disregards the steps value from the user
# so we just calculate it as an average 50 steps
# so we just calculate it as an average 40 steps
return 40
steps = self.params["steps"]
if self.params.get("sampler_name", "k_euler_a") in SECOND_ORDER_SAMPLERS:
Expand Down Expand Up @@ -499,6 +496,8 @@ def extrapolate_dry_run_kudos(self):
return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1
if model_reference.get_model_baseline(model_name) in ["stable_cascade"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1
if model_reference.get_model_baseline(model_name) in ["flux_1"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 8) + 1
# The +1 is the extra kudos burn per request
return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1

Expand All @@ -513,5 +512,12 @@ def has_heavy_operations(self):
return True
return False

def get_highest_model_batching_multiplier(self):
highest_multiplier = 1
for mn in self.get_model_names():
if BASELINE_BATCHING_MULTIPLIERS.get(mn, 1) > highest_multiplier:
highest_multiplier = BASELINE_BATCHING_MULTIPLIERS.get(mn, 1)
return highest_multiplier

def count_pp(self):
return len(self.params.get("post_processing", []))
1 change: 1 addition & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def get_safe_amount(self, amount, wp):
if wp.has_heavy_operations():
pp_multiplier *= 1.8
mps *= pp_multiplier
mps *= wp.get_highest_model_batching_multiplier()
safe_amount = round(safe_generations / mps)
if safe_amount > amount:
safe_amount = amount
Expand Down
9 changes: 8 additions & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

HORDE_VERSION = "4.41.0 "
HORDE_VERSION = "4.42.0 "

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down Expand Up @@ -38,6 +38,13 @@
"4x_AnimeSharp" "CodeFormers",
}

# These models are very large in VRAM, so we increase the calculated MPS
# used to figure out batches by a set multiplier to reduce how many images are batched
# at a time when these models are used.
BASELINE_BATCHING_MULTIPLIERS = {
"flux_1": 2,
}


KNOWN_SAMPLERS = {
"k_lms",
Expand Down
1 change: 1 addition & 0 deletions horde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"InvalidTransparencyModel",
"InvalidTransparencyImg2Img",
"InvalidTransparencyCN",
"HiResMismatch",
]


Expand Down
1 change: 1 addition & 0 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def call_function(self):
"stable diffusion 2 512",
"stable_diffusion_xl",
"stable_cascade",
"flux_1",
}:
self.stable_diffusion_names.add(model)
if self.reference[model].get("nsfw"):
Expand Down
File renamed without changes.
Binary file added img_stable/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions img_stable/2.jpg.license
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis <[email protected]>

SPDX-License-Identifier: CC0-1.0
Binary file added img_stable/3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions img_stable/3.jpg.license
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis <[email protected]>

SPDX-License-Identifier: CC0-1.0

0 comments on commit bdb099e

Please sign in to comment.