Source code for mirar.processors.flat

"""
Module containing processors for flat calibration
"""

import logging
import os.path
import sys
from collections.abc import Callable
from copy import copy

import numpy as np
from astropy.convolution import Tophat2DKernel, convolve_fft

from mirar.data import Image, ImageBatch
from mirar.errors import ImageNotFoundError
from mirar.paths import (
    BASE_NAME_KEY,
    COADD_KEY,
    EXPTIME_KEY,
    FLAT_FRAME_KEY,
    LATEST_SAVE_KEY,
    OBSCLASS_KEY,
)
from mirar.processors.base_processor import ProcessorPremadeCache, ProcessorWithCache
from mirar.processors.utils.image_selector import select_from_images

logger = logging.getLogger(__name__)


[docs] def get_convolution(data: np.ndarray, kernel_width: int) -> np.ndarray: """ Convolve data with a tophat kernel :param data: Image data :param kernel_width: Width of the kernel (pixels) :return: Smoothed image """ pad_top = np.array([data[0] for _ in range(kernel_width)]) pad_bottom = np.array([data[-1] for _ in range(kernel_width)]) extended = np.vstack([pad_top, data, pad_bottom]) pad_left = np.array([extended.T[0] for _ in range(kernel_width)]) pad_right = np.array([extended.T[-1] for _ in range(kernel_width)]) extended = np.hstack([pad_left.T, extended, pad_right.T]) tophat_kernel = Tophat2DKernel(kernel_width) smooth_illumination = convolve_fft( extended, tophat_kernel, nan_treatment="interpolate" )[kernel_width:-kernel_width, kernel_width:-kernel_width] return smooth_illumination
import warnings from astropy.stats import sigma_clipped_stats
[docs] def get_outlier_pixel_mask(img: np.ndarray, thresh: float = 3.0) -> np.ndarray: """ Get oulier pixels that are above or below a threshold :param img: np.ndarray :param thresh: float """ with warnings.catch_warnings(): warnings.simplefilter("ignore") _, median, std = sigma_clipped_stats(img, sigma=3.0) return (img < (median - thresh * std)) | (img > (median + thresh * std))
[docs] def construct_smooth_gradient_for_image(data: np.ndarray) -> np.ndarray: """ Construct a smooth gradient for the image :param data: np.ndarray :return: np.ndarray """ smooth_img = get_convolution(data, 100) return smooth_img
[docs] def smooth_and_normalize_image(data: np.ndarray) -> np.ndarray: """ Smooth and normalize the image :param data: np.ndarray :return: np.ndarray """ smooth_img = get_convolution(data, 100) return data / smooth_img
[docs] def get_smoothened_outlier_pixel_mask_from_list( img_data_list: list[np.ndarray], threshold: float = 3.0 ) -> np.ndarray: """ Take a list of images and return a mask of outlier pixels after removing a smooth gradient from them :param img_data_list: list[np.ndarray] :param threshold: float :return: np.ndarray """ masks = [] for data in img_data_list: smooth_img = smooth_and_normalize_image(data) mask = get_outlier_pixel_mask(smooth_img, thresh=threshold) masks.append(mask) return np.logical_and.reduce(masks)
[docs] class MissingFlatError(ImageNotFoundError): """ Error for when a dark image is missing """
[docs] def default_select_flat( images: ImageBatch, ) -> ImageBatch: """ Select images tagged as flat :param images: set of images :return: subset of flat images """ return select_from_images(images, key=OBSCLASS_KEY, target_values="flat")
[docs] class FlatCalibrator(ProcessorWithCache): """ Processor to apply flat calibration """ base_key = "flat" def __init__( self, *args, x_min: int = 0, x_max: int = sys.maxsize, y_min: int = 0, y_max: int = sys.maxsize, flat_nan_threshold: float = 0.0, select_flat_images: Callable[[ImageBatch], ImageBatch] = default_select_flat, flat_mask_key: str = None, flat_mode: str = "median", **kwargs, ): super().__init__(*args, **kwargs) self.x_min = x_min self.x_max = x_max self.y_min = y_min self.y_max = y_max self.flat_nan_threshold = flat_nan_threshold self.select_cache_images = select_flat_images self.flat_mask_key = flat_mask_key self.flat_mode = flat_mode if self.flat_mode not in ["median", "pixel", "structure"]: raise ValueError(f"Flat mode {self.flat_mode} not supported")
[docs] def description(self) -> str: return "Creates a flat image, divides other images by this image."
def _apply_to_images( self, batch: ImageBatch, ) -> ImageBatch: master_flat = self.get_cache_file(batch) master_flat_data = master_flat.get_data() mask = master_flat_data <= self.flat_nan_threshold if np.sum(mask) > 0: master_flat_data[mask] = np.nan for image in batch: data = image.get_data() data = data / master_flat_data image.set_data(data) image[FLAT_FRAME_KEY] = master_flat[LATEST_SAVE_KEY] return batch
[docs] def make_image( self, images: ImageBatch, ) -> Image: images = self.select_cache_images(images) logger.debug(f"Found {len(images)} suitable flats in batch") n_frames = len(images) if n_frames == 0: err = f"Found {n_frames} suitable flats in batch" logger.error(err) raise MissingFlatError(err) nx, ny = images[0].get_data().shape flats = np.zeros((nx, ny, n_frames)) flat_exptimes = [] for i, img in enumerate(images): data = img.get_data().copy() if self.flat_mask_key is not None: if self.flat_mask_key not in img.header.keys(): err = ( f"Image {img} does not have a mask with key " f"{self.flat_mask_key}" ) logger.error(err) raise KeyError(err) mask_file = img[self.flat_mask_key] logger.debug(f"Masking flat {img[BASE_NAME_KEY]} with mask {mask_file}") if not os.path.exists(mask_file): err = f"Mask file {mask_file} does not exist" logger.error(err) raise FileNotFoundError(err) mask_img = self.open_fits(mask_file) pixels_to_keep = mask_img.get_data().astype(bool) mask = ~pixels_to_keep logger.debug( f"Masking {np.sum(mask)} pixels in flat {img[BASE_NAME_KEY]}" ) data[mask] = np.nan flat_exptimes.append(img[EXPTIME_KEY]) median = np.nanmedian( data[self.x_min : self.x_max, self.y_min : self.y_max] ) flats[:, :, i] = data / median logger.debug(f"Median combining {n_frames} flats") master_flat = np.nanmedian(flats, axis=2) if self.flat_mode != "median": if self.flat_mode == "pixel": mask = get_smoothened_outlier_pixel_mask_from_list( [x.get_data() for x in images] ) # flatdata_norm_smooth = get_convolution(master_flat, 100) # # pixel_variation = master_flat / flatdata_norm_smooth # # # Clip outliers (they'll get worked out in stacking) # std = np.nanstd(pixel_variation) # sig = abs(pixel_variation - np.nanmedian(pixel_variation)) / std # # mask = sig > 1.0 # mask = get_outlier_pixel_mask(master_flat, thresh=1.5) frac = np.sum(mask) / len(mask.flatten()) logger.info( # FIXME: Change to debug f"Masking {100.*frac:.1f}% of pixels in image" ) # pixel_variation[mask] = np.nan # master_flat = pixel_variation / np.nanmedian(pixel_variation) master_flat = np.ones_like(master_flat) master_flat[mask] = np.nan elif self.flat_mode == "structure": flatdata_norm_smooth = get_convolution(master_flat, 100) flatdata_norm_smooth[np.isnan(master_flat)] = np.nan pixel_variation = master_flat / flatdata_norm_smooth # Clip outliers (they'll get worked out in stacking) std = np.nanstd(pixel_variation) sig = abs(pixel_variation - np.nanmedian(pixel_variation)) / std mask = sig > 1.0 logger.debug( f"Masking {np.sum(mask)} pixels " f"out of {len(mask.flatten()) }in flat" ) master_flat = np.ones_like(master_flat) master_flat[mask] = np.nan # pixel_variation[mask] = np.nan # master_flat = ( # pixel_variation # * flatdata_norm_smooth # / np.nanmedian(pixel_variation) # ) else: raise ValueError(f"Flat mode {self.flat_mode} not supported") master_flat_image = Image(master_flat, header=copy(images[0].get_header())) master_flat_image[COADD_KEY] = n_frames master_flat_image["INDIVEXP"] = ",".join( [str(x) for x in np.unique(flat_exptimes)] ) return master_flat_image
[docs] class SkyFlatCalibrator(FlatCalibrator): """ Processor to do flat calibration using sky flats """ def __init__(self, *args, flat_mask_key=None, **kwargs): super().__init__( *args, select_flat_images=self.select_sky_flat, flat_mask_key=flat_mask_key, **kwargs, )
[docs] @staticmethod def select_sky_flat( images: ImageBatch, ) -> ImageBatch: """ Select science images to use as sky flats :param images: set of images :return: subset of 'sky' images """ return select_from_images(images, key=OBSCLASS_KEY, target_values="science")
[docs] def description(self) -> str: return ( "Processor to create a sky flat image, divides other images by this image." )
[docs] class MasterFlatCalibrator(ProcessorPremadeCache, FlatCalibrator): """Processor to do flat calibration with a master flate"""