diff --git a/gradio_server.py b/gradio_server.py index 883f7e7..38cba9e 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -40,6 +40,10 @@ task_id = 0 # progress_tracker = {} # tracker_lock = threading.Lock() last_model_type = None +QUEUE_FILENAME = "queue.json" +t2v_state_ref = None +i2v_state_ref = None +global_state_ref = None def format_time(seconds): if seconds < 60: @@ -150,93 +154,120 @@ def process_prompt_and_add_tasks( return if not image2video: - if "Vace" in file_model_needed and "1.3B" in file_model_needed : - resolution_reformated = str(height) + "*" + str(width) - if not resolution_reformated in VACE_SIZE_CONFIGS: - res = VACE_SIZE_CONFIGS.keys().join(" and ") - gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - return - if not "I" in image_prompt_type: image_source1 = None if not "V" in image_prompt_type: image_source2 = None if not "M" in image_prompt_type: image_source3 = None - - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] - - from wan.utils.utils import resize_and_remove_background - image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1) - - image_source1 = [ image_source1 ] * len(prompts) - image_source2 = [ image_source2 ] * len(prompts) - image_source3 = [ image_source3 ] * len(prompts) - else: - if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0: - return + if image_source1 == None or (isinstance(image_source1, list) and len(image_source1) == 0): + gr.Info("Image 2 Video requires at least one start image.") + return if image_prompt_type == 0: image_source2 = None - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] - else: - image_source1 = [convert_image(image_source1)] - if image_source2 != None: - if isinstance(image_source2 , list): - image_source2 = [ convert_image(tup[0]) for tup in image_source2 ] - else: - image_source2 = [convert_image(image_source2) ] - if len(image_source1) != len(image_source2): - gr.Info("The number of start and end images should be the same ") - return - + if image_source1 and not isinstance(image_source1, list): + image_source1 = [image_source1] + elif not image_source1: + image_source1 = [] + if image_source2 and isinstance(image_source2, list): + if len(image_source2) > 1: + gr.Info("Multiple end images selected, but only the first will be used for standard I2V end image.") + image_source2 = image_source2[0] + elif len(image_source2) == 1: + image_source2 = image_source2[0] + else: + image_source2 = None + + if image_source2 != None and image_prompt_type == 1: + if len(image_source1) != 1: + gr.Info("When using an end image, please provide only one start image.") + if multi_images_gen_type == 0: new_prompts = [] new_image_source1 = [] - new_image_source2 = [] - for i in range(len(prompts) * len(image_source1) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_source1.append(image_source1[i // len(prompts)] ) - if image_source2 != None: - new_image_source2.append(image_source2[i // len(prompts)] ) + new_image_source2_list = [] + num_prompts = len(prompts) + num_images = len(image_source1) + + for i in range(num_prompts * num_images): + prompt_idx = i % num_prompts + image_idx = i // num_prompts + new_prompts.append(prompts[prompt_idx]) + new_image_source1.append(image_source1[image_idx]) + if image_source2 != None and image_prompt_type == 1: + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + prompts = new_prompts - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 - else: + image_source1 = new_image_source1 + + elif multi_images_gen_type == 1: if len(prompts) >= len(image_source1): - if len(prompts) % len(image_source1) !=0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_source1) - new_image_source1 = [] - new_image_source2 = [] - for i, _ in enumerate(prompts): - new_image_source1.append(image_source1[i//rep] ) - if image_source2 != None: - new_image_source2.append(image_source2[i//rep] ) - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 - else: - if len(image_source1) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_source1) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_source1): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts + if len(prompts) % len(image_source1) != 0: + raise gr.Error("If more prompts than images (matching type), prompt count must be multiple of image count.") + rep = len(prompts) // len(image_source1) + new_image_source1 = [] + new_image_source2_list = [] + for i in range(len(prompts)): + img_idx = i // rep + new_image_source1.append(image_source1[img_idx]) + if image_source2 != None and image_prompt_type == 1 and isinstance(image_source2, list) and img_idx < len(image_source2): + new_image_source2_list.append(image_source2[img_idx]) + elif image_source2 != None and image_prompt_type == 1 and not isinstance(image_source2, list): + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + image_source1 = new_image_source1 + else: + if len(image_source1) % len(prompts) != 0: + raise gr.Error("If more images than prompts (matching type), image count must be multiple of prompt count.") + rep = len(image_source1) // len(prompts) + new_prompts = [] + new_image_source2_list = [] + for i in range(len(image_source1)): + prompt_idx = i // rep + new_prompts.append(prompts[prompt_idx]) + if image_source2 != None and image_prompt_type == 1 and isinstance(image_source2, list) and i < len(image_source2): + new_image_source2_list.append(image_source2[i]) + elif image_source2 != None and image_prompt_type == 1 and not isinstance(image_source2, list): + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + prompts = new_prompts - - if image_source1 == None: - image_source1 = [None] * len(prompts) - if image_source2 == None: - image_source2 = [None] * len(prompts) - if image_source3 == None: - image_source3 = [None] * len(prompts) + img1_list = image_source1 if isinstance(image_source1, list) else ([image_source1] if image_source1 else [None]) + img2_list_for_zip = [] + if multi_images_gen_type == 0 and image2video and image_prompt_type == 1: + img2_list_for_zip = new_image_source2_list + elif multi_images_gen_type == 1 and image2video and image_prompt_type == 1: + img2_list_for_zip = new_image_source2_list + elif image_source2 and not image2video: + img2_list_for_zip = [image_source2] * len(prompts) + elif image_source2 and image2video and image_prompt_type == 1: + img2_list_for_zip = [image_source2] * len(prompts) + else: + img2_list_for_zip = [None] * len(prompts) + + if len(img1_list) == 1 and len(prompts) > 1: + img1_list = img1_list * len(prompts) + elif len(img1_list) != len(prompts): + gr.Warning(f"Mismatch between number of prompts ({len(prompts)}) and start images ({len(img1_list)}). Using first image for remaining prompts.") + img1_list = (img1_list + [img1_list[0]] * (len(prompts) - len(img1_list)))[:len(prompts)] + img3_list_for_zip = [image_source3] * len(prompts) if image_source3 and not image2video else [None] * len(prompts) + + for i, single_prompt in enumerate(prompts): + current_image_source1 = img1_list[i] + current_image_source2 = img2_list_for_zip[i] + current_image_source3 = img3_list_for_zip[i] + if current_image_source1 and not isinstance(current_image_source1, list): + current_image_source1_list = [current_image_source1] + elif isinstance(current_image_source1, list): + current_image_source1_list = current_image_source1 + else: + current_image_source1_list = [] - for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) : kwargs = { "prompt" : single_prompt, "negative_prompt" : negative_prompt, @@ -254,9 +285,9 @@ def process_prompt_and_add_tasks( "loras_choices" : loras_choices, "loras_mult_choices" : loras_mult_choices, "image_prompt_type" : image_prompt_type, - "image_source1": image_source1, - "image_source2" : image_source2, - "image_source3" : image_source3 , + "image_source1": current_image_source1_list, + "image_source2" : current_image_source2, + "image_source3" : current_image_source3, "max_frames" : max_frames, "remove_background_image_ref" : remove_background_image_ref, "temporal_upsampling" : temporal_upsampling, @@ -270,7 +301,7 @@ def process_prompt_and_add_tasks( "cfg_zero_step" : cfg_zero_step, "state" : state, "image2video" : image2video - } + } add_video_task(**kwargs) gen = get_gen_info(state) @@ -279,9 +310,6 @@ def process_prompt_and_add_tasks( queue= gen.get("queue", []) return update_queue_data(queue) - - - def add_video_task(**kwargs): global task_id state = kwargs["state"] @@ -289,23 +317,50 @@ def add_video_task(**kwargs): queue = gen["queue"] task_id += 1 current_task_id = task_id - start_image_data = kwargs["image_source1"] - start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data - end_image_data = kwargs["image_source2"] + start_image_paths = kwargs["image_source1"] + if start_image_paths and not isinstance(start_image_paths, list): + start_image_paths = [start_image_paths] + elif start_image_paths is None: + start_image_paths = [] + end_image_path = kwargs["image_source2"] + start_image_data_base64 = [] + if start_image_paths: + try: + loaded_images = [Image.open(p) for p in start_image_paths if p and Path(p).is_file()] + start_image_data_base64 = [pil_to_base64_uri(img, format="jpeg", quality=70) for img in loaded_images] + except Exception as e: + print(f"Warning: Could not load start image(s) for UI preview: {e}") + start_image_data_base64 = [None] * len(start_image_paths) - queue.append({ + end_image_data_base64 = None + if end_image_path and Path(end_image_path).is_file(): + try: + loaded_end_image = Image.open(end_image_path) + end_image_data_base64 = pil_to_base64_uri(loaded_end_image, format="jpeg", quality=70) + except Exception as e: + print(f"Warning: Could not load end image for UI preview: {e}") + + kwargs["image_source1"] = start_image_paths + kwargs["image_source2"] = end_image_path + params_copy = kwargs.copy() + if 'state' in params_copy: + del params_copy['state'] + + queue_item = { "id": current_task_id, "image2video": kwargs["image2video"], - "params": kwargs.copy(), + "params": params_copy, "repeats": kwargs["repeat_generation"], "length": kwargs["video_length"], "steps": kwargs["num_inference_steps"], "prompt": kwargs["prompt"], - "start_image_data": start_image_data, - "end_image_data": end_image_data, - "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data], - "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70) - }) + "start_image_paths": start_image_paths, + "end_image_path": end_image_path, + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64 + } + + queue.append(queue_item) return update_queue_data(queue) def move_up(queue, selected_indices): @@ -347,22 +402,257 @@ def remove_task(queue, selected_indices): wan_model._interrupt = True del queue[idx] return update_queue_data(queue) + +def maybe_trigger_processing(should_start, state): + if should_start: + yield from maybe_start_processing(state) + else: + gen = get_gen_info(state) + last_msg = gen.get("last_msg", "Idle") + yield last_msg +def maybe_start_processing(state, progress=gr.Progress()): + gen = get_gen_info(state) + queue = gen.get("queue", []) + in_progress = gen.get("in_progress", False) + initial_status = gen.get("last_msg", "Idle") + if queue and not in_progress: + initial_status = "Starting automatic processing..." + yield initial_status + try: + for status_update in process_tasks(state, progress): + print(f"*** Yielding from process_tasks: '{status_update}' ***") + yield status_update + print(f"*** Finished iterating process_tasks normally. ***") + except Exception as e: + print(f"*** Error during maybe_start_processing -> process_tasks: {e} ***") + yield f"Error during processing: {str(e)}" + else: + last_msg = gen.get("last_msg", "Idle") + initial_status = last_msg + yield initial_status +def save_queue_to_json(queue, filename=QUEUE_FILENAME): + tasks_to_save = [] + max_id = 0 + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + max_id = max(max_id, task_data["id"]) + + try: + with open(filename, 'w', encoding='utf-8') as f: + json.dump(tasks_to_save, f, indent=4) + print(f"Queue saved successfully to {filename}") + return max_id + except Exception as e: + print(f"Error saving queue to {filename}: {e}") + gr.Warning(f"Failed to save queue: {e}") + return max_id + +def load_queue_from_json(filename=QUEUE_FILENAME): + global task_id + if not Path(filename).is_file(): + print(f"Queue file {filename} not found. Starting with empty queue.") + return [], 0 + try: + with open(filename, 'r', encoding='utf-8') as f: + loaded_tasks = json.load(f) + except Exception as e: + print(f"Error loading or parsing queue file {filename}: {e}") + gr.Warning(f"Failed to load queue: {e}") + return [], 0 + + reconstructed_queue = [] + max_id = 0 + print(f"Loading {len(loaded_tasks)} tasks from {filename}...") + for task_data in loaded_tasks: + if task_data is None or not isinstance(task_data, dict): continue + + start_image_paths = task_data.get('start_image_paths', []) + end_image_path = task_data.get('end_image_path', None) + + start_image_data_base64 = [] + if start_image_paths: + try: + valid_paths = [p for p in start_image_paths if p and Path(p).is_file()] + if len(valid_paths) != len(start_image_paths): + print(f"Warning: Some start image paths in loaded task {task_data.get('id')} not found.") + loaded_images = [Image.open(p) for p in valid_paths] + start_image_data_base64 = [pil_to_base64_uri(img, format="jpeg", quality=70) for img in loaded_images] + except Exception as e: + print(f"Warning: Could not load start image(s) for UI preview from loaded task {task_data.get('id')}: {e}") + start_image_data_base64 = [None] * len(start_image_paths) + + end_image_data_base64 = None + if end_image_path and Path(end_image_path).is_file(): + try: + loaded_end_image = Image.open(end_image_path) + end_image_data_base64 = pil_to_base64_uri(loaded_end_image, format="jpeg", quality=70) + except Exception as e: + print(f"Warning: Could not load end image for UI preview from loaded task {task_data.get('id')}: {e}") + elif end_image_path: + print(f"Warning: End image path in loaded task {task_data.get('id')} not found: {end_image_path}") + + params = task_data.get('params', {}) + params['image_source1'] = start_image_paths + params['image_source2'] = end_image_path + + task_id_loaded = task_data.get('id', 0) + max_id = max(max_id, task_id_loaded) + + queue_item = { + "id": task_id_loaded, + "image2video": task_data.get('image2video', False), + "params": params, + "repeats": task_data.get('repeats', 1), + "length": task_data.get('length', 0), + "steps": task_data.get('steps', 0), + "prompt": task_data.get('prompt', ''), + "start_image_paths": start_image_paths, + "end_image_path": end_image_path, + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64 + } + reconstructed_queue.append(queue_item) + + print(f"Queue loaded successfully from {filename}. Max ID found: {max_id}") + return reconstructed_queue, max_id + +def save_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is empty. Nothing to save.") + return None + tasks_to_save = [] + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + try: + json_string = json.dumps(tasks_to_save, indent=4) + print("Queue data prepared as JSON string for client-side download.") + return json_string + except Exception as e: + print(f"Error converting queue to JSON string: {e}") + gr.Warning(f"Failed to prepare queue data for saving: {e}") + return None + +def load_queue_action(filepath, state_dict): + global task_id + if not filepath or not Path(filepath.name).is_file(): + gr.Warning(f"No file selected or file not found.") + return None + loaded_queue, max_id = load_queue_from_json(filepath.name) + + with lock: + gen = get_gen_info(state_dict) + + if "queue" not in gen or not isinstance(gen.get("queue"), list): + gen["queue"] = [] + + existing_queue = gen["queue"] + existing_queue.clear() + existing_queue.extend(loaded_queue) + + task_id = max(task_id, max_id) + gen["prompts_max"] = len(existing_queue) + + gr.Info(f"Queue loaded from {Path(filepath.name).name}") + return None + +def update_queue_ui_after_load(state_dict): + gen = get_gen_info(state_dict) + queue = gen.get("queue", []) + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + return gr.update(value=raw_data, visible=is_visible), gr.update(visible=is_visible) + +def clear_queue_action(state): + gen = get_gen_info(state) + with lock: + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is already empty.") + return update_queue_data([]) + + queue.clear() + gen["prompts_max"] = 0 + gr.Info("Queue cleared.") + return update_queue_data([]) + +def autoload_queue(state_dict): + global task_id + gen = get_gen_info(state_dict) + queue_changed = False + if Path(QUEUE_FILENAME).exists(): + print(f"Autoloading queue from {QUEUE_FILENAME}...") + if not gen["queue"]: + loaded_queue, max_id = load_queue_from_json(QUEUE_FILENAME) + if loaded_queue: + with lock: + gen["queue"] = loaded_queue + task_id = max(task_id, max_id) + gen["prompts_max"] = len(loaded_queue) + queue_changed = True + +def autosave_queue(): + print("Attempting to autosave queue on exit...") + if global_state_ref is None or t2v_state_ref is None or i2v_state_ref is None: + print("State references not available for autosave. Skipping.") + return + try: + last_tab_was_i2v = global_state_ref.get("last_tab_was_image2video", use_image2video) + active_state = i2v_state_ref if last_tab_was_i2v else t2v_state_ref + + if active_state: + gen = get_gen_info(active_state) + queue = gen.get("queue", []) + if queue: + print(f"Autosaving queue ({len(queue)} items) from {'i2v' if last_tab_was_i2v else 't2v'} state...") + save_queue_to_json(queue, QUEUE_FILENAME) + else: + print("Queue is empty, autosave skipped.") + else: + print("Could not determine active state for autosave.") + except Exception as e: + print(f"Error during autosave: {e}") def get_queue_table(queue): data = [] if len(queue) == 1: - return data - - # def td(l, content, width =None): - # if width !=None: - # l.append("
| Qty | Prompt | Steps | |||||