diff --git a/wan/modules/model.py b/wan/modules/model.py index 5eba92b..d006f9e 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module): grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + hint = None + if self.block_id is not None and hints is not None: + kwargs = { + "seq_lens" : seq_lens, + "grid_sizes" : grid_sizes, + "freqs" :freqs, + "context" : context, + "context_lens" : context_lens, + "e" : e, + } + if self.block_id == 0: + hint = self.vace(hints, x, **kwargs) + else: + hint = self.vace(hints, None, **kwargs) + e = (self.modulation + e).chunk(6, dim=1) # self-attention @@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module): x.addcmul_(y, e[5]) - - if self.block_id is not None and hints != None: + + + if hint is not None: if context_scale == 1: - x.add_(hints[self.block_id]) + x.add_(hint) else: - x.add_(hints[self.block_id], alpha =context_scale) - return x + x.add_(hint, alpha= context_scale) + return x + + class VaceWanAttentionBlock(WanAttentionBlock): def __init__( @@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock): nn.init.zeros_(self.after_proj.weight) nn.init.zeros_(self.after_proj.bias) - def forward(self, c, x, **kwargs): + def forward(self, hints, x, **kwargs): # behold dbm magic ! + c = hints[0] + hints[0] = None if self.block_id == 0: c = self.before_proj(c) + x - all_c = [] - else: - all_c = c - c = all_c.pop(-1) c = super().forward(c, **kwargs) c_skip = self.after_proj(c) - all_c += [c_skip, c] - return all_c + hints[0] = c + return c_skip + + # def forward(self, c, x, **kwargs): + # # behold dbm magic ! + # if self.block_id == 0: + # c = self.before_proj(c) + x + # all_c = [] + # else: + # all_c = c + # c = all_c.pop(-1) + # c = super().forward(c, **kwargs) + # c_skip = self.after_proj(c) + # all_c += [c_skip, c] + # return all_c class Head(nn.Module): @@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin): print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold - def forward_vace( - self, - x, - vace_context, - seq_len, - context, - e, - kwargs - ): - # embeddings - c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] - c = [u.flatten(2).transpose(1, 2) for u in c] - if (len(c) == 1 and seq_len == c[0].size(1)): - c = c[0] - else: - c = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in c - ]) - # arguments - new_kwargs = dict(x=x) - new_kwargs.update(kwargs) - for block in self.vace_blocks: - c = block(c, context= context, e= e, **new_kwargs) - hints = c[:-1] + # def forward_vace( + # self, + # x, + # vace_context, + # seq_len, + # context, + # e, + # kwargs + # ): + # # embeddings + # c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + # c = [u.flatten(2).transpose(1, 2) for u in c] + # if (len(c) == 1 and seq_len == c[0].size(1)): + # c = c[0] + # else: + # c = torch.cat([ + # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + # dim=1) for u in c + # ]) - return hints + # # arguments + # new_kwargs = dict(x=x) + # new_kwargs.update(kwargs) + + # for block in self.vace_blocks: + # c = block(c, context= context, e= e, **new_kwargs) + # hints = c[:-1] + + # return hints def forward( self, @@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin): x_list = [x] context_list = [context] del x + + # arguments + + kwargs = dict( + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + context_lens=context_lens, + ) + + if vace_context == None: + hints_list = [None ] *len(x_list) + else: + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + if (len(c) == 1 and seq_len == c[0].size(1)): + c = c[0] + else: + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + kwargs['context_scale'] = vace_context_scale + hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ] + del c + should_calc = True if self.enable_teacache: if is_uncond: @@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin): if joint_pass or not is_uncond: self.previous_residual_cond = None ori_hidden_states = x_list[0].clone() - # arguments - - kwargs = dict( - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=freqs, - context_lens=context_lens) - - if vace_context == None: - hints_list = [None ] *len(x_list) - else: - hints_list = [] - for x, context in zip(x_list, context_list) : - hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs)) - del x, context - kwargs['context_scale'] = vace_context_scale - for block_idx, block in enumerate(self.blocks): offload.shared_state["layer"] = block_idx diff --git a/wan/text2video.py b/wan/text2video.py index f86284e..d77414d 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -143,6 +143,8 @@ class WanT2V: seq_len=32760, keep_last=True) + self.adapt_vace_model() + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): if ref_images is None: ref_images = [None] * len(frames) @@ -505,3 +507,14 @@ class WanT2V: dist.barrier() return videos[0] if self.rank == 0 else None + + def adapt_vace_model(self): + model = self.model + modules_dict= { k: m for k, m in model.named_modules()} + for num in range(15): + module = modules_dict[f"vace_blocks.{num}"] + target = modules_dict[f"blocks.{2*num}"] + setattr(target, "vace", module ) + delattr(model, "vace_blocks") + + \ No newline at end of file diff --git a/wgp.py b/wgp.py index 53e6b5c..5a25676 100644 --- a/wgp.py +++ b/wgp.py @@ -910,14 +910,6 @@ def get_queue_table(queue): if len(queue) == 1: return data - # def td(l, content, width =None): - # if width !=None: - # l.append("
| Qty | Prompt | Steps | |||||