Skip to content

Commit

Permalink
[AC]: support single channel foreground in background matting (#3683)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Jan 30, 2023
1 parent b9719a4 commit 2029eef
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class MeanOfAbsoluteDifference(BaseBackgroundMattingMetrics):
def update(self, annotation, prediction):
pred = self.get_prediction(prediction)
gt = self.get_annotation(annotation)
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
value = np.mean(abs(pred - gt)) * 1e3
self.results.append(value)
return value
Expand All @@ -105,6 +107,8 @@ class SpatialGradient(BaseBackgroundMattingMetrics):
def update(self, annotation, prediction):
pred = self.get_prediction(prediction)
gt = self.get_annotation(annotation)
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
gt_grad = self.gauss_gradient(gt)
pred_grad = self.gauss_gradient(pred)
value = np.sum((gt_grad - pred_grad) ** 2) / 1000
Expand Down Expand Up @@ -152,6 +156,8 @@ class MeanSquaredErrorWithMask(BaseBackgroundMattingMetrics):
def update(self, annotation, prediction):
pred = self.get_prediction(prediction)
gt = self.get_annotation(annotation)
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
if self.use_mask:
mask = self.prepare_pha(annotation.value) > 0
pred = pred[mask]
Expand Down

0 comments on commit 2029eef

Please sign in to comment.