removed hardcoded input params, moved batch queue generator over to main video generator tab to use it's params

This commit is contained in:
Chris Malone 2025-04-25 17:13:53 +10:00
parent 4a56ffaf22
commit 52d0c5f3f9
1 changed files with 157 additions and 140 deletions

297
wgp.py
View File

@ -91,132 +91,185 @@ def batch_get_sorted_images(folder):
except Exception as e:
return None, f"Error accessing folder {folder}: {e}"
def batch_create_task_entry(task_id, start_img_name, end_img_name, prompt, lora, model):
image_prompt_type = "SE" if end_img_name else "S"
def batch_create_task_entry(task_id, start_img_pil, end_img_pil, params_from_ui):
image_prompt_type = "SE" if end_img_pil else "S"
params = params_from_ui.copy()
params = {
"prompt": prompt,
"negative_prompt": "",
"resolution": "1280x720",
"video_length": 81,
"seed": -1,
"num_inference_steps": 30,
"guidance_scale": 5,
"flow_shift": 5,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"multi_images_gen_type": 0,
"tea_cache_setting": 0,
"tea_cache_start_step_perc": 0,
"loras_multipliers": "",
start_img_pil_copy = start_img_pil.copy() if start_img_pil else None
end_img_pil_copy = end_img_pil.copy() if end_img_pil else None
params.update({
"image_prompt_type": image_prompt_type,
"image_start": start_img_name,
"image_end": end_img_name,
"video_prompt_type": "I",
"image_refs": None,
"video_guide": None,
"video_mask": None,
"camera_type": 1,
"video_source": None,
"keep_frames": "",
"sliding_window_repeat": 0,
"sliding_window_overlap": 16,
"sliding_window_discard_last_frames": 4,
"remove_background_image_ref": True,
"temporal_upsampling": "rife2",
"spatial_upsampling": "",
"RIFLEx_setting": 0,
"slg_switch": 1,
"slg_layers": [9],
"slg_start_perc": 10,
"slg_end_perc": 90,
"cfg_star_switch": 1,
"cfg_zero_step": -1,
"activated_loras": [lora] if lora else [],
"model_filename": model
}
"image_start": start_img_pil_copy,
"image_end": end_img_pil_copy,
})
start_b64 = [pil_to_base64_uri(start_img_pil, format="jpeg", quality=70)] if start_img_pil else None
end_b64 = [pil_to_base64_uri(end_img_pil, format="jpeg", quality=70)] if end_img_pil else None
start_image_data_preview = [start_img_pil_copy] if start_img_pil_copy else None
end_image_data_preview = [end_img_pil_copy] if end_img_pil_copy else None
return {
"id": task_id,
"params": params
"params": params,
"repeats": params.get("repeat_generation", 1),
"length": params.get("video_length"),
"steps": params.get("num_inference_steps"),
"prompt": params.get("prompt", ''),
"start_image_data": start_image_data_preview,
"end_image_data": end_image_data_preview,
"start_image_data_base64": start_b64,
"end_image_data_base64": end_b64,
}
def create_batch_queue(folder, prompt, lora_file, model_file, has_end_frames_checkbox, progress=gr.Progress()):
if not all([folder, prompt, model_file]):
return None, "Error: Folder, Prompt, and Model Filename are required."
def create_batch_tasks_from_folder(batch_folder_input, batch_has_end_frames_cb, ui_params):
global task_id
progress(0, desc="Starting batch queue creation...")
if not batch_folder_input or not batch_folder_input.strip():
return [], "Error: Batch Folder Path is required."
if not os.path.isdir(batch_folder_input):
return [], f"Error: Folder not found or is not a directory: {batch_folder_input}"
images, error = batch_get_sorted_images(folder)
images, error = batch_get_sorted_images(batch_folder_input)
if error:
return None, error
return [], error
if not images:
return None, "Error: No image files found in the specified folder."
return [], "Error: No image files found in the specified folder."
if has_end_frames_checkbox:
if batch_has_end_frames_cb:
if len(images) < 2:
return None, "Error: Need at least 2 images (start/end pair) to form a task when 'Folder contains end frames' is checked."
return [], "Error: Need at least 2 images (start/end pair) to form a task when 'Folder contains start/end image pairs' is checked."
if len(images) % 2 != 0:
gr.Warning(f"Warning: Found an odd number of images ({len(images)}) in paired mode. The last image will be ignored.")
images = images[:-1]
tasks = []
zip_buffer = io.BytesIO()
temp_zip_path = None
tasks_to_add = []
current_task_params = prepare_inputs_dict("state", ui_params.copy())
current_task_params.pop('lset_name', None)
try:
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if has_end_frames_checkbox:
num_tasks = len(images) // 2
progress(0.1, desc="Processing image pairs...")
for i in range(0, len(images) - 1, 2):
task_id = (i // 2) + 1
start_img_path = images[i]
end_img_path = images[i+1]
num_tasks_created = 0
if batch_has_end_frames_cb:
num_tasks_total = len(images) // 2
print(f"Processing {num_tasks_total} image pairs...")
for i in range(0, len(images) - 1, 2):
with lock: # Ensure task_id is unique
current_task_id_local = task_id + 1
task_id += 1
start_arcname = f"task{task_id}_image_start_0{start_img_path.suffix}"
end_arcname = f"task{task_id}_image_end_0{end_img_path.suffix}"
start_img_path = images[i]
end_img_path = images[i+1]
task = batch_create_task_entry(task_id, start_arcname, end_arcname, prompt, lora_file, model_file)
tasks.append(task)
try:
start_img_pil = Image.open(start_img_path)
start_img_pil.load()
start_img_pil = convert_image(start_img_pil)
zipf.write(start_img_path, arcname=start_arcname)
zipf.write(end_img_path, arcname=end_arcname)
progress(0.1 + (0.7 * (task_id / num_tasks)), desc=f"Adding pair {task_id}/{num_tasks}")
else:
num_tasks = len(images)
progress(0.1, desc="Processing single images...")
for i, img_path in enumerate(images):
task_id = i + 1
end_img_pil = Image.open(end_img_path)
end_img_pil.load()
end_img_pil = convert_image(end_img_pil)
start_arcname = f"task{task_id}_image_start_0{img_path.suffix}"
params_for_entry = current_task_params.copy()
params_for_entry['state'] = ui_params['state']
params_for_entry['model_filename'] = ui_params['model_filename']
task = batch_create_task_entry(task_id, start_arcname, None, prompt, lora_file, model_file)
tasks.append(task)
task = batch_create_task_entry(current_task_id_local, start_img_pil, end_img_pil, params_for_entry)
tasks_to_add.append(task)
num_tasks_created += 1
print(f"Added batch task {num_tasks_created}/{num_tasks_total} (Pair: {start_img_path.name}, {end_img_path.name})")
zipf.write(img_path, arcname=start_arcname)
progress(0.1 + (0.7 * (task_id / num_tasks)), desc=f"Adding image {task_id}/{num_tasks}")
except Exception as img_e:
gr.Warning(f"Skipping pair due to error loading images ({start_img_path.name}, {end_img_path.name}): {img_e}")
finally:
if 'start_img_pil' in locals() and hasattr(start_img_pil, 'close'): start_img_pil.close()
if 'end_img_pil' in locals() and hasattr(end_img_pil, 'close'): end_img_pil.close()
progress(0.8, desc="Writing queue manifest...")
json_data = json.dumps(tasks, indent=4)
zipf.writestr("queue.json", json_data)
else:
num_tasks_total = len(images)
print(f"Processing {num_tasks_total} single images...")
for i, img_path in enumerate(images):
with lock:
current_task_id_local = task_id + 1
task_id += 1
progress(0.9, desc="Finalizing zip file...")
zip_buffer.seek(0)
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip", prefix="batch_queue_") as tmp_file:
tmp_file.write(zip_buffer.getvalue())
temp_zip_path = tmp_file.name
try:
start_img_pil = Image.open(img_path)
start_img_pil.load()
start_img_pil = convert_image(start_img_pil)
params_for_entry = current_task_params.copy()
params_for_entry['state'] = ui_params['state']
params_for_entry['model_filename'] = ui_params['model_filename']
task = batch_create_task_entry(current_task_id_local, start_img_pil, None, params_for_entry)
tasks_to_add.append(task)
num_tasks_created += 1
print(f"Added batch task {num_tasks_created}/{num_tasks_total} (Image: {img_path.name})")
progress(1, desc="Batch queue created.")
return temp_zip_path, f"Successfully created queue.zip with {len(tasks)} task(s)."
except Exception as img_e:
gr.Warning(f"Skipping image due to error loading ({img_path.name}): {img_e}")
finally:
if 'start_img_pil' in locals() and hasattr(start_img_pil, 'close'): start_img_pil.close()
if not tasks_to_add:
return [], "Error: No tasks could be created from the images found."
print(f"Successfully prepared {len(tasks_to_add)} batch task(s).")
return tasks_to_add, None
except Exception as e:
if temp_zip_path and os.path.exists(temp_zip_path):
os.remove(temp_zip_path)
traceback.print_exc()
return None, f"Error during zip creation: {e}"
finally:
zip_buffer.close()
return [], f"Error during batch task creation: {e}"
def handle_generate_or_add(
state,
model_choice,
batch_folder_input,
batch_has_end_frames_cb,
*args
):
gen = get_gen_info(state)
queue = gen.setdefault("queue", [])
save_inputs_param_names = list(inspect.signature(save_inputs).parameters)[1:-1]
if len(args) != len(save_inputs_param_names):
gr.Error(f"Internal Error: Mismatched number of arguments for handle_generate_or_add. Expected {len(save_inputs_param_names)}, got {len(args)}.")
return update_queue_data(queue)
all_ui_params_dict = dict(zip(save_inputs_param_names, args))
all_ui_params_dict['state'] = state
all_ui_params_dict['model_filename'] = state["model_filename"]
batch_folder = batch_folder_input.strip() if batch_folder_input else ""
if batch_folder:
print(f"Batch mode triggered with folder: {batch_folder}")
batch_task_params = all_ui_params_dict.copy()
batch_task_params.pop('image_start', None)
batch_task_params.pop('image_end', None)
new_tasks, error = create_batch_tasks_from_folder(
batch_folder,
batch_has_end_frames_cb,
batch_task_params
)
if error:
gr.Error(error)
return update_queue_data(queue)
if new_tasks:
with lock:
queue.extend(new_tasks)
gen["prompts_max"] = len([t for t in queue if t is not None and 'id' in t])
gr.Info(f"Added {len(new_tasks)} tasks from batch folder '{batch_folder}' to the queue.")
else:
gr.Warning("Batch folder was specified, but no tasks were added (check folder contents and permissions).")
return update_queue_data(queue)
else:
print("Standard generation mode triggered.")
state["validate_success"] = 1
return process_prompt_and_add_tasks(state, model_choice)
def extract_parameters_from_video(video_filepath):
if not video_filepath or not hasattr(video_filepath, 'name') or not os.path.exists(video_filepath.name):
@ -3855,6 +3908,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
else:
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
with gr.Accordion("Batch Generator", open=False) as batch_accordion_ui:
batch_folder_input = gr.Textbox(label="Image Folder Path", placeholder="/path/to/your/image_folder")
batch_has_end_frames_cb = gr.Checkbox(label="Folder contains start/end image pairs", value=False)
with gr.Column(visible= "recam" in model_filename ) as recam_column:
camera_type = gr.Dropdown(
@ -4418,6 +4474,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1]
locals_dict = locals()
gen_inputs = [locals_dict[k] for k in inputs_names] + [state]
gen_inputs_list_for_handler = [locals_dict[k] for k in inputs_names]
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_inputs, inputs =[target_settings] + gen_inputs, outputs = [])
@ -4444,8 +4501,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
).then(fn=handle_generate_or_add,
inputs = [state, model_choice, batch_folder_input, batch_has_end_frames_cb,*gen_inputs_list_for_handler],
outputs= queue_df
).then(fn=prepare_generate_video,
inputs= [state],
@ -4471,9 +4528,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
outputs=queue_df
).then(fn=handle_generate_or_add,
inputs = [state, model_choice, batch_folder_input, batch_has_end_frames_cb,*gen_inputs_list_for_handler],
outputs= queue_df
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
inputs=[state],
@ -4984,28 +5041,12 @@ def create_demo():
else:
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
initial_lora_dir = get_lora_dir(transformer_filename)
try:
initial_loras, initial_loras_names, _, _, _, _, _ = setup_loras(
transformer_filename, None, initial_lora_dir, "", None
)
print(f"Found {len(initial_loras_names)} initial loras for default model.")
except Exception as e:
print(f"Warning: Could not initially load loras for default model '{transformer_filename}': {e}")
initial_loras_names = []
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })
dropdown_types = transformer_types if len(transformer_types) > 0 else model_types
available_model_files = []
for model_type in dropdown_types:
choice = get_model_filename(model_type, transformer_quantization)
available_model_files.append(choice)
with gr.Tabs(selected="video_gen", ) as main_tabs:
with gr.Tab("Video Generator", id="video_gen"):
with gr.Row():
@ -5024,30 +5065,6 @@ def create_demo():
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
) = generate_video_tab(model_choice=model_choice, header=header)
with gr.Tab("Extras", id="extras"):
gr.Markdown("## Batch Queue Creator")
gr.Markdown("Create a `queue.zip` file from a folder containing pairs of start/end images, sorted by modification time (oldest first).")
with gr.Row():
with gr.Column(scale=3):
batch_folder_input = gr.Textbox(label="Image Folder Path", placeholder="/path/to/your/image_pairs")
batch_prompt_input = gr.Textbox(label="Prompt for all tasks", lines=2)
batch_lora_input = gr.Dropdown(label="LoRA (Optional)", choices=[""] + initial_loras_names, value="")
batch_model_input = gr.Dropdown(label="Model Filename", choices=available_model_files, value=transformer_filename)
batch_has_end_frames_cb = gr.Checkbox(label="Folder contains end frames", value=False)
batch_generate_button = gr.Button("Generate Batch Queue (.zip)")
with gr.Column(scale=1):
batch_status_output = gr.Markdown("")
batch_download_output = gr.DownloadButton(label="Download queue.zip", visible=False, interactive=True)
batch_generate_button.click(
fn=create_batch_queue,
inputs=[batch_folder_input, batch_prompt_input, batch_lora_input, batch_model_input, batch_has_end_frames_cb],
outputs=[batch_download_output, batch_status_output]
).then(
fn=lambda filepath: gr.update(visible=bool(filepath), value=filepath),
inputs=[batch_download_output],
outputs=[batch_download_output]
)
with gr.Tab("Informations", id="info"):
generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: