From 18bf248fb875d5145f1c6b73303e62ee602ec5ae Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Sun, 15 Jun 2025 09:15:18 +1000 Subject: [PATCH 1/2] Catch errors from AMD's mi-open for requests > 512 in a given dimension (hello, just using tiled vae decoding!) --- wan/modules/vae.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 3c5f345..dcf635b 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -33,9 +33,32 @@ class CausalConv3d(nn.Conv3d): padding[4] -= cache_x.shape[2] cache_x = None x = F.pad(x, padding) - x = super().forward(x) + try: + out = super().forward(x) + print("(ran fine)") + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: running Conv3d on CPU") + + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # <-- FIX: no padding here + self.dilation, self.groups) + + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + print("... returned (from CPU fallback)") + return out + raise - return x class RMS_norm(nn.Module): From 4ca0666aa5c0d1d66a101f2d3ed357d016d24cd6 Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Sun, 15 Jun 2025 10:12:03 +1000 Subject: [PATCH 2/2] If an error occurs because AMD is asked to VAE Decode without tiling, warn and use CPU decoding. --- wan/modules/vae.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/wan/modules/vae.py b/wan/modules/vae.py index dcf635b..ed6a4ab 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -35,32 +35,27 @@ class CausalConv3d(nn.Conv3d): x = F.pad(x, padding) try: out = super().forward(x) - print("(ran fine)") return out except RuntimeError as e: if "miopenStatus" in str(e): - print("⚠️ MIOpen fallback: running Conv3d on CPU") - + print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be " + "used for this decoding (which is very slow). Consider using tiled VAE Decoding.") x_cpu = x.float().cpu() weight_cpu = self.weight.float().cpu() bias_cpu = self.bias.float().cpu() if self.bias is not None else None - print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") out = F.conv3d(x_cpu, weight_cpu, bias_cpu, - self.stride, (0, 0, 0), # <-- FIX: no padding here + self.stride, (0, 0, 0), # avoid double padding here self.dilation, self.groups) - out = out.to(x.device) if x.dtype in (torch.float16, torch.bfloat16): out = out.half() if x.dtype != out.dtype: out = out.to(x.dtype) - print("... returned (from CPU fallback)") return out raise - class RMS_norm(nn.Module): def __init__(self, dim, channel_first=True, images=True, bias=False):