diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 758ad59..41a934b 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -7,23 +7,23 @@ import torch.nn.functional as F major, minor = torch.cuda.get_device_capability(None) bfloat16_supported = major >= 8 -# try: -# from xformers.ops import memory_efficient_attention -# except ImportError: -# memory_efficient_attention = None +try: + from xformers.ops import memory_efficient_attention +except ImportError: + memory_efficient_attention = None -# try: -# import flash_attn_interface -# FLASH_ATTN_3_AVAILABLE = True -# except ModuleNotFoundError: -# FLASH_ATTN_3_AVAILABLE = False +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False -# try: -# import flash_attn -# FLASH_ATTN_2_AVAILABLE = True -# except ModuleNotFoundError: -# FLASH_ATTN_2_AVAILABLE = False -# flash_attn = None +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + flash_attn = None try: from sageattention import sageattn_varlen