"""
Python script containing all IO functions.
All opening/writing of fits files should run via this script.
"""
import copy
import logging
import warnings
from pathlib import Path
from typing import Callable
import numpy as np
from astropy.io import fits
from astropy.utils.exceptions import AstropyUserWarning, AstropyWarning
from mirar.data import Image
from mirar.errors.exceptions import ProcessorError
from mirar.paths import BASE_NAME_KEY, LATEST_SAVE_KEY, RAW_IMG_KEY, core_fields
logger = logging.getLogger(__name__)
[docs]
class MissingCoreFieldError(KeyError, ProcessorError):
"""Base class for missing core field errors"""
[docs]
class ExtensionParsingError(ProcessorError):
"""Base class for mislabelled extension errors"""
[docs]
def create_fits(data: np.ndarray, header: fits.Header | None) -> fits.PrimaryHDU:
"""
Return an astropy PrimaryHDU object created with <data> and <header>
:param data: numpy ndarray containing image data
:param header: astropy Header object
:return: astropy PrimaryHDU object containing the image data and header
"""
proc_hdu = fits.PrimaryHDU(data)
if header is not None:
proc_hdu.header = header
return proc_hdu
[docs]
def create_compressed_fits(
data: np.ndarray, header: fits.Header | None
) -> fits.CompImageHDU:
"""
Return an astropy CompImageHDU object created with <data> and <header>
:param data: numpy ndarray containing image data
:param header: astropy Header object
:return: astropy CompImageHDU object containing the image data and header
"""
proc_hdu = fits.CompImageHDU(data, header=header)
return proc_hdu
[docs]
def save_hdu_as_fits(
hdu: fits.PrimaryHDU | fits.CompImageHDU, path: str | Path, overwrite: bool = True
):
"""
Wrapper function to save an astropy hdu to file
:param hdu: hdu to save
:param path: path to save
:param overwrite: boolean whether to overwrite
:return: None
"""
hdu.verify("silentfix+exception")
hdu.writeto(path, overwrite=overwrite)
[docs]
def save_to_path(
data: np.ndarray,
header: fits.Header | None,
path: str | Path,
overwrite: bool = True,
compress: bool = False,
):
"""
Function to save an image with <data> and <header> to <path>.
:param data: numpy ndarray containing image data
:param header: astropy Header object
:param path: output path to save to
:param overwrite: boolean variable opn whether to overwrite of an
image exists at <path>. Defaults to True.
:param compress: boolean variable on whether to compress the image
:return: None
"""
if compress:
img = create_compressed_fits(data, header=header)
else:
img = create_fits(data, header=header)
save_hdu_as_fits(hdu=img, path=path, overwrite=overwrite)
[docs]
def save_mef_to_path(data_list, header_list, primary_header, path):
"""
Function to save a MEF image with <data> and <header> to <path>.
"""
primary_hdu = fits.PrimaryHDU(header=primary_header)
hdu_list = [primary_hdu]
assert len(data_list) == len(header_list)
for ind, data in enumerate(data_list):
hdu_list.append(fits.ImageHDU(data=data, header=header_list[ind]))
hdulist = fits.HDUList(hdu_list)
hdulist.writeto(path, overwrite=True)
[docs]
def open_compressed_fits(path: str | Path) -> tuple[np.ndarray, fits.Header]:
"""
Opens a compressed fits file and returns the data and header
:param path: path to the compressed fits file
:return: data, header
"""
_, extension_data_list, extension_header_list = open_mef_fits(path)
if len(extension_data_list) == 0:
err = f"Compressed fits file {path} has no extensions."
logger.error(err)
raise ValueError(err)
if len(extension_data_list) != 1:
err = f"Compressed fits file {path} has more than one extension."
logger.error(err)
raise ValueError(err)
return extension_data_list[0], extension_header_list[0]
[docs]
def open_fits(path: str | Path) -> tuple[np.ndarray, fits.Header]:
"""
Function to open a fits file saved to <path>
:param path: path of fits file
:return: tuple containing image data and image header
"""
if isinstance(path, str):
path = Path(path)
try:
with fits.open(path, memmap=False) as img:
if (
sum(
isinstance(x, fits.hdu.compressed.compressed.CompImageHDU)
for x in img
)
> 0
):
raise ExtensionParsingError("This is a compressed fits file")
hdu = img.pop(0)
hdu.verify("silentfix+ignore")
data = hdu.data
header = hdu.header
except ExtensionParsingError:
data, header = open_compressed_fits(path)
if BASE_NAME_KEY not in header:
header[BASE_NAME_KEY] = Path(path).name
if RAW_IMG_KEY not in header.keys():
header[RAW_IMG_KEY] = path.as_posix()
return data, header
[docs]
def save_fits(
image: Image,
path: str | Path,
compress: bool = False,
):
"""
Save an Image to path
:param image: Image to save
:param path: path
:param compress: boolean on whether to compress the image
:return: None
"""
if isinstance(path, str):
path = Path(path)
check_image_has_core_fields(image)
data = image.get_data()
header = image.get_header()
if header is not None:
header[LATEST_SAVE_KEY] = path.as_posix()
logger.debug(f"Saving to {path.as_posix()}")
save_to_path(data, header, path, compress=compress)
[docs]
def open_raw_image(
path: str | Path,
open_f: Callable[[str | Path], tuple[np.ndarray, fits.Header]] = open_fits,
) -> Image:
"""
Function to open a raw image as an Image object
:param path: path of raw image
:param open_f: function to open the raw image
:return: Image object
"""
if isinstance(path, str):
path = Path(path)
data, header = open_f(path)
new_img = Image(data.astype(np.float64), header)
check_image_has_core_fields(new_img)
return new_img
[docs]
def open_mef_fits(
path: str | Path,
) -> tuple[fits.Header, list[np.ndarray], list[fits.Header]]:
"""
Function to open a MEF fits file saved to <path>
:param path: path of fits file
:return: tuple containing image data and image header
"""
split_data, split_headers = [], []
with fits.open(path, memmap=False) as hdu:
primary_header = hdu[0].header # pylint: disable=no-member
num_ext = len(hdu)
for ext in range(1, num_ext):
try:
data = hdu[ext].data.astype(np.float64) # pylint: disable=no-member
except TypeError:
data = hdu[ext].data
split_data.append(data)
split_headers.append(hdu[ext].header) # pylint: disable=no-member
return primary_header, split_data, split_headers
[docs]
def open_mef_image(
path: str | Path,
open_f: Callable[
[str | Path], tuple[fits.Header, list[np.ndarray], list[fits.Header]]
] = open_mef_fits,
extension_key: str | None = None,
) -> list[Image]:
"""
Function to open a raw image as an Image object
:param path: path of raw image
:param open_f: function to open the raw image
:param extension_key: key to use to number the MEF frames
:return: Image object
"""
primary_header, ext_data_list, ext_header_list = open_f(path)
ext_header_list = tag_mef_extension_file_headers(
primary_header=primary_header,
extension_headers=ext_header_list,
extension_key=extension_key,
)
ext_data_list = [x.astype(np.float64) for x in ext_data_list]
split_images_list = []
for i, ext_data in enumerate(ext_data_list):
single_header = ext_header_list[i]
image = Image(data=copy.deepcopy(ext_data), header=copy.deepcopy(single_header))
check_image_has_core_fields(image)
split_images_list.append(image)
names = [x.get_name() for x in split_images_list]
if len(names) != len(set(names)):
raise ExtensionParsingError(f"Found duplicate image names in {names}")
return split_images_list
[docs]
def check_file_is_complete(path: str) -> bool:
"""
Function to check whether a fits file is as large as expected.
Useful to verify with e.g rsync, where files can be partially transferred
Disclaimer: I (Robert) do not feel great about having written
this code block.
It seems to works though, let's hope no one finds out!
I will cover my tracks by hiding the astropy warning which
inspired this block, informing the user that the file
is not as long as expected
:param path: path of file to check
:return: boolean file complete
"""
check = False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=AstropyUserWarning)
try:
with fits.open(path) as hdul:
check = (
hdul[-1].fileinfo()["file"].tell()
== hdul[-1].fileinfo()["file"].size
)
except OSError:
pass
return check
[docs]
def check_image_has_core_fields(img: Image):
"""
Function to ensure that an image has all the core fields
:param img: Image object to check
:return: None
"""
for key in core_fields:
if key not in img.keys():
if BASE_NAME_KEY in img.keys():
msg = f"({img[BASE_NAME_KEY]}) "
err = (
f"New image {msg}is missing the core field {key}. "
f"Available fields are {list(img.keys())}."
)
else:
err = (
f"New image is missing the core field {BASE_NAME_KEY}. Available "
f"fields are {list(img.keys())}."
)
logger.error(err)
raise MissingCoreFieldError(err)