diff --git a/wgp.py b/wgp.py index db02e07..f0a1df3 100644 --- a/wgp.py +++ b/wgp.py @@ -28,6 +28,7 @@ from wan.utils import prompt_parser import base64 import io from PIL import Image +import atexit PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -41,6 +42,9 @@ current_task_id = None task_id = 0 # progress_tracker = {} # tracker_lock = threading.Lock() +last_model_type = None +QUEUE_FILENAME = "queue.json" +global_dict = [] def format_time(seconds): if seconds < 60: @@ -83,170 +87,195 @@ def pil_to_base64_uri(pil_image, format="png", quality=75): def process_prompt_and_add_tasks(state, model_choice): - + if state.get("validate_success",0) != 1: - return - + gr.Info("Validation failed, not adding tasks.") # Added Info + return gr.update() # Return an update to avoid downstream errors + state["validate_success"] = 0 model_filename = state["model_filename"] if model_choice != get_model_type(model_filename): - raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") - + raise gr.Error("Webform model mismatch. The App's selected model has changed since the form was displayed. Please refresh the page or re-select the model.") + + # Get inputs specific to the current model type from the state inputs = state.get(get_model_type(model_filename), None) - inputs["state"] = state - if inputs == None: - return - prompt = inputs["prompt"] - if len(prompt) ==0: - return + if inputs is None: + gr.Warning(f"Could not find inputs for model type {get_model_type(model_filename)} in state.") + return gr.update() # Return empty update + + inputs["state"] = state # Re-add state for add_video_task + + prompt = inputs.get("prompt", "") # Use .get for safety + if not prompt: + gr.Info("Prompt is empty, not adding tasks.") + return gr.update() + prompt, errors = prompt_parser.process_template(prompt) - if len(errors) > 0: + if errors: gr.Info("Error processing prompt template: " + errors) - return - - inputs["model_filename"] = model_filename + return gr.update() + + inputs["model_filename"] = model_filename # Ensure model_filename is in inputs for add_video_task prompts = prompt.replace("\r", "").split("\n") - prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] - if len(prompts) ==0: - return + prompts = [p.strip() for p in prompts if p.strip() and not p.startswith("#")] + if not prompts: + gr.Info("No valid prompts found after processing, not adding tasks.") + return gr.update() - resolution = inputs["resolution"] - width, height = resolution.split("x") - width, height = int(width), int(height) + resolution = inputs.get("resolution", "832x480") # Use .get + width, height = map(int, resolution.split("x")) + + # --- Validation specific to model types --- if test_class_i2v(model_filename): - if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: + if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") - return - resolution = str(width) + "*" + str(height) - if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: - gr.Info(f"Resolution {resolution} not supported by image 2 video") - return - - if "1.3B" in model_filename and width * height > 848*480: - gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") - return + return gr.update() + # Ensure resolution format is correct for I2V, if needed (adjust based on actual requirements) + # resolution_str = f"{width}x{height}" # Use 'x' as separator consistently? Check MAX_AREA_CONFIGS keys + # if resolution_str not in MAX_AREA_CONFIGS: # Or a specific list for I2V + # gr.Info(f"Resolution {resolution} might not be directly supported by image 2 video. Check MAX_AREA_CONFIGS.") + # return gr.update() # Decide if this is a hard error + if "1.3B" in model_filename and width * height > 848*480: + # This check might be too strict depending on the model. Re-evaluate if needed. + gr.Info("You might need the 14B model to generate videos with a resolution equivalent to 720P") + # return gr.update() # Decide if this is a hard error + # --- Task generation based on model type --- + tasks_added = 0 if "Vace" in model_filename: - video_prompt_type = inputs["video_prompt_type"] - image_refs = inputs["image_refs"] - video_guide = inputs["video_guide"] - video_mask = inputs["video_mask"] - if "Vace" in model_filename and "1.3B" in model_filename : - resolution_reformated = str(height) + "*" + str(width) - if not resolution_reformated in VACE_SIZE_CONFIGS: - res = VACE_SIZE_CONFIGS.keys().join(" and ") - gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - return - if not "I" in video_prompt_type: - image_refs = None - if not "V" in video_prompt_type: - video_guide = None - if not "M" in video_prompt_type: - video_mask = None + video_prompt_type = inputs.get("video_prompt_type", "") + image_ref_paths = inputs.get("image_refs") # Now contains file paths + video_guide_path = inputs.get("video_guide") + video_mask_path = inputs.get("video_mask") - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] + # Input filtering based on type + if "I" not in video_prompt_type: image_ref_paths = None + if "V" not in video_prompt_type: video_guide_path = None + if "M" not in video_prompt_type: video_mask_path = None - from wan.utils.utils import resize_and_remove_background - image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1) - + # VACE specific validation (e.g., resolution) + if "1.3B" in model_filename: + resolution_reformated = f"{height}x{width}" # Check VACE_SIZE_CONFIGS format + if resolution_reformated not in VACE_SIZE_CONFIGS: + allowed_res = " and ".join(VACE_SIZE_CONFIGS.keys()) + gr.Info(f"Video Resolution {resolution} for Vace 1.3B model is not supported. Only {allowed_res} resolutions are allowed.") + return gr.update() - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "image_refs": image_refs, - "video_guide" : video_guide, - "video_mask" : video_mask , - } - inputs.update(extra_inputs) - add_video_task(**inputs) - elif "image2video" in model_filename or "Fun_InP" in model_filename : - image_prompt_type = inputs["image_prompt_type"] + # --- VACE image refs are now paths, processing happens in generate_video --- + # Remove PIL processing here: + # if isinstance(image_ref_paths, list): + # # image_refs_pil = [ convert_image(tup[0]) for tup in image_ref_paths ] # This happens later + # # from wan.utils.utils import resize_and_remove_background # This happens later + # # image_refs_processed = resize_and_remove_background(image_refs_pil, width, height, inputs["remove_background_image_ref"] ==1) # This happens later + # pass # Just keep the paths - image_start = inputs["image_start"] - image_end = inputs["image_end"] - if image_start == None or isinstance(image_start, list) and len(image_start) == 0: - return - if not "E" in image_prompt_type: - image_end = None - if isinstance(image_start, list): - image_start = [ convert_image(tup[0]) for tup in image_start ] - else: - image_start = [convert_image(image_start)] - if image_end != None: - if isinstance(image_end , list): - image_end = [ convert_image(tup[0]) for tup in image_end ] + for single_prompt in prompts: + task_params = inputs.copy() # Start with base inputs + task_params.update({ + "prompt": single_prompt, + "image_refs": image_ref_paths, # Pass paths + "video_guide": video_guide_path, + "video_mask": video_mask_path, + }) + add_video_task(**task_params) + tasks_added += 1 + + elif "image2video" in model_filename or "Fun_InP" in model_filename: + image_prompt_type = inputs.get("image_prompt_type", "S") + image_start_paths = inputs.get("image_start") # Now list of file paths or single path + image_end_paths = inputs.get("image_end") # Now list of file paths or single path + + if not image_start_paths or (isinstance(image_start_paths, list) and not image_start_paths): + gr.Info("Image 2 Video requires at least one start image.") + return gr.update() + + # Ensure paths are lists + if image_start_paths and not isinstance(image_start_paths, list): + image_start_paths = [image_start_paths] + if image_end_paths and not isinstance(image_end_paths, list): + image_end_paths = [image_end_paths] + + # Input filtering based on type + if "E" not in image_prompt_type: + image_end_paths = None + + # Validation + if image_end_paths and len(image_start_paths) != len(image_end_paths): + gr.Info("The number of start and end images provided must be the same when using End Images.") + return gr.update() + + # --- I2V start/end images are now paths, processing happens in generate_video --- + # Remove PIL processing here + + # --- Handle multiple prompts/images (using paths) --- + combined_prompts = [] + combined_start_paths = [] + combined_end_paths = [] if image_end_paths else None + + multi_type = inputs.get("multi_images_gen_type", 0) + num_prompts = len(prompts) + num_images = len(image_start_paths) + + if multi_type == 0: # Cartesian product + for i in range(num_prompts * num_images): + prompt_idx = i % num_prompts + image_idx = i // num_prompts + combined_prompts.append(prompts[prompt_idx]) + combined_start_paths.append(image_start_paths[image_idx]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[image_idx]) + else: # Match/Repeat + if num_prompts >= num_images: + if num_prompts % num_images != 0: + gr.Error("If more prompts than images (matching type), prompt count must be multiple of image count.") + return gr.update() + rep = num_prompts // num_images + for i in range(num_prompts): + img_idx = i // rep + combined_prompts.append(prompts[i]) + combined_start_paths.append(image_start_paths[img_idx]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[img_idx]) else: - image_end = [convert_image(image_end) ] - if len(image_start) != len(image_end): - gr.Info("The number of start and end images should be the same ") - return - - if inputs["multi_images_gen_type"] == 0: - new_prompts = [] - new_image_start = [] - new_image_end = [] - for i in range(len(prompts) * len(image_start) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_start.append(image_start[i // len(prompts)] ) - if image_end != None: - new_image_end.append(image_end[i // len(prompts)] ) - prompts = new_prompts - image_start = new_image_start - if image_end != None: - image_end = new_image_end - else: - if len(prompts) >= len(image_start): - if len(prompts) % len(image_start) != 0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_start) - new_image_start = [] - new_image_end = [] - for i, _ in enumerate(prompts): - new_image_start.append(image_start[i//rep] ) - if image_end != None: - new_image_end.append(image_end[i//rep] ) - image_start = new_image_start - if image_end != None: - image_end = new_image_end - else: - if len(image_start) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_start) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_start): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts + if num_images % num_prompts != 0: + gr.Error("If more images than prompts (matching type), image count must be multiple of prompt count.") + return gr.update() + rep = num_images // num_prompts + for i in range(num_images): + prompt_idx = i // rep + combined_prompts.append(prompts[prompt_idx]) + combined_start_paths.append(image_start_paths[i]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[i]) - - if image_start == None: - image_start = [None] * len(prompts) - if image_end == None: - image_end = [None] * len(prompts) + # Create tasks + for i, single_prompt in enumerate(combined_prompts): + task_params = inputs.copy() + task_params.update({ + "prompt": single_prompt, + "image_start": combined_start_paths[i], # Pass single path + "image_end": combined_end_paths[i] if combined_end_paths is not None else None, # Pass single path or None + }) + # Ensure multi_images_gen_type doesn't cause issues later if it was 0/1 + task_params["multi_images_gen_type"] = -1 # Indicate already processed + add_video_task(**task_params) + tasks_added += 1 - for single_prompt, start, end in zip(prompts, image_start, image_end) : - extra_inputs = { - "prompt" : single_prompt, - "image_start": start, - "image_end" : end, - } - inputs.update(extra_inputs) - add_video_task(**inputs) - else: - for single_prompt in prompts : - extra_inputs = { - "prompt" : single_prompt, - } - inputs.update(extra_inputs) - add_video_task(**inputs) + else: # Text to Video (no image inputs specific to generation) + for single_prompt in prompts: + task_params = inputs.copy() + task_params.update({"prompt": single_prompt}) + add_video_task(**task_params) + tasks_added += 1 + # --- Update queue UI --- gen = get_gen_info(state) - gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) + gen["prompts_max"] = tasks_added + gen.get("prompts_max", 0) state["validate_success"] = 1 - queue= gen.get("queue", []) + queue = gen.get("queue", []) return update_queue_data(queue) @@ -259,31 +288,93 @@ def add_video_task(**inputs): queue = gen["queue"] task_id += 1 current_task_id = task_id - inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] - start_image_data = None - end_image_data = None - for name in inputs_to_query: - image= inputs.get(name, None) - if image != None: - image= [image] if not isinstance(image, list) else image - if start_image_data == None: - start_image_data = image - else: - end_image_data = image - break - queue.append({ + # --- Identify image paths from inputs --- + # Use .get() for safety + start_image_paths = inputs.get("image_start") # Could be single path or list + end_image_paths = inputs.get("image_end") # Could be single path or list + ref_image_paths = inputs.get("image_refs") # Could be list or None + + # Standardize to lists or None + if start_image_paths and not isinstance(start_image_paths, list): + start_image_paths = [start_image_paths] + if end_image_paths and not isinstance(end_image_paths, list): + end_image_paths = [end_image_paths] + # ref_image_paths is likely already a list if present + + # Prioritize which images to show as previews in the queue UI + # Typically start/end for I2V, refs for VACE? Or just first available? + primary_preview_paths = None + secondary_preview_paths = None + + if start_image_paths: + primary_preview_paths = start_image_paths + if end_image_paths: + secondary_preview_paths = end_image_paths + elif ref_image_paths: + primary_preview_paths = ref_image_paths + # Add logic for video previews if needed (e.g., video_guide) + + # --- Generate Base64 previews from paths --- + start_image_data_base64 = [] + if primary_preview_paths: + try: + # Load only the first image for preview if it's a list + path_to_load = primary_preview_paths[0] + if path_to_load and Path(path_to_load).is_file(): + loaded_image = Image.open(path_to_load) + b64 = pil_to_base64_uri(loaded_image, format="jpeg", quality=70) + if b64: + start_image_data_base64.append(b64) + else: + print(f"Warning: Primary preview image path not found or invalid: {path_to_load}") + start_image_data_base64.append(None) # Add placeholder if needed + except Exception as e: + print(f"Warning: Could not load primary preview image for UI: {e}") + start_image_data_base64.append(None) + + end_image_data_base64 = [] + if secondary_preview_paths: + try: + path_to_load = secondary_preview_paths[0] + if path_to_load and Path(path_to_load).is_file(): + loaded_image = Image.open(path_to_load) + b64 = pil_to_base64_uri(loaded_image, format="jpeg", quality=70) + if b64: + end_image_data_base64.append(b64) + else: + print(f"Warning: Secondary preview image path not found or invalid: {path_to_load}") + end_image_data_base64.append(None) + except Exception as e: + print(f"Warning: Could not load secondary preview image for UI: {e}") + end_image_data_base64.append(None) + + + # --- Prepare params for the queue (ensure paths are stored) --- + params_copy = inputs.copy() + # Remove state object before storing + if 'state' in params_copy: + del params_copy['state'] + # Ensure image keys contain the paths as received + # (No need to explicitly set image_start_paths etc. if inputs already has them correctly) + + queue_item = { "id": current_task_id, - "params": inputs.copy(), - "repeats": inputs["repeat_generation"], - "length": inputs["video_length"], - "steps": inputs["num_inference_steps"], - "prompt": inputs["prompt"], - "start_image_data": start_image_data, - "end_image_data": end_image_data, - "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, - "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None - }) + "params": params_copy, # Contains paths for image_start, image_end, image_refs etc. + "repeats": inputs.get("repeat_generation", 1), + "length": inputs.get("video_length"), + "steps": inputs.get("num_inference_steps"), + "prompt": inputs.get("prompt"), + # Store the base64 previews separately for the UI + "start_image_data_base64": start_image_data_base64 if start_image_data_base64 else None, + "end_image_data_base64": end_image_data_base64 if end_image_data_base64 else None, + # Keep original paths in params for saving/loading consistency if needed, + # but the generate_video function will use the primary keys like 'image_start' + # "start_image_paths_ref": start_image_paths, # Example if needed for save/load + # "end_image_paths_ref": end_image_paths, # Example if needed for save/load + } + + queue.append(queue_item) return update_queue_data(queue) def move_up(queue, selected_indices): @@ -326,7 +417,258 @@ def remove_task(queue, selected_indices): del queue[idx] return update_queue_data(queue) +def maybe_trigger_processing(should_start, state): + if should_start: + yield from maybe_start_processing(state) + else: + gen = get_gen_info(state) + last_msg = gen.get("last_msg", "Idle") + yield last_msg +def maybe_start_processing(state, progress=gr.Progress()): + gen = get_gen_info(state) + queue = gen.get("queue", []) + in_progress = gen.get("in_progress", False) + initial_status = gen.get("last_msg", "Idle") + if queue and not in_progress: + initial_status = "Starting automatic processing..." + yield initial_status + try: + for status_update in process_tasks(state, progress): + print(f"*** Yielding from process_tasks: '{status_update}' ***") + yield status_update + print(f"*** Finished iterating process_tasks normally. ***") + except Exception as e: + print(f"*** Error during maybe_start_processing -> process_tasks: {e} ***") + yield f"Error during processing: {str(e)}" + else: + last_msg = gen.get("last_msg", "Idle") + initial_status = last_msg + yield initial_status + +def save_queue_to_json(queue, filename=QUEUE_FILENAME): + """Saves the current task queue to a JSON file.""" + tasks_to_save = [] + max_id = 0 + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + + params_to_save = task.get('params', {}).copy() + + # --- REMOVE THESE INCORRECT setdefault calls --- + # params_to_save.setdefault('prompt', '') + # params_to_save.setdefault('repeats', task.get('repeats', 1)) # NO + # params_to_save.setdefault('length', task.get('length')) # NO + # params_to_save.setdefault('steps', task.get('steps')) # NO + # params_to_save.setdefault('model_filename', '') # NO + + # --- INSTEAD: Check if essential model_filename exists --- + if 'model_filename' not in params_to_save or not params_to_save['model_filename']: + print(f"Warning: Skipping task {task.get('id')} during save due to missing model_filename in params.") + continue # Don't save tasks without essential info + + # Remove non-serializable items explicitly if any remain + params_to_save.pop('state', None) + # Remove potentially large PIL objects if they slipped through + keys_to_remove = [k for k, v in params_to_save.items() if isinstance(v, Image.Image)] + for k in keys_to_remove: del params_to_save[k] + + + task_data = { + "id": task.get('id', 0), + "params": params_to_save, # This should now be clean + # Get these values from the original task dict OR the params dict as fallback + "repeats": task.get('repeats', params_to_save.get('repeat_generation', 1)), + "length": task.get('length', params_to_save.get('video_length')), + "steps": task.get('steps', params_to_save.get('num_inference_steps')), + "prompt": task.get('prompt', params_to_save.get('prompt', '')), + # Keep base64 for fast UI reload without reading files + "start_image_data_base64": task.get("start_image_data_base64"), + "end_image_data_base64": task.get("end_image_data_base64"), + } + tasks_to_save.append(task_data) + max_id = max(max_id, task_data["id"]) + + try: + # Use ensure_ascii=False for wider character support, though prompt sanitization helps + with open(filename, 'w', encoding='utf-8') as f: + json.dump(tasks_to_save, f, indent=4, ensure_ascii=False) + print(f"Queue saved successfully to {filename}") + return max_id + except Exception as e: + print(f"Error saving queue to {filename}: {e}") + gr.Warning(f"Failed to save queue: {e}") + return max_id + +def load_queue_from_json(filename=QUEUE_FILENAME): + """Loads tasks from a JSON file back into the queue format.""" + global task_id # To update the global counter + if not Path(filename).is_file(): + print(f"Queue file {filename} not found. Starting with empty queue.") + return [], 0 + try: + with open(filename, 'r', encoding='utf-8') as f: + loaded_tasks_data = json.load(f) + except Exception as e: + print(f"Error loading or parsing queue file {filename}: {e}") + gr.Warning(f"Failed to load queue: {e}") + return [], 0 + + reconstructed_queue = [] + max_id = 0 + print(f"Loading {len(loaded_tasks_data)} tasks from {filename}...") + + for task_data in loaded_tasks_data: + if task_data is None or not isinstance(task_data, dict): continue + + params = task_data.get('params', {}) + if not params or 'model_filename' not in params: + print(f"Skipping task {task_data.get('id')} due to missing params or model_filename.") + continue + + task_id_loaded = task_data.get('id', 0) + max_id = max(max_id, task_id_loaded) + + # Get base64 previews directly from saved data + start_image_data_base64 = task_data.get('start_image_data_base64') + end_image_data_base64 = task_data.get('end_image_data_base64') + + # --- Verify paths in params still exist (optional but good practice) --- + image_path_keys = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] + for key in image_path_keys: + paths = params.get(key) + if paths: + if isinstance(paths, list): + valid_paths = [p for p in paths if p and Path(p).exists()] + if len(valid_paths) != len(paths): + print(f"Warning: Some paths for '{key}' in loaded task {task_id_loaded} not found. Using only valid ones.") + params[key] = valid_paths # Update params with only existing paths + elif isinstance(paths, str): + if not Path(paths).exists(): + print(f"Warning: Path for '{key}' in loaded task {task_id_loaded} not found: {paths}. Setting to None.") + params[key] = None + # --- + + queue_item = { + "id": task_id_loaded, + "params": params, # Contains paths and all other settings + "repeats": task_data.get('repeats', params.get('repeat_generation', 1)), # Get from task_data or params + "length": task_data.get('length', params.get('video_length')), + "steps": task_data.get('steps', params.get('num_inference_steps')), + "prompt": task_data.get('prompt', params.get('prompt', '')), + # Store base64 previews for UI + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64, + # 'start_image_data' and 'end_image_data' (PIL) are not stored/loaded + } + reconstructed_queue.append(queue_item) + + print(f"Queue loaded successfully from {filename}. Max ID found: {max_id}") + # Update global task_id if needed (handled in load_queue_action) + return reconstructed_queue, max_id + +def save_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is empty. Nothing to save.") + return None + tasks_to_save = [] + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + try: + json_string = json.dumps(tasks_to_save, indent=4) + print("Queue data prepared as JSON string for client-side download.") + return json_string + except Exception as e: + print(f"Error converting queue to JSON string: {e}") + gr.Warning(f"Failed to prepare queue data for saving: {e}") + return None + +def load_queue_action(filepath, state_dict): + global task_id + if not filepath or not Path(filepath.name).is_file(): + gr.Warning(f"No file selected or file not found.") + return None + loaded_queue, max_id = load_queue_from_json(filepath.name) + + with lock: + gen = get_gen_info(state_dict) + + if "queue" not in gen or not isinstance(gen.get("queue"), list): + gen["queue"] = [] + + existing_queue = gen["queue"] + existing_queue.clear() + existing_queue.extend(loaded_queue) + + task_id = max(task_id, max_id) + gen["prompts_max"] = len(existing_queue) + + gr.Info(f"Queue loaded from {Path(filepath.name).name}") + return None + +def update_queue_ui_after_load(state_dict): + gen = get_gen_info(state_dict) + queue = gen.get("queue", []) + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + return gr.update(value=raw_data, visible=is_visible), gr.update(visible=is_visible) + +def clear_queue_action(state): + gen = get_gen_info(state) + with lock: + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is already empty.") + return update_queue_data([]) + + queue.clear() + gen["prompts_max"] = 0 + gr.Info("Queue cleared.") + return update_queue_data([]) + +def autoload_queue(state_dict): + global task_id + gen = get_gen_info(state_dict) + queue_changed = False + if Path(QUEUE_FILENAME).exists(): + print(f"Autoloading queue from {QUEUE_FILENAME}...") + if not gen["queue"]: + loaded_queue, max_id = load_queue_from_json(QUEUE_FILENAME) + if loaded_queue: + with lock: + gen["queue"] = loaded_queue + task_id = max(task_id, max_id) + gen["prompts_max"] = len(loaded_queue) + queue_changed = True + +def autosave_queue(): + print("Attempting to autosave queue on exit...") + global global_dict + + if global_dict: + print(f"Autosaving queue ({len(global_dict)} items) from state dict...") + save_queue_to_json(global_dict, QUEUE_FILENAME) + else: + print("Queue is empty in the determined active state dictionary, autosave skipped.") def get_queue_table(queue): data = [] @@ -397,9 +739,9 @@ def update_queue_data(queue): # else: # return gr.HTML(value=data, visible= True) if len(data) == 0: - return gr.DataFrame(visible=False) + return gr.update(value=[], visible=False) else: - return gr.DataFrame(value=data, visible= True) + return gr.update(value=data, visible=True) def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1143,7 +1485,6 @@ def load_i2v_model(model_filename, value): def load_models(model_filename): global transformer_filename - transformer_filename = model_filename download_models(model_filename, text_encoder_filename) if test_class_i2v(model_filename): @@ -1448,9 +1789,20 @@ def finalize_generation(state): time.sleep(0.2) global gen_in_progress gen_in_progress = False - return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="") + queue = gen.get("queue", []) + queue_is_visible = bool(queue) + current_gen_column_visible = queue_is_visible + gen_info_visible = queue_is_visible + return ( + gr.Gallery(selected_index=choice), + gr.Button(interactive=True), + gr.Button(visible=True), + gr.Button(visible=False), + gr.Column(visible=current_gen_column_visible), + gr.HTML(visible=gen_info_visible, value="") + ) def refresh_gallery_on_trigger(state): gen = get_gen_info(state) @@ -1604,6 +1956,33 @@ def generate_video( trans = wan_model.model temp_filename = None + + loaded_image_start_pil = None + loaded_image_end_pil = None + loaded_image_refs_pil = [] + + try: + if image_start and isinstance(image_start, str) and Path(image_start).is_file(): + loaded_image_start_pil = convert_image(Image.open(image_start)) + elif image_start: + print(f"Warning: Start image path not found or invalid: {image_start}") + + if image_end and isinstance(image_end, str) and Path(image_end).is_file(): + loaded_image_end_pil = convert_image(Image.open(image_end)) + elif image_end: + print(f"Warning: End image path not found or invalid: {image_end}") + + if image_refs and isinstance(image_refs, list): + valid_ref_paths = [p for p in image_refs if p and Path(p).is_file()] + if len(valid_ref_paths) != len(image_refs): + print("Warning: Some VACE reference image paths were invalid.") + loaded_image_refs_pil = [convert_image(Image.open(p)) for p in valid_ref_paths] + if not loaded_image_refs_pil and "I" in (video_prompt_type or ""): + print("Warning: No valid VACE reference images loaded despite type 'I'.") + + except Exception as e: + print(f"ERROR loading image file: {e}") + raise gr.Error(f"Failed to load input image: {e}") loras = state["loras"] if len(loras) > 0: @@ -1742,13 +2121,12 @@ def generate_video( trans.teacache_skipped_steps = 0 trans.previous_residual_uncond = None trans.previous_residual_cond = None - video_no += 1 if image2video: samples = wan_model.generate( prompt, - image_start, - image_end if image_end != None else None, + loaded_image_start_pil, + loaded_image_end_pil if loaded_image_end_pil != None else None, frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution_reformated], shift=flow_shift, @@ -1975,7 +2353,7 @@ def process_tasks(state, progress=gr.Progress()): task = queue[0] task_id = task["id"] params = task['params'] - iterator = iter(generate_video(task_id, progress, **params)) + iterator = iter(generate_video(task_id, progress, **params, state=state)) while True: try: ok = False @@ -2341,41 +2719,83 @@ def switch_advanced(state, new_advanced, lset_name): def prepare_inputs_dict(target, inputs ): - - state = inputs.pop("state") - loras = state["loras"] - if "loras_choices" in inputs: - loras_choices = inputs.pop("loras_choices") - inputs.pop("model_filename", None) - activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] - inputs["activated_loras"] = activated_loras + """Prepares the inputs dictionary for saving state, settings, or metadata.""" + + # Make a copy to avoid modifying the original dict + inputs_copy = inputs.copy() + + # --- Remove target early --- + inputs_copy.pop("target", None) # Remove the target key itself + + # Remove objects not suitable for JSON/saving + state = inputs_copy.pop("state", None) # Remove state object + + # Lora handling: activated_loras should already be a list of names/stems + # If loras_choices was passed (from UI save), convert it + if "loras_choices" in inputs_copy and state and "loras" in state: + loras_choices = inputs_copy.pop("loras_choices") + loras = state.get("loras", []) + try: + # Use basename to match how activated_loras is likely stored in settings + activated_lora_names = [os.path.basename(loras[int(no)]) for no in loras_choices] + inputs_copy["activated_loras"] = activated_lora_names + except (IndexError, ValueError, TypeError) as e: + print(f"Warning: Could not convert loras_choices to names: {e}") + inputs_copy["activated_loras"] = [] # Default to empty on error + elif "activated_loras" not in inputs_copy: + inputs_copy["activated_loras"] = [] # Ensure key exists + + + # --- Target-specific adjustments --- if target == "state": - return inputs + # For saving to the main state dict, we want the raw inputs as received + # including file paths. Target is already removed. + return inputs_copy - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] - for k in unsaved_params: - inputs.pop(k) + # For settings and metadata, remove non-serializable or large data (like paths?) + # Keep paths for settings so they can be reloaded, but maybe not metadata? + + # Remove PIL objects if they accidentally got passed (shouldn't happen now) + keys_to_remove = [] + for k, v in inputs_copy.items(): + if isinstance(v, Image.Image): + keys_to_remove.append(k) + print(f"Warning: Removing unexpected PIL Image object for key '{k}' during {target} preparation.") + for k in keys_to_remove: + del inputs_copy[k] - model_filename = state["model_filename"] - inputs["type"] = "Wan2.1GP by DeepBeepMeep - " + get_model_name(model_filename) if target == "settings": - return inputs - - if not any(k in model_filename for k in ["image2video", "Fun_InP"]): - inputs.pop("image_prompt_type") + # Keep file paths (image_start, image_end, image_refs, video_guide, video_mask) + # Ensure prompt key exists if 'prompts' was used in UI defaults + if "prompts" in inputs_copy and "prompt" not in inputs_copy: + inputs_copy["prompt"] = inputs_copy.pop("prompts") + # Convert image_prompt_type back to S/SE if needed? Or keep string? String is fine. + return inputs_copy + elif target == "metadata": + # Add type information + # Get model filename safely + model_filename = inputs_copy.get("model_filename") + if not model_filename and state: + model_filename = state.get("model_filename", "unknown") + elif not model_filename: + model_filename = "unknown" - if not "Vace" in model_filename: - unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"] - for k in unsaved_params: - inputs.pop(k) + inputs_copy["type"] = f"WanGP by DeepBeepMeep - {get_model_name(model_filename)}" - if target == "metadata": - inputs = {k: v for k,v in inputs.items() if v != None } + # Remove file paths from metadata? Or keep them? Keep for reproducibility. + # Remove None values? + metadata_dict = {k: v for k, v in inputs_copy.items() if v is not None} - return inputs + # Clean up keys not relevant for metadata? + metadata_dict.pop("multi_images_gen_type", None) # This was processing logic + + return metadata_dict + + # Should not reach here if target is valid + return inputs_copy def get_function_arguments(func, locals): args_names = list(inspect.signature(func).parameters) @@ -2405,7 +2825,7 @@ def save_inputs( image_prompt_type, image_start, image_end, - video_prompt_type, + video_prompt_type, image_refs, video_guide, video_mask, @@ -2414,7 +2834,7 @@ def save_inputs( temporal_upsampling, spatial_upsampling, RIFLEx_setting, - slg_switch, + slg_switch, slg_layers, slg_start_perc, slg_end_perc, @@ -2423,22 +2843,27 @@ def save_inputs( state, ): - - # if state.get("validate_success",0) != 1: - # return - model_filename = state["model_filename"] - inputs = get_function_arguments(save_inputs, locals()) - inputs.pop("target") - cleaned_inputs = prepare_inputs_dict(target, inputs) + current_locals = locals() + inputs_dict = get_function_arguments(save_inputs, current_locals) + cleaned_inputs = prepare_inputs_dict(target, inputs_dict) + model_filename = state.get("model_filename") + if not model_filename: + print("Warning: Cannot save inputs, model_filename not found in state.") + return + if target == "settings": defaults_filename = get_settings_file_name(model_filename) + try: + with open(defaults_filename, "w", encoding="utf-8") as f: + json.dump(cleaned_inputs, f, indent=4) + gr.Info(f"Default Settings saved for {get_model_name(model_filename)}") + except Exception as e: + print(f"Error saving settings to {defaults_filename}: {e}") + gr.Error(f"Failed to save settings: {e}") - with open(defaults_filename, "w", encoding="utf-8") as f: - json.dump(cleaned_inputs, f, indent=4) - - gr.Info("New Default Settings saved") elif target == "state": - state[get_model_type(model_filename)] = cleaned_inputs + model_type_key = get_model_type(model_filename) + state[model_type_key] = cleaned_inputs def download_loras(): from huggingface_hub import snapshot_download @@ -2660,23 +3085,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if args.multiple_images: image_start = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", + label="Images as starting points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) else: - image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + image_start = gr.Image(label= "Image as a starting point for a new video", type ="filepath",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) if args.multiple_images: image_end = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", + label="Images as ending points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) else: - image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + image_end = gr.Image(label= "Last Image for a new video", type ="filepath", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","I") video_prompt_type = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3) image_refs = gr.Gallery( - label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", + label="Reference Images of Faces and / or Object to be found in the Video", type ="filepath", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) ) video_guide = gr.Video(label= "Reference Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) ) @@ -2898,7 +3323,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], - column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], + column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"], interactive=False, col_count=(9, "fixed"), wrap=True, @@ -2907,7 +3332,68 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible= False, elem_id="queue_df" ) - + with gr.Row(): + queue_json_output = gr.Text(visible=False, label="_queue_json") + save_queue_btn = gr.DownloadButton("Save Queue") + load_queue_btn = gr.UploadButton("Load Queue", file_types=[".json"], type="filepath") + clear_queue_btn = gr.Button("Clear Queue") + trigger_download_js = """ + (jsonString) => { + if (!jsonString) { + console.log("No JSON data received, skipping download."); + return; + } + const blob = new Blob([jsonString], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'queue.json'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } + """ + if not update_form: + save_queue_btn.click( + fn=save_queue_action, + inputs=[state], + outputs=[queue_json_output] + ).then( + fn=None, + inputs=[queue_json_output], + outputs=None, + js=trigger_download_js + ) + load_queue_btn.upload( + fn=load_queue_action, + inputs=[load_queue_btn, state], + outputs=None + ).then( + fn=update_queue_ui_after_load, + inputs=[state], + outputs=[queue_df, current_gen_column], + ).then( + fn=maybe_start_processing, + inputs=[state], + outputs=[gen_status], + show_progress="minimal", + trigger_mode="always_last" + ).then( + fn=finalize_generation, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info], + trigger_mode="always_last" + ) + clear_queue_btn.click( + clear_queue_action, + inputs=[state], + outputs=[queue_df] + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[current_gen_column] + ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column, if update_form: @@ -2988,6 +3474,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then(unload_model_if_needed, inputs= [state], outputs= [] + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) add_to_queue_btn.click(fn=validate_wizard_prompt, @@ -3002,6 +3492,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then( fn=update_status, inputs = [state], + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) close_modal_button.click( @@ -3010,7 +3504,24 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs=[modal_container] ) - return loras_choices, lset_name, state + return ( + loras_choices, + lset_name, + state, + queue_df, + current_gen_column, + gen_status, + output, + abort_btn, + generate_btn, + add_to_queue_btn, + gen_info + # , + # prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + # prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + # advanced_row, image_prompt_column, video_prompt_column, + # *prompt_vars + ) def generate_download_tab(lset_name,loras_choices, state): with gr.Row(): @@ -3294,16 +3805,13 @@ def create_demo(): vertical-align: middle; font-size:11px; } - #xqueue_df table { + #queue_df table { width: 100%; - overflow: hidden !important; } - #xqueue_df::-webkit-scrollbar { - display: none !important; - } - #xqueue_df { + #queue_df { scrollbar-width: none !important; - -ms-overflow-style: none !important; + overflow-x: hidden !important; + overflow-y: auto; } .selection-button { display: none; @@ -3474,8 +3982,15 @@ def create_demo(): with gr.Row(): header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True) with gr.Row(): - - loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) + ( + loras_choices, lset_name, state, queue_df, current_gen_column, + gen_status, output, abort_btn, generate_btn, add_to_queue_btn, + gen_info + # ,prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + # prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + # advanced_row, image_prompt_column, video_prompt_column, + # *prompt_vars_outputs + ) = generate_video_tab(model_choice=model_choice, header=header) with gr.Tab("Informations"): generate_info_tab() if not args.lock_config: @@ -3485,11 +4000,54 @@ def create_demo(): generate_configuration_tab(header, model_choice) with gr.Tab("About"): generate_about_tab() + def run_autoload_and_update(current_state): + autoload_queue(current_state) + gen = get_gen_info(current_state) + queue = gen.get("queue", []) + global global_dict + global_dict = queue + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + should_start_processing = bool(queue) + df_update = gr.update(value=raw_data, visible=is_visible) + col_update = gr.update(visible=is_visible) + return ( + df_update, + col_update, + should_start_processing + ) + + should_start_flag = gr.State(False) + + load_outputs_ui = [ + queue_df, + current_gen_column, + should_start_flag, + ] + + demo.load( + fn=run_autoload_and_update, + inputs=[state], + outputs=load_outputs_ui + ).then( + fn=maybe_trigger_processing, + inputs=[should_start_flag, state], + outputs=[gen_status], + ).then( + fn=finalize_generation, + inputs=[state], + outputs=[ + output, abort_btn, generate_btn, add_to_queue_btn, + current_gen_column, gen_info + ], + trigger_mode="always_last" + ) return demo if __name__ == "__main__": # threading.Thread(target=runner, daemon=True).start() + atexit.register(autosave_queue) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) if os.name == "nt":