From b5676254f871f790d661fc336c6dbd46f19fe632 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Mon, 21 Jul 2025 13:54:18 +0200 Subject: [PATCH] added lora unet support for flux --- flux/model.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/flux/model.py b/flux/model.py index 4ad10e9..aa1d0b1 100644 --- a/flux/model.py +++ b/flux/model.py @@ -84,14 +84,34 @@ class Flux(nn.Module): def preprocess_loras(self, model_type, sd): new_sd = {} if len(sd) == 0: return sd - + def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight first_key= next(iter(sd)) - if first_key.startswith("transformer."): + if first_key.startswith("lora_unet_"): + new_sd = {} + print("Converting Lora Safetensors format to Lora Diffusers format") + repl_list = ["linear1", "linear2", "modulation_lin"] + src_list = ["_" + k + "." for k in repl_list] + tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list] + + for k,v in sd.items(): + k = k.replace("lora_unet_blocks_","diffusion_model.blocks.") + k = k.replace("lora_unet__blocks_","diffusion_model.blocks.") + k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.") + + for s,t in zip(src_list, tgt_list): + k = k.replace(s,t) + + k = k.replace("lora_up","lora_B") + k = k.replace("lora_down","lora_A") + + new_sd[k] = v + + elif first_key.startswith("transformer."): root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2", "time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2", "x_embedder", "context_embedder", "proj_out" ]