From f00c6435ae7e0e8a7dd25a3863990f3db0e95276 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 6 Apr 2025 19:43:34 +1000 Subject: [PATCH] move autosave --- gradio_server.py | 71 ++++++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index d0c9b4e..dc440a2 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -26,6 +26,7 @@ from wan.utils import prompt_parser import base64 import io from PIL import Image +import atexit PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.4" @@ -41,9 +42,9 @@ task_id = 0 # tracker_lock = threading.Lock() last_model_type = None QUEUE_FILENAME = "queue.json" -t2v_state_ref = None -i2v_state_ref = None -global_state_ref = None +global_t2v_state_dict = None +global_i2v_state_dict = None +global_state = {} def format_time(seconds): if seconds < 60: @@ -629,25 +630,40 @@ def autoload_queue(state_dict): def autosave_queue(): print("Attempting to autosave queue on exit...") - if global_state_ref is None or t2v_state_ref is None or i2v_state_ref is None: - print("State references not available for autosave. Skipping.") - return - try: - last_tab_was_i2v = global_state_ref.get("last_tab_was_image2video", use_image2video) - active_state = i2v_state_ref if last_tab_was_i2v else t2v_state_ref + global global_t2v_state_dict, global_i2v_state_dict, global_state - if active_state: - gen = get_gen_info(active_state) - queue = gen.get("queue", []) - if queue: - print(f"Autosaving queue ({len(queue)} items) from {'i2v' if last_tab_was_i2v else 't2v'} state...") - save_queue_to_json(queue, QUEUE_FILENAME) - else: - print("Queue is empty, autosave skipped.") + active_state_dict = None + last_tab_was_i2v = None + + if global_state and 'last_tab_was_image2video' in global_state: + last_tab_was_i2v = global_state['last_tab_was_image2video'] + active_state_dict = global_i2v_state_dict if last_tab_was_i2v else global_t2v_state_dict + print(f"Using last active tab info: {'i2v' if last_tab_was_i2v else 't2v'}") + else: + print("Last active tab info not found, using fallback logic.") + if use_image2video and global_i2v_state_dict: + active_state_dict = global_i2v_state_dict + last_tab_was_i2v = True + elif not use_image2video and global_t2v_state_dict: + active_state_dict = global_t2v_state_dict + last_tab_was_i2v = False + elif global_i2v_state_dict: + active_state_dict = global_i2v_state_dict + last_tab_was_i2v = True + elif global_t2v_state_dict: + active_state_dict = global_t2v_state_dict + last_tab_was_i2v = False + if active_state_dict: + gen = active_state_dict.get("gen", {}) + queue = gen.get("queue", []) + if queue: + tab_name = 'i2v' if last_tab_was_i2v else 't2v' + print(f"Autosaving queue ({len(queue)} items) from {tab_name} state dict...") + save_queue_to_json(queue, QUEUE_FILENAME) else: - print("Could not determine active state for autosave.") - except Exception as e: - print(f"Error during autosave: {e}") + print("Queue is empty in the determined active state dictionary, autosave skipped.") + else: + print(f"Could not determine active state dictionary for autosave. T2V dict exists: {global_t2v_state_dict is not None}, I2V dict exists: {global_i2v_state_dict is not None}, Last active tab known: {last_tab_was_i2v is not None}") def get_queue_table(queue): data = [] @@ -2787,7 +2803,7 @@ def check_refresh_input_type(state): def generate_video_tab(image2video=False): filename = transformer_filename_i2v if image2video else transformer_filename_t2v ui_defaults= get_default_settings(filename, image2video) - + global global_t2v_state_dict, global_i2v_state_dict state_dict = {} state_dict["advanced"] = advanced @@ -2797,6 +2813,10 @@ def generate_video_tab(image2video=False): gen = dict() gen["queue"] = [] state_dict["gen"] = gen + if image2video: + global_i2v_state_dict = state_dict + else: + global_t2v_state_dict = state_dict preset_to_load = lora_preselected_preset if use_image2video == image2video else "" @@ -3622,9 +3642,9 @@ def generate_about_tab(): def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): + global global_state t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) - new_t2v = evt.index == 0 new_i2v = evt.index == 1 i2v_light_sync = gr.Text() @@ -3639,7 +3659,8 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): else: gen = t2v_state["gen"] i2v_state["gen"] = gen - + if new_t2v or new_i2v: + global_state['last_tab_was_image2video'] = new_i2v if new_t2v or new_i2v: if last_tab_was_image2video != None and new_t2v != new_i2v: @@ -3977,13 +3998,11 @@ def create_demo(): t2v_header, t2v_light_sync, t2v_full_sync, t2v_state, t2v_queue_df, t2v_current_gen_column, t2v_gen_status, t2v_output, t2v_abort_btn, t2v_generate_btn, t2v_add_to_queue_btn, t2v_gen_info) = generate_video_tab(False) - t2v_state_ref = t2v_state.value with gr.Tab("Image To Video", id="i2v") as i2v_tab: (i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state, i2v_queue_df, i2v_current_gen_column, i2v_gen_status, i2v_output, i2v_abort_btn, i2v_generate_btn, i2v_add_to_queue_btn, i2v_gen_info) = generate_video_tab(True) - i2v_state_ref = i2v_state.value if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) @@ -4066,8 +4085,8 @@ def create_demo(): return demo if __name__ == "__main__": - autosave_queue() # threading.Thread(target=runner, daemon=True).start() + atexit.register(autosave_queue) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) if os.name == "nt":