Source code for mirar.database.constraints

"""
Module for DBQueryConstraints to carefully specify postgres query constraints
"""

import numpy as np

POSTGRES_ACCEPTED_COMPARISONS = ["=", "<", ">", "<=", ">=", "between", "<>", "!="]


[docs] class DBQueryConstraints: """ Object containing one or more postgres query constraints """ def __init__( self, columns: str | list[str] | None = None, accepted_values: ( str | int | float | list[str | float | int | list] | None ) = None, comparison_types: str | list[str] | None = None, ): self.columns = [] self.accepted_values = [] self.comparison_types = [] self.q3c_query = None if columns is None: assert accepted_values is None else: if not isinstance(columns, list): columns = [columns] if not isinstance(accepted_values, list): accepted_values = [accepted_values] assert len(columns) == len(accepted_values) if comparison_types is None: comparison_types = ["="] * len(accepted_values) assert len(comparison_types) == len(accepted_values) for i, column in enumerate(columns): self.add_constraint( column=column, accepted_values=accepted_values[i], comparison_type=comparison_types[i], )
[docs] def add_constraint( self, column: str, accepted_values: str | int | float | tuple[float, float] | tuple[int, int], comparison_type: str = "=", ): """ Add a new constraint :param column: column :param accepted_values: accepted value for comparison :param comparison_type: type of comparison, e.g '=' :return: None """ assert comparison_type in POSTGRES_ACCEPTED_COMPARISONS if comparison_type == "between": assert np.logical_and( isinstance(accepted_values, tuple), len(accepted_values) == 2 ) self.columns.append(column) self.accepted_values.append(accepted_values) self.comparison_types.append(comparison_type)
[docs] def add_q3c_constraint( self, ra: float, dec: float, crossmatch_radius_arcsec: float, ra_field_name: str = "ra", dec_field_name: str = "dec", ): """ Add a q3c constraint :param ra: ra of source :param dec: dec of source :param crossmatch_radius_arcsec: crossmatch radius in arcsec :param ra_field_name: ra field name in database :param dec_field_name: dec field name in database :return: None """ crossmatch_radius_deg = crossmatch_radius_arcsec / 3600.0 constraints = ( f"q3c_radial_query({ra_field_name},{dec_field_name}," f"{ra},{dec},{crossmatch_radius_deg}) " ) self.q3c_query = constraints
def __add__(self, other): new = self.__class__(self.columns, self.accepted_values, self.comparison_types) new.q3c_query = self.q3c_query for args in other: new.add_constraint(*args) if other.q3c_query is not None: new.add_q3c_constraint(*other.q3c_query) return new def __iadd__(self, other): for args in other: self.add_constraint(*args) if other.q3c_query is not None: self.add_q3c_constraint(*other.q3c_query) return self def __len__(self): return self.columns.__len__() def __iter__(self): return iter(zip(self.columns, self.accepted_values, self.comparison_types))
[docs] def parse_constraints( self, ) -> str: """ Converts the list of constraints to sql :return: sql string """ constraints = [] if self.q3c_query is not None: constraints.append(self.q3c_query) for i, column in enumerate(self.columns): if self.comparison_types[i] == "between": constraints.append( f"{column.lower()} between {self.accepted_values[i][0]} " f"and {self.accepted_values[i][1]}" ) else: constraints.append( f"{column.lower()} {self.comparison_types[i]} " f"'{self.accepted_values[i]}'" ) return " AND ".join(constraints)