removed hardcoded input params, moved batch queue generator over to main video generator tab to use it's params
This commit is contained in:
parent
4a56ffaf22
commit
52d0c5f3f9
297
wgp.py
297
wgp.py
|
|
@ -91,132 +91,185 @@ def batch_get_sorted_images(folder):
|
|||
except Exception as e:
|
||||
return None, f"Error accessing folder {folder}: {e}"
|
||||
|
||||
def batch_create_task_entry(task_id, start_img_name, end_img_name, prompt, lora, model):
|
||||
image_prompt_type = "SE" if end_img_name else "S"
|
||||
def batch_create_task_entry(task_id, start_img_pil, end_img_pil, params_from_ui):
|
||||
image_prompt_type = "SE" if end_img_pil else "S"
|
||||
params = params_from_ui.copy()
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": "",
|
||||
"resolution": "1280x720",
|
||||
"video_length": 81,
|
||||
"seed": -1,
|
||||
"num_inference_steps": 30,
|
||||
"guidance_scale": 5,
|
||||
"flow_shift": 5,
|
||||
"embedded_guidance_scale": 6,
|
||||
"repeat_generation": 1,
|
||||
"multi_images_gen_type": 0,
|
||||
"tea_cache_setting": 0,
|
||||
"tea_cache_start_step_perc": 0,
|
||||
"loras_multipliers": "",
|
||||
start_img_pil_copy = start_img_pil.copy() if start_img_pil else None
|
||||
end_img_pil_copy = end_img_pil.copy() if end_img_pil else None
|
||||
|
||||
params.update({
|
||||
"image_prompt_type": image_prompt_type,
|
||||
"image_start": start_img_name,
|
||||
"image_end": end_img_name,
|
||||
"video_prompt_type": "I",
|
||||
"image_refs": None,
|
||||
"video_guide": None,
|
||||
"video_mask": None,
|
||||
"camera_type": 1,
|
||||
"video_source": None,
|
||||
"keep_frames": "",
|
||||
"sliding_window_repeat": 0,
|
||||
"sliding_window_overlap": 16,
|
||||
"sliding_window_discard_last_frames": 4,
|
||||
"remove_background_image_ref": True,
|
||||
"temporal_upsampling": "rife2",
|
||||
"spatial_upsampling": "",
|
||||
"RIFLEx_setting": 0,
|
||||
"slg_switch": 1,
|
||||
"slg_layers": [9],
|
||||
"slg_start_perc": 10,
|
||||
"slg_end_perc": 90,
|
||||
"cfg_star_switch": 1,
|
||||
"cfg_zero_step": -1,
|
||||
"activated_loras": [lora] if lora else [],
|
||||
"model_filename": model
|
||||
}
|
||||
"image_start": start_img_pil_copy,
|
||||
"image_end": end_img_pil_copy,
|
||||
})
|
||||
|
||||
start_b64 = [pil_to_base64_uri(start_img_pil, format="jpeg", quality=70)] if start_img_pil else None
|
||||
end_b64 = [pil_to_base64_uri(end_img_pil, format="jpeg", quality=70)] if end_img_pil else None
|
||||
|
||||
start_image_data_preview = [start_img_pil_copy] if start_img_pil_copy else None
|
||||
end_image_data_preview = [end_img_pil_copy] if end_img_pil_copy else None
|
||||
|
||||
return {
|
||||
"id": task_id,
|
||||
"params": params
|
||||
"params": params,
|
||||
"repeats": params.get("repeat_generation", 1),
|
||||
"length": params.get("video_length"),
|
||||
"steps": params.get("num_inference_steps"),
|
||||
"prompt": params.get("prompt", ''),
|
||||
"start_image_data": start_image_data_preview,
|
||||
"end_image_data": end_image_data_preview,
|
||||
"start_image_data_base64": start_b64,
|
||||
"end_image_data_base64": end_b64,
|
||||
}
|
||||
|
||||
def create_batch_queue(folder, prompt, lora_file, model_file, has_end_frames_checkbox, progress=gr.Progress()):
|
||||
if not all([folder, prompt, model_file]):
|
||||
return None, "Error: Folder, Prompt, and Model Filename are required."
|
||||
def create_batch_tasks_from_folder(batch_folder_input, batch_has_end_frames_cb, ui_params):
|
||||
global task_id
|
||||
|
||||
progress(0, desc="Starting batch queue creation...")
|
||||
if not batch_folder_input or not batch_folder_input.strip():
|
||||
return [], "Error: Batch Folder Path is required."
|
||||
if not os.path.isdir(batch_folder_input):
|
||||
return [], f"Error: Folder not found or is not a directory: {batch_folder_input}"
|
||||
|
||||
images, error = batch_get_sorted_images(folder)
|
||||
images, error = batch_get_sorted_images(batch_folder_input)
|
||||
if error:
|
||||
return None, error
|
||||
return [], error
|
||||
if not images:
|
||||
return None, "Error: No image files found in the specified folder."
|
||||
return [], "Error: No image files found in the specified folder."
|
||||
|
||||
if has_end_frames_checkbox:
|
||||
if batch_has_end_frames_cb:
|
||||
if len(images) < 2:
|
||||
return None, "Error: Need at least 2 images (start/end pair) to form a task when 'Folder contains end frames' is checked."
|
||||
return [], "Error: Need at least 2 images (start/end pair) to form a task when 'Folder contains start/end image pairs' is checked."
|
||||
if len(images) % 2 != 0:
|
||||
gr.Warning(f"Warning: Found an odd number of images ({len(images)}) in paired mode. The last image will be ignored.")
|
||||
images = images[:-1]
|
||||
|
||||
tasks = []
|
||||
zip_buffer = io.BytesIO()
|
||||
temp_zip_path = None
|
||||
tasks_to_add = []
|
||||
current_task_params = prepare_inputs_dict("state", ui_params.copy())
|
||||
current_task_params.pop('lset_name', None)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
if has_end_frames_checkbox:
|
||||
num_tasks = len(images) // 2
|
||||
progress(0.1, desc="Processing image pairs...")
|
||||
for i in range(0, len(images) - 1, 2):
|
||||
task_id = (i // 2) + 1
|
||||
start_img_path = images[i]
|
||||
end_img_path = images[i+1]
|
||||
num_tasks_created = 0
|
||||
if batch_has_end_frames_cb:
|
||||
num_tasks_total = len(images) // 2
|
||||
print(f"Processing {num_tasks_total} image pairs...")
|
||||
for i in range(0, len(images) - 1, 2):
|
||||
with lock: # Ensure task_id is unique
|
||||
current_task_id_local = task_id + 1
|
||||
task_id += 1
|
||||
|
||||
start_arcname = f"task{task_id}_image_start_0{start_img_path.suffix}"
|
||||
end_arcname = f"task{task_id}_image_end_0{end_img_path.suffix}"
|
||||
start_img_path = images[i]
|
||||
end_img_path = images[i+1]
|
||||
|
||||
task = batch_create_task_entry(task_id, start_arcname, end_arcname, prompt, lora_file, model_file)
|
||||
tasks.append(task)
|
||||
try:
|
||||
start_img_pil = Image.open(start_img_path)
|
||||
start_img_pil.load()
|
||||
start_img_pil = convert_image(start_img_pil)
|
||||
|
||||
zipf.write(start_img_path, arcname=start_arcname)
|
||||
zipf.write(end_img_path, arcname=end_arcname)
|
||||
progress(0.1 + (0.7 * (task_id / num_tasks)), desc=f"Adding pair {task_id}/{num_tasks}")
|
||||
else:
|
||||
num_tasks = len(images)
|
||||
progress(0.1, desc="Processing single images...")
|
||||
for i, img_path in enumerate(images):
|
||||
task_id = i + 1
|
||||
end_img_pil = Image.open(end_img_path)
|
||||
end_img_pil.load()
|
||||
end_img_pil = convert_image(end_img_pil)
|
||||
|
||||
start_arcname = f"task{task_id}_image_start_0{img_path.suffix}"
|
||||
params_for_entry = current_task_params.copy()
|
||||
params_for_entry['state'] = ui_params['state']
|
||||
params_for_entry['model_filename'] = ui_params['model_filename']
|
||||
|
||||
task = batch_create_task_entry(task_id, start_arcname, None, prompt, lora_file, model_file)
|
||||
tasks.append(task)
|
||||
task = batch_create_task_entry(current_task_id_local, start_img_pil, end_img_pil, params_for_entry)
|
||||
tasks_to_add.append(task)
|
||||
num_tasks_created += 1
|
||||
print(f"Added batch task {num_tasks_created}/{num_tasks_total} (Pair: {start_img_path.name}, {end_img_path.name})")
|
||||
|
||||
zipf.write(img_path, arcname=start_arcname)
|
||||
progress(0.1 + (0.7 * (task_id / num_tasks)), desc=f"Adding image {task_id}/{num_tasks}")
|
||||
except Exception as img_e:
|
||||
gr.Warning(f"Skipping pair due to error loading images ({start_img_path.name}, {end_img_path.name}): {img_e}")
|
||||
finally:
|
||||
if 'start_img_pil' in locals() and hasattr(start_img_pil, 'close'): start_img_pil.close()
|
||||
if 'end_img_pil' in locals() and hasattr(end_img_pil, 'close'): end_img_pil.close()
|
||||
|
||||
progress(0.8, desc="Writing queue manifest...")
|
||||
json_data = json.dumps(tasks, indent=4)
|
||||
zipf.writestr("queue.json", json_data)
|
||||
else:
|
||||
num_tasks_total = len(images)
|
||||
print(f"Processing {num_tasks_total} single images...")
|
||||
for i, img_path in enumerate(images):
|
||||
with lock:
|
||||
current_task_id_local = task_id + 1
|
||||
task_id += 1
|
||||
|
||||
progress(0.9, desc="Finalizing zip file...")
|
||||
zip_buffer.seek(0)
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip", prefix="batch_queue_") as tmp_file:
|
||||
tmp_file.write(zip_buffer.getvalue())
|
||||
temp_zip_path = tmp_file.name
|
||||
try:
|
||||
start_img_pil = Image.open(img_path)
|
||||
start_img_pil.load()
|
||||
start_img_pil = convert_image(start_img_pil)
|
||||
params_for_entry = current_task_params.copy()
|
||||
params_for_entry['state'] = ui_params['state']
|
||||
params_for_entry['model_filename'] = ui_params['model_filename']
|
||||
task = batch_create_task_entry(current_task_id_local, start_img_pil, None, params_for_entry)
|
||||
tasks_to_add.append(task)
|
||||
num_tasks_created += 1
|
||||
print(f"Added batch task {num_tasks_created}/{num_tasks_total} (Image: {img_path.name})")
|
||||
|
||||
progress(1, desc="Batch queue created.")
|
||||
return temp_zip_path, f"Successfully created queue.zip with {len(tasks)} task(s)."
|
||||
except Exception as img_e:
|
||||
gr.Warning(f"Skipping image due to error loading ({img_path.name}): {img_e}")
|
||||
finally:
|
||||
if 'start_img_pil' in locals() and hasattr(start_img_pil, 'close'): start_img_pil.close()
|
||||
|
||||
if not tasks_to_add:
|
||||
return [], "Error: No tasks could be created from the images found."
|
||||
|
||||
print(f"Successfully prepared {len(tasks_to_add)} batch task(s).")
|
||||
return tasks_to_add, None
|
||||
|
||||
except Exception as e:
|
||||
if temp_zip_path and os.path.exists(temp_zip_path):
|
||||
os.remove(temp_zip_path)
|
||||
traceback.print_exc()
|
||||
return None, f"Error during zip creation: {e}"
|
||||
finally:
|
||||
zip_buffer.close()
|
||||
return [], f"Error during batch task creation: {e}"
|
||||
|
||||
def handle_generate_or_add(
|
||||
state,
|
||||
model_choice,
|
||||
batch_folder_input,
|
||||
batch_has_end_frames_cb,
|
||||
*args
|
||||
):
|
||||
gen = get_gen_info(state)
|
||||
queue = gen.setdefault("queue", [])
|
||||
save_inputs_param_names = list(inspect.signature(save_inputs).parameters)[1:-1]
|
||||
|
||||
if len(args) != len(save_inputs_param_names):
|
||||
gr.Error(f"Internal Error: Mismatched number of arguments for handle_generate_or_add. Expected {len(save_inputs_param_names)}, got {len(args)}.")
|
||||
return update_queue_data(queue)
|
||||
|
||||
all_ui_params_dict = dict(zip(save_inputs_param_names, args))
|
||||
all_ui_params_dict['state'] = state
|
||||
all_ui_params_dict['model_filename'] = state["model_filename"]
|
||||
|
||||
batch_folder = batch_folder_input.strip() if batch_folder_input else ""
|
||||
|
||||
if batch_folder:
|
||||
print(f"Batch mode triggered with folder: {batch_folder}")
|
||||
batch_task_params = all_ui_params_dict.copy()
|
||||
batch_task_params.pop('image_start', None)
|
||||
batch_task_params.pop('image_end', None)
|
||||
new_tasks, error = create_batch_tasks_from_folder(
|
||||
batch_folder,
|
||||
batch_has_end_frames_cb,
|
||||
batch_task_params
|
||||
)
|
||||
|
||||
if error:
|
||||
gr.Error(error)
|
||||
return update_queue_data(queue)
|
||||
|
||||
if new_tasks:
|
||||
with lock:
|
||||
queue.extend(new_tasks)
|
||||
gen["prompts_max"] = len([t for t in queue if t is not None and 'id' in t])
|
||||
gr.Info(f"Added {len(new_tasks)} tasks from batch folder '{batch_folder}' to the queue.")
|
||||
else:
|
||||
gr.Warning("Batch folder was specified, but no tasks were added (check folder contents and permissions).")
|
||||
|
||||
return update_queue_data(queue)
|
||||
|
||||
else:
|
||||
print("Standard generation mode triggered.")
|
||||
state["validate_success"] = 1
|
||||
return process_prompt_and_add_tasks(state, model_choice)
|
||||
|
||||
def extract_parameters_from_video(video_filepath):
|
||||
if not video_filepath or not hasattr(video_filepath, 'name') or not os.path.exists(video_filepath.name):
|
||||
|
|
@ -3855,6 +3908,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
||||
else:
|
||||
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
||||
with gr.Accordion("Batch Generator", open=False) as batch_accordion_ui:
|
||||
batch_folder_input = gr.Textbox(label="Image Folder Path", placeholder="/path/to/your/image_folder")
|
||||
batch_has_end_frames_cb = gr.Checkbox(label="Folder contains start/end image pairs", value=False)
|
||||
|
||||
with gr.Column(visible= "recam" in model_filename ) as recam_column:
|
||||
camera_type = gr.Dropdown(
|
||||
|
|
@ -4418,6 +4474,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1]
|
||||
locals_dict = locals()
|
||||
gen_inputs = [locals_dict[k] for k in inputs_names] + [state]
|
||||
gen_inputs_list_for_handler = [locals_dict[k] for k in inputs_names]
|
||||
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
||||
save_inputs, inputs =[target_settings] + gen_inputs, outputs = [])
|
||||
|
||||
|
|
@ -4444,8 +4501,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
).then(fn=save_inputs,
|
||||
inputs =[target_state] + gen_inputs,
|
||||
outputs= None
|
||||
).then(fn=process_prompt_and_add_tasks,
|
||||
inputs = [state, model_choice],
|
||||
).then(fn=handle_generate_or_add,
|
||||
inputs = [state, model_choice, batch_folder_input, batch_has_end_frames_cb,*gen_inputs_list_for_handler],
|
||||
outputs= queue_df
|
||||
).then(fn=prepare_generate_video,
|
||||
inputs= [state],
|
||||
|
|
@ -4471,9 +4528,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
).then(fn=save_inputs,
|
||||
inputs =[target_state] + gen_inputs,
|
||||
outputs= None
|
||||
).then(fn=process_prompt_and_add_tasks,
|
||||
inputs = [state, model_choice],
|
||||
outputs=queue_df
|
||||
).then(fn=handle_generate_or_add,
|
||||
inputs = [state, model_choice, batch_folder_input, batch_has_end_frames_cb,*gen_inputs_list_for_handler],
|
||||
outputs= queue_df
|
||||
).then(
|
||||
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
|
||||
inputs=[state],
|
||||
|
|
@ -4984,28 +5041,12 @@ def create_demo():
|
|||
else:
|
||||
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
||||
|
||||
initial_lora_dir = get_lora_dir(transformer_filename)
|
||||
try:
|
||||
initial_loras, initial_loras_names, _, _, _, _, _ = setup_loras(
|
||||
transformer_filename, None, initial_lora_dir, "", None
|
||||
)
|
||||
print(f"Found {len(initial_loras_names)} initial loras for default model.")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not initially load loras for default model '{transformer_filename}': {e}")
|
||||
initial_loras_names = []
|
||||
|
||||
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
global model_list
|
||||
|
||||
tab_state = gr.State({ "tab_no":0 })
|
||||
|
||||
dropdown_types = transformer_types if len(transformer_types) > 0 else model_types
|
||||
available_model_files = []
|
||||
for model_type in dropdown_types:
|
||||
choice = get_model_filename(model_type, transformer_quantization)
|
||||
available_model_files.append(choice)
|
||||
|
||||
with gr.Tabs(selected="video_gen", ) as main_tabs:
|
||||
with gr.Tab("Video Generator", id="video_gen"):
|
||||
with gr.Row():
|
||||
|
|
@ -5024,30 +5065,6 @@ def create_demo():
|
|||
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
||||
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
|
||||
) = generate_video_tab(model_choice=model_choice, header=header)
|
||||
with gr.Tab("Extras", id="extras"):
|
||||
gr.Markdown("## Batch Queue Creator")
|
||||
gr.Markdown("Create a `queue.zip` file from a folder containing pairs of start/end images, sorted by modification time (oldest first).")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_folder_input = gr.Textbox(label="Image Folder Path", placeholder="/path/to/your/image_pairs")
|
||||
batch_prompt_input = gr.Textbox(label="Prompt for all tasks", lines=2)
|
||||
batch_lora_input = gr.Dropdown(label="LoRA (Optional)", choices=[""] + initial_loras_names, value="")
|
||||
batch_model_input = gr.Dropdown(label="Model Filename", choices=available_model_files, value=transformer_filename)
|
||||
batch_has_end_frames_cb = gr.Checkbox(label="Folder contains end frames", value=False)
|
||||
batch_generate_button = gr.Button("Generate Batch Queue (.zip)")
|
||||
with gr.Column(scale=1):
|
||||
batch_status_output = gr.Markdown("")
|
||||
batch_download_output = gr.DownloadButton(label="Download queue.zip", visible=False, interactive=True)
|
||||
|
||||
batch_generate_button.click(
|
||||
fn=create_batch_queue,
|
||||
inputs=[batch_folder_input, batch_prompt_input, batch_lora_input, batch_model_input, batch_has_end_frames_cb],
|
||||
outputs=[batch_download_output, batch_status_output]
|
||||
).then(
|
||||
fn=lambda filepath: gr.update(visible=bool(filepath), value=filepath),
|
||||
inputs=[batch_download_output],
|
||||
outputs=[batch_download_output]
|
||||
)
|
||||
with gr.Tab("Informations", id="info"):
|
||||
generate_info_tab()
|
||||
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
|
||||
|
|
|
|||
Loading…
Reference in New Issue