Source code for jwst.refpix.optimized_convolution

"""
Functions for Simple Improved Reference Subtraction (SIRS).

They are for using the SIRS algorithm
to improve the 1/f noise, only for full frame non-IRS2 NIR data.
"""

import logging

import numpy as np

log = logging.getLogger(__name__)

__all__ = ["make_kernels", "get_conv_kernel_coeffs", "apply_conv_kernel"]


[docs] def make_kernels(sirs_kernel_model, detector, gaussmooth, halfwidth): """ Make convolution kernels from Fourier coefficients in the reference file. Parameters ---------- sirs_kernel_model : `~stdatamodels.jwst.datamodels.SIRSKernelModel` Data model containing the Fourier coefficients from the reference files for Simple Improved Reference Subtraction (SIRS). detector : str Name of the detector of the input data. gaussmooth : float Width of Gaussian smoothing kernel to use as a low-pass filter on reference file's coefficients. halfwidth : int Half-width of convolution kernel to build from reference file's coefficients. Returns ------- kernels : list List of kernels appropriate for convolution with the left and right reference pixels. """ gamma, zeta = get_conv_kernel_coeffs(sirs_kernel_model, detector) if gamma is None or zeta is None: log.info(f"Optimized convolution kernel coefficients NOT found for detector {detector}") return None kernels_left = [] kernels_right = [] for chan in range(gamma.shape[0]): n = len(gamma[chan]) - 1 kernel_left = np.fft.fftshift(np.fft.irfft(gamma[chan]))[n - halfwidth : n + halfwidth + 1] kernel_right = np.fft.fftshift(np.fft.irfft(zeta[chan]))[n - halfwidth : n + halfwidth + 1] x = np.arange(-halfwidth, halfwidth + 1) window = np.exp(-(x**2) / (2 * gaussmooth**2)) window /= np.sum(window) kernel_right = np.convolve(kernel_right, window, mode="same") kernel_left = np.convolve(kernel_left, window, mode="same") kernels_right += [kernel_right] kernels_left += [kernel_left] return [kernels_left, kernels_right]
[docs] def get_conv_kernel_coeffs(sirs_kernel_model, detector): """ Get the convolution kernels coefficients from the reference file. Parameters ---------- sirs_kernel_model : `~stdatamodels.jwst.datamodels.SIRSKernelModel` Data model containing the Fourier coefficients from the reference files for Simple Improved Reference Subtraction (SIRS) detector : str Name of the detector of the input data Returns ------- gamma, zeta : ndarray Fourier coefficients """ mdl_dict = sirs_kernel_model.to_flat_dict() gamma, zeta = None, None for item in mdl_dict: det = item.split(sep=".")[0] if detector.lower() == det.lower(): arr_name = item.split(sep=".")[1] if arr_name == "gamma": gamma = np.array(mdl_dict[item]) elif arr_name == "zeta": zeta = np.array(mdl_dict[item]) if gamma is not None and zeta is not None: break return gamma, zeta
[docs] def apply_conv_kernel(data, kernels, sigreject=4.0): """ Apply the convolution kernel. Parameters ---------- data : ndarray 2-D data to be corrected kernels : list List containing the left and right kernels sigreject : float Number of sigmas to reject as outliers Returns ------- data : ndarray 2-D data model with convolution """ data = data.astype(float) npix = data.shape[-1] kernels_l, kernels_r = kernels nchan = len(kernels_l) l = data[:, :4] r = data[:, -4:] # Find the approximate standard deviations of the reference pixels # using an outlier-robust median approach. Mask pixels that differ # by more than sigreject sigma from this level. # NOTE: The Median Absolute Deviation (MAD) is calculated as the # median of the absolute differences between data values and their # median. For normal distribution MAD is equal to 1.48 times the # standard deviation but is a more robust estimate of the dispersion # of data values.The calculation of MAD is straightforward but # time-consuming, especially if MAD estimates are needed for the # local environment around every pixel of a large image. The # calculation is MAD = np.median(np.abs(x-np.median(x))). # Reference: https://www.interstellarmedium.org/numerical_tools/mad/ median_absolute_deviation = 1.48 medl = np.median(l) sigl = median_absolute_deviation * np.median(np.abs(l - medl)) medr = np.median(r) sigr = median_absolute_deviation * np.median(np.abs(r - medr)) # nl and nr are the number of good reference pixels in the left and right # channel in each row. These will be used in lieu of replacing the values # of those pixels directly. goodl = 1 * (np.abs(l - medl) <= sigreject * sigl) nl = np.sum(goodl, axis=1) goodr = 1 * (np.abs(r - medr) <= sigreject * sigr) nr = np.sum(goodr, axis=1) # Average of the left and right channels, replacing masked pixels with zeros. # Appropriate normalization factors will be computed later. l = np.sum(l * goodl, axis=1) / 4 r = np.sum(r * goodr, axis=1) / 4 for chan in range(nchan): kernel_l = kernels_l[chan] kernel_r = kernels_r[chan] # Compute normalizations so that we don't have to directly # replace the values of flagged/masked reference pixels. norm_l = np.convolve(np.ones(nl.shape), kernel_l, mode="same") norm_l /= np.convolve(nl / 4, kernel_l, mode="same") norm_r = np.convolve(np.ones(nr.shape), kernel_r, mode="same") norm_r /= np.convolve(nr / 4, kernel_r, mode="same") template = np.convolve(l, kernel_l, mode="same") * norm_l template += np.convolve(r, kernel_r, mode="same") * norm_r data[:, chan * npix // 4 : (chan + 1) * npix // 4] -= template[:, np.newaxis] log.debug("Optimized convolution kernel applied") return data