From 18bf248fb875d5145f1c6b73303e62ee602ec5ae Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Sun, 15 Jun 2025 09:15:18 +1000 Subject: [PATCH] Catch errors from AMD's mi-open for requests > 512 in a given dimension (hello, just using tiled vae decoding!) --- wan/modules/vae.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 3c5f345..dcf635b 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -33,9 +33,32 @@ class CausalConv3d(nn.Conv3d): padding[4] -= cache_x.shape[2] cache_x = None x = F.pad(x, padding) - x = super().forward(x) + try: + out = super().forward(x) + print("(ran fine)") + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: running Conv3d on CPU") + + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # <-- FIX: no padding here + self.dilation, self.groups) + + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + print("... returned (from CPU fallback)") + return out + raise - return x class RMS_norm(nn.Module):