""" Gradio demo for MegaStyle-FLUX style transfer. Usage: python gradio_demo.py --ckpt_path models/megastyle_flux.safetensors # Then open http://localhost:8080 """ import os import argparse import random import torch import gradio as gr from PIL import Image from flux_image_mega import FluxImagePipeline, ModelConfig try: import torch_npu # noqa: F401 except Exception: pass def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--ckpt_path", type=str, default="models/megastyle_flux.safetensors", help="Path to MegaStyle-FLUX LoRA checkpoint.", ) parser.add_argument( "--ref_path", type=str, default="ref_styles", help="Directory with example reference style images.", ) parser.add_argument("--server_name", type=str, default="0.0.0.0") parser.add_argument("--server_port", type=int, default=8080) parser.add_argument("--share", action="store_true") return parser.parse_args() def pick_device(): if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch, "npu", None) and torch.npu.is_available(): return torch.device("npu") return torch.device("cpu") def load_pipeline(ckpt_path: str, device: torch.device) -> FluxImagePipeline: pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ], ) pipe.load_lora(pipe.dit, ckpt_path, alpha=1) return pipe def build_examples(ref_path: str): if not os.path.isdir(ref_path): return [] files = sorted( os.path.join(ref_path, f) for f in os.listdir(ref_path) if f.lower().endswith(".jpg") ) default_prompts = ["A bench", "A car", "A house with a tree beside"] examples = [] for i, f in enumerate(files[:12]): examples.append([f, default_prompts[i % len(default_prompts)]]) return examples def main(): args = parse_args() device = pick_device() print(f"[MegaStyle] Loading pipeline on {device} ...") pipe = load_pipeline(args.ckpt_path, device) print("[MegaStyle] Pipeline ready.") @torch.no_grad() def generate(style_image, prompt, height, width, num_inference_steps, embedded_guidance, seed, randomize_seed): if style_image is None: raise gr.Error("Please provide a reference style image.") if not prompt or not prompt.strip(): raise gr.Error("Please provide a content prompt.") if randomize_seed or seed is None or int(seed) < 0: seed = random.randint(0, 2**31 - 1) seed = int(seed) style_image = style_image.convert("RGB").resize((int(width), int(height))) image = pipe( prompt=prompt, height=int(height), width=int(width), ipadapter_images=style_image, seed=seed, num_inference_steps=int(num_inference_steps), embedded_guidance=float(embedded_guidance), enable_shift_rope=True, ) return image, seed header_md = """ # MegaStyle-FLUX: Style Transfer Demo **MegaStyle** is a scalable data curation pipeline that explores the consistent text-to-image style mapping ability of modern T2I models to build an intra-style consistent, inter-style diverse and high-quality style dataset — **MegaStyle-1.4M** (1.4M images across 170K curated style prompts and 400K content prompts). Trained on MegaStyle-1.4M, **MegaStyle-FLUX** performs generalizable reference-based style transfer: given any *reference style image* and a *content prompt*, it synthesizes the described content while faithfully preserving the reference's style. > Upload a reference style image on the left and enter a content prompt, > then click **Generate**. **References** - Paper: arXiv:2604.08364 - Project page: jeoyal.github.io/MegaStyle - Code: github.com/Tencent/MegaStyle - Model weights (MegaStyle-FLUX / Encoder): HuggingFace · ModelScope - Dataset (MegaStyle-1.4M): HuggingFace · ModelScope - Base model: FLUX.1-dev · Style encoder vision backbone: SigLIP-so400m - Built on top of DiffSynth-Studio """ footer_md = r""" --- ### Citation If this work is helpful for your research, please consider citing: ```bibtex @article{gao2026megastyle, title = {MegaStyle: Constructing Diverse and Scalable Style Dataset via Consistent Text-to-Image Style Mapping}, author = {Gao, Junyao and Liu, Sibo and Li, Jiaxing and Sun, Yanan and Tu, Yuanpeng and Shen, Fei and Zhang, Weidong and Zhao, Cairong and Zhang, Jun}, journal = {arXiv preprint arXiv:2604.08364}, year = {2026} } ``` ### Acknowledgements Built on top of [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). All assets and code are released under the repository [LICENSE](https://github.com/Tencent/MegaStyle/blob/main/LICENSE.txt). """ with gr.Blocks(title="MegaStyle-FLUX Demo") as demo: gr.Markdown(header_md) with gr.Row(): with gr.Column(): style_image = gr.Image(label="Reference Style Image", type="pil") prompt = gr.Textbox(label="Content Prompt", value="A bench", placeholder="e.g. A house with a tree beside") with gr.Row(): height = gr.Slider(256, 1536, value=512, step=16, label="Height") width = gr.Slider(256, 1536, value=512, step=16, label="Width") with gr.Row(): num_inference_steps = gr.Slider(10, 50, value=30, step=1, label="Steps") embedded_guidance = gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Embedded Guidance") with gr.Row(): seed = gr.Number(value=42, label="Seed", precision=0) randomize_seed = gr.Checkbox(value=False, label="Random seed") run_btn = gr.Button("Generate", variant="primary") with gr.Column(): out_image = gr.Image(label="Generated Image", type="pil") used_seed = gr.Number(label="Used Seed", precision=0, interactive=False) examples = build_examples(args.ref_path) if examples: gr.Examples( examples=examples, inputs=[style_image, prompt], label="Reference style examples", examples_per_page=12, ) run_btn.click( fn=generate, inputs=[style_image, prompt, height, width, num_inference_steps, embedded_guidance, seed, randomize_seed], outputs=[out_image, used_seed], ) gr.Markdown(footer_md) demo.queue().launch( server_name=args.server_name, server_port=args.server_port, share=args.share, ) if __name__ == "__main__": main()