Source code for mirar.processors.xmatch

"""
Module to cross-match a candidate_table with different catalogs
"""

import logging

import astropy.units as u
import numpy as np
import pandas as pd
from astropy.coordinates import SkyCoord

from mirar.catalog.base.base_xmatch_catalog import BaseXMatchCatalog
from mirar.data import SourceBatch
from mirar.processors.base_processor import BaseSourceProcessor

logger = logging.getLogger(__name__)


[docs] class XMatch(BaseSourceProcessor): """ Class to cross-match a candidate_table to a catalog """ max_n_cpu = 4 base_key = "XMATCH" def __init__( self, catalog: BaseXMatchCatalog, ): self.catalog = catalog super().__init__()
[docs] def description(self): return ( f"Processor to cross-match sources with " f"'{self.catalog.catalog_name}' catalog." )
def _apply_to_sources( self, batch: SourceBatch, ) -> SourceBatch: for source_list in batch: candidate_table = source_list.get_data() ras = candidate_table["ra"] decs = candidate_table["dec"] crds = SkyCoord(ras, decs, unit=u.deg) query_names = np.array([f"q{x}" for x in np.arange(len(ras))]) catalog = self.catalog logger.debug(f"Querying {catalog.catalog_name} for {len(ras)} sources.") query_coords = { f"{query_names[ind]}": [ras[ind], decs[ind]] for ind in range(len(ras)) } query_results = catalog.query(query_coords) available_projection_keys = [] for k in catalog.projection.keys(): if catalog.projection[k] == 1: available_projection_keys += [k] # Add placeholder columns for each catalog column for key in available_projection_keys: for num in range(self.catalog.num_sources): colname = catalog.column_names[key] candidate_table[colname + f"{num + 1}"] = np.array( np.nan, dtype=catalog.column_dtypes[colname], ) # Add column for number of matches nmatch_colname = f"nmtch{self.catalog.abbreviation}" candidate_table[nmatch_colname] = 0 for query_ind, query_name in enumerate(query_names): results = query_results[query_name] for result_ind, result in enumerate(results): for key in result.keys(): colname = catalog.column_names[key] + f"{result_ind + 1}" candidate_table.at[query_ind, colname] = result[key] candidate_table.at[query_ind, nmatch_colname] = len(results) # Calculate distances between query and result and add to table for num in range(self.catalog.num_sources): result_ra_colname = self.catalog.ra_column_name + f"{num + 1}" result_dec_colname = self.catalog.dec_column_name + f"{num + 1}" dist_colname = f"dist{self.catalog.abbreviation}nr{num + 1}" candidate_table[dist_colname] = np.array(np.nan, dtype=float) crd_nanmask = pd.notnull(candidate_table[result_ra_colname]) result_crds = SkyCoord( ra=candidate_table[result_ra_colname][crd_nanmask], dec=candidate_table[result_dec_colname][crd_nanmask], unit=u.deg, ) candidate_table.loc[crd_nanmask, [dist_colname]] = ( crds[crd_nanmask].separation(result_crds).arcsec ) candidate_table = candidate_table.replace({np.nan: None}) source_list.set_data(candidate_table) return batch