Skip to content

Commit

Permalink
added code
Browse files Browse the repository at this point in the history
  • Loading branch information
anmolduainter committed Oct 22, 2023
1 parent 0a8d801 commit ad27f7d
Show file tree
Hide file tree
Showing 88 changed files with 279 additions and 64 deletions.
Binary file added DebugImages/Images/img4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DebugImages/Images/img4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
File renamed without changes
File renamed without changes
Binary file added DebugImages/TargetClothes/Shirts/test2/a.png
26 changes: 23 additions & 3 deletions VirtualTryOn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __getitem__(self, index):
from HumanParser import HumanParser
from .utils import *
import albumentations as A
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

def MakeDir(path):
if (not os.path.exists(path)):
Expand All @@ -143,8 +144,12 @@ def __init__(self, instance_dir, save_dir, target_number = 20) -> None:
print ("Number of Images to be generated : " + str(target_number))
self.transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.Affine(scale=(0.2, 1.0), rotate = (-45, 45), translate_percent = 0.05, keep_ratio = True)
])
A.Affine(scale=(0.1, 0.5), translate_percent = 0.05, keep_ratio = True)
])

self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

print ("-------------------------------------- \n")

def reset(self):
Expand All @@ -160,13 +165,28 @@ def create_augmentation(self, img, save_idx):
self.c += 1
c += 1

def clipseg_masks(self, img):
size = img.shape[:2]
prompts = ["clothes"]
inputs = self.processor(text=prompts, images=[img] * len(prompts), padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits) > 0.5
preds = preds.numpy().astype(np.uint8) * 255
preds = cv2.resize(preds, size)
preds = cv2.cvtColor(preds, cv2.COLOR_GRAY2RGB)
# preds = Image.fromarray(preds)
return preds


def create(self):
print ("Creating Dataset!!")
self.reset()
for idx, imf in enumerate(self.all_imf):
img = read_img_rgb(imf, resize = (512, 512))
masked_img, cloth_mask = self.human_parser.infer(img)
# masked_img, cloth_mask = self.human_parser.infer(img)
cloth_mask = self.clipseg_masks(img)
cloth_img = img * (cloth_mask/255.0)
# cloth_img = cloth_img + ([255,255,255] - cloth_mask)
non_zero_points = np.argwhere(cloth_mask)
Expand Down
104 changes: 83 additions & 21 deletions VirtualTryOn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,102 @@ def __init__(self, model_path,
self.guidance_scale = guidance_scale
self.seed = seed
print ("---------------------------------------")

def reset(self):
self.oh = None
self.ow = None
self.rxmin = None
self.rymin = None
self.rh = None
self.rw = None

def preprocess(self, img, masked_img, cloth_mask):
h,w = img.shape[:2]
offset = 10
non_zero_points = np.argwhere(cloth_mask)
min_x = np.min(non_zero_points[:, 1]) - offset
max_x = np.max(non_zero_points[:, 1]) + offset
min_y = np.min(non_zero_points[:, 0]) - offset
max_y = np.max(non_zero_points[:, 0]) + offset
min_x = 0 if min_x < 0 else min_x
min_y = 0 if min_y < 0 else min_y
max_x = w-1 if max_x > w-1 else max_x
max_y = h-1 if max_y > h-1 else max_y
self.rxmin = min_x
self.rymin = min_y
img = img[min_y: max_y, min_x: max_x]
masked_img = masked_img[min_y: max_y, min_x: max_x]
cloth_mask = cloth_mask[min_y: max_y, min_x: max_x]
new_height = max_y - min_y
new_width = max_x - min_x
self.rh = new_height
self.rw = new_width

masked_img[masked_img == 0] = 255.0
masked_img = masked_img.astype(np.uint8)

img = cv2.resize(img, (512, 512))
masked_img = cv2.resize(masked_img, (512,512))
cloth_mask = cv2.resize(cloth_mask, (512, 512))

return img, masked_img, cloth_mask

def postprocess(self, pred_image, img, pimg, pcloth_mask):
print (pred_image.shape)
print (pimg.shape)
print (pcloth_mask.shape)
pimg[pcloth_mask[:,:,0]==255] = pred_image[pcloth_mask[:,:,0]==255]
pimg = cv2.resize(pimg, (self.rw, self.rh))
img[self.rymin:self.rymin+self.rh , self.rxmin:self.rxmin+self.rw] = pimg
cv2.imwrite("./pimg.jpg", img)
img = cv2.resize(img, (self.ow, self.oh), interpolation = cv2.INTER_CUBIC)
return img

def _inference(self, pimg, pmasked_img, pcloth_mask, meta_prompt):
pimg, pmasked_img, pcloth_mask = convert_numpy_to_PIL([pimg, pmasked_img, pcloth_mask])
print ("Running the model!")
new_img = self.model(
prompt = meta_prompt,
image = pimg if self.run_on == "original" else pmasked_img,
mask_image = pcloth_mask,
num_inference_steps = self.num_inference_steps,
guidance_scale = self.guidance_scale,
generator=torch.Generator(device=self.device).manual_seed(self.seed)
).images[0]


final_image = np.array(new_img)
return final_image

def infer(self, person_img_path = None, meta_prompt = None):
print ("Doing Inference ::: ")
print ("Person Image Path : " + str(person_img_path))
print ("Meta Prompt : " + str(meta_prompt))

if (person_img_path is None or meta_prompt is None):
print ("Please provide all inputs for generating Inference!")
return None

img = read_img_rgb(person_img_path, resize = (512, 512))
img = read_img_rgb(person_img_path)
self.oh, self.ow = img.shape[:2]
img = cv2.resize(img, (512, 512))
masked_img, cloth_mask = self.human_parser.infer(img)
img = img.astype(np.uint8)
masked_img = masked_img.astype(np.uint8)
cloth_mask = cloth_mask.astype(np.uint8)

cv2.imwrite("./masked_img1.jpg", masked_img)
masked_img = img * (cloth_mask/255)
masked_img = (masked_img/255) + (1 - (cloth_mask/255))
masked_img = masked_img * 255.0
masked_img = masked_img.astype(np.uint8)
cv2.imwrite("./masked_img.jpg", masked_img)
# exit()
# cv2.imwrite("./Img.jpg", img)
# cv2.imwrite("./Img1.jpg", masked_img)
# cv2.imwrite("./Img2.jpg", cloth_mask)


img, masked_img, cloth_mask = convert_numpy_to_PIL([img, masked_img, cloth_mask])
print ("Running the model!")
new_img = self.model(
prompt = meta_prompt,
image = img if self.run_on == "original" else masked_img,
mask_image = cloth_mask,
num_inference_steps = self.num_inference_steps,
guidance_scale = self.guidance_scale,
generator=torch.Generator(device=self.device).manual_seed(self.seed)
).images[0]
pimg, pmasked_img, pcloth_mask = self.preprocess(img, masked_img, cloth_mask)
cv2.imwrite("./AImg.jpg", pimg)
cv2.imwrite("./AImg1.jpg", pmasked_img)
cv2.imwrite("./AImg2.jpg", pcloth_mask)


final_image = np.array(new_img)

final_image = self._inference(pimg, pmasked_img, pcloth_mask, meta_prompt)
final_image = self.postprocess(final_image, img, pimg, pcloth_mask)

# # Refinement Process
# img, cloth_mask = convert_PIL_to_numpy([img, cloth_mask])
Expand Down
4 changes: 2 additions & 2 deletions VirtualTryOn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def collate_fn(examples):
pil_image = example["PIL_images"]
pil_image.save("./check.jpg")

mask = clipseg_masks(pil_image) # generate a random mask
# mask = random_mask(pil_image.size, 1, False) # generate a random mask
# mask = clipseg_masks(pil_image) # generate a random mask
mask = random_mask(pil_image, pil_image.size, 1, False) # generate a random mask
mask.save("./check1.jpg")

mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) # prepare mask and masked image
Expand Down
15 changes: 13 additions & 2 deletions VirtualTryOn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")


