Compare commits
13 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
6f34f0b862 | |
|
|
263ce31f26 | |
|
|
2cd9eb6c9c | |
|
|
f35d5954e0 | |
|
|
644492946c | |
|
|
466a0ee715 | |
|
|
a2761f21c5 | |
|
|
541ea19f3f | |
|
|
f00c6435ae | |
|
|
5d8f95b5d6 | |
|
|
1d3f4a2573 | |
|
|
aba3002b83 | |
|
|
3b066849de |
259
wgp.py
259
wgp.py
|
|
@ -33,6 +33,7 @@ import tempfile
|
|||
import atexit
|
||||
import shutil
|
||||
import glob
|
||||
from mutagen.mp4 import MP4, MP4Tags
|
||||
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
|
@ -52,29 +53,202 @@ task_id = 0
|
|||
# progress_tracker = {}
|
||||
# tracker_lock = threading.Lock()
|
||||
|
||||
# def download_ffmpeg():
|
||||
# if os.name != 'nt': return
|
||||
# exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
||||
# if all(os.path.exists(e) for e in exes): return
|
||||
# api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
|
||||
# r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
|
||||
# assets = r.json().get('assets', [])
|
||||
# zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
|
||||
# if not zip_asset: return
|
||||
# zip_url = zip_asset['browser_download_url']
|
||||
# zip_name = zip_asset['name']
|
||||
# with requests.get(zip_url, stream=True) as resp:
|
||||
# total = int(resp.headers.get('Content-Length', 0))
|
||||
# with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
|
||||
# for chunk in resp.iter_content(chunk_size=8192):
|
||||
# f.write(chunk)
|
||||
# pbar.update(len(chunk))
|
||||
# with zipfile.ZipFile(zip_name) as z:
|
||||
# for f in z.namelist():
|
||||
# if f.endswith(tuple(exes)) and '/bin/' in f:
|
||||
# z.extract(f)
|
||||
# os.rename(f, os.path.basename(f))
|
||||
# os.remove(zip_name)
|
||||
def download_ffmpeg():
|
||||
if os.name != 'nt': return
|
||||
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
||||
if all(os.path.exists(e) for e in exes): return
|
||||
api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
|
||||
r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
|
||||
assets = r.json().get('assets', [])
|
||||
zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
|
||||
if not zip_asset: return
|
||||
zip_url = zip_asset['browser_download_url']
|
||||
zip_name = zip_asset['name']
|
||||
with requests.get(zip_url, stream=True) as resp:
|
||||
total = int(resp.headers.get('Content-Length', 0))
|
||||
with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
with zipfile.ZipFile(zip_name) as z:
|
||||
for f in z.namelist():
|
||||
if f.endswith(tuple(exes)) and '/bin/' in f:
|
||||
z.extract(f)
|
||||
os.rename(f, os.path.basename(f))
|
||||
os.remove(zip_name)
|
||||
|
||||
def extract_parameters_from_video(video_filepath):
|
||||
if not video_filepath or not hasattr(video_filepath, 'name') or not os.path.exists(video_filepath.name):
|
||||
print("No valid video file provided for parameter extraction.")
|
||||
return None, "No valid video file provided."
|
||||
|
||||
filepath = video_filepath.name
|
||||
print(f"Attempting to extract parameters from: {filepath}")
|
||||
|
||||
try:
|
||||
video = MP4(filepath)
|
||||
if isinstance(video.tags, MP4Tags) and '©cmt' in video.tags:
|
||||
comment_tag_value = video.tags['©cmt'][0]
|
||||
params = json.loads(comment_tag_value)
|
||||
print(f"Successfully extracted parameters: {list(params.keys())}")
|
||||
return params
|
||||
else:
|
||||
print("No '©cmt' metadata tag found in the video.")
|
||||
return None
|
||||
except mutagen.MutagenError as e:
|
||||
print(f"Error reading video file with mutagen: {e}")
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON from metadata: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during parameter extraction: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def get_lora_indices(activated_lora_filenames, state):
|
||||
indices = []
|
||||
loras_full_paths = state.get("loras") if isinstance(state.get("loras"), list) else []
|
||||
if not loras_full_paths:
|
||||
print("Warning: Lora list not found or invalid in state during parameter application.")
|
||||
return []
|
||||
|
||||
lora_filenames_in_state = [os.path.basename(p) for p in loras_full_paths if isinstance(p, str)]
|
||||
|
||||
if not isinstance(activated_lora_filenames, list):
|
||||
print(f"Warning: 'activated_loras' parameter is not a list ({type(activated_lora_filenames)}). Skipping Lora loading.")
|
||||
return []
|
||||
|
||||
for filename in activated_lora_filenames:
|
||||
if not isinstance(filename, str):
|
||||
print(f"Warning: Non-string filename found in activated_loras: {filename}. Skipping.")
|
||||
continue
|
||||
try:
|
||||
idx = lora_filenames_in_state.index(filename)
|
||||
indices.append(str(idx))
|
||||
except ValueError:
|
||||
print(f"Warning: Loaded Lora '{filename}' not found in current Lora list. Skipping.")
|
||||
except Exception as e:
|
||||
print(f"Error processing Lora filename '{filename}': {e}")
|
||||
return indices
|
||||
|
||||
def apply_parameters_to_ui(params_dict, state, *components):
|
||||
try:
|
||||
component_param_names = list(inspect.signature(save_inputs).parameters)[1:-1]
|
||||
except NameError:
|
||||
print("CRITICAL ERROR: save_inputs function not defined when apply_parameters_to_ui is called.")
|
||||
return tuple([gr.update()] * len(components))
|
||||
|
||||
num_expected_params = len(component_param_names)
|
||||
num_received_components = len(components)
|
||||
|
||||
updates_list = [gr.update()] * num_received_components
|
||||
|
||||
if num_expected_params != num_received_components:
|
||||
print(f"Warning in apply_parameters_to_ui: Mismatch between expected params ({num_expected_params}) and received components ({num_received_components}). Proceeding by matching names to the expected number of components.")
|
||||
|
||||
param_name_to_expected_index = {name: i for i, name in enumerate(component_param_names)}
|
||||
|
||||
if not params_dict or not isinstance(params_dict, dict):
|
||||
print("No parameters provided or invalid format for UI update.")
|
||||
return tuple(updates_list)
|
||||
|
||||
print(f"Applying parameters: {list(params_dict.keys())}")
|
||||
|
||||
lora_choices_comp_name = 'loras_choices'
|
||||
lora_mult_comp_name = 'loras_multipliers'
|
||||
if lora_choices_comp_name in param_name_to_expected_index and lora_mult_comp_name in param_name_to_expected_index:
|
||||
idx_choices = param_name_to_expected_index[lora_choices_comp_name]
|
||||
idx_mult = param_name_to_expected_index[lora_mult_comp_name]
|
||||
|
||||
if idx_choices < num_received_components and idx_mult < num_received_components:
|
||||
activated_loras = params_dict.get('activated_loras', [])
|
||||
lora_indices = get_lora_indices(activated_loras, state)
|
||||
updates_list[idx_choices] = gr.update(value=lora_indices)
|
||||
|
||||
loras_mult_value = params_dict.get('loras_multipliers', '')
|
||||
updates_list[idx_mult] = gr.update(value=loras_mult_value)
|
||||
else:
|
||||
print(f"Warning: Lora component indices ({idx_choices}, {idx_mult}) out of bounds for received components ({num_received_components}).")
|
||||
|
||||
vpt_key = 'video_prompt_type'
|
||||
vpt_guide_comp_name = 'video_prompt_type_video_guide'
|
||||
vpt_refs_comp_name = 'video_prompt_type_image_refs'
|
||||
|
||||
if vpt_key in params_dict and vpt_guide_comp_name in param_name_to_expected_index and vpt_refs_comp_name in param_name_to_expected_index:
|
||||
idx_guide = param_name_to_expected_index[vpt_guide_comp_name]
|
||||
idx_refs = param_name_to_expected_index[vpt_refs_comp_name]
|
||||
|
||||
if idx_guide < num_received_components and idx_refs < num_received_components:
|
||||
loaded_video_prompt_type = params_dict.get(vpt_key, '')
|
||||
|
||||
image_refs_value = "I" if "I" in loaded_video_prompt_type else ""
|
||||
updates_list[idx_refs] = gr.update(value=image_refs_value)
|
||||
|
||||
guide_dd_value = ""
|
||||
if "PV" in loaded_video_prompt_type: guide_dd_value = "PV"
|
||||
elif "DV" in loaded_video_prompt_type: guide_dd_value = "DV"
|
||||
elif "CV" in loaded_video_prompt_type: guide_dd_value = "CV"
|
||||
elif "MV" in loaded_video_prompt_type: guide_dd_value = "MV"
|
||||
elif "V" in loaded_video_prompt_type: guide_dd_value = "V"
|
||||
updates_list[idx_guide] = gr.update(value=guide_dd_value)
|
||||
else:
|
||||
print(f"Warning: Video prompt type component indices ({idx_guide}, {idx_refs}) out of bounds for received components ({num_received_components}).")
|
||||
|
||||
handled_keys = {'activated_loras', 'loras_multipliers', 'video_prompt_type'}
|
||||
for key, value in params_dict.items():
|
||||
if key in handled_keys:
|
||||
continue
|
||||
|
||||
if key in param_name_to_expected_index:
|
||||
idx = param_name_to_expected_index[key]
|
||||
|
||||
if idx >= num_received_components:
|
||||
print(f"Warning: Index {idx} for key '{key}' is out of bounds for received components ({num_received_components}). Skipping update.")
|
||||
continue
|
||||
|
||||
target_component = components[idx]
|
||||
processed_value = value
|
||||
|
||||
try:
|
||||
if key == 'remove_background_image_ref' and isinstance(target_component, gr.Checkbox):
|
||||
processed_value = 1 if value == 1 or str(value).lower() == 'true' else 0
|
||||
elif isinstance(target_component, (gr.Slider, gr.Number)):
|
||||
try:
|
||||
temp_val = float(value)
|
||||
processed_value = int(temp_val) if temp_val.is_integer() else temp_val
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
print(f"Warning: Could not convert {key} value '{value}' to number. Using raw value.")
|
||||
processed_value = value
|
||||
elif isinstance(target_component, gr.Dropdown):
|
||||
is_multiselect = getattr(target_component, 'multiselect', False)
|
||||
if is_multiselect:
|
||||
if not isinstance(value, list):
|
||||
print(f"Warning: Expected list for multiselect {key}, got {type(value)}. Resetting to empty list.")
|
||||
processed_value = []
|
||||
else:
|
||||
processed_value = [str(item) for item in value]
|
||||
else:
|
||||
if value is None:
|
||||
processed_value = ''
|
||||
else:
|
||||
processed_value = value
|
||||
elif isinstance(target_component, gr.Textbox):
|
||||
processed_value = str(value) if value is not None else ""
|
||||
elif isinstance(target_component, gr.Radio):
|
||||
processed_value = str(value) if value is not None else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during type processing for key '{key}' with value '{value}': {e}. Using raw value.")
|
||||
processed_value = value
|
||||
|
||||
if processed_value is not None:
|
||||
updates_list[idx] = gr.update(value=processed_value)
|
||||
else:
|
||||
updates_list[idx] = gr.update(value="")
|
||||
|
||||
print(f"Parameter application generated {len(updates_list)} updates.")
|
||||
return tuple(updates_list)
|
||||
|
||||
def format_time(seconds):
|
||||
if seconds < 60:
|
||||
|
|
@ -3799,6 +3973,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
load_params_video_input = gr.File(
|
||||
label="Load Parameters from Video Metadata",
|
||||
file_types=[".mp4"],
|
||||
type="filepath",
|
||||
)
|
||||
with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
|
||||
with gr.Row(elem_id="image-modal-close-button-row"):
|
||||
close_modal_button = gr.Button("❌", size="sm")
|
||||
|
|
@ -4217,6 +4397,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
hidden_force_quit_trigger = gr.Button("force_quit", visible=False, elem_id="force_quit_btn_hidden")
|
||||
hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num")
|
||||
single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn")
|
||||
extracted_params_state = gr.State({})
|
||||
|
||||
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
||||
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
|
||||
|
|
@ -4374,6 +4555,36 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||
|
||||
start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js = get_timer_js()
|
||||
|
||||
load_params_video_input.upload(
|
||||
fn=extract_parameters_from_video,
|
||||
inputs=[load_params_video_input],
|
||||
outputs=[extracted_params_state]
|
||||
).then(
|
||||
fn=apply_parameters_to_ui,
|
||||
inputs=[extracted_params_state, state] + gen_inputs,
|
||||
outputs=gen_inputs
|
||||
).then(
|
||||
fn=switch_prompt_type,
|
||||
inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars],
|
||||
outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
|
||||
).then(
|
||||
fn=refresh_image_prompt_type,
|
||||
inputs=[state, image_prompt_type],
|
||||
outputs=[image_start, image_end]
|
||||
).then(
|
||||
fn=lambda vpt_guide_val, vpt_refs_val: (
|
||||
gr.update(visible="I" in vpt_refs_val), gr.update(visible="I" in vpt_refs_val),
|
||||
gr.update(visible="V" in vpt_guide_val or "M" in vpt_guide_val or "P" in vpt_guide_val or "D" in vpt_guide_val or "C" in vpt_guide_val),
|
||||
gr.update(visible="V" in vpt_guide_val or "M" in vpt_guide_val or "P" in vpt_guide_val or "D" in vpt_guide_val or "C" in vpt_guide_val),
|
||||
gr.update(visible="M" in vpt_guide_val)
|
||||
),
|
||||
inputs=[video_prompt_type_video_guide, video_prompt_type_image_refs],
|
||||
outputs=[
|
||||
image_refs, remove_background_image_ref,
|
||||
video_guide, keep_frames_video_guide, video_mask
|
||||
]
|
||||
)
|
||||
|
||||
single_hidden_trigger_btn.click(
|
||||
fn=show_countdown_info_from_state,
|
||||
inputs=[hidden_countdown_state],
|
||||
|
|
@ -5076,7 +5287,7 @@ def create_demo():
|
|||
|
||||
if __name__ == "__main__":
|
||||
atexit.register(autosave_queue)
|
||||
# download_ffmpeg()
|
||||
download_ffmpeg()
|
||||
# threading.Thread(target=runner, daemon=True).start()
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
server_port = int(args.server_port)
|
||||
|
|
|
|||
Loading…
Reference in New Issue