Skip to content

Commit

Permalink
added code
Browse files Browse the repository at this point in the history
  • Loading branch information
anmolduainter committed Oct 21, 2023
1 parent 712522a commit 0a8d801
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 46 deletions.
28 changes: 16 additions & 12 deletions VirtualTryOn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from io import BytesIO
import numpy as np
from HumanParser import HumanParser
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from .utils import *


Expand All @@ -22,13 +25,13 @@ def __init__(self, model_path,
print ("Model Path : " + str(model_path))
print ("Device : " + str(device))
self.device = device
self.model = StableDiffusionInpaintPipeline.from_pretrained(
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
torch_dtype = torch.float32,
safety_checker = None
)
self.model = pipeline
self.model = self.model.to(device)

self.human_parser = HumanParser()
self.run_on = run_on
self.num_inference_steps = num_inference_steps
Expand All @@ -46,15 +49,16 @@ def infer(self, person_img_path = None, meta_prompt = None):
img = img.astype(np.uint8)
masked_img = masked_img.astype(np.uint8)
cloth_mask = cloth_mask.astype(np.uint8)
# print (img.shape)
# print (masked_img.shape)
# print (cloth_mask.shape)

# cv2.imwrite("./img.jpg", img)
# cv2.imwrite("./masked.jpg", masked_img)
# cv2.imwrite("./cloth_mask.jpg", cloth_mask)
cv2.imwrite("./masked_img1.jpg", masked_img)
masked_img = img * (cloth_mask/255)
masked_img = (masked_img/255) + (1 - (cloth_mask/255))
masked_img = masked_img * 255.0
masked_img = masked_img.astype(np.uint8)
cv2.imwrite("./masked_img.jpg", masked_img)
# exit()


img, masked_img, cloth_mask = convert_numpy_to_PIL([img, masked_img, cloth_mask])
print ("Running the model!")
new_img = self.model(
Expand All @@ -67,11 +71,11 @@ def infer(self, person_img_path = None, meta_prompt = None):
).images[0]


new_img = np.array(new_img)
final_image = np.array(new_img)

# Refinement Process
img, cloth_mask = convert_PIL_to_numpy([img, cloth_mask])
final_image = new_img * (cloth_mask/255.0) + img * ((np.array([255, 255, 255]) - cloth_mask)/255.0)
# # Refinement Process
# img, cloth_mask = convert_PIL_to_numpy([img, cloth_mask])
# final_image = new_img * (cloth_mask/255.0) + img * ((np.array([255, 255, 255]) - cloth_mask)/255.0)
return final_image


Expand Down
79 changes: 45 additions & 34 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,55 @@
from VirtualTryOn import VirtualTryOnInference, VirtualTryOnTrain, get_config_default, DataCreation


def Train(model_name, output_dir, instance_dir, instance_prompt):
params = get_config_default()
params.pretrained_model_name_or_path = model_name
params.instance_data_dir = instance_dir
params.output_dir = output_dir
params.instance_prompt = instance_prompt
params.resolution = 512
params.train_batch_size = 1
params.gradient_accumulation_steps = 1
params.learning_rate = 2e-6
params.lr_scheduler = "constant"
params.lr_warmup_steps = 0
params.use_8bit_adam = True
params.max_train_steps = 300
# params.train_text_encoder = True
vt = VirtualTryOnTrain(params)
vt.train()

def Inference(model_path, instance_prompt):
for idx in range(0, 10):
vt = VirtualTryOnInference(
model_path=model_path,
device = "cuda",
run_on="original",
num_inference_steps=50,
guidance_scale=10,
seed = idx
)

prompt = instance_prompt
img_path = "./DebugImages/Images/img1.jpg"
# img_path = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/a.png"
res_img = vt.infer(img_path, prompt)
cv2.imwrite("./res_" + str(idx) + ".jpg", res_img)

model_name = "runwayml/stable-diffusion-inpainting"
# instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/CheckShirt/"
# instance_dir = "./DataCreation/"
output_dir = "/data/Kaggle/StableDiff/Shirt_Outputs/"
instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/"
instance_prompt = "Light blue/Green Patterned Trees UBIAA shirt, Lyocell 100%"
instance_prompt = "a photo of UBIAA shirt"


# Train(model_name, output_dir, instance_dir, instance_prompt)
Inference(
model_path=output_dir,
instance_prompt=instance_prompt
)


# d = DataCreation(
# instance_dir= "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/",
Expand All @@ -20,36 +62,5 @@

# exit()

params = get_config_default()
params.pretrained_model_name_or_path = model_name
params.instance_data_dir = instance_dir
params.output_dir = output_dir
params.instance_prompt = instance_prompt
params.resolution = 512
params.train_batch_size = 1
params.gradient_accumulation_steps = 1
params.learning_rate = 5e-6
params.lr_scheduler = "constant"
params.lr_warmup_steps = 0
params.use_8bit_adam = True
params.max_train_steps = 500
params.train_text_encoder = True
# vt = VirtualTryOnTrain(params)
# vt.train()

instance_prompt = "Light blue/Green Patterned Trees UBIAA shirt"
for idx in range(0, 10):
# idx = -1
vt = VirtualTryOnInference(
model_path=output_dir,
device = "cuda",
run_on="moriginal",
num_inference_steps=50,
guidance_scale=7,
seed = idx
)

prompt = instance_prompt
img_path = "./DebugImages/Images/img4.png"
res_img = vt.infer(img_path, prompt)
cv2.imwrite("./res_" + str(idx) + ".jpg", res_img)

# instance_prompt = "UBIAA shirt"

0 comments on commit 0a8d801

Please sign in to comment.