From f35d5954e0571ae899ee186216abcd68e216650d Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 25 Apr 2025 19:48:19 +1000 Subject: [PATCH] fixed missing params, made parameter loading less hardcoded --- wgp.py | 195 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 111 insertions(+), 84 deletions(-) diff --git a/wgp.py b/wgp.py index 60682d8..d9720b3 100644 --- a/wgp.py +++ b/wgp.py @@ -106,95 +106,109 @@ def extract_parameters_from_video(video_filepath): traceback.print_exc() return None +def get_lora_indices(activated_lora_filenames, state): + indices = [] + loras_full_paths = state.get("loras") if isinstance(state.get("loras"), list) else [] + if not loras_full_paths: + print("Warning: Lora list not found or invalid in state during parameter application.") + return [] + + lora_filenames_in_state = [os.path.basename(p) for p in loras_full_paths if isinstance(p, str)] + + if not isinstance(activated_lora_filenames, list): + print(f"Warning: 'activated_loras' parameter is not a list ({type(activated_lora_filenames)}). Skipping Lora loading.") + return [] + + for filename in activated_lora_filenames: + if not isinstance(filename, str): + print(f"Warning: Non-string filename found in activated_loras: {filename}. Skipping.") + continue + try: + idx = lora_filenames_in_state.index(filename) + indices.append(str(idx)) + except ValueError: + print(f"Warning: Loaded Lora '{filename}' not found in current Lora list. Skipping.") + except Exception as e: + print(f"Error processing Lora filename '{filename}': {e}") + return indices + def apply_parameters_to_ui(params_dict, state): + component_keys_map = [ + ('prompt', ''), ('negative_prompt', ''), ('resolution', '832x480'), ('video_length', 81), + ('seed', -1), ('num_inference_steps', 30), ('guidance_scale', 5.0), ('flow_shift', 5.0), + ('repeat_generation', 1), ('multi_images_gen_type', 0), ('tea_cache_setting', 0.0), ('tea_cache_start_step_perc', 0), + ('loras_choices', []), ('loras_multipliers', ''), + ('image_prompt_type', 'S'), + ('video_prompt_type_video_guide', ''), + ('video_prompt_type_image_refs', ''), + ('camera_type', 1), + ('keep_frames', ''), ('remove_background_image_ref', 1), + ('sliding_window_repeat', 0), ('sliding_window_overlap', 16), ('sliding_window_discard_last_frames', 4), + ('temporal_upsampling', ''), ('spatial_upsampling', ''), + ('RIFLEx_setting', 0), ('slg_switch', 0), ('slg_layers', []), + ('slg_start_perc', 10), ('slg_end_perc', 90), + ('cfg_star_switch', 0), ('cfg_zero_step', -1) + ] + if not params_dict or not isinstance(params_dict, dict): print("No parameters provided or invalid format for UI update.") - return gr.Info("No parameters loaded or parameters were invalid.") + return tuple([gr.update()] * len(component_keys_map)) - print("Applying parameters to UI...") - ui_updates = {} - current_model_filename = state["model_filename"] + ui_update_values = {key: default for key, default in component_keys_map} - def get_lora_indices(activated_lora_filenames, state): - indices = [] - loras_full_paths = state.get("loras", []) - if not loras_full_paths: - print("Warning: Lora list not found in state during parameter application.") - return [] + activated_loras = params_dict.get('activated_loras', []) + ui_update_values['loras_choices'] = get_lora_indices(activated_loras, state) + ui_update_values['loras_multipliers'] = params_dict.get('loras_multipliers', '') - lora_filenames_in_state = [os.path.basename(p) for p in loras_full_paths] + loaded_video_prompt_type = params_dict.get('video_prompt_type', '') + ui_update_values['video_prompt_type_image_refs'] = "I" if "I" in loaded_video_prompt_type else "" - for filename in activated_lora_filenames: + guide_dd_value = "" + guide_letters = "ODPCMV" + if "PV" in loaded_video_prompt_type: guide_dd_value = "PV" + elif "DV" in loaded_video_prompt_type: guide_dd_value = "DV" + elif "CV" in loaded_video_prompt_type: guide_dd_value = "CV" + elif "MV" in loaded_video_prompt_type: guide_dd_value = "MV" + elif "V" in loaded_video_prompt_type: guide_dd_value = "V" + + ui_update_values['video_prompt_type_video_guide'] = guide_dd_value + + handled_keys = {'activated_loras', 'loras_multipliers', 'video_prompt_type'} + for key, default in component_keys_map: + if key in handled_keys or key.startswith('video_prompt_type_'): + continue + if key in params_dict: + value = params_dict[key] try: - idx = lora_filenames_in_state.index(filename) - indices.append(str(idx)) - except ValueError: - print(f"Warning: Loaded Lora '{filename}' not found in current Lora list. Skipping.") - return indices + current_type = type(default) + if value is None: + value = default + print(f"Parameter '{key}': Received None, using default ({value}).") + if current_type == int: + value = int(float(value)) + elif current_type == float: + value = float(value) + elif current_type == str: + value = str(value) + elif current_type == list: + if not isinstance(value, list): + print(f"Warning: Parameter '{key}' expected list, got {type(value)}. Using default.") + value = default + if key == 'remove_background_image_ref': + value = int(value) + ui_update_values[key] = value + except (ValueError, TypeError, Exception) as e: + print(f"Warning: Parameter '{key}': Error processing value '{value}' ({e}). Using default '{default}'.") + ui_update_values[key] = default - ui_updates['prompt'] = gr.update(value=params_dict.get('prompt', '')) - ui_updates['negative_prompt'] = gr.update(value=params_dict.get('negative_prompt', '')) - ui_updates['resolution'] = gr.update(value=params_dict.get('resolution')) - ui_updates['video_length'] = gr.update(value=params_dict.get('video_length')) - ui_updates['seed'] = gr.update(value=params_dict.get('seed', -1)) - ui_updates['num_inference_steps'] = gr.update(value=params_dict.get('num_inference_steps')) - ui_updates['guidance_scale'] = gr.update(value=params_dict.get('guidance_scale')) - ui_updates['flow_shift'] = gr.update(value=params_dict.get('flow_shift')) - # ui_updates['embedded_guidance_scale'] = gr.update(value=params_dict.get('embedded_guidance_scale')) # Hidden? Check UI - ui_updates['repeat_generation'] = gr.update(value=params_dict.get('repeat_generation', 1)) - ui_updates['multi_images_gen_type'] = gr.update(value=params_dict.get('multi_images_gen_type', 0)) - ui_updates['tea_cache_setting'] = gr.update(value=float(params_dict.get('tea_cache_setting', 0))) - ui_updates['tea_cache_start_step_perc'] = gr.update(value=params_dict.get('tea_cache_start_step_perc', 0)) + updates = [] + for key, _ in component_keys_map: + value_to_set = ui_update_values.get(key) + updates.append(gr.update(value=value_to_set)) - activated_lora_filenames = params_dict.get('activated_loras', []) - lora_indices = get_lora_indices(activated_lora_filenames, state) - ui_updates['loras_choices'] = gr.update(value=lora_indices) - ui_updates['loras_multipliers'] = gr.update(value=params_dict.get('loras_multipliers', '')) - - ui_updates['image_prompt_type'] = gr.update(value=params_dict.get('image_prompt_type', 'S')) - ui_updates['video_prompt_type'] = gr.update(value=params_dict.get('video_prompt_type', '')) - ui_updates['keep_frames'] = gr.update(value=params_dict.get('keep_frames', '')) - ui_updates['remove_background_image_ref'] = gr.update(value=params_dict.get('remove_background_image_ref', 1)) - - ui_updates['sliding_window_repeat'] = gr.update(value=params_dict.get('sliding_window_repeat', 0)) - ui_updates['sliding_window_overlap'] = gr.update(value=params_dict.get('sliding_window_overlap', 16)) - ui_updates['sliding_window_discard_last_frames'] = gr.update(value=params_dict.get('sliding_window_discard_last_frames', 4)) - - ui_updates['temporal_upsampling'] = gr.update(value=params_dict.get('temporal_upsampling', '')) - ui_updates['spatial_upsampling'] = gr.update(value=params_dict.get('spatial_upsampling', '')) - - ui_updates['RIFLEx_setting'] = gr.update(value=params_dict.get('RIFLEx_setting', 0)) - ui_updates['slg_switch'] = gr.update(value=params_dict.get('slg_switch', 0)) - slg_layers_val = params_dict.get('slg_layers', [9]) - if slg_layers_val is None: slg_layers_val = [] - if isinstance(slg_layers_val, list): - slg_layers_val = [str(i) for i in slg_layers_val] - ui_updates['slg_layers'] = gr.update(value=slg_layers_val) - - ui_updates['slg_start_perc'] = gr.update(value=params_dict.get('slg_start_perc', 10)) - ui_updates['slg_end_perc'] = gr.update(value=params_dict.get('slg_end_perc', 90)) - ui_updates['cfg_star_switch'] = gr.update(value=params_dict.get('cfg_star_switch', 0)) - ui_updates['cfg_zero_step'] = gr.update(value=params_dict.get('cfg_zero_step', -1)) - - ordered_keys = [ - 'prompt', 'negative_prompt', 'resolution', 'video_length', 'seed', 'num_inference_steps', - 'guidance_scale', 'flow_shift', - 'repeat_generation', 'multi_images_gen_type', - 'loras_choices', 'loras_multipliers', - 'image_prompt_type', 'video_prompt_type', 'keep_frames', 'remove_background_image_ref', - 'sliding_window_repeat', 'sliding_window_overlap', 'sliding_window_discard_last_frames', - 'temporal_upsampling', 'spatial_upsampling', - 'RIFLEx_setting', 'slg_switch', 'slg_layers', 'slg_start_perc', 'slg_end_perc', - 'cfg_star_switch', 'cfg_zero_step' - ] - return_values = [] - for key in ordered_keys: - update_instruction = ui_updates.get(key, gr.update()) - return_values.append(update_instruction) - - print("Parameter application mapping complete.") - return tuple(return_values) + print(f"Parameter application direct updates created ({len(updates)} updates).") + return tuple(updates) def format_time(seconds): if seconds < 60: @@ -4109,13 +4123,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non updatable_ui_components = [ prompt, negative_prompt, resolution, video_length, seed, num_inference_steps, guidance_scale, flow_shift, - repeat_generation, multi_images_gen_type, + repeat_generation, multi_images_gen_type, tea_cache_setting, tea_cache_start_step_perc, loras_choices, loras_multipliers, - image_prompt_type, video_prompt_type, keep_frames, remove_background_image_ref, - sliding_window_repeat, sliding_window_overlap, sliding_window_discard_last_frames, - temporal_upsampling, spatial_upsampling, - RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, - cfg_star_switch, cfg_zero_step + image_prompt_type, + video_prompt_type_video_guide, + video_prompt_type_image_refs, + camera_type, + keep_frames, + remove_background_image_ref, + sliding_window_repeat, + sliding_window_overlap, + sliding_window_discard_last_frames, + temporal_upsampling, + spatial_upsampling, + RIFLEx_setting, + slg_switch, + slg_layers, + slg_start_perc, + slg_end_perc, + cfg_star_switch, + cfg_zero_step ] load_params_video_input.upload( fn=extract_parameters_from_video,