Source code for ct.colormap

"""
Functions for querying matplotlib's colormaps.
"""

import matplotlib
import numpy as np
from jaxtyping import Float


[docs] def query( values: Float[np.ndarray, "*batch"], colormap: str = "viridis", ) -> Float[np.ndarray, "*batch 3"]: """ Query matplotlib's color map. Args: values: Scalar values to map to colors. Valid range is [0, 1]. colormap: Name of matplotlib color map. Returns: RGB colors corresponding to input values. Raises: ValueError: If values.dtype is not float32 or float64. """ assert isinstance(values, np.ndarray) if not values.dtype == np.float32 and not values.dtype == np.float64: raise ValueError( "Matplotlib's colormap has different behavior for ints and floats. " "To unify behavior, we require floats (between 0-1 if valid). " f"However, dtype of {values.dtype} is used." ) cmap = matplotlib.cm.get_cmap(colormap) colors = cmap(values)[..., :3] # Remove alpha. return colors.astype(np.float32)
[docs] def normalize( array: Float[np.ndarray, "*batch"], vmin: float = 0.0, vmax: float = 1.0, clip: bool = False, ) -> Float[np.ndarray, "*batch"]: """ Normalize array to [vmin, vmax]. Args: array: Input array to normalize. vmin: Minimum value in output range. vmax: Maximum value in output range. clip: If True, clip array to [vmin, vmax]. Returns: Normalized array with same shape as input. """ if clip: array = np.clip(array, vmin, vmax) else: amin = array.min() amax = array.max() array = (array - amin) / (amax - amin) * (vmax - vmin) + vmin return array