-
Notifications
You must be signed in to change notification settings - Fork 53
/
example.py
19 lines (18 loc) · 776 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os
import torch
import ImageReward as RM
if __name__ == "__main__":
prompt = "a painting of an ocean with clouds and birds, day time, low depth field effect"
img_prefix = "assets/images"
generations = [f"{pic_id}.webp" for pic_id in range(1, 5)]
img_list = [os.path.join(img_prefix, img) for img in generations]
model = RM.load("ImageReward-v1.0")
with torch.no_grad():
ranking, rewards = model.inference_rank(prompt, img_list)
# Print the result
print("\nPreference predictions:\n")
print(f"ranking = {ranking}")
print(f"rewards = {rewards}")
for index in range(len(img_list)):
score = model.score(prompt, img_list[index])
print(f"{generations[index]:>16s}: {score:.2f}")