Source code for ct.metric

"""
Functions for computing image metrics.
"""

import numpy as np
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity
from pathlib import Path
from typing import Tuple, Optional, Union, Literal
from jaxtyping import Float

from . import image
from . import io
from . import sanity
from .util import _safe_torch as torch


[docs] def image_psnr( im_pd: Float[np.ndarray, "h w 3"], im_gt: Float[np.ndarray, "h w 3"], im_mask: Optional[Float[np.ndarray, "h w"]] = None, ) -> float: """ Computes PSNR given images in numpy arrays. Args: im_pd: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_gt: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_mask: numpy array, (h, w), float32, range [0, 1], enforced. Value > 0.5 means foreground. None means all foreground. Returns: PSNR value in float. """ if im_mask is None: h, w = im_pd.shape[:2] im_mask = np.ones((h, w), dtype=np.float32) _check_inputs(im_pd, im_gt, im_mask) im_mask = im_mask[:, :, None] # (h, w) -> (h, w, 1) pr = im_pd[im_mask[:, :, 0] > 0.5].ravel() gt = im_gt[im_mask[:, :, 0] > 0.5].ravel() assert pr.dtype == gt.dtype and pr.dtype == np.float32 ans = peak_signal_noise_ratio(gt, pr) return float(ans)
[docs] def image_ssim( im_pd: Float[np.ndarray, "h w 3"], im_gt: Float[np.ndarray, "h w 3"], im_mask: Optional[Float[np.ndarray, "h w"]] = None, ) -> float: """ Computes SSIM given images in numpy arrays. Args: im_pd: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_gt: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_mask: numpy array, (h, w), float32, range [0, 1], enforced. Value > 0.5 means foreground. None means all foreground. Returns: SSIM value in float. """ if im_mask is None: h, w = im_pd.shape[:2] im_mask = np.ones((h, w), dtype=np.float32) _check_inputs(im_pd, im_gt, im_mask) im_mask = im_mask[:, :, None] # (h, w) -> (h, w, 1) pr = im_pd * im_mask gt = im_gt * im_mask assert pr.dtype == gt.dtype and pr.dtype == np.float32 mean, S = structural_similarity( pr, gt, channel_axis=-1, data_range=1.0, full=True, ) return float(S[im_mask[:, :, 0] > 0.5].mean())
[docs] def image_lpips( im_pd: Float[np.ndarray, "h w 3"], im_gt: Float[np.ndarray, "h w 3"], im_mask: Optional[Float[np.ndarray, "h w"]] = None, ) -> float: """ Computes LPIPS given images in numpy arrays. Args: im_pd: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_gt: numpy array, (h, w, 3), float32, range [0, 1], enforced. im_mask: numpy array, (h, w), float32, range [0, 1], enforced. Value > 0.5 means foreground. None means all foreground. Returns: LPIPS value in float. """ import lpips if im_mask is None: h, w = im_pd.shape[:2] im_mask = np.ones((h, w), dtype=np.float32) _check_inputs(im_pd, im_gt, im_mask) im_mask = im_mask[:, :, None] # (h, w) -> (h, w, 1) pr = im_mask * (im_pd * 2 - 1) gt = im_mask * (im_gt * 2 - 1) pr = pr.transpose(2, 0, 1)[None, ...] gt = gt.transpose(2, 0, 1)[None, ...] if "loss_fn" in image_lpips.static_vars: loss_fn = image_lpips.static_vars["loss_fn"] else: loss_fn = lpips.LPIPS(net="alex") image_lpips.static_vars["loss_fn"] = loss_fn ans = loss_fn.forward(torch.tensor(pr), torch.tensor(gt)).cpu().detach().numpy() return float(ans)
image_lpips.static_vars = {}
[docs] def image_psnr_with_paths( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, ) -> float: """ Args: im_pd_path: Path to the rendered RGB image. The image will be resized to the same (h, w) as im_gt. im_gt_path: Path to the ground truth RGB image. im_mask_path: Path to the mask image. The mask will be resized to the same (h, w) as im_gt. Returns: PSNR value in float. """ im_pd, im_gt, im_mask = load_im_pd_im_gt_im_mask_for_eval( im_pd_path, im_gt_path, im_mask_path, alpha_mode="white", ) return image_psnr(im_pd, im_gt, im_mask)
[docs] def image_ssim_with_paths( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, ) -> float: """ Args: im_pd_path: Path to the rendered RGB image. The image will be resized to the same (h, w) as im_gt. im_gt_path: Path to the ground truth RGB image. im_mask_path: Path to the mask image. The mask will be resized to the same (h, w) as im_gt. Returns: SSIM value in float. """ im_pd, im_gt, im_mask = load_im_pd_im_gt_im_mask_for_eval( im_pd_path, im_gt_path, im_mask_path, alpha_mode="white", ) return image_ssim(im_pd, im_gt, im_mask)
[docs] def image_lpips_with_paths( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, ) -> float: """ Args: im_pd_path: Path to the rendered RGB image. The image will be resized to the same (h, w) as im_gt. im_gt_path: Path to the ground truth RGB image. im_mask_path: Path to the mask image. The mask will be resized to the same (h, w) as im_gt. """ im_pd, im_gt, im_mask = load_im_pd_im_gt_im_mask_for_eval( im_pd_path, im_gt_path, im_mask_path, alpha_mode="white", ) return image_lpips(im_pd, im_gt, im_mask)
[docs] def load_im_pd_im_gt_im_mask_for_eval( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, alpha_mode: Literal["white", "keep"] = "white", ) -> Tuple[ Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w"] ]: """ Load prediction, ground truth, and mask images for image metric evaluation. Args: im_pd_path: Path to the rendered image. im_gt_path: Path to the ground truth RGB or RGBA image. im_mask_path: Path to the mask image. The mask will be resized to the same (h, w) as im_gt. alpha_mode: The mode on how to handle the alpha channel. Currently only "white" is supported. The alpha_mode parameter can be: - "white": If im_gt contains alpha channel, im_gt will be converted to RGB, the background will be rendered as white, the alpha channel will be then ignored. - "keep": If im_gt contains alpha channel, the alpha channel will be used as mask. This mask can be overwritten by im_mask_path if im_mask_path is not None. (This option is not implemented yet.) Returns: [Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w"]]: - im_pd: (h, w, 3), float32, value in [0, 1]. - im_gt: (h, w, 3), float32, value in [0, 1]. - im_mask: (h, w), float32, value only 0 or 1. Even if im_mask_path is None, im_mask will be returned as all 1s. """ if alpha_mode != "white": raise NotImplementedError('Currently only alpha_mode="white" is supported.') # Prepare im_gt. # (h, w, 3) or (h, w, 4), float32. # If (h, w, 4), the alpha channel will be separated. im_gt = io.imread(im_gt_path, alpha_mode=alpha_mode) if im_gt.shape[2] == 4: im_gt_alpha = im_gt[:, :, 3] im_gt = im_gt[:, :, :3] else: im_gt_alpha = None # Prepare im_pd. # (h, w, 3), float32. im_pd = io.imread(im_pd_path) # Resize gt and pd to smaller wh. gt_w, gt_h = im_gt.shape[1], im_gt.shape[0] pd_w, pd_h = im_pd.shape[1], im_pd.shape[0] min_wh = min(gt_w, pd_w), min(gt_h, pd_h) im_gt = image.resize(im_gt, shape_wh=min_wh) if im_gt_alpha is not None: im_gt_alpha = image.resize(im_gt_alpha, shape_wh=min_wh) im_pd = image.resize(im_pd, shape_wh=min_wh) # Prepare im_mask. # (h, w), float32, value only 0 or 1. if im_mask_path is None: if im_gt_alpha is None: im_mask = np.ones((min_wh[1], min_wh[0]), dtype=np.float32) else: im_mask = (im_gt_alpha > 0.5).astype(np.float32) else: im_mask = io.imread(im_mask_path, alpha_mode="ignore") im_mask = image.resize(im_mask, shape_wh=min_wh) if im_mask.ndim == 3: im_mask = im_mask[:, :, 0] im_mask = (im_mask > 0.5).astype(np.float32) return im_pd, im_gt, im_mask
def _check_inputs( im_pd: Float[np.ndarray, "h w 3"], im_gt: Float[np.ndarray, "h w 3"], im_mask: Float[np.ndarray, "h w"], ) -> None: # Instance type. sanity.assert_numpy(im_pd, name="im_pd") sanity.assert_numpy(im_gt, name="im_gt") sanity.assert_numpy(im_mask, name="im_mask") # Dtype. if im_pd.dtype != np.float32: raise ValueError("im_pd must be float32") if im_gt.dtype != np.float32: raise ValueError("im_gt must be float32") if im_mask.dtype != np.float32: raise ValueError("im_mask must be float32") # Shape. sanity.assert_shape(im_pd, (None, None, 3), name="im_pd") sanity.assert_shape(im_gt, (None, None, 3), name="im_gt") sanity.assert_shape(im_mask, (None, None), name="im_mask") if im_pd.shape != im_gt.shape: raise ValueError("im_pd and im_gt must have same shape") if im_pd.shape[:2] != im_mask.shape: raise ValueError("im_pd and im_mask must have same (h, w)") # Range. if im_pd.max() > 1.0 or im_pd.min() < 0.0: raise ValueError("im_pd must be in range [0, 1]") if im_gt.max() > 1.0 or im_gt.min() < 0.0: raise ValueError("im_gt must be in range [0, 1]") if im_mask.max() > 1.0 or im_mask.min() < 0.0: raise ValueError("im_mask must be in range [0, 1]")