move autosave
This commit is contained in:
parent
5d8f95b5d6
commit
f00c6435ae
|
|
@ -26,6 +26,7 @@ from wan.utils import prompt_parser
|
|||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
import atexit
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.3.4"
|
||||
|
|
@ -41,9 +42,9 @@ task_id = 0
|
|||
# tracker_lock = threading.Lock()
|
||||
last_model_type = None
|
||||
QUEUE_FILENAME = "queue.json"
|
||||
t2v_state_ref = None
|
||||
i2v_state_ref = None
|
||||
global_state_ref = None
|
||||
global_t2v_state_dict = None
|
||||
global_i2v_state_dict = None
|
||||
global_state = {}
|
||||
|
||||
def format_time(seconds):
|
||||
if seconds < 60:
|
||||
|
|
@ -629,25 +630,40 @@ def autoload_queue(state_dict):
|
|||
|
||||
def autosave_queue():
|
||||
print("Attempting to autosave queue on exit...")
|
||||
if global_state_ref is None or t2v_state_ref is None or i2v_state_ref is None:
|
||||
print("State references not available for autosave. Skipping.")
|
||||
return
|
||||
try:
|
||||
last_tab_was_i2v = global_state_ref.get("last_tab_was_image2video", use_image2video)
|
||||
active_state = i2v_state_ref if last_tab_was_i2v else t2v_state_ref
|
||||
global global_t2v_state_dict, global_i2v_state_dict, global_state
|
||||
|
||||
if active_state:
|
||||
gen = get_gen_info(active_state)
|
||||
queue = gen.get("queue", [])
|
||||
if queue:
|
||||
print(f"Autosaving queue ({len(queue)} items) from {'i2v' if last_tab_was_i2v else 't2v'} state...")
|
||||
save_queue_to_json(queue, QUEUE_FILENAME)
|
||||
else:
|
||||
print("Queue is empty, autosave skipped.")
|
||||
active_state_dict = None
|
||||
last_tab_was_i2v = None
|
||||
|
||||
if global_state and 'last_tab_was_image2video' in global_state:
|
||||
last_tab_was_i2v = global_state['last_tab_was_image2video']
|
||||
active_state_dict = global_i2v_state_dict if last_tab_was_i2v else global_t2v_state_dict
|
||||
print(f"Using last active tab info: {'i2v' if last_tab_was_i2v else 't2v'}")
|
||||
else:
|
||||
print("Last active tab info not found, using fallback logic.")
|
||||
if use_image2video and global_i2v_state_dict:
|
||||
active_state_dict = global_i2v_state_dict
|
||||
last_tab_was_i2v = True
|
||||
elif not use_image2video and global_t2v_state_dict:
|
||||
active_state_dict = global_t2v_state_dict
|
||||
last_tab_was_i2v = False
|
||||
elif global_i2v_state_dict:
|
||||
active_state_dict = global_i2v_state_dict
|
||||
last_tab_was_i2v = True
|
||||
elif global_t2v_state_dict:
|
||||
active_state_dict = global_t2v_state_dict
|
||||
last_tab_was_i2v = False
|
||||
if active_state_dict:
|
||||
gen = active_state_dict.get("gen", {})
|
||||
queue = gen.get("queue", [])
|
||||
if queue:
|
||||
tab_name = 'i2v' if last_tab_was_i2v else 't2v'
|
||||
print(f"Autosaving queue ({len(queue)} items) from {tab_name} state dict...")
|
||||
save_queue_to_json(queue, QUEUE_FILENAME)
|
||||
else:
|
||||
print("Could not determine active state for autosave.")
|
||||
except Exception as e:
|
||||
print(f"Error during autosave: {e}")
|
||||
print("Queue is empty in the determined active state dictionary, autosave skipped.")
|
||||
else:
|
||||
print(f"Could not determine active state dictionary for autosave. T2V dict exists: {global_t2v_state_dict is not None}, I2V dict exists: {global_i2v_state_dict is not None}, Last active tab known: {last_tab_was_i2v is not None}")
|
||||
|
||||
def get_queue_table(queue):
|
||||
data = []
|
||||
|
|
@ -2787,7 +2803,7 @@ def check_refresh_input_type(state):
|
|||
def generate_video_tab(image2video=False):
|
||||
filename = transformer_filename_i2v if image2video else transformer_filename_t2v
|
||||
ui_defaults= get_default_settings(filename, image2video)
|
||||
|
||||
global global_t2v_state_dict, global_i2v_state_dict
|
||||
state_dict = {}
|
||||
|
||||
state_dict["advanced"] = advanced
|
||||
|
|
@ -2797,6 +2813,10 @@ def generate_video_tab(image2video=False):
|
|||
gen = dict()
|
||||
gen["queue"] = []
|
||||
state_dict["gen"] = gen
|
||||
if image2video:
|
||||
global_i2v_state_dict = state_dict
|
||||
else:
|
||||
global_t2v_state_dict = state_dict
|
||||
|
||||
preset_to_load = lora_preselected_preset if use_image2video == image2video else ""
|
||||
|
||||
|
|
@ -3622,9 +3642,9 @@ def generate_about_tab():
|
|||
|
||||
|
||||
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
||||
global global_state
|
||||
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
||||
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
||||
|
||||
new_t2v = evt.index == 0
|
||||
new_i2v = evt.index == 1
|
||||
i2v_light_sync = gr.Text()
|
||||
|
|
@ -3639,7 +3659,8 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|||
else:
|
||||
gen = t2v_state["gen"]
|
||||
i2v_state["gen"] = gen
|
||||
|
||||
if new_t2v or new_i2v:
|
||||
global_state['last_tab_was_image2video'] = new_i2v
|
||||
|
||||
if new_t2v or new_i2v:
|
||||
if last_tab_was_image2video != None and new_t2v != new_i2v:
|
||||
|
|
@ -3977,13 +3998,11 @@ def create_demo():
|
|||
t2v_header, t2v_light_sync, t2v_full_sync, t2v_state, t2v_queue_df,
|
||||
t2v_current_gen_column, t2v_gen_status, t2v_output, t2v_abort_btn,
|
||||
t2v_generate_btn, t2v_add_to_queue_btn, t2v_gen_info) = generate_video_tab(False)
|
||||
t2v_state_ref = t2v_state.value
|
||||
with gr.Tab("Image To Video", id="i2v") as i2v_tab:
|
||||
(i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name,
|
||||
i2v_header, i2v_light_sync, i2v_full_sync, i2v_state, i2v_queue_df,
|
||||
i2v_current_gen_column, i2v_gen_status, i2v_output, i2v_abort_btn,
|
||||
i2v_generate_btn, i2v_add_to_queue_btn, i2v_gen_info) = generate_video_tab(True)
|
||||
i2v_state_ref = i2v_state.value
|
||||
if not args.lock_config:
|
||||
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
||||
generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
|
||||
|
|
@ -4066,8 +4085,8 @@ def create_demo():
|
|||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
autosave_queue()
|
||||
# threading.Thread(target=runner, daemon=True).start()
|
||||
atexit.register(autosave_queue)
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
server_port = int(args.server_port)
|
||||
if os.name == "nt":
|
||||
|
|
|
|||
Loading…
Reference in New Issue