diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 3c5f345..ed6a4ab 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -33,9 +33,27 @@ class CausalConv3d(nn.Conv3d): padding[4] -= cache_x.shape[2] cache_x = None x = F.pad(x, padding) - x = super().forward(x) - - return x + try: + out = super().forward(x) + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be " + "used for this decoding (which is very slow). Consider using tiled VAE Decoding.") + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # avoid double padding here + self.dilation, self.groups) + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + return out + raise class RMS_norm(nn.Module):