"""An object to hold flag information."""
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path
from typing import Protocol
import hickle
import numpy as np
from attrs import converters as cnv
from attrs import define, evolve, field
from attrs import validators as vld
from hickleable import hickleable
from .attrs import npfield
from .history import History, Stamp
try:
from typing import Self
except ImportError:
from typing import Self
class _GSDataSized(Protocol):
nloads: int | None
ntimes: int | None
npols: int | None
nfreqs: int | None
[docs]
@hickleable()
@define(slots=False)
class GSFlag:
"""A generic container for Global-Signal flags data.
Parameters
----------
flags
The flags as a boolean array. The array may have up to 4 dimensions -- load,
pol, time, and freq -- but need not have all of the dimensions.
axes
A tuple of strings specifying the axes of the data array. The possible axes are
"load", "pol", "time", and "freq". They must be in that order, but not all
must be present, only as many as flags has dimensions.
history
A tuple of dictionaries, each of which is a record of a previous processing
step.
filename
The filename from which the data was read (if any). Used for writing additional
data if more is added (eg. flags, data model).
"""
_axes = ("load", "pol", "time", "freq")
flags: np.ndarray = npfield(dtype=bool)
axes: tuple[str] = field(converter=tuple)
history: History = field(
factory=History, validator=vld.instance_of(History), eq=False
)
filename: Path | None = field(default=None, converter=cnv.optional(Path))
@flags.validator
def _flags_vld(self, _, value):
if value.ndim > 4:
raise ValueError("Flag array must have at most 4 dimensions")
@axes.validator
def _axes_vld(self, _, value):
if len(set(value)) != len(value):
raise ValueError(f"Axes must be unique, got {value}")
if len(value) != self.flags.ndim:
raise ValueError(
f"Number of axes must match number of dimensions in flags. "
f"Got {len(value)} axes and {self.flags.ndim} dimensions"
)
if any(ax not in self._axes for ax in value):
raise ValueError("Axes must be a subset of load, pol, time, freq")
idx = [value.index(ax) for ax in self._axes if ax in value]
if idx != sorted(idx):
raise ValueError(f"Axes must be in order {self._axes}")
@axes.default
def _axes_default(self):
if self.flags.ndim == 4:
return self._axes
else:
raise ValueError(
"Axes must be specified if flag array has fewer than 4 dims"
)
@cached_property
def nfreqs(self) -> int | None:
"""The number of frequency channels."""
if "freq" not in self.axes:
return None
return self.flags.shape[self.axes.index("freq")]
@cached_property
def nloads(self) -> int | None:
"""The number of loads."""
return None if "load" not in self.axes else self.flags.shape[0]
@property
def ntimes(self) -> int | None:
"""The number of times."""
if "time" not in self.axes:
return None
return self.flags.shape[self.axes.index("time")]
@property
def npols(self) -> int | None:
"""The number of polarizations."""
return self.flags.shape[self.axes.index("pol")] if "pol" in self.axes else None
[docs]
@classmethod
def from_file(cls, filename: str | Path, filetype: str | None = None, **kw) -> Self:
"""Create a GSFlag instance from a file.
This method attempts to auto-detect the file type and read it.
"""
filename = Path(filename)
if filetype is None:
filetype = filename.suffix[1:] # Remove the leading dot
if filetype.lower() == "gsflag":
return cls.read_gsflag(filename)
else:
raise ValueError(f"Unrecognized file type: {filetype}")
[docs]
@classmethod
def read_gsflag(cls, filename: str) -> Self:
"""Read a GSFlag file to create the object."""
obj = hickle.load(filename)
return obj.update(
history=Stamp("Read GSFlag file", parameters={"filename": filename})
)
[docs]
def write_gsflag(self, filename: str) -> Self:
"""Write the data in the GSData object to a GSH5 file."""
new = self.update(
history=Stamp("Wrote GSFlag file", parameters={"filename": filename})
)
hickle.dump(new, filename, mode="w")
return new.update(filename=filename)
[docs]
def update(self, **kwargs) -> Self:
"""Return a new GSFlag object with updated attributes."""
# If the user passes a single dictionary as history, append it.
# Otherwise raise an error, unless it's not passed at all.
history = kwargs.pop("history", None)
history = self.history if history is None else self.history.add(history)
return evolve(self, history=history, **kwargs)
@property
def full_rank_flags(self) -> np.ndarray:
"""Return a full-rank flag array."""
flg = self.flags.copy()
if "load" not in self.axes:
flg = np.expand_dims(flg, axis=0)
if "pol" not in self.axes:
flg = np.expand_dims(flg, axis=1)
if "time" not in self.axes:
flg = np.expand_dims(flg, axis=2)
if "freq" not in self.axes:
flg = np.expand_dims(flg, axis=3)
return flg
def _check_compat(self, other: _GSDataSized) -> None:
if (
self.nloads is not None
and other.nloads is not None
and self.nloads != other.nloads
):
raise ValueError(
"Objects have different nloads. Got "
f"this={self.nloads} and that={other.nloads}."
)
if (
self.npols is not None
and other.npols is not None
and self.npols != other.npols
):
raise ValueError(
"Objects have different npols. Got "
f"this={self.npols} and that={other.npols}"
)
if (
self.ntimes is not None
and other.ntimes is not None
and self.ntimes != other.ntimes
):
raise ValueError(
"Objects have different ntimes. Got "
f"this={self.ntimes} and that={other.ntimes}"
)
if (
self.nfreqs is not None
and other.nfreqs is not None
and self.nfreqs != other.nfreqs
):
raise ValueError(
"Objects have different nfreqs. Got "
f"this={self.nfreqs} and that={other.nfreqs}"
)
def __or__(self, other: Self) -> Self:
"""Take the product of two GSFlag objects and return a new one."""
if not isinstance(other, GSFlag):
raise TypeError("can only 'or' GSFlag objects")
self._check_compat(other)
new_flags = np.squeeze(self.full_rank_flags | other.full_rank_flags)
axes = tuple(ax for ax in self._axes if ax in self.axes + other.axes)
return self.update(
flags=new_flags,
axes=axes,
history=self.history.add(other.history).add(
Stamp("Multiplied GSFlag objects")
),
filename=None,
)
def __and__(self, other: Self) -> Self:
"""Take the product of two GSFlag objects and return a new one."""
if not isinstance(other, GSFlag):
raise TypeError("can only 'and' GSFlag objects")
self._check_compat(other)
new_flags = np.squeeze(self.full_rank_flags & other.full_rank_flags)
axes = tuple(ax for ax in self._axes if ax in self.axes + other.axes)
return self.update(
flags=new_flags,
axes=axes,
history=self.history.add(other.history).add(
Stamp("Multiplied GSFlag objects")
),
filename=None,
)
[docs]
def select(self, idx: np.ndarray | slice, axis: str, squeeze: bool = False) -> Self:
"""Select a subset of the data along the given axis."""
if axis not in self._axes:
raise ValueError(f"Axis {axis} not recognized")
if isinstance(idx, slice):
idx = np.arange(*idx.indices(self.flags.shape[self.axes.index(axis)]))
elif idx.dtype == bool:
idx = np.where(idx)[0]
# Do nothing if the axis is not present
if axis not in self.axes:
return self
new_flags = self.flags.copy()
new_flags = np.take(new_flags, idx, axis=self.axes.index(axis))
if squeeze:
axes = tuple(
ax
for i, ax in enumerate(self.axes)
if ax != axis or new_flags.shape[i] > 1
)
new_flags = np.squeeze(new_flags)
else:
axes = self.axes
history = self.history.add(
Stamp("Selected subset of data", parameters={"axis": axis, "idx": idx})
)
return self.update(
flags=new_flags,
axes=axes,
history=history,
filename=None,
)
[docs]
def op_on_axis(self, op, axis: str) -> Self:
"""Apply an operation along the given axis."""
if axis not in ("load", "pol", "time", "freq"):
raise ValueError(f"Axis {axis} not recognized")
# Do nothing if the axis is not present
if axis not in self.axes:
return self
new_flags = self.flags.copy()
new_flags = op(new_flags, axis=self.axes.index(axis))
axes = tuple(ax for ax in self.axes if ax != axis)
return self.update(
flags=new_flags,
axes=axes,
history=self.history.add(
Stamp("Applied operation to data", parameters={"axis": axis, "op": op})
),
filename=None,
)
[docs]
def any(self, axis: str | None = None) -> bool | Self:
"""Return True if any of the flags are True."""
return self.flags.any() if axis is None else self.op_on_axis(np.any, axis)
[docs]
def all(self, axis: str | None = None) -> bool | Self:
"""Return True if any of the flags are True."""
return self.flags.all() if axis is None else self.op_on_axis(np.all, axis)
[docs]
def concat(self, others: Self | Sequence[Self], axis: str) -> Self:
"""Get a new GSFlag by concatenating other flags to this one."""
if not hasattr(others, "__len__"):
others = [others]
if not all(isinstance(o, GSFlag) for o in others):
raise TypeError("can only concatenate GSFlag objects")
if axis not in ("load", "pol", "time", "freq"):
raise ValueError(f"Axis {axis} not recognized")
if axis not in self.axes:
raise ValueError(f"Axis {axis} not present in this GSFlag object")
new_flags = np.concatenate(
[self.flags] + [o.flags for o in others], axis=self.axes.index(axis)
)
return self.update(
flags=new_flags,
history=self.history.add(
Stamp("Concatenated GSFlag objects", parameters={"axis": axis})
),
filename=None,
)