"""
A module containing the class GSData, a variant of UVData specific to single antennas.
The GSData object simplifies handling of radio astronomy data taken from a single
antenna, adding self-consistent metadata along with the data itself, and providing
key methods for data selection, I/O, and analysis.
"""
import logging
import warnings
from collections.abc import Iterable
from copy import deepcopy
from functools import cached_property
from pathlib import Path
from typing import Any, Literal, Self
import astropy.units as un
import h5py
import hickle
import numpy as np
from astropy.coordinates import Longitude
from astropy.table import QTable
from astropy.time import Time
from attrs import cmp_using, define, evolve, field
from attrs import converters as cnv
from attrs import validators as vld
from . import coordinates as crd
from .attrs import cmp_qtable, lstfield, npfield, timefield
from .gsflag import GSFlag
from .history import History, Stamp
from .telescope import Telescope, _pol_converter
from .utils import time_concat
logger = logging.getLogger(__name__)
[docs]
@define(slots=False)
class GSData:
"""A generic container for Global-Signal data.
Parameters
----------
data
The data array (i.e. what the telescope measures). This must be a 4D array whose
dimensions are (load, polarization, time, frequency). The data can be raw
powers, calibrated temperatures, or even model residuals to such. Their type is
specified by the ``data_unit`` attribute.
freqs
The frequency array. This must be a 1D array of frequencies specified as an
astropy Quantity.
times
The time array. This must be a 2D array of shape (times, loads). It can be in
one of two formats: either an astropy Time object, specifying the absolute time,
or an astropy Longitude object, specying the LSTs. In "lst" mode, there are
many methods that become unavailable.
telescope_location
The telescope location. This must be an astropy EarthLocation object.
loads
The names of the loads. Usually there is a single load ("ant"), but arbitrary
loads may be specified.
nsamples
An array with the same shape as the data array, specifying the number of samples
that go into each data point. This is unitless, and can be used with the
``effective_integration_time`` attribute to compute the total effective
integration time going into any measurement.
effective_integration_time
An astropy Quantity that specifies the amount of time going into a single
"sample" of the data. This can either be a scalar, or a 4D array with the same
shape as the data array. If it is a scalar, it is assumed to be the same for all
data points. The default value is the integration time of the telescope.
Note that this value is *only* meant to be used to track the expected noise
level in the data, in conjunction with nsamples. It is not checked for
whether the time_ranges match the integration time (since the effective time
can be smaller than the time range due to windowing, or because the time_range
includes multiple observations).
flags
A dictionary mapping filter names to boolean arrays. Each boolean array has the
same shape as the data array, and is True where the data is flagged.
history
A tuple of dictionaries, each of which is a record of a previous processing
step.
telescope_name
The name of the telescope.
residuals
An optional array of the same shape as data that holds the residuals of a model
fit to the data.
auxiliary_measurements
A dictionary mapping measurement names to arrays. Each array must have its
leading axis be the same length as the time array.
filename
The filename from which the data was read (if any). Used for writing additional
data if more is added (eg. flags, data model).
"""
telescope: Telescope = field(validator=vld.instance_of(Telescope))
data: np.ndarray = npfield(dtype=float, possible_ndims=(4,))
freqs: un.Quantity[un.MHz] = npfield(possible_ndims=(1,), unit=un.MHz)
times: Time = timefield(possible_ndims=(2,))
pols: tuple[str] = field(converter=_pol_converter)
_effective_integration_time: un.Quantity[un.s] = npfield(
possible_ndims=(0, 3), unit=un.s
)
nsamples: np.ndarray = npfield(dtype=float, possible_ndims=(4,))
loads: tuple[str] = field(converter=tuple)
flags: dict[str, GSFlag] = field(factory=dict)
history: History = field(
factory=History, validator=vld.instance_of(History), eq=False
)
residuals: np.ndarray | None = npfield(
default=None, possible_ndims=(4,), dtype=float
)
data_unit: Literal["power", "temperature", "uncalibrated", "uncalibrated_temp"] = (
field(default="power")
)
auxiliary_measurements: QTable | None = field(
default=None, converter=cnv.optional(QTable), eq=cmp_using(cmp_qtable)
)
time_ranges: Time = timefield(shape=(None, None, 2))
lsts: Longitude = lstfield(possible_ndims=(2,))
lst_ranges: Longitude = lstfield(possible_ndims=(3,))
filename: Path | None = field(default=None, converter=cnv.optional(Path), eq=False)
_file_appendable: bool = field(default=True, converter=bool)
name: str = field(default="", converter=str)
@nsamples.validator
def _nsamples_validator(self, attribute, value):
if value.shape != self.data.shape:
raise ValueError("nsamples must have the same shape as data")
@nsamples.default
def _nsamples_default(self) -> np.ndarray:
return np.ones_like(self.data)
@flags.validator
def _flags_validator(self, attribute, value):
if not isinstance(value, dict):
raise TypeError("flags must be a dict")
for key, flag in value.items():
if not isinstance(flag, GSFlag):
raise TypeError("flags values must be GSFlag instances")
flag._check_compat(self)
if not isinstance(key, str):
raise ValueError("flags keys must be strings")
@residuals.validator
def _residuals_validator(self, attribute, value):
if value is not None and value.shape != self.data.shape:
raise ValueError("residuals must have the same shape as data")
@freqs.validator
def _freqs_validator(self, attribute, value):
if value.shape != (self.nfreqs,):
raise ValueError(
"freqs must have the size nfreqs. "
f"Got {value.shape} instead of {self.nfreqs}"
)
@times.validator
def _times_validator(self, attribute, value):
if value.shape != (self.ntimes, self.nloads):
raise ValueError(
f"times must have the size (ntimes, nloads), got {value.shape} "
f"instead of {(self.ntimes, self.nloads)}"
)
@pols.default
def _pols_default(self) -> tuple[str]:
return self.telescope.pols
@time_ranges.default
def _time_ranges_default(self):
return time_concat(
(
self.times[:, :, None],
self.times[:, :, None] + self.telescope.integration_time,
),
axis=-1,
)
@time_ranges.validator
def _time_ranges_validator(self, attribute, value):
if value.shape != (self.ntimes, self.nloads, 2):
raise ValueError(
f"time_ranges must have the size (ntimes, nloads, 2), got {value.shape}"
f" instead of {(self.ntimes, self.nloads, 2)}."
)
if not np.all((value[..., 1] - value[..., 0]).value > 0):
# TODO: properly check lst-type input, which can wrap...
raise ValueError("time_ranges must all be greater than zero")
@loads.default
def _loads_default(self) -> tuple[str]:
if self.nloads == 1:
return ("ant",)
elif self.nloads == 3:
return ("ant", "internal_load", "internal_load_plus_noise_source")
else:
raise ValueError(
"If data has more than one source, loads must be specified"
)
@loads.validator
def _loads_validator(self, attribute, value):
if len(value) != self.data.shape[0]:
raise ValueError(
"loads must have the same length as the number of loads in data. Got "
f"{len(value)} and {self.data.shape[0]}"
)
if not all(isinstance(x, str) for x in value):
raise ValueError("loads must be a tuple of strings")
@auxiliary_measurements.validator
def _aux_meas_vld(self, attribute, value):
if value is None:
return
if len(value) != self.ntimes:
raise ValueError(
"auxiliary_measurements must be length ntimes."
f" Got {len(value)} instead of {self.ntimes}."
)
@_effective_integration_time.default
def _eff_int_time_default(self) -> un.Quantity[un.s]:
return self.telescope.integration_time * np.ones(
(self.nloads, self.npols, self.ntimes)
)
@_effective_integration_time.validator
def _eff_int_time_vld(self, attribute, value):
if np.any(value.value <= 0):
raise ValueError("effective_integration_time must be greater than zero")
if value.size != 1 and value.shape != (self.nloads, self.npols, self.ntimes):
raise ValueError(
"effective_integration_time must be a scalar or have shape "
f"(nloads, npols, ntimes), got {value.shape}"
)
@cached_property
def effective_integration_time(self) -> un.Quantity[un.s]:
"""The effective integration time."""
if self._effective_integration_time.size == 1:
return self._effective_integration_time * np.ones(self.data.shape[:-1])
return self._effective_integration_time
@data_unit.validator
def _data_unit_validator(self, attribute, value):
if value not in (
"power",
"temperature",
"uncalibrated",
"uncalibrated_temp",
):
raise ValueError(
'data_unit must be one of "power", "temperature", "uncalibrated",'
'"uncalibrated_temp"'
)
@property
def nfreqs(self) -> int:
"""The number of frequency channels."""
return self.data.shape[-1]
@property
def nloads(self) -> int:
"""The number of loads."""
return self.data.shape[0]
@property
def ntimes(self) -> int:
"""The number of times."""
return self.data.shape[-2]
@property
def npols(self) -> int:
"""The number of polarizations."""
return self.data.shape[1]
@property
def model(self) -> np.ndarray | None:
"""The model of the data."""
if self.residuals is None:
return None
return self.data - self.residuals
@lsts.default
def _lsts_default(self) -> Longitude:
return self.times.sidereal_time("apparent", self.telescope.location)
@lsts.validator
def _lsts_validator(self, attribute, value):
if value.shape != (self.ntimes, self.nloads):
raise ValueError(
f"lsts must have the size (ntimes, nloads), got {value.shape} "
f"instead of {(self.ntimes, self.nloads)}"
)
@lst_ranges.default
def _lst_ranges_default(self) -> Longitude:
return self.time_ranges.sidereal_time("apparent", self.telescope.location)
@lst_ranges.validator
def _lst_ranges_validator(self, attribute, value):
if value.shape != (self.ntimes, self.nloads, 2):
raise ValueError(
f"lst_ranges must have the size (ntimes, nloads, 2), got {value.shape} "
f"instead of {(self.ntimes, self.nloads, 2)}"
)
[docs]
@classmethod
def from_file(
cls,
filename: str | Path,
reader: str | None = None,
selectors: dict[str, Any] | None = None,
concat_axis: Literal["load", "pol", "time", "freq"] | None = None,
**kw,
) -> Self:
"""Create a GSData instance from a file.
This method attempts to auto-detect the file type and read it.
"""
from .readers import GSDATA_READERS
selectors = selectors or {}
def _from_file(pth, reader):
filename = Path(pth)
if reader is None:
reader = filename.suffix[1:]
fnc = next(
(k for k in GSDATA_READERS.values() if reader in k.suffices), None
)
if fnc is None:
raise ValueError(f"Unrecognized file type {reader}")
if fnc.select_on_read:
return fnc(filename, selectors=selectors, **kw)
from .select import select_freqs, select_loads, select_lsts, select_times
data = fnc(filename, **kw)
selectors_cp = deepcopy(selectors)
if "freq_selector" in selectors_cp:
data = select_freqs(data, **selectors_cp.pop("freq_selector"))
if "time_selector" in selectors_cp:
data = select_times(data, **selectors_cp.pop("time_selector"))
if "lst_selector" in selectors_cp:
data = select_lsts(data, **selectors_cp.pop("lst_selector"))
if "load_selector" in selectors_cp:
data = select_loads(data, **selectors_cp.pop("load_selector"))
if selectors_cp:
raise ValueError(
f"Unrecognized selectors: {selectors_cp.keys()}. Available "
"selectors: freq_selector, time_selector, lst_selector, "
"load_selector"
)
return data
filename = [filename] if isinstance(filename, str | Path) else filename
datas = [_from_file(pth, reader) for pth in filename]
if len(datas) == 1:
return datas[0]
from .concat import concat
return concat(datas, concat_axis)
[docs]
def write_gsh5(self, filename: str | Path, group: str = "/") -> Self:
"""Write the data in the GSData object to a GSH5 file."""
filename = Path(filename)
if filename.exists():
with h5py.File(filename, "r") as fl:
if group in fl:
raise ValueError(
f"group {group} in file {filename} already exists!"
)
mode = "a"
else:
mode = "w"
with h5py.File(filename, mode) as fl:
if group not in fl:
fl = fl.create_group(group)
# The GSH5 file version: <major>.<minor>. The minor version is incremented
# when the file format changes in a backwards-compatible way. The major
# version is incremented when the file format changes in a way
# that requires a new reader.
fl.attrs["version"] = "2.1"
meta = fl.create_group("metadata")
self.telescope.write(meta.create_group("telescope"))
meta["freqs"] = self.freqs.to_value("MHz")
meta["freqs"].attrs["unit"] = "MHz"
meta["effective_integration_time"] = (
self._effective_integration_time.to_value("s")
)
meta["times"] = self.times.jd
meta["time_ranges"] = self.time_ranges.jd
meta["lsts"] = self.lsts.hour
meta["lst_ranges"] = self.lst_ranges.hour
meta.attrs["data_unit"] = self.data_unit
meta["loads"] = self.loads
meta.attrs["history"] = repr(self.history)
meta.attrs["name"] = self.name
dgrp = fl.create_group("data")
dgrp["data"] = self.data
dgrp["nsamples"] = self.nsamples
flg_grp = dgrp.create_group("flags")
if self.flags:
flg_grp.attrs["names"] = tuple(self.flags.keys())
for name, flag in self.flags.items():
hickle.dump(flag, flg_grp.create_group(name))
# Data model
if self.residuals is not None:
dgrp["residuals"] = self.residuals
# Now aux measurements
aux_grp = fl.create_group("auxiliary_measurements")
if self.auxiliary_measurements is not None:
for name, meas in self.auxiliary_measurements.items():
aux_grp[name] = meas
return self.update(filename=filename)
[docs]
def update(self, **kwargs):
"""Return a new GSData object with updated attributes."""
# If the user passes a single dictionary as history, append it.
# Otherwise raise an error, unless it's not passed at all.
history = kwargs.pop("history", None)
if isinstance(history, Stamp):
history = self.history.add(history)
elif isinstance(history, dict):
history = self.history.add(Stamp(**history))
elif history is not None:
raise ValueError("History must be a Stamp object or dictionary")
else:
history = self.history
return evolve(self, history=history, **kwargs)
def __add__(self, other: Self) -> Self:
"""Add two GSData objects."""
if not isinstance(other, GSData):
raise TypeError("can only add GSData objects")
if self.data.shape != other.data.shape:
raise ValueError("Cannot add GSData objects with different shapes")
if not np.allclose(self.freqs, other.freqs):
raise ValueError("Cannot add GSData objects with different frequencies")
if self.auxiliary_measurements and not other.auxiliary_measurements:
aux = self.auxiliary_measurements
elif not self.auxiliary_measurements and other.auxiliary_measurements:
aux = other.auxiliary_measurements
elif self.auxiliary_measurements:
aux = dict(other.auxiliary_measurements.items())
aux.update(self.auxiliary_measurements)
aux = QTable(aux)
if any(
k in other.auxiliary_measurements.columns
for k in self.auxiliary_measurements.columns
) and not all(
np.allclose(
other.auxiliary_measurements[k], self.auxiliary_measurements[k]
)
for k in self.auxiliary_measurements.columns
):
warnings.warn(
"Overlapping auxiliary measurements exist between objects,"
" the ones in the first object will be retained.",
stacklevel=2,
)
else:
aux = None
if not np.allclose(self.times.jd, other.times.jd, rtol=0, atol=1e-8):
raise ValueError("Cannot add GSData objects with different times")
# If non of the above, then we have two GSData objects at the same times and
# frequencies. Adding them should just be a weighted sum.
nsamples = self.flagged_nsamples + other.flagged_nsamples
d1 = np.ma.masked_array(self.data, mask=self.complete_flags)
d2 = np.ma.masked_array(other.data, mask=other.complete_flags)
mean = self.flagged_nsamples * d1 + other.flagged_nsamples * d2
if self.residuals is not None and other.residuals is not None:
r1 = np.ma.masked_array(self.residuals, mask=self.complete_flags)
r2 = np.ma.masked_array(other.residuals, mask=other.complete_flags)
resids = (
self.flagged_nsamples * r1 + other.flagged_nsamples * r2
) / nsamples
else:
resids = None
total_flags = GSFlag(flags=self.complete_flags & other.complete_flags)
return self.update(
data=mean.data,
residuals=resids,
nsamples=nsamples,
flags={"summed_flags": total_flags},
auxiliary_measurements=aux,
)
@cached_property
def gha(self) -> np.ndarray:
"""The GHA's of the observations."""
return crd.lst2gha(self.lsts)
[docs]
def get_moon_azel(self) -> tuple[np.ndarray, np.ndarray]:
"""Get the Moon's azimuth and elevation for each time in deg."""
return crd.moon_azel(
self.times[:, self.loads.index("ant")], self.telescope.location
)
[docs]
def get_sun_azel(self) -> tuple[np.ndarray, np.ndarray]:
"""Get the Sun's azimuth and elevation for each time in deg."""
return crd.sun_azel(
self.times[:, self.loads.index("ant")], self.telescope.location
)
@property
def nflagging_ops(self) -> int:
"""Returns the number of flagging operations."""
return len(self.flags)
[docs]
def get_cumulative_flags(
self, which_flags: tuple[str] | None = None, ignore_flags: tuple[str] = ()
) -> np.ndarray:
"""Return accumulated flags."""
if which_flags is None:
which_flags = self.flags.keys()
elif not which_flags or not self.flags:
return np.zeros(self.data.shape, dtype=bool)
which_flags = tuple(s for s in which_flags if s not in ignore_flags)
if not which_flags:
return np.zeros(self.data.shape, dtype=bool)
flg = self.flags[which_flags[0]].full_rank_flags
for k in which_flags[1:]:
flg = flg | self.flags[k].full_rank_flags
# Get into full data-shape
if flg.shape != self.data.shape:
flg = flg | np.zeros(self.data.shape, dtype=bool)
return flg
@cached_property
def complete_flags(self) -> np.ndarray:
"""Returns the complete flag array."""
return self.get_cumulative_flags()
[docs]
def get_flagged_nsamples(
self,
which_flags: tuple[str, ...] | None = None,
ignore_flags: tuple[str, ...] = (),
) -> np.ndarray:
"""Get the nsamples of the data after accounting for flags."""
cumflags = self.get_cumulative_flags(which_flags, ignore_flags)
return self.nsamples * (~cumflags).astype(int)
@cached_property
def flagged_nsamples(self) -> np.ndarray:
"""Weights accounting for all flags."""
return self.get_flagged_nsamples()
[docs]
def get_initial_yearday(self, hours: bool = False, minutes: bool = False) -> str:
"""Return the year-day representation of the first time-sample in the data."""
if minutes and not hours:
raise ValueError("Cannot return minutes without hours")
subfmt = "date_hm" if hours else "date"
out = self.times[0, self.loads.index("ant")].to_value("yday", subfmt)
if hours and not minutes:
out = ":".join(out.split(":")[:-1])
return out
[docs]
def add_flags(
self,
filt: str,
flags: np.ndarray | GSFlag | Path,
append_to_file: bool = False,
):
"""Append a set of flags to the object and optionally append them to file.
You can always write out a *new* file, but appending flags is non-destructive,
and so we allow it to be appended, in order to save disk space and I/O.
"""
if isinstance(flags, np.ndarray):
flags = GSFlag(flags=flags, axes=("load", "pol", "time", "freq"))
elif isinstance(flags, str | Path):
flags = GSFlag.from_file(flags)
flags._check_compat(self)
if filt in self.flags:
raise ValueError(f"Flags for filter '{filt}' already exist")
return self.update(flags={**self.flags, filt: flags})
[docs]
def remove_flags(self, filt: str) -> Self:
"""Remove flags for a given filter."""
if filt not in self.flags:
raise ValueError(f"No flags for filter '{filt}'")
return self.update(flags={k: v for k, v in self.flags.items() if k != filt})
[docs]
def time_iter(self) -> Iterable[tuple[slice, slice, slice]]:
"""Return an iterator over the time axis of data-shape arrays."""
for i in range(self.ntimes):
yield (slice(None), slice(None), i, slice(None))
[docs]
def load_iter(self) -> Iterable[tuple[int]]:
"""Return an iterator over the load axis of data-shape arrays."""
for i in range(self.nloads):
yield (i,)
[docs]
def freq_iter(self) -> Iterable[tuple[slice, slice, slice]]:
"""Return an iterator over the frequency axis of data-shape arrays."""
for i in range(self.nfreqs):
yield (slice(None), slice(None), slice(None), i)