def clipseg_masks(img):
size = img.size
prompts = ["clothes"]
Expand All @@ -22,8 +23,18 @@ def clipseg_masks(img):
return preds

# generate random masks
def random_mask(im_shape, ratio=1, mask_full_image=False):
mask = Image.new("L", im_shape, 0)
def random_mask(img, im_shape, ratio=1, mask_full_image=False):
size = img.size
prompts = ["clothes"]
inputs = processor(text=prompts, images=[img] * len(prompts), padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits) > 0.5
preds = preds.numpy().astype(np.uint8) * 255
preds = cv2.resize(preds, size)
mask = Image.fromarray(preds)
# mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
# use this to always mask the whole image
Expand Down
51 changes: 51 additions & 0 deletions infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Inference

# python main.py \
# --infer \
# --img_path "./DebugImages/Images/img4.jpg"\
# --model_dir "./Models/Shirts/test1/"\
# --prompt "UBIAA shirt"\
# --infer_output "./ResultsPredictions/" \
# --infer_output_prefix "shirt_test1"

# python main.py \
# --infer \
# --img_path "./DebugImages/Images/img4.jpg"\
# --model_dir "./Models/Shirts/test2/"\
# --prompt "UBIAA shirt"\
# --infer_output "./ResultsPredictions/" \
# --infer_output_prefix "shirt_test2"

# python main.py \
# --infer \
# --img_path "./DebugImages/Images/img4.jpg"\
# --model_dir "./Models/Shirts/test4/"\
# --prompt "UBIAA shirt"\
# --infer_output "./ResultsPredictions/" \
# --infer_output_prefix "shirt_test4"


# python main.py \
# --infer \
# --img_path "./DebugImages/Images/img4.jpg"\
# --model_dir "./Models/Shirts/test5/"\
# --prompt "UBIAA shirt"\
# --infer_output "./ResultsPredictions/" \
# --infer_output_prefix "shirt_test5"

