Source code for mirar.processors.split

"""
Module for splitting images into sub-images
"""

import copy
import logging

import numpy as np
from astropy.wcs import WCS

from mirar.data import Dataset, Image, ImageBatch
from mirar.errors import ProcessorError
from mirar.paths import BASE_NAME_KEY, LATEST_SAVE_KEY, LATEST_WEIGHT_SAVE_KEY
from mirar.processors.astromatic.swarp import Swarp
from mirar.processors.base_processor import BaseImageProcessor

logger = logging.getLogger(__name__)

SUB_ID_KEY = "SUBDETID"
SUB_COORD_KEY = "SUBCOORD"


[docs] class ImageSplittingError(ProcessorError): """ Error raised when there is an issue with splitting images """
[docs] class SplitImage(BaseImageProcessor): """ Processor for splitting images """ base_key = "split" def __init__(self, buffer_pixels: int = 0, n_x: int = 1, n_y: int = 1): super().__init__() self.buffer_pixels = buffer_pixels self.n_x = n_x self.n_y = n_y
[docs] def description(self) -> str: return ( f"Processor to split images into " f"{self.n_x}x{self.n_y}={self.n_x*self.n_y} smaller images." )
[docs] def get_range( self, n_chunks: int, pixel_width: int, i: int, ) -> tuple[int, int]: """ Function to return pixel index range for sub images :param n_chunks: number of chunks to divide axis into :param pixel_width: total pixel width of axis :param i: index of chunk to evaluate :return: lower pixel index and upper pixel index of chunk """ lower = max(0, i * int(pixel_width / n_chunks) - self.buffer_pixels) upper = min( pixel_width, (1 + i) * int(pixel_width / n_chunks) + self.buffer_pixels ) return lower, upper
def _apply_to_images( self, batch: ImageBatch, ) -> ImageBatch: new_images = ImageBatch() logger.debug(f"Splitting each data into {self.n_x*self.n_y} sub-images") for image in batch: pix_width_x, pix_width_y = image.get_data().shape k = 0 for index_x in range(self.n_x): x_0, x_1 = self.get_range(self.n_x, pix_width_x, index_x) for index_y in range(self.n_y): y_0, y_1 = self.get_range(self.n_y, pix_width_y, index_y) new_data = np.array(image.get_data()[x_0:x_1, y_0:y_1]) new_header = copy.copy(image.get_header()) for key in ["DETSIZE", "INFOSEC", "TRIMSEC", "DATASEC"]: if key in new_header.keys(): del new_header[key] sub_img_id = f"{index_x}_{index_y}" new_header[SUB_COORD_KEY] = ( sub_img_id, "Sub-data coordinate, in form x_y", ) new_header["SUBNX"] = (index_x + 1, "Sub-data x index") new_header["SUBNY"] = (index_y + 1, "Sub-data y index") new_header["SUBNXTOT"] = (self.n_x, "Total number of sub-data in x") new_header["SUBNYTOT"] = (self.n_y, "Total number of sub-data in y") new_header[SUB_ID_KEY] = k k += 1 new_header["SRCIMAGE"] = ( image[BASE_NAME_KEY], "Source data name, from which sub-data was made", ) new_header["NAXIS1"], new_header["NAXIS2"] = new_data.shape new_header[BASE_NAME_KEY] = image[BASE_NAME_KEY].replace( ".fits", f"_{sub_img_id}.fits" ) for key in [LATEST_SAVE_KEY, LATEST_WEIGHT_SAVE_KEY]: if key in new_header.keys(): del new_header[key] new_images.append(Image(data=new_data, header=new_header)) return new_images
[docs] def update_dataset(self, dataset: Dataset) -> Dataset: all_new_batches = [] for batch in dataset: new_images = [[] for _ in range(self.n_x * self.n_y)] for image in batch: idx = image[SUB_ID_KEY] new_images[idx] += [image] all_new_batches += new_images all_new_batches = [ImageBatch(x) for x in all_new_batches] return Dataset(all_new_batches)
[docs] class SwarpImageSplitter(SplitImage): """ Processor for splitting images using Swarp """ def __init__( self, swarp_config_path: str, output_sub_dir: str = "swarp_split", buffer_pixels: int = 0, n_x: int = 1, n_y: int = 1, ): super().__init__(buffer_pixels=buffer_pixels, n_x=n_x, n_y=n_y) self.swarp_config_path = swarp_config_path self.output_sub_dir = output_sub_dir def _apply_to_images( self, batch: ImageBatch, ) -> ImageBatch: new_images = ImageBatch() for image in batch: pix_width_x, pix_width_y = image.get_data().shape src_imagename = image[BASE_NAME_KEY] try: old_wcs = WCS(image.get_header()) except Exception as exc: logger.error(f"Could not parse WCS from header: {exc}") raise ImageSplittingError from exc k = 0 for index_x in range(self.n_x): x_0, x_1 = self.get_range(self.n_x, pix_width_x, index_x) for index_y in range(self.n_y): y_0, y_1 = self.get_range(self.n_y, pix_width_y, index_y) new_image_center_x = (x_1 + x_0) / 2 new_image_center_y = (y_1 + y_0) / 2 new_image_center_radec = old_wcs.all_pix2world( [new_image_center_y], [new_image_center_x], 0 ) sub_img_id = f"{index_x}_{index_y}" resampler = Swarp( swarp_config_path=self.swarp_config_path, temp_output_sub_dir=self.output_sub_dir, center_type="MANUAL", center_ra=new_image_center_radec[0][0], center_dec=new_image_center_radec[1][0], x_imgpixsize=x_1 - x_0, y_imgpixsize=y_1 - y_0, cache=False, include_scamp=False, ) resampler.set_night(night_sub_dir=self.night_sub_dir) image[BASE_NAME_KEY] = src_imagename.replace( ".fits", f"_{sub_img_id}.fits" ) resampled_image = resampler.apply(ImageBatch(image))[0] resampled_image[SUB_ID_KEY] = k resampled_image[SUB_COORD_KEY] = ( sub_img_id, "Sub-data coordinate, in form x_y", ) resampled_image["SUBNX"] = (index_x + 1, "Sub-data x index") resampled_image["SUBNY"] = (index_y + 1, "Sub-data y index") resampled_image["SUBNXTOT"] = ( self.n_x, "Total number of sub-data in x", ) resampled_image["SUBNYTOT"] = ( self.n_y, "Total number of sub-data in y", ) k += 1 new_images.append(resampled_image) return new_images