Source code for pygsdata.register

"""Register functions as processors for GSData objects."""

import contextlib
import functools
import inspect
from collections.abc import Callable, Sequence
from typing import Literal, get_args, get_origin

import attrs

from .gsdata import GSData
from .gsflag import GSFlag
from .history import Stamp

GSDATA_PROCESSORS = {}

RegKind = Literal["gather", "calibrate", "filter", "reduce", "supplement"]


def _register(func: callable, kind: RegKind) -> callable:
    sig = inspect.signature(func)
    first_param = next(iter(sig.parameters.keys()))

    annotation = sig.parameters[first_param].annotation
    # Handle string annotations (forward references)
    if isinstance(annotation, str):
        with contextlib.suppress(Exception):
            annotation = eval(annotation, func.__globals__)
    allowed = False
    if annotation is GSData:
        allowed = True
    elif get_origin(annotation) in (list, Sequence):
        args = get_args(annotation)
        if args and args[0] is GSData:
            allowed = True
    if not allowed:
        raise TypeError(
            f"{func.__name__} must accept a GSData object or "
            "Sequence[GSData] as the first argument"
        )

    @functools.wraps(func)
    def wrapper(data: GSData, *args, message: str = "", **kw) -> GSData | list[GSData]:
        newdata = func(data, *args, **kw)

        history = Stamp(
            message=message,
            function=func.__name__,
            parameters=kw,
        )

        kw = {"history": history}

        if kind not in ("supplement", "filter"):
            # Any function that is not a supplement or filter is CHANGING data,
            # and should no longer be associated with the original file, in the sense
            # that new flags and data models should not be added to the file.
            kw["file_appendable"] = False

        if isinstance(newdata, GSData):
            return newdata.update(**kw)
        try:
            return [nd.update(**kw) for nd in newdata]
        except Exception as e:
            raise TypeError(
                f"{func.__name__} returned {type(newdata)} "
                f"instead of GSData or list thereof."
            ) from e

    GSDATA_PROCESSORS[func.__name__] = wrapper
    return wrapper


[docs] @attrs.define() class gsregister: # noqa: N801 """Decorator to register a function as a processor for GSData objects.""" kind: RegKind = attrs.field( validator=attrs.validators.in_( ["gather", "calibrate", "filter", "reduce", "supplement"] ) ) def __call__(self, func: Callable) -> Callable: """Register a function as a processor for GSData objects.""" return _register(func, self.kind)
# Some simple registered functions @gsregister("supplement") def add_flags(data: GSData, filt: str, flags: GSFlag) -> GSData: """Add flags to a GSData object.""" return data.add_flags(filt, flags)