# python main.py \
# --infer \
# --img_path "./DebugImages/Images/img1.jpg"\
# --model_dir "./Models/Jackets/test1/"\
# --prompt "JACKUUO jacket"\
# --infer_output "./ResultsPredictions/" \
# --infer_output_prefix "jacket_test1"

python main.py \
--infer \
--img_path "./DebugImages/Images/img2.jpg"\
--model_dir "./Models/Jackets/test3/"\
--prompt "JACKUUO jacket"\
--infer_output "./ResultsPredictions/" \
--infer_output_prefix "jacket_test3"

100 changes: 64 additions & 36 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import cv2
from VirtualTryOn import VirtualTryOnInference, VirtualTryOnTrain, get_config_default, DataCreation
import argparse
import os

def MakeDir(path):
if (not os.path.exists(path)):
os.mkdir(path)

def Train(model_name, output_dir, instance_dir, instance_prompt):
def Train(output_dir, instance_dir, instance_prompt):
model_name = "runwayml/stable-diffusion-inpainting"
params = get_config_default()
params.pretrained_model_name_or_path = model_name
params.instance_data_dir = instance_dir
Expand All @@ -15,52 +21,74 @@ def Train(model_name, output_dir, instance_dir, instance_prompt):
params.lr_scheduler = "constant"
params.lr_warmup_steps = 0
params.use_8bit_adam = True
params.max_train_steps = 300
params.max_train_steps = 900
# params.train_text_encoder = True
vt = VirtualTryOnTrain(params)
vt.train()

def Inference(model_path, instance_prompt):
for idx in range(0, 10):
vt = VirtualTryOnInference(
model_path=model_path,
device = "cuda",
run_on="original",
num_inference_steps=50,
guidance_scale=10,
seed = idx
)
def Inference(img_path, model_path, instance_prompt, output_path, prefix):
# for idx in range(0, 10):
idx = 0
vt = VirtualTryOnInference(
model_path=model_path,
device = "cuda",
run_on="moriginal",
num_inference_steps=50,
guidance_scale=20,
seed = idx
)

MakeDir(output_path)
img_name = img_path.split("/")[-1].split(".")[-2]
prompt = instance_prompt
# img_path = "./DebugImages/Images/img2.jpg"
# # img_path = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/a.png"
res_img = vt.infer(img_path, prompt)
cv2.imwrite(output_path + "/" + prefix + img_name + ".jpg", res_img)

prompt = instance_prompt
img_path = "./DebugImages/Images/img1.jpg"
# img_path = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/a.png"
res_img = vt.infer(img_path, prompt)
cv2.imwrite("./res_" + str(idx) + ".jpg", res_img)

model_name = "runwayml/stable-diffusion-inpainting"
# instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/CheckShirt/"
# instance_dir = "./DataCreation/"
output_dir = "/data/Kaggle/StableDiff/Shirt_Outputs/"
instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/"
instance_prompt = "a photo of UBIAA shirt"
if (__name__ == "__main__"):

parser = argparse.ArgumentParser()
parser.add_argument("--train", help="do_training", action="store_true")
parser.add_argument("--infer", help="do_inference", action="store_true")
parser.add_argument("--instance_dir", help="instance dir", default=None)
parser.add_argument("--model_dir", help="model_dir", default=None)
parser.add_argument("--prompt", help="instance_prompt", default=None)
parser.add_argument("--img_path", help="image path", default=None)
parser.add_argument("--infer_output", help="infer output", default=None)
parser.add_argument("--infer_output_prefix", help="infer output", default="res")

# Train(model_name, output_dir, instance_dir, instance_prompt)
Inference(
model_path=output_dir,
instance_prompt=instance_prompt
)
args = parser.parse_args()
if args.train:
Train(output_dir = args.model_dir,
instance_dir = args.instance_dir,
instance_prompt = args.prompt
)

if (args.infer):

# d = DataCreation(
# instance_dir= "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt3/",
# save_dir="./DataCreation/",
# target_number=100
# )
Inference(
img_path = args.img_path,
model_path=args.model_dir,
instance_prompt=args.prompt,
output_path=args.infer_output,
prefix = args.infer_output_prefix
)

# d.create()
# # instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Jackets/"
# instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/CheckShirt/"
# # instance_dir = "./DataCreation/"
# # instance_dir = "/home/user/anmol/StableDiff/d1/diffusers/examples/research_projects/dreambooth_inpaint/Shirt2/"

# exit()
# output_dir = "/data/Kaggle/StableDiff/Shirt_Outputs1/"
# instance_prompt = "UBIAA shirt"
# # instance_prompt = "UBIAA jacket, high resolution"

# # d = DataCreation(
# # instance_dir= instance_dir,
# # save_dir="./DataCreation/",
# # target_number=100
# # )

# instance_prompt = "UBIAA shirt"
# # d.create()
Loading

0 comments on commit ad27f7d

Please sign in to comment.