Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs 、use lastest model、add ps1 to run #1

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# from wd14 tagger
IMAGE_SIZE = 448

WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
Expand Down Expand Up @@ -72,7 +72,7 @@ def run_batch(path_imgs):
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh:
if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i]

if len(tag_text) > 0:
Expand Down
15 changes: 15 additions & 0 deletions tagger.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# tagger script by @bdsqlsz

# Train data path
$input_img = "./input" # input images path
$batch_size = 4
$thresh = 0.35

python tag_images_by_wd14_tagger.py `
$input_img `
--batch_size=$batch_size `
--thresh=$thresh `
--caption_extension .txt

Write-Output "Tagger finished"
Read-Host | Out-Null ;