From 3b066849de0aa226f581a37d7743a0ddcc36dfa9 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 04:12:14 +1000 Subject: [PATCH 1/6] remove horizontal scrollbar, fix css issues on start/end frame thumbnails --- gradio_server.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 7f4e345..366cf39 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -2805,7 +2805,7 @@ def generate_video_tab(image2video=False): queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], - column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], + column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"], interactive=False, col_count=(9, "fixed"), wrap=True, @@ -3375,16 +3375,13 @@ def create_demo(): vertical-align: middle; font-size:11px; } - #xqueue_df table { + #queue_df table { width: 100%; overflow: hidden !important; } - #xqueue_df::-webkit-scrollbar { - display: none !important; - } - #xqueue_df { - scrollbar-width: none !important; - -ms-overflow-style: none !important; + #queue_df { + overflow-x: hidden !important; + overflow-y: auto; } .selection-button { display: none; From aba3002b83f02b200e1bce786e4d531705629bbb Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 04:32:57 +1000 Subject: [PATCH 2/6] fix gen referenced before assignment bug when switching tabs --- gradio_server.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 366cf39..883f7e7 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -3240,11 +3240,18 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): global_state["last_tab_was_image2video"] = new_i2v - if(server_config.get("reload_model",2) == 1): + if server_config.get("reload_model", 2) == 1: queue = gen.get("queue", []) + queue_empty = len(queue) == 0 - queue_empty = len(queue) == 0 - if queue_empty: + is_switching_between_gen_tabs = ( + last_tab_was_image2video is not None and + (new_t2v or new_i2v) and + last_tab_was_image2video != new_i2v + ) + + if is_switching_between_gen_tabs and queue_empty: + print("Reloading model due to switch between T2V/I2V tabs.") global wan_model, offloadobj if wan_model is not None: if offloadobj is not None: From 1d3f4a2573d954927238897ef6d3093df099add6 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 17:14:36 +1000 Subject: [PATCH 3/6] add queue saving/loading/clearing --- gradio_server.py | 800 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 643 insertions(+), 157 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 883f7e7..38cba9e 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -40,6 +40,10 @@ task_id = 0 # progress_tracker = {} # tracker_lock = threading.Lock() last_model_type = None +QUEUE_FILENAME = "queue.json" +t2v_state_ref = None +i2v_state_ref = None +global_state_ref = None def format_time(seconds): if seconds < 60: @@ -150,93 +154,120 @@ def process_prompt_and_add_tasks( return if not image2video: - if "Vace" in file_model_needed and "1.3B" in file_model_needed : - resolution_reformated = str(height) + "*" + str(width) - if not resolution_reformated in VACE_SIZE_CONFIGS: - res = VACE_SIZE_CONFIGS.keys().join(" and ") - gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - return - if not "I" in image_prompt_type: image_source1 = None if not "V" in image_prompt_type: image_source2 = None if not "M" in image_prompt_type: image_source3 = None - - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] - - from wan.utils.utils import resize_and_remove_background - image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1) - - image_source1 = [ image_source1 ] * len(prompts) - image_source2 = [ image_source2 ] * len(prompts) - image_source3 = [ image_source3 ] * len(prompts) - else: - if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0: - return + if image_source1 == None or (isinstance(image_source1, list) and len(image_source1) == 0): + gr.Info("Image 2 Video requires at least one start image.") + return if image_prompt_type == 0: image_source2 = None - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] - else: - image_source1 = [convert_image(image_source1)] - if image_source2 != None: - if isinstance(image_source2 , list): - image_source2 = [ convert_image(tup[0]) for tup in image_source2 ] - else: - image_source2 = [convert_image(image_source2) ] - if len(image_source1) != len(image_source2): - gr.Info("The number of start and end images should be the same ") - return - + if image_source1 and not isinstance(image_source1, list): + image_source1 = [image_source1] + elif not image_source1: + image_source1 = [] + if image_source2 and isinstance(image_source2, list): + if len(image_source2) > 1: + gr.Info("Multiple end images selected, but only the first will be used for standard I2V end image.") + image_source2 = image_source2[0] + elif len(image_source2) == 1: + image_source2 = image_source2[0] + else: + image_source2 = None + + if image_source2 != None and image_prompt_type == 1: + if len(image_source1) != 1: + gr.Info("When using an end image, please provide only one start image.") + if multi_images_gen_type == 0: new_prompts = [] new_image_source1 = [] - new_image_source2 = [] - for i in range(len(prompts) * len(image_source1) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_source1.append(image_source1[i // len(prompts)] ) - if image_source2 != None: - new_image_source2.append(image_source2[i // len(prompts)] ) + new_image_source2_list = [] + num_prompts = len(prompts) + num_images = len(image_source1) + + for i in range(num_prompts * num_images): + prompt_idx = i % num_prompts + image_idx = i // num_prompts + new_prompts.append(prompts[prompt_idx]) + new_image_source1.append(image_source1[image_idx]) + if image_source2 != None and image_prompt_type == 1: + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + prompts = new_prompts - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 - else: + image_source1 = new_image_source1 + + elif multi_images_gen_type == 1: if len(prompts) >= len(image_source1): - if len(prompts) % len(image_source1) !=0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_source1) - new_image_source1 = [] - new_image_source2 = [] - for i, _ in enumerate(prompts): - new_image_source1.append(image_source1[i//rep] ) - if image_source2 != None: - new_image_source2.append(image_source2[i//rep] ) - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 - else: - if len(image_source1) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_source1) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_source1): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts + if len(prompts) % len(image_source1) != 0: + raise gr.Error("If more prompts than images (matching type), prompt count must be multiple of image count.") + rep = len(prompts) // len(image_source1) + new_image_source1 = [] + new_image_source2_list = [] + for i in range(len(prompts)): + img_idx = i // rep + new_image_source1.append(image_source1[img_idx]) + if image_source2 != None and image_prompt_type == 1 and isinstance(image_source2, list) and img_idx < len(image_source2): + new_image_source2_list.append(image_source2[img_idx]) + elif image_source2 != None and image_prompt_type == 1 and not isinstance(image_source2, list): + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + image_source1 = new_image_source1 + else: + if len(image_source1) % len(prompts) != 0: + raise gr.Error("If more images than prompts (matching type), image count must be multiple of prompt count.") + rep = len(image_source1) // len(prompts) + new_prompts = [] + new_image_source2_list = [] + for i in range(len(image_source1)): + prompt_idx = i // rep + new_prompts.append(prompts[prompt_idx]) + if image_source2 != None and image_prompt_type == 1 and isinstance(image_source2, list) and i < len(image_source2): + new_image_source2_list.append(image_source2[i]) + elif image_source2 != None and image_prompt_type == 1 and not isinstance(image_source2, list): + new_image_source2_list.append(image_source2) + else: + new_image_source2_list.append(None) + prompts = new_prompts - - if image_source1 == None: - image_source1 = [None] * len(prompts) - if image_source2 == None: - image_source2 = [None] * len(prompts) - if image_source3 == None: - image_source3 = [None] * len(prompts) + img1_list = image_source1 if isinstance(image_source1, list) else ([image_source1] if image_source1 else [None]) + img2_list_for_zip = [] + if multi_images_gen_type == 0 and image2video and image_prompt_type == 1: + img2_list_for_zip = new_image_source2_list + elif multi_images_gen_type == 1 and image2video and image_prompt_type == 1: + img2_list_for_zip = new_image_source2_list + elif image_source2 and not image2video: + img2_list_for_zip = [image_source2] * len(prompts) + elif image_source2 and image2video and image_prompt_type == 1: + img2_list_for_zip = [image_source2] * len(prompts) + else: + img2_list_for_zip = [None] * len(prompts) + + if len(img1_list) == 1 and len(prompts) > 1: + img1_list = img1_list * len(prompts) + elif len(img1_list) != len(prompts): + gr.Warning(f"Mismatch between number of prompts ({len(prompts)}) and start images ({len(img1_list)}). Using first image for remaining prompts.") + img1_list = (img1_list + [img1_list[0]] * (len(prompts) - len(img1_list)))[:len(prompts)] + img3_list_for_zip = [image_source3] * len(prompts) if image_source3 and not image2video else [None] * len(prompts) + + for i, single_prompt in enumerate(prompts): + current_image_source1 = img1_list[i] + current_image_source2 = img2_list_for_zip[i] + current_image_source3 = img3_list_for_zip[i] + if current_image_source1 and not isinstance(current_image_source1, list): + current_image_source1_list = [current_image_source1] + elif isinstance(current_image_source1, list): + current_image_source1_list = current_image_source1 + else: + current_image_source1_list = [] - for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) : kwargs = { "prompt" : single_prompt, "negative_prompt" : negative_prompt, @@ -254,9 +285,9 @@ def process_prompt_and_add_tasks( "loras_choices" : loras_choices, "loras_mult_choices" : loras_mult_choices, "image_prompt_type" : image_prompt_type, - "image_source1": image_source1, - "image_source2" : image_source2, - "image_source3" : image_source3 , + "image_source1": current_image_source1_list, + "image_source2" : current_image_source2, + "image_source3" : current_image_source3, "max_frames" : max_frames, "remove_background_image_ref" : remove_background_image_ref, "temporal_upsampling" : temporal_upsampling, @@ -270,7 +301,7 @@ def process_prompt_and_add_tasks( "cfg_zero_step" : cfg_zero_step, "state" : state, "image2video" : image2video - } + } add_video_task(**kwargs) gen = get_gen_info(state) @@ -279,9 +310,6 @@ def process_prompt_and_add_tasks( queue= gen.get("queue", []) return update_queue_data(queue) - - - def add_video_task(**kwargs): global task_id state = kwargs["state"] @@ -289,23 +317,50 @@ def add_video_task(**kwargs): queue = gen["queue"] task_id += 1 current_task_id = task_id - start_image_data = kwargs["image_source1"] - start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data - end_image_data = kwargs["image_source2"] + start_image_paths = kwargs["image_source1"] + if start_image_paths and not isinstance(start_image_paths, list): + start_image_paths = [start_image_paths] + elif start_image_paths is None: + start_image_paths = [] + end_image_path = kwargs["image_source2"] + start_image_data_base64 = [] + if start_image_paths: + try: + loaded_images = [Image.open(p) for p in start_image_paths if p and Path(p).is_file()] + start_image_data_base64 = [pil_to_base64_uri(img, format="jpeg", quality=70) for img in loaded_images] + except Exception as e: + print(f"Warning: Could not load start image(s) for UI preview: {e}") + start_image_data_base64 = [None] * len(start_image_paths) - queue.append({ + end_image_data_base64 = None + if end_image_path and Path(end_image_path).is_file(): + try: + loaded_end_image = Image.open(end_image_path) + end_image_data_base64 = pil_to_base64_uri(loaded_end_image, format="jpeg", quality=70) + except Exception as e: + print(f"Warning: Could not load end image for UI preview: {e}") + + kwargs["image_source1"] = start_image_paths + kwargs["image_source2"] = end_image_path + params_copy = kwargs.copy() + if 'state' in params_copy: + del params_copy['state'] + + queue_item = { "id": current_task_id, "image2video": kwargs["image2video"], - "params": kwargs.copy(), + "params": params_copy, "repeats": kwargs["repeat_generation"], "length": kwargs["video_length"], "steps": kwargs["num_inference_steps"], "prompt": kwargs["prompt"], - "start_image_data": start_image_data, - "end_image_data": end_image_data, - "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data], - "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70) - }) + "start_image_paths": start_image_paths, + "end_image_path": end_image_path, + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64 + } + + queue.append(queue_item) return update_queue_data(queue) def move_up(queue, selected_indices): @@ -347,22 +402,257 @@ def remove_task(queue, selected_indices): wan_model._interrupt = True del queue[idx] return update_queue_data(queue) + +def maybe_trigger_processing(should_start, state): + if should_start: + yield from maybe_start_processing(state) + else: + gen = get_gen_info(state) + last_msg = gen.get("last_msg", "Idle") + yield last_msg +def maybe_start_processing(state, progress=gr.Progress()): + gen = get_gen_info(state) + queue = gen.get("queue", []) + in_progress = gen.get("in_progress", False) + initial_status = gen.get("last_msg", "Idle") + if queue and not in_progress: + initial_status = "Starting automatic processing..." + yield initial_status + try: + for status_update in process_tasks(state, progress): + print(f"*** Yielding from process_tasks: '{status_update}' ***") + yield status_update + print(f"*** Finished iterating process_tasks normally. ***") + except Exception as e: + print(f"*** Error during maybe_start_processing -> process_tasks: {e} ***") + yield f"Error during processing: {str(e)}" + else: + last_msg = gen.get("last_msg", "Idle") + initial_status = last_msg + yield initial_status +def save_queue_to_json(queue, filename=QUEUE_FILENAME): + tasks_to_save = [] + max_id = 0 + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + max_id = max(max_id, task_data["id"]) + + try: + with open(filename, 'w', encoding='utf-8') as f: + json.dump(tasks_to_save, f, indent=4) + print(f"Queue saved successfully to {filename}") + return max_id + except Exception as e: + print(f"Error saving queue to {filename}: {e}") + gr.Warning(f"Failed to save queue: {e}") + return max_id + +def load_queue_from_json(filename=QUEUE_FILENAME): + global task_id + if not Path(filename).is_file(): + print(f"Queue file {filename} not found. Starting with empty queue.") + return [], 0 + try: + with open(filename, 'r', encoding='utf-8') as f: + loaded_tasks = json.load(f) + except Exception as e: + print(f"Error loading or parsing queue file {filename}: {e}") + gr.Warning(f"Failed to load queue: {e}") + return [], 0 + + reconstructed_queue = [] + max_id = 0 + print(f"Loading {len(loaded_tasks)} tasks from {filename}...") + for task_data in loaded_tasks: + if task_data is None or not isinstance(task_data, dict): continue + + start_image_paths = task_data.get('start_image_paths', []) + end_image_path = task_data.get('end_image_path', None) + + start_image_data_base64 = [] + if start_image_paths: + try: + valid_paths = [p for p in start_image_paths if p and Path(p).is_file()] + if len(valid_paths) != len(start_image_paths): + print(f"Warning: Some start image paths in loaded task {task_data.get('id')} not found.") + loaded_images = [Image.open(p) for p in valid_paths] + start_image_data_base64 = [pil_to_base64_uri(img, format="jpeg", quality=70) for img in loaded_images] + except Exception as e: + print(f"Warning: Could not load start image(s) for UI preview from loaded task {task_data.get('id')}: {e}") + start_image_data_base64 = [None] * len(start_image_paths) + + end_image_data_base64 = None + if end_image_path and Path(end_image_path).is_file(): + try: + loaded_end_image = Image.open(end_image_path) + end_image_data_base64 = pil_to_base64_uri(loaded_end_image, format="jpeg", quality=70) + except Exception as e: + print(f"Warning: Could not load end image for UI preview from loaded task {task_data.get('id')}: {e}") + elif end_image_path: + print(f"Warning: End image path in loaded task {task_data.get('id')} not found: {end_image_path}") + + params = task_data.get('params', {}) + params['image_source1'] = start_image_paths + params['image_source2'] = end_image_path + + task_id_loaded = task_data.get('id', 0) + max_id = max(max_id, task_id_loaded) + + queue_item = { + "id": task_id_loaded, + "image2video": task_data.get('image2video', False), + "params": params, + "repeats": task_data.get('repeats', 1), + "length": task_data.get('length', 0), + "steps": task_data.get('steps', 0), + "prompt": task_data.get('prompt', ''), + "start_image_paths": start_image_paths, + "end_image_path": end_image_path, + "start_image_data_base64": start_image_data_base64, + "end_image_data_base64": end_image_data_base64 + } + reconstructed_queue.append(queue_item) + + print(f"Queue loaded successfully from {filename}. Max ID found: {max_id}") + return reconstructed_queue, max_id + +def save_queue_action(state): + gen = get_gen_info(state) + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is empty. Nothing to save.") + return None + tasks_to_save = [] + with lock: + for task in queue: + if task is None or not isinstance(task, dict): continue + params_copy = task.get('params', {}).copy() + if 'state' in params_copy: + del params_copy['state'] + task_data = { + "id": task.get('id', 0), + "image2video": task.get('image2video', False), + "params": params_copy, + "repeats": task.get('repeats', 1), + "length": task.get('length', 0), + "steps": task.get('steps', 0), + "prompt": task.get('prompt', ''), + "start_image_paths": task.get('start_image_paths', []), + "end_image_path": task.get('end_image_path', None), + } + tasks_to_save.append(task_data) + try: + json_string = json.dumps(tasks_to_save, indent=4) + print("Queue data prepared as JSON string for client-side download.") + return json_string + except Exception as e: + print(f"Error converting queue to JSON string: {e}") + gr.Warning(f"Failed to prepare queue data for saving: {e}") + return None + +def load_queue_action(filepath, state_dict): + global task_id + if not filepath or not Path(filepath.name).is_file(): + gr.Warning(f"No file selected or file not found.") + return None + loaded_queue, max_id = load_queue_from_json(filepath.name) + + with lock: + gen = get_gen_info(state_dict) + + if "queue" not in gen or not isinstance(gen.get("queue"), list): + gen["queue"] = [] + + existing_queue = gen["queue"] + existing_queue.clear() + existing_queue.extend(loaded_queue) + + task_id = max(task_id, max_id) + gen["prompts_max"] = len(existing_queue) + + gr.Info(f"Queue loaded from {Path(filepath.name).name}") + return None + +def update_queue_ui_after_load(state_dict): + gen = get_gen_info(state_dict) + queue = gen.get("queue", []) + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + return gr.update(value=raw_data, visible=is_visible), gr.update(visible=is_visible) + +def clear_queue_action(state): + gen = get_gen_info(state) + with lock: + queue = gen.get("queue", []) + if not queue: + gr.Info("Queue is already empty.") + return update_queue_data([]) + + queue.clear() + gen["prompts_max"] = 0 + gr.Info("Queue cleared.") + return update_queue_data([]) + +def autoload_queue(state_dict): + global task_id + gen = get_gen_info(state_dict) + queue_changed = False + if Path(QUEUE_FILENAME).exists(): + print(f"Autoloading queue from {QUEUE_FILENAME}...") + if not gen["queue"]: + loaded_queue, max_id = load_queue_from_json(QUEUE_FILENAME) + if loaded_queue: + with lock: + gen["queue"] = loaded_queue + task_id = max(task_id, max_id) + gen["prompts_max"] = len(loaded_queue) + queue_changed = True + +def autosave_queue(): + print("Attempting to autosave queue on exit...") + if global_state_ref is None or t2v_state_ref is None or i2v_state_ref is None: + print("State references not available for autosave. Skipping.") + return + try: + last_tab_was_i2v = global_state_ref.get("last_tab_was_image2video", use_image2video) + active_state = i2v_state_ref if last_tab_was_i2v else t2v_state_ref + + if active_state: + gen = get_gen_info(active_state) + queue = gen.get("queue", []) + if queue: + print(f"Autosaving queue ({len(queue)} items) from {'i2v' if last_tab_was_i2v else 't2v'} state...") + save_queue_to_json(queue, QUEUE_FILENAME) + else: + print("Queue is empty, autosave skipped.") + else: + print("Could not determine active state for autosave.") + except Exception as e: + print(f"Error during autosave: {e}") def get_queue_table(queue): data = [] if len(queue) == 1: - return data - - # def td(l, content, width =None): - # if width !=None: - # l.append("" + content + "") - # else: - # l.append("" + content + "") - - # data.append("") - + return data for i, item in enumerate(queue): if i==0: continue @@ -381,22 +671,6 @@ def get_queue_table(queue): start_img_md = f'Start' if end_img_uri: end_img_md = f'End' - # if i % 2 == 1: - # data.append("") - # else: - # data.append("") - - # td(data,str(item.get('repeats', "1")) ) - # td(data, prompt_cell, "100%") - # td(data, num_steps, "100%") - # td(data, start_img_md) - # td(data, end_img_md) - # td(data, "↑") - # td(data, "↓") - # td(data, "✖") - # data.append("") - # data.append("
QtyPromptSteps
") - # return ''.join(data) data.append([item.get('repeats', "1"), prompt_cell, @@ -409,18 +683,13 @@ def get_queue_table(queue): "✖" ]) return data + def update_queue_data(queue): - data = get_queue_table(queue) - - # if len(data) == 0: - # return gr.HTML(visible=False) - # else: - # return gr.HTML(value=data, visible= True) if len(data) == 0: - return gr.DataFrame(visible=False) + return gr.update(value=[], visible=False) else: - return gr.DataFrame(value=data, visible= True) + return gr.update(value=data, visible=True) def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1364,7 +1633,6 @@ def is_gen_location(state): if gen_location == None: return None return state["image2video"] == gen_location - def refresh_gallery(state, msg): gen = get_gen_info(state) @@ -1409,8 +1677,6 @@ def refresh_gallery(state, msg): html_output = gr.HTML(html, visible= True) return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive) - - def finalize_generation(state): gen = get_gen_info(state) choice = gen.get("selected",0) @@ -1418,15 +1684,27 @@ def finalize_generation(state): del gen["in_progress"] if gen.get("last_selected", True): file_list = gen.get("file_list", []) - choice = len(file_list) - 1 - + choice = max(len(file_list) - 1, 0) if file_list else 0 gen["extra_orders"] = 0 time.sleep(0.2) global gen_in_progress gen_in_progress = False - return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="") + queue = gen.get("queue", []) + queue_is_visible = bool(queue) + + current_gen_column_visible = queue_is_visible + gen_info_visible = queue_is_visible + + return ( + gr.Gallery(selected_index=choice), + gr.Button(interactive=True), + gr.Button(visible=True), + gr.Button(visible=False), + gr.Column(visible=current_gen_column_visible), + gr.HTML(visible=gen_info_visible, value="") + ) def refresh_gallery_on_trigger(state): gen = get_gen_info(state) @@ -1577,6 +1855,47 @@ def generate_video( trans = wan_model.model temp_filename = None + loaded_image_source1_pil = None + loaded_image_source2_pil = None + + image_source1_paths = image_source1 + image_source2_path = image_source2 + + loaded_imgs = [] + + if image_source1_paths: + if not isinstance(image_source1_paths, list): + image_source1_paths = [image_source1_paths] + try: + valid_paths = [p for p in image_source1_paths if p and Path(p).is_file()] + if not valid_paths and image2video: + raise FileNotFoundError(f"Required start image file(s) not found at path(s): {image_source1_paths}") + elif not valid_paths and not image2video: + print(f"Warning: No valid reference images found for VACE at path(s): {image_source1_paths}") + else: + loaded_imgs = [convert_image(Image.open(p)) for p in valid_paths] + if image2video: + if not loaded_imgs: + raise ValueError(f"Could not load the required start image for I2V task from: {image_source1_paths}") + loaded_image_source1_pil = loaded_imgs[0] + if len(loaded_imgs) > 1: + print(f"Warning: Multiple start images ({len(loaded_imgs)}) found for a single I2V task. Using only the first one: {valid_paths[0]}") + + except FileNotFoundError as e: + print(f"Error: Image file not found. {e}") + raise gr.Error(f"Required image file not found: {e}") + except Exception as e: + print(f"Error loading start image file(s) {image_source1_paths}: {e}") + raise gr.Error(f"Error loading start image file: {e}") + if image_source2_path and isinstance(image_source2_path, str) and Path(image_source2_path).is_file(): + try: + loaded_image_source2_pil = convert_image(Image.open(image_source2_path)) + except FileNotFoundError: + print(f"Error: End image file not found at path: {image_source2_path}. Proceeding without it.") + except Exception as e: + print(f"Error loading end image file {image_source2_path}: {e}") + elif image_source2_path: + print(f"Warning: End image path provided but not found: {image_source2_path}") loras = state["loras"] if len(loras) > 0: @@ -1653,14 +1972,22 @@ def generate_video( raise gr.Error("Teacache not supported for this model") if "Vace" in model_filename: - resolution_reformated = str(height) + "*" + str(width) - src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2], - [image_source3], - [image_source1], - video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu", - trim_video=max_frames) + resolution_reformated = str(height) + "*" + str(width) + src_video_path = image_source2 # Assuming image_source2 holds the video path for Vace + src_mask_path = image_source3 # Assuming image_source3 holds the mask video path for Vace + src_ref_images_paths = image_source1 # Assuming image_source1 holds the ref image paths for Vace + + # Ensure prepare_source gets paths as expected. If it needs loaded data, load here. + # Let's assume it takes paths for now. + src_video, src_mask, src_ref_images = wan_model.prepare_source( + [src_video_path] if src_video_path else [], + [src_mask_path] if src_mask_path else [], + src_ref_images_paths if src_ref_images_paths else [], # Should be list of paths + video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu", + trim_video=max_frames + ) else: - src_video, src_mask, src_ref_images = None, None, None + src_video, src_mask, src_ref_images = None, None, None import random @@ -1700,6 +2027,7 @@ def generate_video( gen["progress_args"] = progress_args try: + print(f"[{datetime.now()}] generate_video: Entering generation loop, repeat_no={repeat_no}") start_time = time.time() # with tracker_lock: # progress_tracker[task_id] = { @@ -1718,11 +2046,12 @@ def generate_video( trans.previous_residual_cond = None video_no += 1 + print(f"[{datetime.now()}] generate_video: Calling wan_model.generate for seed {seed}...") if image2video: samples = wan_model.generate( prompt, - image_source1, - image_source2 if image_source2 != None else None, + loaded_image_source1_pil, + loaded_image_source2_pil if loaded_image_source2_pil != None else None, frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution], shift=flow_shift, @@ -1767,7 +2096,9 @@ def generate_video( cfg_zero_step = cfg_zero_step, ) # samples = torch.empty( (1,2)) #for testing + print(f"[{datetime.now()}] generate_video: wan_model.generate completed for seed {seed}. Samples is None: {samples is None}") except Exception as e: + print(f"[{datetime.now()}] generate_video: Exception during generation: {e}") if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) offload.last_offload_obj.unload_all() @@ -1801,6 +2132,7 @@ def generate_video( raise gr.Error(new_error, print_exception= False) finally: + print(f"[{datetime.now()}] generate_video: Exiting generation loop iteration for seed {seed}") pass # with tracker_lock: # if task_id in progress_tracker: @@ -1979,7 +2311,7 @@ def process_tasks(state, progress=gr.Progress()): task = queue[0] task_id = task["id"] params = task['params'] - iterator = iter(generate_video(task_id, progress, **params)) + iterator = iter(generate_video(task_id, progress, **params, state=state)) while True: try: ok = False @@ -2536,7 +2868,7 @@ def generate_video_tab(image2video=False): delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) - state = gr.State(state_dict) + state = gr.State(state_dict) vace_model = "Vace" in filename and not image2video trigger_refresh_input_type = gr.Text(interactive= False, visible= False) with gr.Column(visible= image2video or vace_model) as image_prompt_column: @@ -2548,17 +2880,17 @@ def generate_video_tab(image2video=False): if args.multiple_images: image_source1 = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", + label="Images as starting points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True) else: - image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil") + image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="filepath") if args.multiple_images: image_source2 = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", + label="Images as ending points for new videos", type ="filepath", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1) else: - image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1) + image_source2 = gr.Image(label= "Last Image for a new video", type ="filepath", visible=image_prompt_type==1) image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2]) @@ -2570,7 +2902,7 @@ def generate_video_tab(image2video=False): image_prompt_type ="I" image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model) image_source1 = gr.Gallery( - label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", + label="Reference Images of Faces and / or Object to be found in the Video", type ="filepath", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type ) image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type ) @@ -2809,10 +3141,9 @@ def generate_video_tab(image2video=False): interactive=False, col_count=(9, "fixed"), wrap=True, - value=[], + value=[], # Set value=[] initially line_breaks= True, - visible= False, - # every=1, + visible= False, # Set visible=False initially elem_id="queue_df" ) # queue_df = gr.HTML("", @@ -2820,6 +3151,12 @@ def generate_video_tab(image2video=False): # elem_id="queue_df" # ) + with gr.Row(): + queue_json_output = gr.Text(visible=False, label="_queue_json") + save_queue_btn = gr.DownloadButton("Save Queue") + load_queue_btn = gr.UploadButton("Load Queue", file_types=[".json"], type="filepath") + clear_queue_btn = gr.Button("Clear Queue") + def handle_selection(state, evt: gr.SelectData): gen = get_gen_info(state) queue = gen.get("queue", []) @@ -2885,6 +3222,62 @@ def generate_video_tab(image2video=False): outputs=[gen_progress_html], show_progress="hidden" ) + trigger_download_js = """ + (jsonString) => { + if (!jsonString) { + console.log("No JSON data received, skipping download."); + return; + } + const blob = new Blob([jsonString], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'queue.json'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } + """ + save_queue_btn.click( + fn=save_queue_action, + inputs=[state], + outputs=[queue_json_output] + ).then( + fn=None, + inputs=[queue_json_output], + outputs=None, + js=trigger_download_js + ) + load_queue_btn.upload( + fn=load_queue_action, + inputs=[load_queue_btn, state], + outputs=None + ).then( + fn=update_queue_ui_after_load, + inputs=[state], + outputs=[queue_df, current_gen_column], + ).then( + fn=maybe_start_processing, + inputs=[state], + outputs=[gen_status], + show_progress="minimal", + trigger_mode="always_last" + ).then( + fn=finalize_generation, + inputs=[state], + outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info], + trigger_mode="always_last" + ) + clear_queue_btn.click( + clear_queue_action, + inputs=[state], + outputs=[queue_df] + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[current_gen_column] + ) save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers, @@ -2984,6 +3377,10 @@ def generate_video_tab(image2video=False): ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) add_to_queue_btn.click(fn=validate_wizard_prompt, @@ -2996,15 +3393,35 @@ def generate_video_tab(image2video=False): ).then( fn=update_status, inputs = [state], + ).then( + fn=lambda state_dict: gr.update(visible=bool(get_gen_info(state_dict).get("queue", []))), + inputs=[state], + outputs=[current_gen_column] ) - close_modal_button.click( lambda: gr.update(visible=False), inputs=[], outputs=[modal_container] ) - return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state + return ( + loras_column, + loras_choices, + presets_column, + lset_name, + header, + light_sync, + full_sync, + state, + queue_df, + current_gen_column, + gen_status, + output, + abort_btn, + generate_btn, + add_to_queue_btn, + gen_info + ) def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state): with gr.Row(): @@ -3204,7 +3621,7 @@ def generate_about_tab(): gr.Markdown("- Remade_AI : for creating their awesome Loras collection") -def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): +def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) @@ -3384,7 +3801,6 @@ def create_demo(): } #queue_df table { width: 100%; - overflow: hidden !important; } #queue_df { overflow-x: hidden !important; @@ -3553,15 +3969,20 @@ def create_demo(): gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") - global_dict = {} - global_dict["last_tab_was_image2video"] = use_image2video - global_state = gr.State(global_dict) with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs: with gr.Tab("Text To Video", id="t2v") as t2v_tab: - t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False) + (t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, + t2v_header, t2v_light_sync, t2v_full_sync, t2v_state, t2v_queue_df, + t2v_current_gen_column, t2v_gen_status, t2v_output, t2v_abort_btn, + t2v_generate_btn, t2v_add_to_queue_btn, t2v_gen_info) = generate_video_tab(False) + t2v_state_ref = t2v_state.value with gr.Tab("Image To Video", id="i2v") as i2v_tab: - i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True) + (i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, + i2v_header, i2v_light_sync, i2v_full_sync, i2v_state, i2v_queue_df, + i2v_current_gen_column, i2v_gen_status, i2v_output, i2v_abort_btn, + i2v_generate_btn, i2v_add_to_queue_btn, i2v_gen_info) = generate_video_tab(True) + i2v_state_ref = i2v_state.value if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) @@ -3569,9 +3990,73 @@ def create_demo(): generate_configuration_tab() with gr.Tab("About"): generate_about_tab() + def run_autoload_and_update(t2v_st_val, i2v_st_val): + active_state_dict = i2v_st_val if use_image2video else t2v_st_val + target_queue_df_update = gr.skip() + target_column_update = gr.skip() + other_queue_df_update = gr.skip() + other_column_update = gr.skip() + autoload_queue(active_state_dict) + gen = get_gen_info(active_state_dict) + queue = gen.get("queue", []) + raw_data = get_queue_table(queue) + is_visible = len(raw_data) > 0 + should_start_processing = bool(queue) + df_update = gr.update(value=raw_data, visible=is_visible) + col_update = gr.update(visible=is_visible) + + if use_image2video: + target_queue_df_update = df_update + target_column_update = col_update + other_queue_df_update = gr.DataFrame(value=[], visible=False) + other_column_update = gr.update(visible=False) + return ( + other_queue_df_update, other_column_update, + target_queue_df_update, target_column_update, + should_start_processing, active_state_dict + ) + else: + target_queue_df_update = df_update + target_column_update = col_update + other_queue_df_update = gr.DataFrame(value=[], visible=False) + other_column_update = gr.update(visible=False) + return ( + target_queue_df_update, target_column_update, + other_queue_df_update, other_column_update, + should_start_processing, active_state_dict + ) + + should_start_flag = gr.State(False) + updated_state_on_load = gr.State({}) + load_outputs_ui = [ + t2v_queue_df, t2v_current_gen_column, + i2v_queue_df, i2v_current_gen_column, + should_start_flag, updated_state_on_load + ] + demo.load( + fn=run_autoload_and_update, + inputs=[t2v_state, i2v_state], + outputs=load_outputs_ui + ).then( + fn=maybe_trigger_processing, + inputs=[should_start_flag, updated_state_on_load], + outputs=[i2v_gen_status if use_image2video else t2v_gen_status], + ).then( + fn=finalize_generation, + inputs=[updated_state_on_load], + outputs=[ + i2v_output if use_image2video else t2v_output, + i2v_abort_btn if use_image2video else t2v_abort_btn, + i2v_generate_btn if use_image2video else t2v_generate_btn, + i2v_add_to_queue_btn if use_image2video else t2v_add_to_queue_btn, + i2v_current_gen_column if use_image2video else t2v_current_gen_column, + i2v_gen_info if use_image2video else t2v_gen_info + ], + trigger_mode="always_last" + ) main_tabs.select( fn=on_tab_select, - inputs=[global_state, t2v_state, i2v_state], + inputs=[t2v_state, i2v_state], outputs=[ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync @@ -3580,6 +4065,7 @@ def create_demo(): return demo if __name__ == "__main__": + autosave_queue() # threading.Thread(target=runner, daemon=True).start() os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) From 5d8f95b5d616dd82f63265884879d60083e0016b Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 17:35:55 +1000 Subject: [PATCH 4/6] fix double scrollbar styling issue --- gradio_server.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 38cba9e..d0c9b4e 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -1858,25 +1858,25 @@ def generate_video( loaded_image_source1_pil = None loaded_image_source2_pil = None - image_source1_paths = image_source1 + image_source1_path = image_source1 image_source2_path = image_source2 loaded_imgs = [] - if image_source1_paths: - if not isinstance(image_source1_paths, list): - image_source1_paths = [image_source1_paths] + if image_source1_path: + if not isinstance(image_source1_path, list): + image_source1_path = [image_source1_path] try: - valid_paths = [p for p in image_source1_paths if p and Path(p).is_file()] + valid_paths = [p for p in image_source1_path if p and Path(p).is_file()] if not valid_paths and image2video: - raise FileNotFoundError(f"Required start image file(s) not found at path(s): {image_source1_paths}") + raise FileNotFoundError(f"Required start image file(s) not found at path(s): {image_source1_path}") elif not valid_paths and not image2video: - print(f"Warning: No valid reference images found for VACE at path(s): {image_source1_paths}") + print(f"Warning: No valid reference images found for VACE at path(s): {image_source1_path}") else: loaded_imgs = [convert_image(Image.open(p)) for p in valid_paths] if image2video: if not loaded_imgs: - raise ValueError(f"Could not load the required start image for I2V task from: {image_source1_paths}") + raise ValueError(f"Could not load the required start image for I2V task from: {image_source1_path}") loaded_image_source1_pil = loaded_imgs[0] if len(loaded_imgs) > 1: print(f"Warning: Multiple start images ({len(loaded_imgs)}) found for a single I2V task. Using only the first one: {valid_paths[0]}") @@ -1885,7 +1885,7 @@ def generate_video( print(f"Error: Image file not found. {e}") raise gr.Error(f"Required image file not found: {e}") except Exception as e: - print(f"Error loading start image file(s) {image_source1_paths}: {e}") + print(f"Error loading start image file(s) {image_source1_path}: {e}") raise gr.Error(f"Error loading start image file: {e}") if image_source2_path and isinstance(image_source2_path, str) and Path(image_source2_path).is_file(): try: @@ -3803,6 +3803,7 @@ def create_demo(): width: 100%; } #queue_df { + scrollbar-width: none !important; overflow-x: hidden !important; overflow-y: auto; } From f00c6435ae7e0e8a7dd25a3863990f3db0e95276 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 19:43:34 +1000 Subject: [PATCH 5/6] move autosave --- gradio_server.py | 71 ++++++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index d0c9b4e..dc440a2 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -26,6 +26,7 @@ from wan.utils import prompt_parser import base64 import io from PIL import Image +import atexit PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -41,9 +42,9 @@ task_id = 0 # tracker_lock = threading.Lock() last_model_type = None QUEUE_FILENAME = "queue.json" -t2v_state_ref = None -i2v_state_ref = None -global_state_ref = None +global_t2v_state_dict = None +global_i2v_state_dict = None +global_state = {} def format_time(seconds): if seconds < 60: @@ -629,25 +630,40 @@ def autoload_queue(state_dict): def autosave_queue(): print("Attempting to autosave queue on exit...") - if global_state_ref is None or t2v_state_ref is None or i2v_state_ref is None: - print("State references not available for autosave. Skipping.") - return - try: - last_tab_was_i2v = global_state_ref.get("last_tab_was_image2video", use_image2video) - active_state = i2v_state_ref if last_tab_was_i2v else t2v_state_ref + global global_t2v_state_dict, global_i2v_state_dict, global_state - if active_state: - gen = get_gen_info(active_state) - queue = gen.get("queue", []) - if queue: - print(f"Autosaving queue ({len(queue)} items) from {'i2v' if last_tab_was_i2v else 't2v'} state...") - save_queue_to_json(queue, QUEUE_FILENAME) - else: - print("Queue is empty, autosave skipped.") + active_state_dict = None + last_tab_was_i2v = None + + if global_state and 'last_tab_was_image2video' in global_state: + last_tab_was_i2v = global_state['last_tab_was_image2video'] + active_state_dict = global_i2v_state_dict if last_tab_was_i2v else global_t2v_state_dict + print(f"Using last active tab info: {'i2v' if last_tab_was_i2v else 't2v'}") + else: + print("Last active tab info not found, using fallback logic.") + if use_image2video and global_i2v_state_dict: + active_state_dict = global_i2v_state_dict + last_tab_was_i2v = True + elif not use_image2video and global_t2v_state_dict: + active_state_dict = global_t2v_state_dict + last_tab_was_i2v = False + elif global_i2v_state_dict: + active_state_dict = global_i2v_state_dict + last_tab_was_i2v = True + elif global_t2v_state_dict: + active_state_dict = global_t2v_state_dict + last_tab_was_i2v = False + if active_state_dict: + gen = active_state_dict.get("gen", {}) + queue = gen.get("queue", []) + if queue: + tab_name = 'i2v' if last_tab_was_i2v else 't2v' + print(f"Autosaving queue ({len(queue)} items) from {tab_name} state dict...") + save_queue_to_json(queue, QUEUE_FILENAME) else: - print("Could not determine active state for autosave.") - except Exception as e: - print(f"Error during autosave: {e}") + print("Queue is empty in the determined active state dictionary, autosave skipped.") + else: + print(f"Could not determine active state dictionary for autosave. T2V dict exists: {global_t2v_state_dict is not None}, I2V dict exists: {global_i2v_state_dict is not None}, Last active tab known: {last_tab_was_i2v is not None}") def get_queue_table(queue): data = [] @@ -2787,7 +2803,7 @@ def check_refresh_input_type(state): def generate_video_tab(image2video=False): filename = transformer_filename_i2v if image2video else transformer_filename_t2v ui_defaults= get_default_settings(filename, image2video) - + global global_t2v_state_dict, global_i2v_state_dict state_dict = {} state_dict["advanced"] = advanced @@ -2797,6 +2813,10 @@ def generate_video_tab(image2video=False): gen = dict() gen["queue"] = [] state_dict["gen"] = gen + if image2video: + global_i2v_state_dict = state_dict + else: + global_t2v_state_dict = state_dict preset_to_load = lora_preselected_preset if use_image2video == image2video else "" @@ -3622,9 +3642,9 @@ def generate_about_tab(): def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): + global global_state t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) - new_t2v = evt.index == 0 new_i2v = evt.index == 1 i2v_light_sync = gr.Text() @@ -3639,7 +3659,8 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): else: gen = t2v_state["gen"] i2v_state["gen"] = gen - + if new_t2v or new_i2v: + global_state['last_tab_was_image2video'] = new_i2v if new_t2v or new_i2v: if last_tab_was_image2video != None and new_t2v != new_i2v: @@ -3977,13 +3998,11 @@ def create_demo(): t2v_header, t2v_light_sync, t2v_full_sync, t2v_state, t2v_queue_df, t2v_current_gen_column, t2v_gen_status, t2v_output, t2v_abort_btn, t2v_generate_btn, t2v_add_to_queue_btn, t2v_gen_info) = generate_video_tab(False) - t2v_state_ref = t2v_state.value with gr.Tab("Image To Video", id="i2v") as i2v_tab: (i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state, i2v_queue_df, i2v_current_gen_column, i2v_gen_status, i2v_output, i2v_abort_btn, i2v_generate_btn, i2v_add_to_queue_btn, i2v_gen_info) = generate_video_tab(True) - i2v_state_ref = i2v_state.value if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) @@ -4066,8 +4085,8 @@ def create_demo(): return demo if __name__ == "__main__": - autosave_queue() # threading.Thread(target=runner, daemon=True).start() + atexit.register(autosave_queue) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) if os.name == "nt": From 541ea19f3fe05071035f4b81f2193aac4fabb54c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 7 Apr 2025 07:34:29 +1000 Subject: [PATCH 6/6] fix queue autosaving on exit, restore original global_state accidentally removed for reload_model 1 behaviour --- gradio_server.py | 59 +++++++++++++----------------------------------- 1 file changed, 16 insertions(+), 43 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index dc440a2..5c08f1b 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -42,9 +42,7 @@ task_id = 0 # tracker_lock = threading.Lock() last_model_type = None QUEUE_FILENAME = "queue.json" -global_t2v_state_dict = None -global_i2v_state_dict = None -global_state = {} +global_dict = [] def format_time(seconds): if seconds < 60: @@ -630,40 +628,13 @@ def autoload_queue(state_dict): def autosave_queue(): print("Attempting to autosave queue on exit...") - global global_t2v_state_dict, global_i2v_state_dict, global_state + global global_dict - active_state_dict = None - last_tab_was_i2v = None - - if global_state and 'last_tab_was_image2video' in global_state: - last_tab_was_i2v = global_state['last_tab_was_image2video'] - active_state_dict = global_i2v_state_dict if last_tab_was_i2v else global_t2v_state_dict - print(f"Using last active tab info: {'i2v' if last_tab_was_i2v else 't2v'}") + if global_dict: + print(f"Autosaving queue ({len(global_dict)} items) from state dict...") + save_queue_to_json(global_dict, QUEUE_FILENAME) else: - print("Last active tab info not found, using fallback logic.") - if use_image2video and global_i2v_state_dict: - active_state_dict = global_i2v_state_dict - last_tab_was_i2v = True - elif not use_image2video and global_t2v_state_dict: - active_state_dict = global_t2v_state_dict - last_tab_was_i2v = False - elif global_i2v_state_dict: - active_state_dict = global_i2v_state_dict - last_tab_was_i2v = True - elif global_t2v_state_dict: - active_state_dict = global_t2v_state_dict - last_tab_was_i2v = False - if active_state_dict: - gen = active_state_dict.get("gen", {}) - queue = gen.get("queue", []) - if queue: - tab_name = 'i2v' if last_tab_was_i2v else 't2v' - print(f"Autosaving queue ({len(queue)} items) from {tab_name} state dict...") - save_queue_to_json(queue, QUEUE_FILENAME) - else: - print("Queue is empty in the determined active state dictionary, autosave skipped.") - else: - print(f"Could not determine active state dictionary for autosave. T2V dict exists: {global_t2v_state_dict is not None}, I2V dict exists: {global_i2v_state_dict is not None}, Last active tab known: {last_tab_was_i2v is not None}") + print("Queue is empty in the determined active state dictionary, autosave skipped.") def get_queue_table(queue): data = [] @@ -2803,7 +2774,6 @@ def check_refresh_input_type(state): def generate_video_tab(image2video=False): filename = transformer_filename_i2v if image2video else transformer_filename_t2v ui_defaults= get_default_settings(filename, image2video) - global global_t2v_state_dict, global_i2v_state_dict state_dict = {} state_dict["advanced"] = advanced @@ -2813,10 +2783,6 @@ def generate_video_tab(image2video=False): gen = dict() gen["queue"] = [] state_dict["gen"] = gen - if image2video: - global_i2v_state_dict = state_dict - else: - global_t2v_state_dict = state_dict preset_to_load = lora_preselected_preset if use_image2video == image2video else "" @@ -3641,8 +3607,7 @@ def generate_about_tab(): gr.Markdown("- Remade_AI : for creating their awesome Loras collection") -def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): - global global_state +def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) new_t2v = evt.index == 0 @@ -3653,12 +3618,15 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): t2v_full_sync = gr.Text() last_tab_was_image2video =global_state.get("last_tab_was_image2video", None) + global global_dict if last_tab_was_image2video == None or last_tab_was_image2video: gen = i2v_state["gen"] t2v_state["gen"] = gen + global_dict = gen.get("queue", []) else: gen = t2v_state["gen"] i2v_state["gen"] = gen + global_dict = gen.get("queue", []) if new_t2v or new_i2v: global_state['last_tab_was_image2video'] = new_i2v @@ -3991,6 +3959,9 @@ def create_demo(): gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") + state_dict = {} + state_dict["last_tab_was_image2video"] = use_image2video + global_state = gr.State(state_dict) with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs: with gr.Tab("Text To Video", id="t2v") as t2v_tab: @@ -4019,6 +3990,8 @@ def create_demo(): autoload_queue(active_state_dict) gen = get_gen_info(active_state_dict) queue = gen.get("queue", []) + global global_dict + global_dict = queue raw_data = get_queue_table(queue) is_visible = len(raw_data) > 0 should_start_processing = bool(queue) @@ -4076,7 +4049,7 @@ def create_demo(): ) main_tabs.select( fn=on_tab_select, - inputs=[t2v_state, i2v_state], + inputs=[global_state, t2v_state, i2v_state], outputs=[ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync