From 61489e1a177a18562369df3aaa8022fde142e106 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 7 Apr 2025 14:47:43 +1000 Subject: [PATCH] add queue saving/loading/clearing/autoloading/autosaving, fix styling --- wgp.py | 1011 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 782 insertions(+), 229 deletions(-) diff --git a/wgp.py b/wgp.py index c81ab9a..44146a2 100644 --- a/wgp.py +++ b/wgp.py @@ -28,6 +28,7 @@ from wan.utils import prompt_parser import base64 import io from PIL import Image +import atexit PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -41,6 +42,9 @@ current_task_id = None task_id = 0 # progress_tracker = {} # tracker_lock = threading.Lock() +last_model_type = None +QUEUE_FILENAME = "queue.json" +global_dict = [] def format_time(seconds): if seconds < 60: @@ -83,170 +87,195 @@ def pil_to_base64_uri(pil_image, format="png", quality=75): def process_prompt_and_add_tasks(state, model_choice): - + if state.get("validate_success",0) != 1: - return - + gr.Info("Validation failed, not adding tasks.") # Added Info + return gr.update() # Return an update to avoid downstream errors + state["validate_success"] = 0 model_filename = state["model_filename"] if model_choice != get_model_type(model_filename): - raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") - + raise gr.Error("Webform model mismatch. The App's selected model has changed since the form was displayed. Please refresh the page or re-select the model.") + + # Get inputs specific to the current model type from the state inputs = state.get(get_model_type(model_filename), None) - inputs["state"] = state - if inputs == None: - return - prompt = inputs["prompt"] - if len(prompt) ==0: - return + if inputs is None: + gr.Warning(f"Could not find inputs for model type {get_model_type(model_filename)} in state.") + return gr.update() # Return empty update + + inputs["state"] = state # Re-add state for add_video_task + + prompt = inputs.get("prompt", "") # Use .get for safety + if not prompt: + gr.Info("Prompt is empty, not adding tasks.") + return gr.update() + prompt, errors = prompt_parser.process_template(prompt) - if len(errors) > 0: + if errors: gr.Info("Error processing prompt template: " + errors) - return - - inputs["model_filename"] = model_filename + return gr.update() + + inputs["model_filename"] = model_filename # Ensure model_filename is in inputs for add_video_task prompts = prompt.replace("\r", "").split("\n") - prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] - if len(prompts) ==0: - return + prompts = [p.strip() for p in prompts if p.strip() and not p.startswith("#")] + if not prompts: + gr.Info("No valid prompts found after processing, not adding tasks.") + return gr.update() - resolution = inputs["resolution"] - width, height = resolution.split("x") - width, height = int(width), int(height) + resolution = inputs.get("resolution", "832x480") # Use .get + width, height = map(int, resolution.split("x")) + + # --- Validation specific to model types --- if test_class_i2v(model_filename): - if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: + if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") - return - resolution = str(width) + "*" + str(height) - if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: - gr.Info(f"Resolution {resolution} not supported by image 2 video") - return - - if "1.3B" in model_filename and width * height > 848*480: - gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") - return + return gr.update() + # Ensure resolution format is correct for I2V, if needed (adjust based on actual requirements) + # resolution_str = f"{width}x{height}" # Use 'x' as separator consistently? Check MAX_AREA_CONFIGS keys + # if resolution_str not in MAX_AREA_CONFIGS: # Or a specific list for I2V + # gr.Info(f"Resolution {resolution} might not be directly supported by image 2 video. Check MAX_AREA_CONFIGS.") + # return gr.update() # Decide if this is a hard error + if "1.3B" in model_filename and width * height > 848*480: + # This check might be too strict depending on the model. Re-evaluate if needed. + gr.Info("You might need the 14B model to generate videos with a resolution equivalent to 720P") + # return gr.update() # Decide if this is a hard error + # --- Task generation based on model type --- + tasks_added = 0 if "Vace" in model_filename: - video_prompt_type = inputs["video_prompt_type"] - image_refs = inputs["image_refs"] - video_guide = inputs["video_guide"] - video_mask = inputs["video_mask"] - if "Vace" in model_filename and "1.3B" in model_filename : - resolution_reformated = str(height) + "*" + str(width) - if not resolution_reformated in VACE_SIZE_CONFIGS: - res = VACE_SIZE_CONFIGS.keys().join(" and ") - gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - return - if not "I" in video_prompt_type: - image_refs = None - if not "V" in video_prompt_type: - video_guide = None - if not "M" in video_prompt_type: - video_mask = None + video_prompt_type = inputs.get("video_prompt_type", "") + image_ref_paths = inputs.get("image_refs") # Now contains file paths + video_guide_path = inputs.get("video_guide") + video_mask_path = inputs.get("video_mask") - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] + # Input filtering based on type + if "I" not in video_prompt_type: image_ref_paths = None + if "V" not in video_prompt_type: video_guide_path = None + if "M" not in video_prompt_type: video_mask_path = None - from wan.utils.utils import resize_and_remove_background - image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1) - + # VACE specific validation (e.g., resolution) + if "1.3B" in model_filename: + resolution_reformated = f"{height}x{width}" # Check VACE_SIZE_CONFIGS format + if resolution_reformated not in VACE_SIZE_CONFIGS: + allowed_res = " and ".join(VACE_SIZE_CONFIGS.keys()) + gr.Info(f"Video Resolution {resolution} for Vace 1.3B model is not supported. Only {allowed_res} resolutions are allowed.") + return gr.update() - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "image_refs": image_refs, - "video_guide" : video_guide, - "video_mask" : video_mask , - } - inputs.update(extra_inputs) - add_video_task(**inputs) - elif "image2video" in model_filename or "Fun_InP" in model_filename : - image_prompt_type = inputs["image_prompt_type"] + # --- VACE image refs are now paths, processing happens in generate_video --- + # Remove PIL processing here: + # if isinstance(image_ref_paths, list): + # # image_refs_pil = [ convert_image(tup[0]) for tup in image_ref_paths ] # This happens later + # # from wan.utils.utils import resize_and_remove_background # This happens later + # # image_refs_processed = resize_and_remove_background(image_refs_pil, width, height, inputs["remove_background_image_ref"] ==1) # This happens later + # pass # Just keep the paths - image_start = inputs["image_start"] - image_end = inputs["image_end"] - if image_start == None or isinstance(image_start, list) and len(image_start) == 0: - return - if not "E" in image_prompt_type: - image_end = None - if isinstance(image_start, list): - image_start = [ convert_image(tup[0]) for tup in image_start ] - else: - image_start = [convert_image(image_start)] - if image_end != None: - if isinstance(image_end , list): - image_end = [ convert_image(tup[0]) for tup in image_end ] + for single_prompt in prompts: + task_params = inputs.copy() # Start with base inputs + task_params.update({ + "prompt": single_prompt, + "image_refs": image_ref_paths, # Pass paths + "video_guide": video_guide_path, + "video_mask": video_mask_path, + }) + add_video_task(**task_params) + tasks_added += 1 + + elif "image2video" in model_filename or "Fun_InP" in model_filename: + image_prompt_type = inputs.get("image_prompt_type", "S") + image_start_paths = inputs.get("image_start") # Now list of file paths or single path + image_end_paths = inputs.get("image_end") # Now list of file paths or single path + + if not image_start_paths or (isinstance(image_start_paths, list) and not image_start_paths): + gr.Info("Image 2 Video requires at least one start image.") + return gr.update() + + # Ensure paths are lists + if image_start_paths and not isinstance(image_start_paths, list): + image_start_paths = [image_start_paths] + if image_end_paths and not isinstance(image_end_paths, list): + image_end_paths = [image_end_paths] + + # Input filtering based on type + if "E" not in image_prompt_type: + image_end_paths = None + + # Validation + if image_end_paths and len(image_start_paths) != len(image_end_paths): + gr.Info("The number of start and end images provided must be the same when using End Images.") + return gr.update() + + # --- I2V start/end images are now paths, processing happens in generate_video --- + # Remove PIL processing here + + # --- Handle multiple prompts/images (using paths) --- + combined_prompts = [] + combined_start_paths = [] + combined_end_paths = [] if image_end_paths else None + + multi_type = inputs.get("multi_images_gen_type", 0) + num_prompts = len(prompts) + num_images = len(image_start_paths) + + if multi_type == 0: # Cartesian product + for i in range(num_prompts * num_images): + prompt_idx = i % num_prompts + image_idx = i // num_prompts + combined_prompts.append(prompts[prompt_idx]) + combined_start_paths.append(image_start_paths[image_idx]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[image_idx]) + else: # Match/Repeat + if num_prompts >= num_images: + if num_prompts % num_images != 0: + gr.Error("If more prompts than images (matching type), prompt count must be multiple of image count.") + return gr.update() + rep = num_prompts // num_images + for i in range(num_prompts): + img_idx = i // rep + combined_prompts.append(prompts[i]) + combined_start_paths.append(image_start_paths[img_idx]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[img_idx]) else: - image_end = [convert_image(image_end) ] - if len(image_start) != len(image_end): - gr.Info("The number of start and end images should be the same ") - return - - if inputs["multi_images_gen_type"] == 0: - new_prompts = [] - new_image_start = [] - new_image_end = [] - for i in range(len(prompts) * len(image_start) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_start.append(image_start[i // len(prompts)] ) - if image_end != None: - new_image_end.append(image_end[i // len(prompts)] ) - prompts = new_prompts - image_start = new_image_start - if image_end != None: - image_end = new_image_end - else: - if len(prompts) >= len(image_start): - if len(prompts) % len(image_start) != 0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_start) - new_image_start = [] - new_image_end = [] - for i, _ in enumerate(prompts): - new_image_start.append(image_start[i//rep] ) - if image_end != None: - new_image_end.append(image_end[i//rep] ) - image_start = new_image_start - if image_end != None: - image_end = new_image_end - else: - if len(image_start) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_start) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_start): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts + if num_images % num_prompts != 0: + gr.Error("If more images than prompts (matching type), image count must be multiple of prompt count.") + return gr.update() + rep = num_images // num_prompts + for i in range(num_images): + prompt_idx = i // rep + combined_prompts.append(prompts[prompt_idx]) + combined_start_paths.append(image_start_paths[i]) + if combined_end_paths is not None: + combined_end_paths.append(image_end_paths[i]) - - if image_start == None: - image_start = [None] * len(prompts) - if image_end == None: - image_end = [None] * len(prompts) + # Create tasks + for i, single_prompt in enumerate(combined_prompts): + task_params = inputs.copy() + task_params.update({ + "prompt": single_prompt, + "image_start": combined_start_paths[i], # Pass single path + "image_end": combined_end_paths[i] if combined_end_paths is not None else None, # Pass single path or None + }) + # Ensure multi_images_gen_type doesn't cause issues later if it was 0/1 + task_params["multi_images_gen_type"] = -1 # Indicate already processed + add_video_task(**task_params) + tasks_added += 1 - for single_prompt, start, end in zip(prompts, image_start, image_end) : - extra_inputs = { - "prompt" : single_prompt, - "image_start": start, - "image_end" : end, - } - inputs.update(extra_inputs) - add_video_task(**inputs) - else: - for single_prompt in prompts : - extra_inputs = { - "prompt" : single_prompt, - } - inputs.update(extra_inputs) - add_video_task(**inputs) + else: # Text to Video (no image inputs specific to generation) + for single_prompt in prompts: + task_params = inputs.copy() + task_params.update({"prompt": single_prompt}) + add_video_task(**task_params) + tasks_added += 1 + # --- Update queue UI --- gen = get_gen_info(state) - gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) + gen["prompts_max"] = tasks_added + gen.get("prompts_max", 0) state["validate_success"] = 1 - queue= gen.get("queue", []) + queue = gen.get("queue", []) return update_queue_data(queue) @@ -259,31 +288,93 @@ def add_video_task(**inputs): queue = gen["queue"] task_id += 1 current_task_id = task_id - inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] - start_image_data = None - end_image_data = None - for name in inputs_to_query: - image= inputs.get(name, None) - if image != None: - image= [image] if not isinstance(image, list) else image - if start_image_data == None: - start_image_data = image - else: - end_image_data = image - break - queue.append({ + # --- Identify image paths from inputs --- + # Use .get() for safety + start_image_paths = inputs.get("image_start") # Could be single path or list + end_image_paths = inputs.get("image_end") # Could be single path or list + ref_image_paths = inputs.get("image_refs") # Could be list or None + + # Standardize to lists or None + if start_image_paths and not isinstance(start_image_paths, list): + start_image_paths = [start_image_paths] + if end_image_paths and not isinstance(end_image_paths, list): + end_image_paths = [end_image_paths] + # ref_image_paths is likely already a list if present + + # Prioritize which images to show as previews in the queue UI + # Typically start/end for I2V, refs for VACE? Or just first available? + primary_preview_paths = None + secondary_preview_paths = None + + if start_image_paths: + primary_preview_paths = start_image_paths + if end_image_paths: + secondary_preview_paths = end_image_paths + elif ref_image_paths: + primary_preview_paths = ref_image_paths + # Add logic for video previews if needed (e.g., video_guide) + + # --- Generate Base64 previews from paths --- + start_image_data_base64 = [] + if primary_preview_paths: + try: + # Load only the first image for preview if it's a list + path_to_load = primary_preview_paths[0] + if path_to_load and Path(path_to_load).is_file(): + loaded_image = Image.open(path_to_load) + b64 = pil_to_base64_uri(loaded_image, format="jpeg", quality=70) + if b64: + start_image_data_base64.append(b64) + else: + print(f"Warning: Primary preview image path not found or invalid: {path_to_load}") + start_image_data_base64.append(None) # Add placeholder if needed + except Exception as e: + print(f"Warning: Could not load primary preview image for UI: {e}") + start_image_data_base64.append(None) + + end_image_data_base64 = [] + if secondary_preview_paths: + try: + path_to_load = secondary_preview_paths[0] + if path_to_load and Path(path_to_load).is_file(): + loaded_image = Image.open(path_to_load) + b64 = pil_to_base64_uri(loaded_image, format="jpeg", quality=70) + if b64: + end_image_data_base64.append(b64) + else: + print(f"Warning: Secondary preview image path not found or invalid: {path_to_load}") + end_image_data_base64.append(None) + except Exception as e: + print(f"Warning: Could not load secondary preview image for UI: {e}") + end_image_data_base64.append(None) + + + # --- Prepare params for the queue (ensure paths are stored) --- + params_copy = inputs.copy() + # Remove state object before storing + if 'state' in params_copy: + del params_copy['state'] + # Ensure image keys contain the paths as received + # (No need to explicitly set image_start_paths etc. if inputs already has them correctly) + + queue_item = { "id": current_task_id, - "params": inputs.copy(), - "repeats": inputs["repeat_generation"], - "length": inputs["video_length"], - "steps": inputs["num_inference_steps"], - "prompt": inputs["prompt"], - "start_image_data": start_image_data, - "end_image_data": end_image_data, - "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, - "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None - }) + "params": params_copy, # Contains paths for image_start, image_end, image_refs etc. + "repeats": inputs.get("repeat_generation", 1), + "length": inputs.get("video_length"), + "steps": inputs.get("num_inference_steps"), + "prompt": inputs.get("prompt"), + # Store the base64 previews separately for the UI + "start_image_data_base64": start_image_data_base64 if start_image_data_base64 else None, + "end_image_data_base64": end_image_data_base64 if end_image_data_base64 else None, + # Keep original paths in params for saving/loading consistency if needed, + # but the generate_video function will use the primary keys like 'image_start' + # "start_image_paths_ref": start_image_paths, # Example if needed for save/load + # "end_image_paths_ref": end_image_paths, # Example if needed for save/load + } + + queue.append(queue_item) return update_queue_data(queue) def move_up(queue, selected_indices): @@ -326,7 +417,258 @@ def remove_task(queue, selected_indices): del queue[idx] return update_queue_data(queue) +def maybe_trigger_processing(should_start, state): + if should_start: + yield from maybe_start_processing(state) + else: + gen = get_gen_info(state) + last_msg = gen.get("last_msg", "Idle") + yield last_msg +def maybe_start_processing(state, progress=gr.Progress()): + gen = get_gen_info(state) + queue = gen.get("queue", []) + in_progress = gen.get("in_progress", False) + initial_status = gen.get("last_msg", "Idle") + if queue and not in_progress: + initial_status = "Starting automatic processing..." + yield initial_status + try: + for status_update in process_tasks(state, progress): + print(f"*** Yielding from process_tasks: '{status_update}' ***") + yield status_update + print(f"*** Finished iterating process_tasks normally. ***") + except Exception as e: + print(f"*** Error during maybe_start_processing -> process_tasks: {e} ***") + yield f"Error during processing: {str(e)}" + else: + last_msg = gen.get("last_msg", "Idle") + initial_status = last_msg + yield initial_status + +def save_queue_to_json(queue, filename=QUEUE_FILENAME): + """Saves the current task queue to a JSON file.""" + tasks_to_save = [] + max_id = 0 + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + + params_to_save = task.get('params', {}).copy() + + # --- REMOVE THESE INCORRECT setdefault calls --- + # params_to_save.setdefault('prompt', '') + # params_to_save.setdefault('repeats', task.get('repeats', 1)) # NO + # params_to_save.setdefault('length', task.get('length')) # NO + # params_to_save.setdefault('steps', task.get('steps')) # NO + # params_to_save.setdefault('model_filename', '') # NO + + # --- INSTEAD: Check if essential model_filename exists --- + if 'model_filename' not in params_to_save or not params_to_save['model_filename']: + print(f"Warning: Skipping task {task.get('id')} during save due to missing model_filename in params.") + continue # Don't save tasks without essential info + + # Remove non-serializable items explicitly if any remain + params_to_save.pop('state', None) + # Remove potentially large PIL objects if they slipped through + keys_to_remove = [k for k, v in params_to_save.items() if isinstance(v, Image.Image)] + for k in keys_to_remove: del params_to_save[k] + + + task_data = { + "id": task.get('id', 0), + "params": params_to_save, # This should now be clean + # Get these values from the original task dict OR the params dict as fallback + "repeats": task.get('repeats', params_to_save.get('repeat_generation', 1)), + "length": task.get('length', params_to_save.get('video_length')), + "steps": task.get('steps', params_to_save.get('num_inference_steps')), + "prompt": task.get('prompt', params_to_save.get('prompt', '')), + # Keep base64 for fast UI reload without reading files + "start_image_data_base64": task.get("start_image_data_base64"), + "end_image_data_base64": task.get("end_image_data_base64"), + } + tasks_to_save.append(task_data) + max_id = max(max_id, task_data["id"]) + + try: + # Use ensure_ascii=False for wider character support, though prompt sanitization helps + with open(filename, 'w', encoding='utf-8') as f: + json.dump(tasks_to_save, f, indent=4, ensure_ascii=False) + print(f"Queue saved successfully to {filename}") + return max_id + except Exception as e: + print(f"Error saving queue to {filename}: {e}") + gr.Warning(f"Failed to save queue: {e}") + return max_id + +def load_queue_from_json(filename=QUEUE_FILENAME): + """Loads tasks from a JSON file back into the queue format.""" + global task_id # To update the global counter + if not Path(filename).is_file(): + print(f"Queue file {filename} not found. Starting with empty queue.") + return [], 0 + try: + with open(filename, 'r', encoding='utf-8') as f: + loaded_tasks_data = json.load(f) + except Exception as e: + print(f"Error loading or parsing queue file {filename}: {e}") + gr.Warning(f"Failed to load queue: {e}") + return [], 0 + + reconstructed_queue = [] + max_id = 0 + print(f"Loading {len(loaded_tasks_data)} tasks from {filename}...") + + for task_data in loaded_tasks_data: + if task_data is None or not isinstance(task_data, dict): continue + + params = task_data.get('params', {}) + if not params or 'model_filename' not in params: + print(f"Skipping task {task_data.get('id')} due to missing params or model_filename.") + continue + + task_id_loaded = task_data.get('id', 0) + max_id = max(max_id, task_id_loaded) + + # Get base64 previews directly from saved data + start_image_data_base64 = task_data.get('start_image_data_base64') + end_image_data_base64 = task_data.get('end_image_data_base64') + + # --- Verify paths in params still exist (optional but good practice) --- + image_path_keys = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] + for key in image_path_keys: + paths = params.get(key) + if paths: + if isinstance(paths, list): + valid_paths = [p for p in paths if p and Path(p).exists()] + if len(valid_paths) != len(paths): + print(f"Warning: Some paths for '{key}' in loaded task {task_id_loaded} not found. Using only valid ones.") + params[key] = valid_paths # Update params with only existing paths + elif isinstance(paths, str): + if not Path(paths).exists(): + print(f"Warning: Path for '{key}' in loaded task {task_id_loaded} not found: {paths}. Setting to None.") + params[key] = None + # --- + + queue_item = { + "id": task_id_loaded, + "params": params, # Contains paths and all other settings + "repeats": task_data.get('repeats', params.get('repeat_generation', 1)), # Get from task_data or params + "length": task_data.get('length', params.get('video_length')), + "steps": task_data.get('steps', params.get('num_inference_steps')), + "prompt": task_data.get('prompt', params.get('prompt', '')), + # Store base64 previews for UI + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64, + # 'start_image_data' and 'end_image_data' (PIL) are not stored/loaded + } + reconstructed_queue.append(queue_item) + + print(f"Queue loaded successfully from {filename}. Max ID found: {max_id}") + # Update global task_id if needed (handled in load_queue_action) + return reconstructed_queue, max_id + +def save_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is empty. Nothing to save.") + return None + tasks_to_save = [] + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + try: + json_string = json.dumps(tasks_to_save, indent=4) + print("Queue data prepared as JSON string for client-side download.") + return json_string + except Exception as e: + print(f"Error converting queue to JSON string: {e}") + gr.Warning(f"Failed to prepare queue data for saving: {e}") + return None + +def load_queue_action(filepath, state_dict): + global task_id + if not filepath or not Path(filepath.name).is_file(): + gr.Warning(f"No file selected or file not found.") + return None + loaded_queue, max_id = load_queue_from_json(filepath.name) + + with lock: + gen = get_gen_info(state_dict) + + if "queue" not in gen or not isinstance(gen.get("queue"), list): + gen["queue"] = [] + + existing_queue = gen["queue"] + existing_queue.clear() + existing_queue.extend(loaded_queue) + + task_id = max(task_id, max_id) + gen["prompts_max"] = len(existing_queue) + + gr.Info(f"Queue loaded from {Path(filepath.name).name}") + return None + +def update_queue_ui_after_load(state_dict): + gen = get_gen_info(state_dict) + queue = gen.get("queue", []) + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + return gr.update(value=raw_data, visible=is_visible), gr.update(visible=is_visible) + +def clear_queue_action(state): + gen = get_gen_info(state) + with lock: + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is already empty.") + return update_queue_data([]) + + queue.clear() + gen["prompts_max"] = 0 + gr.Info("Queue cleared.") + return update_queue_data([]) + +def autoload_queue(state_dict): + global task_id + gen = get_gen_info(state_dict) + queue_changed = False + if Path(QUEUE_FILENAME).exists(): + print(f"Autoloading queue from {QUEUE_FILENAME}...") + if not gen["queue"]: + loaded_queue, max_id = load_queue_from_json(QUEUE_FILENAME) + if loaded_queue: + with lock: + gen["queue"] = loaded_queue + task_id = max(task_id, max_id) + gen["prompts_max"] = len(loaded_queue) + queue_changed = True + +def autosave_queue(): + print("Attempting to autosave queue on exit...") + global global_dict + + if global_dict: + print(f"Autosaving queue ({len(global_dict)} items) from state dict...") + save_queue_to_json(global_dict, QUEUE_FILENAME) + else: + print("Queue is empty in the determined active state dictionary, autosave skipped.") def get_queue_table(queue): data = [] @@ -397,9 +739,9 @@ def update_queue_data(queue): # else: # return gr.HTML(value=data, visible= True) if len(data) == 0: - return gr.DataFrame(visible=False) + return gr.update(value=[], visible=False) else: - return gr.DataFrame(value=data, visible= True) + return gr.update(value=data, visible=True) def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1132,7 +1474,6 @@ def load_i2v_model(model_filename, value): def load_models(model_filename): global transformer_filename - transformer_filename = model_filename download_models(model_filename, text_encoder_filename) if test_class_i2v(model_filename): @@ -1431,9 +1772,20 @@ def finalize_generation(state): time.sleep(0.2) global gen_in_progress gen_in_progress = False - return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="") + queue = gen.get("queue", []) + queue_is_visible = bool(queue) + current_gen_column_visible = queue_is_visible + gen_info_visible = queue_is_visible + return ( + gr.Gallery(selected_index=choice), + gr.Button(interactive=True), + gr.Button(visible=True), + gr.Button(visible=False), + gr.Column(visible=current_gen_column_visible), + gr.HTML(visible=gen_info_visible, value="") + ) def refresh_gallery_on_trigger(state): gen = get_gen_info(state) @@ -1587,6 +1939,33 @@ def generate_video( trans = wan_model.model temp_filename = None + + loaded_image_start_pil = None + loaded_image_end_pil = None + loaded_image_refs_pil = [] + + try: + if image_start and isinstance(image_start, str) and Path(image_start).is_file(): + loaded_image_start_pil = convert_image(Image.open(image_start)) + elif image_start: + print(f"Warning: Start image path not found or invalid: {image_start}") + + if image_end and isinstance(image_end, str) and Path(image_end).is_file(): + loaded_image_end_pil = convert_image(Image.open(image_end)) + elif image_end: + print(f"Warning: End image path not found or invalid: {image_end}") + + if image_refs and isinstance(image_refs, list): + valid_ref_paths = [p for p in image_refs if p and Path(p).is_file()] + if len(valid_ref_paths) != len(image_refs): + print("Warning: Some VACE reference image paths were invalid.") + loaded_image_refs_pil = [convert_image(Image.open(p)) for p in valid_ref_paths] + if not loaded_image_refs_pil and "I" in (video_prompt_type or ""): + print("Warning: No valid VACE reference images loaded despite type 'I'.") + + except Exception as e: + print(f"ERROR loading image file: {e}") + raise gr.Error(f"Failed to load input image: {e}") loras = state["loras"] if len(loras) > 0: @@ -1725,13 +2104,12 @@ def generate_video( trans.teacache_skipped_steps = 0 trans.previous_residual_uncond = None trans.previous_residual_cond = None - video_no += 1 if image2video: samples = wan_model.generate( prompt, - image_start, - image_end if image_end != None else None, + loaded_image_start_pil, + loaded_image_end_pil if loaded_image_end_pil != None else None, frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution_reformated], shift=flow_shift, @@ -1958,7 +2336,7 @@ def process_tasks(state, progress=gr.Progress()): task = queue[0] task_id = task["id"] params = task['params'] - iterator = iter(generate_video(task_id, progress, **params)) + iterator = iter(generate_video(task_id, progress, **params, state=state)) while True: try: ok = False @@ -2324,41 +2702,83 @@ def switch_advanced(state, new_advanced, lset_name): def prepare_inputs_dict(target, inputs ): - - state = inputs.pop("state") - loras = state["loras"] - if "loras_choices" in inputs: - loras_choices = inputs.pop("loras_choices") - inputs.pop("model_filename", None) - activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] - inputs["activated_loras"] = activated_loras + """Prepares the inputs dictionary for saving state, settings, or metadata.""" + + # Make a copy to avoid modifying the original dict + inputs_copy = inputs.copy() + + # --- Remove target early --- + inputs_copy.pop("target", None) # Remove the target key itself + + # Remove objects not suitable for JSON/saving + state = inputs_copy.pop("state", None) # Remove state object + + # Lora handling: activated_loras should already be a list of names/stems + # If loras_choices was passed (from UI save), convert it + if "loras_choices" in inputs_copy and state and "loras" in state: + loras_choices = inputs_copy.pop("loras_choices") + loras = state.get("loras", []) + try: + # Use basename to match how activated_loras is likely stored in settings + activated_lora_names = [os.path.basename(loras[int(no)]) for no in loras_choices] + inputs_copy["activated_loras"] = activated_lora_names + except (IndexError, ValueError, TypeError) as e: + print(f"Warning: Could not convert loras_choices to names: {e}") + inputs_copy["activated_loras"] = [] # Default to empty on error + elif "activated_loras" not in inputs_copy: + inputs_copy["activated_loras"] = [] # Ensure key exists + + + # --- Target-specific adjustments --- if target == "state": - return inputs + # For saving to the main state dict, we want the raw inputs as received + # including file paths. Target is already removed. + return inputs_copy - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] - for k in unsaved_params: - inputs.pop(k) + # For settings and metadata, remove non-serializable or large data (like paths?) + # Keep paths for settings so they can be reloaded, but maybe not metadata? + + # Remove PIL objects if they accidentally got passed (shouldn't happen now) + keys_to_remove = [] + for k, v in inputs_copy.items(): + if isinstance(v, Image.Image): + keys_to_remove.append(k) + print(f"Warning: Removing unexpected PIL Image object for key '{k}' during {target} preparation.") + for k in keys_to_remove: + del inputs_copy[k] - model_filename = state["model_filename"] - inputs["type"] = "Wan2.1GP by DeepBeepMeep - " + get_model_name(model_filename) if target == "settings": - return inputs - - if not any(k in model_filename for k in ["image2video", "Fun_InP"]): - inputs.pop("image_prompt_type") + # Keep file paths (image_start, image_end, image_refs, video_guide, video_mask) + # Ensure prompt key exists if 'prompts' was used in UI defaults + if "prompts" in inputs_copy and "prompt" not in inputs_copy: + inputs_copy["prompt"] = inputs_copy.pop("prompts") + # Convert image_prompt_type back to S/SE if needed? Or keep string? String is fine. + return inputs_copy + elif target == "metadata": + # Add type information + # Get model filename safely + model_filename = inputs_copy.get("model_filename") + if not model_filename and state: + model_filename = state.get("model_filename", "unknown") + elif not model_filename: + model_filename = "unknown" - if not "Vace" in model_filename: - unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"] - for k in unsaved_params: - inputs.pop(k) + inputs_copy["type"] = f"WanGP by DeepBeepMeep - {get_model_name(model_filename)}" - if target == "metadata": - inputs = {k: v for k,v in inputs.items() if v != None } + # Remove file paths from metadata? Or keep them? Keep for reproducibility. + # Remove None values? + metadata_dict = {k: v for k, v in inputs_copy.items() if v is not None} - return inputs + # Clean up keys not relevant for metadata? + metadata_dict.pop("multi_images_gen_type", None) # This was processing logic + + return metadata_dict + + # Should not reach here if target is valid + return inputs_copy def get_function_arguments(func, locals): args_names = list(inspect.signature(func).parameters) @@ -2388,7 +2808,7 @@ def save_inputs( image_prompt_type, image_start, image_end, - video_prompt_type, + video_prompt_type, image_refs, video_guide, video_mask, @@ -2397,7 +2817,7 @@ def save_inputs( temporal_upsampling, spatial_upsampling, RIFLEx_setting, - slg_switch, + slg_switch, slg_layers, slg_start_perc, slg_end_perc, @@ -2406,22 +2826,27 @@ def save_inputs( state, ): - - # if state.get("validate_success",0) != 1: - # return - model_filename = state["model_filename"] - inputs = get_function_arguments(save_inputs, locals()) - inputs.pop("target") - cleaned_inputs = prepare_inputs_dict(target, inputs) + current_locals = locals() + inputs_dict = get_function_arguments(save_inputs, current_locals) + cleaned_inputs = prepare_inputs_dict(target, inputs_dict) + model_filename = state.get("model_filename") + if not model_filename: + print("Warning: Cannot save inputs, model_filename not found in state.") + return + if target == "settings": defaults_filename = get_settings_file_name(model_filename) + try: + with open(defaults_filename, "w", encoding="utf-8") as f: + json.dump(cleaned_inputs, f, indent=4) + gr.Info(f"Default Settings saved for {get_model_name(model_filename)}") + except Exception as e: + print(f"Error saving settings to {defaults_filename}: {e}") + gr.Error(f"Failed to save settings: {e}") - with open(defaults_filename, "w", encoding="utf-8") as f: - json.dump(cleaned_inputs, f, indent=4) - - gr.Info("New Default Settings saved") elif target == "state": - state[get_model_type(model_filename)] = cleaned_inputs + model_type_key = get_model_type(model_filename) + state[model_type_key] = cleaned_inputs def download_loras(): from huggingface_hub import snapshot_download @@ -2648,23 +3073,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if args.multiple_images: image_start = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", + label="Images as starting points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) else: - image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + image_start = gr.Image(label= "Image as a starting point for a new video", type ="filepath",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) if args.multiple_images: image_end = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", + label="Images as ending points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) else: - image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + image_end = gr.Image(label= "Last Image for a new video", type ="filepath", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","I") video_prompt_type = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3) image_refs = gr.Gallery( - label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", + label="Reference Images of Faces and / or Object to be found in the Video", type ="filepath", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) ) video_guide = gr.Video(label= "Reference Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) ) @@ -2886,7 +3311,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], - column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], + column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"], interactive=False, col_count=(9, "fixed"), wrap=True, @@ -2895,7 +3320,68 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible= False, elem_id="queue_df" ) - + with gr.Row(): + queue_json_output = gr.Text(visible=False, label="_queue_json") + save_queue_btn = gr.DownloadButton("Save Queue") + load_queue_btn = gr.UploadButton("Load Queue", file_types=[".json"], type="filepath") + clear_queue_btn = gr.Button("Clear Queue") + trigger_download_js = """ + (jsonString) => { + if (!jsonString) { + console.log("No JSON data received, skipping download."); + return; + } + const blob = new Blob([jsonString], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'queue.json'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } + """ + if not update_form: + save_queue_btn.click( + fn=save_queue_action, + inputs=[state], + outputs=[queue_json_output] + ).then( + fn=None, + inputs=[queue_json_output], + outputs=None, + js=trigger_download_js + ) + load_queue_btn.upload( + fn=load_queue_action, + inputs=[load_queue_btn, state], + outputs=None + ).then( + fn=update_queue_ui_after_load, + inputs=[state], + outputs=[queue_df, current_gen_column], + ).then( + fn=maybe_start_processing, + inputs=[state], + outputs=[gen_status], + show_progress="minimal", + trigger_mode="always_last" + ).then( + fn=finalize_generation, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info], + trigger_mode="always_last" + ) + clear_queue_btn.click( + clear_queue_action, + inputs=[state], + outputs=[queue_df] + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[current_gen_column] + ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column, if update_form: @@ -2976,6 +3462,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then(unload_model_if_needed, inputs= [state], outputs= [] + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) add_to_queue_btn.click(fn=validate_wizard_prompt, @@ -2990,6 +3480,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then( fn=update_status, inputs = [state], + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) close_modal_button.click( @@ -2998,7 +3492,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs=[modal_container] ) - return loras_choices, lset_name, state + return ( + loras_choices, + lset_name, + state, + queue_df, + current_gen_column, + gen_status, + output, + abort_btn, + generate_btn, + add_to_queue_btn, + gen_info, + prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + advanced_row, image_prompt_column, video_prompt_column, + *prompt_vars + ) def generate_download_tab(lset_name,loras_choices, state): with gr.Row(): @@ -3220,10 +3730,6 @@ def generate_info_tab(): gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") - - - - def create_demo(): css = """ .title-with-lines { @@ -3287,16 +3793,13 @@ def create_demo(): vertical-align: middle; font-size:11px; } - #xqueue_df table { + #queue_df table { width: 100%; - overflow: hidden !important; } - #xqueue_df::-webkit-scrollbar { - display: none !important; - } - #xqueue_df { + #queue_df { scrollbar-width: none !important; - -ms-overflow-style: none !important; + overflow-x: hidden !important; + overflow-y: auto; } .selection-button { display: none; @@ -3474,8 +3977,15 @@ def create_demo(): ) gr.Markdown("
") with gr.Row(): - - loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) + ( + loras_choices, lset_name, state, queue_df, current_gen_column, + gen_status, output, abort_btn, generate_btn, add_to_queue_btn, + gen_info, + prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var, + prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, + advanced_row, image_prompt_column, video_prompt_column, + *prompt_vars_outputs + ) = generate_video_tab(model_choice=model_choice, header=header) with gr.Tab("Informations"): generate_info_tab() if not args.lock_config: @@ -3485,11 +3995,54 @@ def create_demo(): generate_configuration_tab() with gr.Tab("About"): generate_about_tab() + def run_autoload_and_update(current_state): + autoload_queue(current_state) + gen = get_gen_info(current_state) + queue = gen.get("queue", []) + global global_dict + global_dict = queue + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + should_start_processing = bool(queue) + df_update = gr.update(value=raw_data, visible=is_visible) + col_update = gr.update(visible=is_visible) + return ( + df_update, + col_update, + should_start_processing + ) + + should_start_flag = gr.State(False) + + load_outputs_ui = [ + queue_df, + current_gen_column, + should_start_flag, + ] + + demo.load( + fn=run_autoload_and_update, + inputs=[state], + outputs=load_outputs_ui + ).then( + fn=maybe_trigger_processing, + inputs=[should_start_flag, state], + outputs=[gen_status], + ).then( + fn=finalize_generation, + inputs=[state], + outputs=[ + output, abort_btn, generate_btn, add_to_queue_btn, + current_gen_column, gen_info + ], + trigger_mode="always_last" + ) return demo if __name__ == "__main__": # threading.Thread(target=runner, daemon=True).start() + atexit.register(autosave_queue) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) if os.name == "nt":