fixed missing params, made parameter loading less hardcoded

This commit is contained in:
Chris Malone 2025-04-25 19:48:19 +10:00
parent 644492946c
commit f35d5954e0
1 changed files with 111 additions and 84 deletions

195
wgp.py
View File

@ -106,95 +106,109 @@ def extract_parameters_from_video(video_filepath):
traceback.print_exc()
return None
def get_lora_indices(activated_lora_filenames, state):
indices = []
loras_full_paths = state.get("loras") if isinstance(state.get("loras"), list) else []
if not loras_full_paths:
print("Warning: Lora list not found or invalid in state during parameter application.")
return []
lora_filenames_in_state = [os.path.basename(p) for p in loras_full_paths if isinstance(p, str)]
if not isinstance(activated_lora_filenames, list):
print(f"Warning: 'activated_loras' parameter is not a list ({type(activated_lora_filenames)}). Skipping Lora loading.")
return []
for filename in activated_lora_filenames:
if not isinstance(filename, str):
print(f"Warning: Non-string filename found in activated_loras: {filename}. Skipping.")
continue
try:
idx = lora_filenames_in_state.index(filename)
indices.append(str(idx))
except ValueError:
print(f"Warning: Loaded Lora '{filename}' not found in current Lora list. Skipping.")
except Exception as e:
print(f"Error processing Lora filename '{filename}': {e}")
return indices
def apply_parameters_to_ui(params_dict, state):
component_keys_map = [
('prompt', ''), ('negative_prompt', ''), ('resolution', '832x480'), ('video_length', 81),
('seed', -1), ('num_inference_steps', 30), ('guidance_scale', 5.0), ('flow_shift', 5.0),
('repeat_generation', 1), ('multi_images_gen_type', 0), ('tea_cache_setting', 0.0), ('tea_cache_start_step_perc', 0),
('loras_choices', []), ('loras_multipliers', ''),
('image_prompt_type', 'S'),
('video_prompt_type_video_guide', ''),
('video_prompt_type_image_refs', ''),
('camera_type', 1),
('keep_frames', ''), ('remove_background_image_ref', 1),
('sliding_window_repeat', 0), ('sliding_window_overlap', 16), ('sliding_window_discard_last_frames', 4),
('temporal_upsampling', ''), ('spatial_upsampling', ''),
('RIFLEx_setting', 0), ('slg_switch', 0), ('slg_layers', []),
('slg_start_perc', 10), ('slg_end_perc', 90),
('cfg_star_switch', 0), ('cfg_zero_step', -1)
]
if not params_dict or not isinstance(params_dict, dict):
print("No parameters provided or invalid format for UI update.")
return gr.Info("No parameters loaded or parameters were invalid.")
return tuple([gr.update()] * len(component_keys_map))
print("Applying parameters to UI...")
ui_updates = {}
current_model_filename = state["model_filename"]
ui_update_values = {key: default for key, default in component_keys_map}
def get_lora_indices(activated_lora_filenames, state):
indices = []
loras_full_paths = state.get("loras", [])
if not loras_full_paths:
print("Warning: Lora list not found in state during parameter application.")
return []
activated_loras = params_dict.get('activated_loras', [])
ui_update_values['loras_choices'] = get_lora_indices(activated_loras, state)
ui_update_values['loras_multipliers'] = params_dict.get('loras_multipliers', '')
lora_filenames_in_state = [os.path.basename(p) for p in loras_full_paths]
loaded_video_prompt_type = params_dict.get('video_prompt_type', '')
ui_update_values['video_prompt_type_image_refs'] = "I" if "I" in loaded_video_prompt_type else ""
for filename in activated_lora_filenames:
guide_dd_value = ""
guide_letters = "ODPCMV"
if "PV" in loaded_video_prompt_type: guide_dd_value = "PV"
elif "DV" in loaded_video_prompt_type: guide_dd_value = "DV"
elif "CV" in loaded_video_prompt_type: guide_dd_value = "CV"
elif "MV" in loaded_video_prompt_type: guide_dd_value = "MV"
elif "V" in loaded_video_prompt_type: guide_dd_value = "V"
ui_update_values['video_prompt_type_video_guide'] = guide_dd_value
handled_keys = {'activated_loras', 'loras_multipliers', 'video_prompt_type'}
for key, default in component_keys_map:
if key in handled_keys or key.startswith('video_prompt_type_'):
continue
if key in params_dict:
value = params_dict[key]
try:
idx = lora_filenames_in_state.index(filename)
indices.append(str(idx))
except ValueError:
print(f"Warning: Loaded Lora '{filename}' not found in current Lora list. Skipping.")
return indices
current_type = type(default)
if value is None:
value = default
print(f"Parameter '{key}': Received None, using default ({value}).")
if current_type == int:
value = int(float(value))
elif current_type == float:
value = float(value)
elif current_type == str:
value = str(value)
elif current_type == list:
if not isinstance(value, list):
print(f"Warning: Parameter '{key}' expected list, got {type(value)}. Using default.")
value = default
if key == 'remove_background_image_ref':
value = int(value)
ui_update_values[key] = value
except (ValueError, TypeError, Exception) as e:
print(f"Warning: Parameter '{key}': Error processing value '{value}' ({e}). Using default '{default}'.")
ui_update_values[key] = default
ui_updates['prompt'] = gr.update(value=params_dict.get('prompt', ''))
ui_updates['negative_prompt'] = gr.update(value=params_dict.get('negative_prompt', ''))
ui_updates['resolution'] = gr.update(value=params_dict.get('resolution'))
ui_updates['video_length'] = gr.update(value=params_dict.get('video_length'))
ui_updates['seed'] = gr.update(value=params_dict.get('seed', -1))
ui_updates['num_inference_steps'] = gr.update(value=params_dict.get('num_inference_steps'))
ui_updates['guidance_scale'] = gr.update(value=params_dict.get('guidance_scale'))
ui_updates['flow_shift'] = gr.update(value=params_dict.get('flow_shift'))
# ui_updates['embedded_guidance_scale'] = gr.update(value=params_dict.get('embedded_guidance_scale')) # Hidden? Check UI
ui_updates['repeat_generation'] = gr.update(value=params_dict.get('repeat_generation', 1))
ui_updates['multi_images_gen_type'] = gr.update(value=params_dict.get('multi_images_gen_type', 0))
ui_updates['tea_cache_setting'] = gr.update(value=float(params_dict.get('tea_cache_setting', 0)))
ui_updates['tea_cache_start_step_perc'] = gr.update(value=params_dict.get('tea_cache_start_step_perc', 0))
updates = []
for key, _ in component_keys_map:
value_to_set = ui_update_values.get(key)
updates.append(gr.update(value=value_to_set))
activated_lora_filenames = params_dict.get('activated_loras', [])
lora_indices = get_lora_indices(activated_lora_filenames, state)
ui_updates['loras_choices'] = gr.update(value=lora_indices)
ui_updates['loras_multipliers'] = gr.update(value=params_dict.get('loras_multipliers', ''))
ui_updates['image_prompt_type'] = gr.update(value=params_dict.get('image_prompt_type', 'S'))
ui_updates['video_prompt_type'] = gr.update(value=params_dict.get('video_prompt_type', ''))
ui_updates['keep_frames'] = gr.update(value=params_dict.get('keep_frames', ''))
ui_updates['remove_background_image_ref'] = gr.update(value=params_dict.get('remove_background_image_ref', 1))
ui_updates['sliding_window_repeat'] = gr.update(value=params_dict.get('sliding_window_repeat', 0))
ui_updates['sliding_window_overlap'] = gr.update(value=params_dict.get('sliding_window_overlap', 16))
ui_updates['sliding_window_discard_last_frames'] = gr.update(value=params_dict.get('sliding_window_discard_last_frames', 4))
ui_updates['temporal_upsampling'] = gr.update(value=params_dict.get('temporal_upsampling', ''))
ui_updates['spatial_upsampling'] = gr.update(value=params_dict.get('spatial_upsampling', ''))
ui_updates['RIFLEx_setting'] = gr.update(value=params_dict.get('RIFLEx_setting', 0))
ui_updates['slg_switch'] = gr.update(value=params_dict.get('slg_switch', 0))
slg_layers_val = params_dict.get('slg_layers', [9])
if slg_layers_val is None: slg_layers_val = []
if isinstance(slg_layers_val, list):
slg_layers_val = [str(i) for i in slg_layers_val]
ui_updates['slg_layers'] = gr.update(value=slg_layers_val)
ui_updates['slg_start_perc'] = gr.update(value=params_dict.get('slg_start_perc', 10))
ui_updates['slg_end_perc'] = gr.update(value=params_dict.get('slg_end_perc', 90))
ui_updates['cfg_star_switch'] = gr.update(value=params_dict.get('cfg_star_switch', 0))
ui_updates['cfg_zero_step'] = gr.update(value=params_dict.get('cfg_zero_step', -1))
ordered_keys = [
'prompt', 'negative_prompt', 'resolution', 'video_length', 'seed', 'num_inference_steps',
'guidance_scale', 'flow_shift',
'repeat_generation', 'multi_images_gen_type',
'loras_choices', 'loras_multipliers',
'image_prompt_type', 'video_prompt_type', 'keep_frames', 'remove_background_image_ref',
'sliding_window_repeat', 'sliding_window_overlap', 'sliding_window_discard_last_frames',
'temporal_upsampling', 'spatial_upsampling',
'RIFLEx_setting', 'slg_switch', 'slg_layers', 'slg_start_perc', 'slg_end_perc',
'cfg_star_switch', 'cfg_zero_step'
]
return_values = []
for key in ordered_keys:
update_instruction = ui_updates.get(key, gr.update())
return_values.append(update_instruction)
print("Parameter application mapping complete.")
return tuple(return_values)
print(f"Parameter application direct updates created ({len(updates)} updates).")
return tuple(updates)
def format_time(seconds):
if seconds < 60:
@ -4109,13 +4123,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
updatable_ui_components = [
prompt, negative_prompt, resolution, video_length, seed, num_inference_steps,
guidance_scale, flow_shift,
repeat_generation, multi_images_gen_type,
repeat_generation, multi_images_gen_type, tea_cache_setting, tea_cache_start_step_perc,
loras_choices, loras_multipliers,
image_prompt_type, video_prompt_type, keep_frames, remove_background_image_ref,
sliding_window_repeat, sliding_window_overlap, sliding_window_discard_last_frames,
temporal_upsampling, spatial_upsampling,
RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc,
cfg_star_switch, cfg_zero_step
image_prompt_type,
video_prompt_type_video_guide,
video_prompt_type_image_refs,
camera_type,
keep_frames,
remove_background_image_ref,
sliding_window_repeat,
sliding_window_overlap,
sliding_window_discard_last_frames,
temporal_upsampling,
spatial_upsampling,
RIFLEx_setting,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
cfg_star_switch,
cfg_zero_step
]
load_params_video_input.upload(
fn=extract_parameters_from_video,