Skip to content

Commit

Permalink
upload code for next version
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Jun 29, 2023
1 parent 6842d65 commit 1ab085f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
Binary file modified autodistill_dinov2/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions autodistill_dinov2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dinov2_model import DINOv2
45 changes: 32 additions & 13 deletions autodistill_dinov2/dinov2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import supervision as sv
import torch
import torchvision.transforms as T
from autodistill.detection import DetectionTargetModel
from autodistill.detection import CaptionOntology
from autodistill.classification import ClassificationBaseModel
from PIL import Image
from sklearn import svm
from tqdm import tqdm

import warnings

warnings.filterwarnings("ignore", category=UserWarning)

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -48,42 +53,56 @@ def compute_embeddings(files: list, dinov2_vits14) -> dict:


@dataclass
class DINOv2(DetectionTargetModel):
def __init__(self):
class DINOv2(ClassificationBaseModel):
ontology: CaptionOntology

def __init__(self, ontology: CaptionOntology):
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dinov2_vits14.to(device)

self.dinov2_model = dinov2_vits14
self.ontology = ontology

def predict(self, input: str) -> sv.Detections:
def predict(self, input: str) -> sv.Classifications:
embedding = compute_embeddings([input], self.dinov2_model)

return self.model.predict(np.array(embedding[input]).reshape(-1, 384))
class_id = self.model.predict(np.array(embedding[input]).reshape(-1, 384))

def train(self, dataset_location: str):
dataset = sv.ClassificationDataset.from_multiclass_folder_structure(
dataset_location
return sv.Classifications(
class_id=np.array([self.ontology.classes().index(class_id)]),
confidence=np.array([1]),
)

def train(self, dataset_location: str):
dataset = sv.ClassificationDataset.from_folder_structure(dataset_location)

clf = svm.SVC(gamma="scale")

classes = dataset.classes
images = list(dataset.images.keys())[:500]
images = list(dataset.images.keys())
annotations = dataset.annotations

images = [file for file in images if file.endswith(".jpg")]
all_images = []

for image in images:
class_label = classes[annotations[image].class_id[0]]

all_images.append(os.path.join(dataset_location, class_label, image))

embeddings = compute_embeddings(images, self.dinov2_model)
embeddings = compute_embeddings(all_images, self.dinov2_model)

with open("embeddings.json", "w") as f:
json.dump(embeddings, f)

y = [classes[annotations[file].class_id[0]] for file in images]
y = [
classes[annotations[os.path.basename(file)].class_id[0]]
for file in all_images
]

embedding_list = [embeddings[file] for file in images]
embedding_list = [embeddings[file] for file in all_images]

# svm needs at least 2 classes
unqiue_classes = list(set(y))
Expand Down

0 comments on commit 1ab085f

Please sign in to comment.