Skip to content

Commit

Permalink
fix compare_util
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Oct 23, 2023
1 parent 87fd930 commit 6f3ce35
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions tests/compare_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def cosine(gt: np.ndarray, pred: np.ndarray, *args):

result = (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2))

if not (isinstance(gt, bool) or isinstance(pred, bool)):
return -1 if math.isnan(result) or mse(gt, pred) >= 0.01 else result
else:
return -1 if math.isnan(result) else result
return -1 if math.isnan(result) else result


def compare_arrays(gt: np.ndarray, pred: np.ndarray):
Expand All @@ -65,8 +62,16 @@ def euclidean(gt: np.ndarray, pred: np.ndarray, *args):
return np.linalg.norm(gt.reshape(-1) - pred.reshape(-1))


def mse(gt: np.ndarray, pred: np.ndarray, *args):
return np.mean((gt - pred) ** 2)
# def mse(gt: np.ndarray, pred: np.ndarray, *args):
# return np.mean((gt - pred) ** 2)

def divide(gt: np.ndarray, pred: np.ndarray):
result = np.divide(gt, pred)
return result


def mean(gt: np.ndarray):
return np.mean(gt)


def allclose(gt: np.ndarray, pred: np.ndarray, thresh: float):
Expand Down Expand Up @@ -128,6 +133,8 @@ def compare_binfile(result_path: Tuple[str, str],
compare_op = gt
if compare_op(similarity, threshold):
return False, similarity_info
if (mean(divide(gt_arr, pred_arr)) > 1.5 or mean(divide(gt_arr, pred_arr)) < 0.6):
return False, similarity_info , f"\nmaybe a case of multiples"
return True, similarity_info


Expand Down

0 comments on commit 6f3ce35

Please sign in to comment.