Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add push_to_hub #1048

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion flux/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from functools import partial
from pathlib import Path
import os

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -15,6 +16,9 @@
from PIL import Image
from tqdm import tqdm

from huggingface_hub import HfApi, interpreter_login
from huggingface_hub.utils import HfFolder

from flux import FluxPipeline


Expand Down Expand Up @@ -156,6 +160,108 @@ def save_adapters(iteration, flux, args):
},
)

def push_to_hub(args):
if args.hf_token is None:
interpreter_login(new_session=False, write_permission=True)
else:
HfFolder.save_token(args.hf_token)

repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"

readme_content = generate_readme(args, repo_id)
readme_path = os.path.join(args.output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)

api = HfApi()

api.create_repo(
repo_id,
private=args.hf_private,
exist_ok=True
)

api.upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)

def generate_readme(args, repo_id):
import yaml
import re
base_model = f"flux-{args.model}"
tags = [
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
"mlx",
"mlx-trainer"
]

widgets = []
sample_image_paths = []
# Look for progress images directly in the output directory
for filename in os.listdir(args.output_dir):
match = re.search(r"(\d+)_progress\.png$", filename)
if match:
iteration = int(match.group(1))
sample_image_paths.append((iteration, filename))

sample_image_paths.sort(key=lambda x: x[0], reverse=True)

if sample_image_paths:
widgets.append(
{
"text": args.progress_prompt,
"output": {
"url": sample_image_paths[0][1]
},
}
)

readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if sample_image_paths else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
license: other
---

# {os.path.basename(args.output_dir)}
Model trained with the MLX Flux Dreambooth script

<Gallery />

## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
```py
from flux import FluxPipeline
import mlx.core as mx
flux = FluxPipeline("flux-{args.model}")
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
flux.flow.load_weights("{repo_id}")
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
image.save("my_image.png")
```

## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}')
image = pipeline({args.progress_prompt}').images[0]
image.save("my_image.png")
```

For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
"""
return readme_content

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -245,7 +351,28 @@ def save_adapters(iteration, flux, args):
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)

parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push the model to Hugging Face Hub after training",
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Hugging Face token for pushing to Hub",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default=None,
help="Hugging Face repository ID for pushing to Hub",
)
parser.add_argument(
"--hf_private",
action="store_true",
help="Make the Hugging Face repository private",
)
parser.add_argument("dataset")

args = parser.parse_args()
Expand Down Expand Up @@ -376,3 +503,6 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if (i + 1) % 10 == 0:
losses = []
tic = time.time()

if args.push_to_hub:
push_to_hub(args)