Source code for mirar.processors.sources.source_detector

"""
Module to detect candidates in an image
"""

import logging
import os
from pathlib import Path

import numpy as np
import pandas as pd
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS

from mirar.data import Image, ImageBatch, SourceBatch, SourceTable
from mirar.data.utils import encode_img, write_regions_file
from mirar.paths import (
    BASE_NAME_KEY,
    CAND_DEC_KEY,
    CAND_RA_KEY,
    LATEST_SAVE_KEY,
    LATEST_WEIGHT_SAVE_KEY,
    REF_IMG_KEY,
    SCI_IMG_KEY,
    SCOR_IMG_KEY,
    SOURCE_HISTORY_KEY,
    SOURCE_NAME_KEY,
    XPOS_KEY,
    YPOS_KEY,
    get_output_dir,
)
from mirar.processors.astromatic.sextractor.sourceextractor import run_sextractor_dual
from mirar.processors.base_processor import BaseSourceGenerator, PrerequisiteError
from mirar.processors.photometry.utils import make_cutouts
from mirar.processors.zogy.zogy import ZOGY
from mirar.utils.ldac_tools import get_table_from_ldac

logger = logging.getLogger(__name__)


[docs] def generate_candidates_table( diff: Image, scorr_catalog_path: str | Path, sci_resamp_image_path: str | Path, ref_resamp_image_path: str | Path, diff_scorr_path: str | Path, isdiffpos: bool = True, ) -> pd.DataFrame: """ Generate a candidates table from a difference image and scorr catalog :param diff: Difference image :param scorr_catalog_path: Path to the scorr catalog :param sci_resamp_image_path: Path to the resampled science image :param ref_resamp_image_path: Path to the resampled reference image :param diff_scorr_path: Path to the scorr image :param isdiffpos: Is the difference image positive? :return: Candidates table """ det_srcs = get_table_from_ldac(scorr_catalog_path) # Check if required sextractor specific keys are present in table for key in [ "XPEAK_IMAGE", "YPEAK_IMAGE", "FWHM_IMAGE", "XPEAK_IMAGE", "YPEAK_IMAGE", "ALPHAWIN_J2000", "DELTAWIN_J2000", "A_IMAGE", "B_IMAGE", "ELONGATION", ]: if key not in det_srcs.colnames: raise PrerequisiteError(f"Required key {key} not found in scorr catalog.") diff_path = diff[LATEST_SAVE_KEY] diff_wght_path = diff[LATEST_WEIGHT_SAVE_KEY] logger.debug(f"Found {len(det_srcs)} candidates in image.") xpeaks, ypeaks = det_srcs["XPEAK_IMAGE"] - 1, det_srcs["YPEAK_IMAGE"] - 1 det_srcs["xpeak"] = xpeaks det_srcs["ypeak"] = ypeaks scorr_data = fits.getdata(diff_scorr_path) scorr_peaks = scorr_data[ypeaks, xpeaks] det_srcs["scorr"] = scorr_peaks weight_data = fits.getdata(diff_wght_path) det_srcs["weight"] = weight_data[ypeaks, xpeaks] weight_percentile_lolim = np.nanpercentile(weight_data[weight_data > 0], q=30) det_srcs = det_srcs[ (det_srcs["scorr"] > 4) & (det_srcs["weight"] > weight_percentile_lolim) ] det_srcs = det_srcs.to_pandas() det_srcs[SOURCE_HISTORY_KEY] = [pd.DataFrame() for _ in range(len(det_srcs))] det_srcs[SOURCE_NAME_KEY] = None logger.debug( f"Filtered to {len(det_srcs)} candidates in image with " f"scorr peak > 3." ) # Rename sextractor keys ydims, xdims = diff.get_data().shape det_srcs["NAXIS1"] = xdims det_srcs["NAXIS2"] = ydims det_srcs[XPOS_KEY] = det_srcs["X_IMAGE"] det_srcs[YPOS_KEY] = det_srcs["Y_IMAGE"] det_srcs["SEXTR_RA"] = det_srcs["ALPHAWIN_J2000"] det_srcs["SEXTR_DEC"] = det_srcs["DELTAWIN_J2000"] wcs = WCS(diff.get_header()) img_ra, img_dec = wcs.all_pix2world(det_srcs[XPOS_KEY], det_srcs[YPOS_KEY], 1) det_srcs[CAND_RA_KEY] = img_ra det_srcs[CAND_DEC_KEY] = img_dec det_crds = SkyCoord(ra=img_ra, dec=img_dec, unit="deg") sextractor_crds = SkyCoord( ra=det_srcs["SEXTR_RA"], dec=det_srcs["SEXTR_DEC"], unit="deg" ) det_srcs["crd_offset"] = sextractor_crds.separation(det_crds).arcsec det_srcs["fwhm"] = det_srcs["FWHM_IMAGE"] det_srcs["aimage"] = det_srcs["A_IMAGE"] det_srcs["bimage"] = det_srcs["B_IMAGE"] det_srcs["aimagerat"] = det_srcs["aimage"] / det_srcs["fwhm"] det_srcs["bimagerat"] = det_srcs["bimage"] / det_srcs["fwhm"] det_srcs["elong"] = det_srcs["ELONGATION"] cutout_size_display = 40 display_sci_ims = [] display_ref_ims = [] display_diff_ims = [] nan_fracs = [] # Cutouts for _, row in det_srcs.iterrows(): xpeak, ypeak = int(row["xpeak"]), int(row["ypeak"]) display_sci_cutout, display_ref_cutout, display_diff_cutout = make_cutouts( [sci_resamp_image_path, ref_resamp_image_path, diff_path], (xpeak, ypeak), cutout_size_display, ) display_sci_bit = encode_img(display_sci_cutout) display_ref_bit = encode_img(display_ref_cutout) display_diff_bit = encode_img(display_diff_cutout) display_sci_ims.append(display_sci_bit) display_ref_ims.append(display_ref_bit) display_diff_ims.append(display_diff_bit) nan_fracs.append( np.sum(np.isnan(display_diff_cutout)) / np.prod(display_diff_cutout.shape) ) det_srcs["cutout_science"] = display_sci_ims det_srcs["cutout_template"] = display_ref_ims det_srcs["cutout_difference"] = display_diff_ims det_srcs["maskfrac"] = nan_fracs det_srcs["isdiffpos"] = isdiffpos return det_srcs
[docs] class ZOGYSourceDetector(BaseSourceGenerator): """ Class to detect candidates by running sourceextractor on a difference image and scorr image from ZOGY """ base_key = "DETCANDS" def __init__( self, cand_det_sextractor_config: str, cand_det_sextractor_filter: str, cand_det_sextractor_nnw: str, cand_det_sextractor_params: str, output_sub_dir: str = "candidates", write_regions: bool = False, detect_negative_sources: bool = False, ): """ Parameters :param cand_det_sextractor_config: SExtractor configuration file :param cand_det_sextractor_filter: SExtractor filter file :param cand_det_sextractor_nnw: SExtractor neural network file :param cand_det_sextractor_params: SExtractor parameter file :param output_sub_dir: Directory to save output :param write_regions: Write regions file with detected sources :param detect_negative_sources: Detect negative sources in addition to positive sources? If true, sources are also detected in the negative scorr image, and are marked with isdiffpos=False """ super().__init__() self.output_sub_dir = output_sub_dir self.cand_det_sextractor_config = cand_det_sextractor_config self.cand_det_sextractor_filter = cand_det_sextractor_filter self.cand_det_sextractor_nnw = cand_det_sextractor_nnw self.cand_det_sextractor_params = cand_det_sextractor_params self.write_regions = write_regions self.detect_negative_sources = detect_negative_sources
[docs] def description(self) -> str: return ( "Extracts detected sources from images, " "and converts them to a pandas dataframe" )
[docs] def get_sub_output_dir(self) -> Path: """ Get output sub-directory Returns: """ return get_output_dir(self.output_sub_dir, self.night_sub_dir)
def _apply_to_images( self, batch: ImageBatch, ) -> SourceBatch: all_cands = SourceBatch() for image in batch: scorr_image_path = os.path.join( self.get_sub_output_dir(), image[SCOR_IMG_KEY] ) diff_image_path = os.path.join( self.get_sub_output_dir(), image[BASE_NAME_KEY] ) scorr_image = self.open_fits(scorr_image_path) scorr_mask_path = os.path.join( self.get_sub_output_dir(), scorr_image[LATEST_WEIGHT_SAVE_KEY] ) cands_catalog_name = diff_image_path.replace(".fits", ".dets") cands_catalog_name, _ = run_sextractor_dual( det_image=scorr_image_path, measure_image=diff_image_path, output_dir=self.get_sub_output_dir(), catalog_name=cands_catalog_name, config=self.cand_det_sextractor_config, parameters_name=self.cand_det_sextractor_params, filter_name=self.cand_det_sextractor_filter, starnnw_name=self.cand_det_sextractor_nnw, weight_image=scorr_mask_path, gain=1.0, ) sci_image_path = os.path.join(self.get_sub_output_dir(), image[SCI_IMG_KEY]) ref_image_path = os.path.join(self.get_sub_output_dir(), image[REF_IMG_KEY]) srcs_table = generate_candidates_table( diff=image, scorr_catalog_path=cands_catalog_name, sci_resamp_image_path=sci_image_path, ref_resamp_image_path=ref_image_path, diff_scorr_path=scorr_image_path, ) logger.debug(f"Found {len(srcs_table)} candidates in positive image") if self.detect_negative_sources: negative_img_paths = [] for img_path in [scorr_image_path, diff_image_path]: negative_img_path = img_path.replace(".fits", ".neg.fits") img = self.open_fits(img_path) img.set_data(img.get_data() * -1) self.save_fits(img, negative_img_path) negative_img_paths.append(negative_img_path) negative_scorr_path, negative_diff_path = negative_img_paths cands_catalog_name = diff_image_path.replace(".fits", ".neg.dets") negative_image = Image( header=image.get_header(), data=-1.0 * image.get_data() ) cands_catalog_name, _ = run_sextractor_dual( det_image=negative_scorr_path, measure_image=negative_diff_path, output_dir=self.get_sub_output_dir(), catalog_name=cands_catalog_name, config=self.cand_det_sextractor_config, parameters_name=self.cand_det_sextractor_params, filter_name=self.cand_det_sextractor_filter, starnnw_name=self.cand_det_sextractor_nnw, weight_image=scorr_mask_path, gain=1.0, ) negative_srcs_table = generate_candidates_table( diff=negative_image, scorr_catalog_path=cands_catalog_name, sci_resamp_image_path=sci_image_path, ref_resamp_image_path=ref_image_path, diff_scorr_path=negative_scorr_path, isdiffpos=False, ) srcs_table = pd.concat( [srcs_table, negative_srcs_table], axis=0, ignore_index=True, sort=False, ) logger.debug( f"Found {len(negative_srcs_table)} sources in " f"negative image {image[BASE_NAME_KEY]}" ) logger.debug(srcs_table) logger.debug(f"Found {len(srcs_table)} total sources.") if len(srcs_table) > 0: x_shape, y_shape = image.get_data().shape srcs_table["X_SHAPE"] = x_shape srcs_table["Y_SHAPE"] = y_shape metadata = self.get_metadata(image) if len(srcs_table) == 0: msg = f"No sources found in image {image[BASE_NAME_KEY]}" logger.warning(msg) else: msg = f"Found {len(srcs_table)} sources in image {image[BASE_NAME_KEY]}" if self.write_regions: reg_name = diff_image_path.replace(".fits", ".cands.img.reg") write_regions_file( regions_path=reg_name, x_coords=srcs_table["X_IMAGE"], y_coords=srcs_table["Y_IMAGE"], text=[str(round(x, 2)) for x in srcs_table["scorr"]], ) write_regions_file( regions_path=diff_image_path.replace(".fits", ".cands.wcs.reg"), x_coords=srcs_table[CAND_RA_KEY], y_coords=srcs_table[CAND_DEC_KEY], region_radius=5.0 / 3600.0, system="wcs", text=[str(round(x, 2)) for x in srcs_table["scorr"]], ) logger.debug(msg) all_cands.append(SourceTable(srcs_table, metadata=metadata)) return all_cands
[docs] def check_prerequisites(self): check = np.sum([isinstance(x, ZOGY) for x in self.preceding_steps]) if check != 1: logger.warning("ZOGY must be run before ZOGYSourceDetector")