Source code for mirar.processors.database.database_updater

"""
Module for processors to modify database entries
"""

import logging
from abc import ABC
from typing import Optional

import pandas as pd

from mirar.data import ImageBatch
from mirar.database.constraints import DBQueryConstraints
from mirar.database.transactions import select_from_table
from mirar.database.utils import get_sequence_key_names_from_table
from mirar.processors.database.database_inserter import DatabaseImageInserter
from mirar.processors.database.database_selector import (
    BaseDatabaseSelector,
    BaseImageDatabaseSelector,
)

logger = logging.getLogger(__name__)


[docs] class BaseDatabaseUpdater(BaseDatabaseSelector, ABC): """ Base class for database updaters """ base_key = "dbupdater" def __init__(self, db_alter_columns: str | list[str], **kwargs): super().__init__(db_output_columns=db_alter_columns, **kwargs) if not isinstance(db_alter_columns, list): db_alter_columns = [db_alter_columns] self.db_alter_columns = db_alter_columns
[docs] class ImageDatabaseUpdater(BaseDatabaseUpdater, BaseImageDatabaseSelector, ABC): """Base Class for updating image entries in a database""" def _apply_to_images( self, batch: ImageBatch, ) -> ImageBatch: for image in batch: val_dict = self.generate_value_dict(image) new = self.db_table(**val_dict) new.update_entry(update_keys=self.db_alter_columns) return batch
[docs] @staticmethod def generate_value_dict(image): """ Get the value dictionary for an image :param image: Image :return: Value dictionary """ return DatabaseImageInserter.generate_value_dict(image)
[docs] class ImageSequenceDatabaseUpdater(ImageDatabaseUpdater): """ Processor to modify images in a database with a sequence """ def __init__(self, sequence_key: Optional[str | list[str]] = None, **kwargs): super().__init__(**kwargs) self.sequence_key = sequence_key
[docs] def description(self) -> str: return ( f"Update entries in '{self.db_table.__name__}' " f"table of db using {self.sequence_key}" )
[docs] def get_constraints(self, data) -> DBQueryConstraints: """ Function to get the constraints for a database query :param data: Image :return: Constraints for a database query """ if self.sequence_key is None: self.sequence_key = get_sequence_key_names_from_table( self.db_table.sql_model.__tablename__, self.db_name ) accepted_values = [data[x.lower()] for x in self.sequence_key] comparison_types = ["="] * len(accepted_values) query_constraints = DBQueryConstraints( columns=self.sequence_key, accepted_values=accepted_values, comparison_types=comparison_types, ) return query_constraints
[docs] class ImageDatabaseMultiEntryUpdater(ImageSequenceDatabaseUpdater): """ Processor to modify multiple entries specified by a list of sequences in an image database """ def __init__(self, sequence_key: str, **kwargs): super().__init__(**kwargs) self.sequence_key = sequence_key.lower() def _apply_to_images( self, batch: ImageBatch, ) -> ImageBatch: for image in batch: try: unique_key_vals = [int(y) for y in image[self.sequence_key].split(",")] except ValueError as exc: raise ValueError("Sequence keys must be integers") from exc data_df = pd.DataFrame(unique_key_vals, columns=[self.sequence_key]) for _, row in data_df.iterrows(): constraints = DBQueryConstraints( columns=self.sequence_key, accepted_values=row[self.sequence_key] ) old = select_from_table( db_constraints=constraints, sql_table=self.db_table.sql_model, ) assert ( len(old) == 1 ), f"Multiple entries found for unique key {self.sequence_key}" for key in self.db_alter_columns: old[key] = image[key] new = self.db_table(**old.to_dict(orient="records")[0]) new.update_entry(update_keys=self.db_alter_columns) return batch