Skip to content

Analysis Tools

Utilities

toast.ops.Delete

Bases: Operator

Class to purge data from observations.

This operator takes lists of shared, detdata, intervals and meta keys to delete from observations.

Source code in toast/ops/delete.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@trait_docs
class Delete(Operator):
    """Class to purge data from observations.

    This operator takes lists of shared, detdata, intervals and meta keys to delete from
    observations.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    meta = List([], help="List of Observation dictionary keys to delete")

    detdata = List([], help="List of Observation detdata keys to delete")

    shared = List([], help="List of Observation shared keys to delete")

    intervals = List(
        [],
        help="List of tuples of Observation intervals keys to delete",
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()
        for ob in data.obs:
            for key in self.detdata:
                # This ignores non-existant keys
                del ob.detdata[key]
            for key in self.shared:
                # This ignores non-existant keys
                del ob.shared[key]
            for key in self.intervals:
                # This ignores non-existant keys
                del ob.intervals[key]
            for key in self.meta:
                try:
                    del ob[key]
                except KeyError:
                    pass
        return

    def _finalize(self, data, **kwargs):
        return None

    def _requires(self):
        # Although we could require nothing, since we are deleting keys only if they
        # exist, providing these as requirements allows us to catch dependency issues
        # in pipelines.
        req = dict()
        if self.meta is not None:
            req["meta"] = list(self.meta)
        if self.detdata is not None:
            req["detdata"] = list(self.detdata)
        if self.shared is not None:
            req["shared"] = list(self.shared)
        if self.intervals is not None:
            req["intervals"] = list(self.intervals)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

detdata = List([], help='List of Observation detdata keys to delete') class-attribute instance-attribute

intervals = List([], help='List of tuples of Observation intervals keys to delete') class-attribute instance-attribute

meta = List([], help='List of Observation dictionary keys to delete') class-attribute instance-attribute

shared = List([], help='List of Observation shared keys to delete') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/delete.py
37
38
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/delete.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()
    for ob in data.obs:
        for key in self.detdata:
            # This ignores non-existant keys
            del ob.detdata[key]
        for key in self.shared:
            # This ignores non-existant keys
            del ob.shared[key]
        for key in self.intervals:
            # This ignores non-existant keys
            del ob.intervals[key]
        for key in self.meta:
            try:
                del ob[key]
            except KeyError:
                pass
    return

_finalize(data, **kwargs)

Source code in toast/ops/delete.py
60
61
def _finalize(self, data, **kwargs):
    return None

_provides()

Source code in toast/ops/delete.py
78
79
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/delete.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _requires(self):
    # Although we could require nothing, since we are deleting keys only if they
    # exist, providing these as requirements allows us to catch dependency issues
    # in pipelines.
    req = dict()
    if self.meta is not None:
        req["meta"] = list(self.meta)
    if self.detdata is not None:
        req["detdata"] = list(self.detdata)
    if self.shared is not None:
        req["shared"] = list(self.shared)
    if self.intervals is not None:
        req["intervals"] = list(self.intervals)
    return req

toast.ops.Reset

Bases: Operator

Class to reset data from observations.

This operator takes lists of shared, detdata, intervals, and meta keys to reset. Numerical data objects and arrays are set to zero. String objects are set to an empty string. Any object that defines a clear() method will have that called. Any object not matching those criteria will be set to None. Since an IntervalList is not mutable, any specified intervals will simply be deleted.

Source code in toast/ops/reset.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@trait_docs
class Reset(Operator):
    """Class to reset data from observations.

    This operator takes lists of shared, detdata, intervals, and meta keys to reset.
    Numerical data objects and arrays are set to zero.  String objects are set to an
    empty string.  Any object that defines a `clear()` method will have that called.
    Any object not matching those criteria will be set to None.  Since an IntervalList
    is not mutable, any specified intervals will simply be deleted.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    meta = List([], help="List of Observation dictionary keys to reset")

    detdata = List([], help="List of Observation detdata keys to reset")

    shared = List([], help="List of Observation shared keys to reset")

    intervals = List(
        [],
        help="List of tuples of Observation intervals keys to reset",
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()
        for ob in data.obs:
            if len(self.detdata) > 0:
                # Get the detectors we are using for this observation
                dets = ob.select_local_detectors(detectors)
                if len(dets) == 0:
                    # Nothing to do for this observation
                    continue
                for key in self.detdata:
                    for d in dets:
                        ob.detdata[key][d, :] = 0
            for key in self.shared:
                scomm = ob.shared[key].nodecomm
                if scomm is None:
                    # No MPI, just set to zero
                    ob.shared[key].data[:] = 0
                else:
                    # Only rank zero on each node resets
                    if scomm.rank == 0:
                        ob.shared[key]._flat[:] = 0
                    scomm.barrier()
            for key in self.intervals:
                # This ignores non-existant keys
                del ob.intervals[key]
            for key in self.meta:
                if isinstance(ob[key], np.ndarray):
                    # This is an array, set to zero
                    ob[key][:] = 0
                elif hasattr(ob[key], "clear"):
                    # This is some kind of container (list, dict, etc).  Clear it.
                    ob[key].clear()
                elif isinstance(ob[key], bool):
                    # Boolean scalar, set to False
                    ob[key] = False
                elif isinstance(ob[key], numbers.Number):
                    # This is a scalar numeric value
                    ob[key] = 0
                elif isinstance(ob[key], (str, bytes)):
                    # This is string-like
                    ob[key] = ""
                else:
                    # This is something else.  Set to None
                    ob[key] = None
        return

    def _finalize(self, data, **kwargs):
        return None

    def _requires(self):
        req = dict()
        if self.meta is not None:
            req["meta"] = list(self.meta)
        if self.detdata is not None:
            req["detdata"] = list(self.detdata)
        if self.shared is not None:
            req["shared"] = list(self.shared)
        if self.intervals is not None:
            req["intervals"] = list(self.intervals)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

detdata = List([], help='List of Observation detdata keys to reset') class-attribute instance-attribute

intervals = List([], help='List of tuples of Observation intervals keys to reset') class-attribute instance-attribute

meta = List([], help='List of Observation dictionary keys to reset') class-attribute instance-attribute

shared = List([], help='List of Observation shared keys to reset') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/reset.py
43
44
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/reset.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()
    for ob in data.obs:
        if len(self.detdata) > 0:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(detectors)
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            for key in self.detdata:
                for d in dets:
                    ob.detdata[key][d, :] = 0
        for key in self.shared:
            scomm = ob.shared[key].nodecomm
            if scomm is None:
                # No MPI, just set to zero
                ob.shared[key].data[:] = 0
            else:
                # Only rank zero on each node resets
                if scomm.rank == 0:
                    ob.shared[key]._flat[:] = 0
                scomm.barrier()
        for key in self.intervals:
            # This ignores non-existant keys
            del ob.intervals[key]
        for key in self.meta:
            if isinstance(ob[key], np.ndarray):
                # This is an array, set to zero
                ob[key][:] = 0
            elif hasattr(ob[key], "clear"):
                # This is some kind of container (list, dict, etc).  Clear it.
                ob[key].clear()
            elif isinstance(ob[key], bool):
                # Boolean scalar, set to False
                ob[key] = False
            elif isinstance(ob[key], numbers.Number):
                # This is a scalar numeric value
                ob[key] = 0
            elif isinstance(ob[key], (str, bytes)):
                # This is string-like
                ob[key] = ""
            else:
                # This is something else.  Set to None
                ob[key] = None
    return

_finalize(data, **kwargs)

Source code in toast/ops/reset.py
93
94
def _finalize(self, data, **kwargs):
    return None

_provides()

Source code in toast/ops/reset.py
108
109
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/reset.py
 96
 97
 98
 99
100
101
102
103
104
105
106
def _requires(self):
    req = dict()
    if self.meta is not None:
        req["meta"] = list(self.meta)
    if self.detdata is not None:
        req["detdata"] = list(self.detdata)
    if self.shared is not None:
        req["shared"] = list(self.shared)
    if self.intervals is not None:
        req["intervals"] = list(self.intervals)
    return req

toast.ops.Copy

Bases: Operator

Class to copy data.

This operator takes lists of shared, detdata, and meta keys to copy to a new location in each observation.

Each list contains tuples specifying the input and output key names.

Source code in toast/ops/copy.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@trait_docs
class Copy(Operator):
    """Class to copy data.

    This operator takes lists of shared, detdata, and meta keys to copy to a new
    location in each observation.

    Each list contains tuples specifying the input and output key names.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    meta = List([], help="List of tuples of Observation meta keys to copy")

    detdata = List([], help="List of tuples of Observation detdata keys to copy")

    shared = List([], help="List of tuples of Observation shared keys to copy")

    intervals = List(
        [],
        help="List of tuples of Observation intervals keys to copy",
    )

    @traitlets.validate("meta")
    def _check_meta(self, proposal):
        val = proposal["value"]
        for v in val:
            if not isinstance(v, (tuple, list)):
                raise traitlets.TraitError("trait should be a list of tuples")
            if len(v) != 2:
                raise traitlets.TraitError("key tuples should have 2 values")
            if not isinstance(v[0], str) or not isinstance(v[1], str):
                raise traitlets.TraitError("key tuples should have string values")
        return val

    @traitlets.validate("detdata")
    def _check_detdata(self, proposal):
        val = proposal["value"]
        for v in val:
            if not isinstance(v, (tuple, list)):
                raise traitlets.TraitError("trait should be a list of tuples")
            if len(v) != 2:
                raise traitlets.TraitError("key tuples should have 2 values")
            if not isinstance(v[0], str) or not isinstance(v[1], str):
                raise traitlets.TraitError("key tuples should have string values")
        return val

    @traitlets.validate("shared")
    def _check_shared(self, proposal):
        val = proposal["value"]
        for v in val:
            if not isinstance(v, (tuple, list)):
                raise traitlets.TraitError("trait should be a list of tuples")
            if len(v) != 2:
                raise traitlets.TraitError("key tuples should have 2 values")
            if not isinstance(v[0], str) or not isinstance(v[1], str):
                raise traitlets.TraitError("key tuples should have string values")
        return val

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()
        for ob in data.obs:
            for in_key, out_key in self.meta:
                if out_key in ob:
                    # The key exists- issue a warning before overwriting.
                    msg = "Observation key {} already exists- overwriting".format(
                        out_key
                    )
                    log.warning(msg)
                ob[out_key] = ob[in_key]

            for in_key, out_key in self.shared:
                # Although this is an internal function, the input arguments come
                # from existing shared objects and so should already be valid.
                ob.shared.assign_mpishared(
                    out_key, ob.shared[in_key], ob.shared.comm_type(in_key)
                )

            if len(self.detdata) > 0:
                # Get the detectors we are using for this observation.
                # We copy the full set of detectors, even if they are flagged.
                dets = ob.select_local_detectors(
                    detectors,
                    flagmask=0,
                )
                if len(dets) == 0:
                    # Nothing to do for this observation
                    continue
                for in_key, out_key in self.detdata:
                    if in_key not in ob.detdata:
                        continue
                    if out_key in ob.detdata:
                        # The key exists- verify that dimensions / dtype match
                        in_dtype = ob.detdata[in_key].dtype
                        out_dtype = ob.detdata[out_key].dtype
                        if out_dtype != in_dtype:
                            msg = f"Cannot copy to existing detdata key {out_key}"
                            msg += f" with different dtype ({out_dtype}) != {in_dtype}"
                            log.error(msg)
                            raise RuntimeError(msg)
                        in_shape = ob.detdata[in_key].detector_shape
                        out_shape = ob.detdata[out_key].detector_shape
                        if out_shape != in_shape:
                            msg = f"Cannot copy to existing detdata key {out_key}"
                            msg += f" with different detector shape ({out_shape})"
                            msg += f" != {in_shape}"
                            log.error(msg)
                            raise RuntimeError(msg)
                        if ob.detdata[out_key].detectors != dets:
                            # The output has a different set of detectors.  Reallocate.
                            ob.detdata[out_key].change_detectors(dets)
                        # Copy units
                        ob.detdata[out_key].update_units(ob.detdata[in_key].units)
                    else:
                        sample_shape = None
                        shp = ob.detdata[in_key].detector_shape
                        if len(shp) > 1:
                            sample_shape = shp[1:]
                        ob.detdata.create(
                            out_key,
                            sample_shape=sample_shape,
                            dtype=ob.detdata[in_key].dtype,
                            detectors=dets,
                            units=ob.detdata[in_key].units,
                        )
                    # Copy detector data
                    for d in dets:
                        ob.detdata[out_key][d, :] = ob.detdata[in_key][d, :]
        return

    def _finalize(self, data, **kwargs):
        return None

    def _requires(self):
        req = dict()
        if self.meta is not None:
            req["meta"] = [x[0] for x in self.meta]
        if self.detdata is not None:
            req["detdata"] = [x[0] for x in self.detdata]
        if self.shared is not None:
            req["shared"] = [x[0] for x in self.shared]
        if self.intervals is not None:
            req["intervals"] = [x[0] for x in self.intervals]
        return req

    def _provides(self):
        prov = dict()
        if self.meta is not None:
            prov["meta"] = [x[1] for x in self.meta]
        if self.detdata is not None:
            prov["detdata"] = [x[1] for x in self.detdata]
        if self.shared is not None:
            prov["shared"] = [x[1] for x in self.shared]
        if self.intervals is not None:
            prov["intervals"] = [x[1] for x in self.intervals]
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

detdata = List([], help='List of tuples of Observation detdata keys to copy') class-attribute instance-attribute

intervals = List([], help='List of tuples of Observation intervals keys to copy') class-attribute instance-attribute

meta = List([], help='List of tuples of Observation meta keys to copy') class-attribute instance-attribute

shared = List([], help='List of tuples of Observation shared keys to copy') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/copy.py
76
77
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_detdata(proposal)

Source code in toast/ops/copy.py
52
53
54
55
56
57
58
59
60
61
62
@traitlets.validate("detdata")
def _check_detdata(self, proposal):
    val = proposal["value"]
    for v in val:
        if not isinstance(v, (tuple, list)):
            raise traitlets.TraitError("trait should be a list of tuples")
        if len(v) != 2:
            raise traitlets.TraitError("key tuples should have 2 values")
        if not isinstance(v[0], str) or not isinstance(v[1], str):
            raise traitlets.TraitError("key tuples should have string values")
    return val

_check_meta(proposal)

Source code in toast/ops/copy.py
40
41
42
43
44
45
46
47
48
49
50
@traitlets.validate("meta")
def _check_meta(self, proposal):
    val = proposal["value"]
    for v in val:
        if not isinstance(v, (tuple, list)):
            raise traitlets.TraitError("trait should be a list of tuples")
        if len(v) != 2:
            raise traitlets.TraitError("key tuples should have 2 values")
        if not isinstance(v[0], str) or not isinstance(v[1], str):
            raise traitlets.TraitError("key tuples should have string values")
    return val

_check_shared(proposal)

Source code in toast/ops/copy.py
64
65
66
67
68
69
70
71
72
73
74
@traitlets.validate("shared")
def _check_shared(self, proposal):
    val = proposal["value"]
    for v in val:
        if not isinstance(v, (tuple, list)):
            raise traitlets.TraitError("trait should be a list of tuples")
        if len(v) != 2:
            raise traitlets.TraitError("key tuples should have 2 values")
        if not isinstance(v[0], str) or not isinstance(v[1], str):
            raise traitlets.TraitError("key tuples should have string values")
    return val

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/copy.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()
    for ob in data.obs:
        for in_key, out_key in self.meta:
            if out_key in ob:
                # The key exists- issue a warning before overwriting.
                msg = "Observation key {} already exists- overwriting".format(
                    out_key
                )
                log.warning(msg)
            ob[out_key] = ob[in_key]

        for in_key, out_key in self.shared:
            # Although this is an internal function, the input arguments come
            # from existing shared objects and so should already be valid.
            ob.shared.assign_mpishared(
                out_key, ob.shared[in_key], ob.shared.comm_type(in_key)
            )

        if len(self.detdata) > 0:
            # Get the detectors we are using for this observation.
            # We copy the full set of detectors, even if they are flagged.
            dets = ob.select_local_detectors(
                detectors,
                flagmask=0,
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            for in_key, out_key in self.detdata:
                if in_key not in ob.detdata:
                    continue
                if out_key in ob.detdata:
                    # The key exists- verify that dimensions / dtype match
                    in_dtype = ob.detdata[in_key].dtype
                    out_dtype = ob.detdata[out_key].dtype
                    if out_dtype != in_dtype:
                        msg = f"Cannot copy to existing detdata key {out_key}"
                        msg += f" with different dtype ({out_dtype}) != {in_dtype}"
                        log.error(msg)
                        raise RuntimeError(msg)
                    in_shape = ob.detdata[in_key].detector_shape
                    out_shape = ob.detdata[out_key].detector_shape
                    if out_shape != in_shape:
                        msg = f"Cannot copy to existing detdata key {out_key}"
                        msg += f" with different detector shape ({out_shape})"
                        msg += f" != {in_shape}"
                        log.error(msg)
                        raise RuntimeError(msg)
                    if ob.detdata[out_key].detectors != dets:
                        # The output has a different set of detectors.  Reallocate.
                        ob.detdata[out_key].change_detectors(dets)
                    # Copy units
                    ob.detdata[out_key].update_units(ob.detdata[in_key].units)
                else:
                    sample_shape = None
                    shp = ob.detdata[in_key].detector_shape
                    if len(shp) > 1:
                        sample_shape = shp[1:]
                    ob.detdata.create(
                        out_key,
                        sample_shape=sample_shape,
                        dtype=ob.detdata[in_key].dtype,
                        detectors=dets,
                        units=ob.detdata[in_key].units,
                    )
                # Copy detector data
                for d in dets:
                    ob.detdata[out_key][d, :] = ob.detdata[in_key][d, :]
    return

_finalize(data, **kwargs)

Source code in toast/ops/copy.py
151
152
def _finalize(self, data, **kwargs):
    return None

_provides()

Source code in toast/ops/copy.py
166
167
168
169
170
171
172
173
174
175
176
def _provides(self):
    prov = dict()
    if self.meta is not None:
        prov["meta"] = [x[1] for x in self.meta]
    if self.detdata is not None:
        prov["detdata"] = [x[1] for x in self.detdata]
    if self.shared is not None:
        prov["shared"] = [x[1] for x in self.shared]
    if self.intervals is not None:
        prov["intervals"] = [x[1] for x in self.intervals]
    return prov

_requires()

Source code in toast/ops/copy.py
154
155
156
157
158
159
160
161
162
163
164
def _requires(self):
    req = dict()
    if self.meta is not None:
        req["meta"] = [x[0] for x in self.meta]
    if self.detdata is not None:
        req["detdata"] = [x[0] for x in self.detdata]
    if self.shared is not None:
        req["shared"] = [x[0] for x in self.shared]
    if self.intervals is not None:
        req["intervals"] = [x[0] for x in self.intervals]
    return req

toast.ops.Combine

Bases: Operator

Arithmetic with detector data.

Two detdata objects are combined element-wise using addition, subtraction, multiplication, or division. The desired operation is specified by the "op" trait as a string. The result is stored in the specified detdata object:

result = first (op) second

If the result name is the same as the first or second input, then this input will be overwritten.

Source code in toast/ops/arithmetic.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@trait_docs
class Combine(Operator):
    """Arithmetic with detector data.

    Two detdata objects are combined element-wise using addition, subtraction,
    multiplication, or division.  The desired operation is specified by the "op"
    trait as a string.  The result is stored in the specified detdata object:

    result = first (op) second

    If the result name is the same as the first or second input, then this
    input will be overwritten.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    op = Unicode(
        None,
        allow_none=True,
        help="Operation on the timestreams: 'subtract', 'add', 'multiply', or 'divide'",
    )

    first = Unicode(None, allow_none=True, help="The first detdata object")

    second = Unicode(None, allow_none=True, help="The second detdata object")

    result = Unicode(None, allow_none=True, help="The resulting detdata object")

    @traitlets.validate("op")
    def _check_op(self, proposal):
        val = proposal["value"]
        if val is not None:
            if val not in ["add", "subtract", "multiply", "divide"]:
                raise traitlets.TraitError("op must be one of the 4 allowed strings")
        return val

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for check_name, check_val in [
            ("first", self.first),
            ("second", self.second),
            ("result", self.result),
            ("op", self.op),
        ]:
            if check_val is None:
                msg = f"The {check_name} trait must be set before calling exec"
                log.error(msg)
                raise RuntimeError(msg)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            local_dets = ob.select_local_detectors(detectors)
            if len(local_dets) == 0:
                # Nothing to do for this observation
                continue
            if self.first not in ob.detdata:
                msg = f"The first detdata key '{self.first}' does not exist in"
                msg += f" observation {ob.name}, skipping"
                log.verbose(msg)
                continue
            if self.second not in ob.detdata:
                msg = f"The second detdata key '{self.first}' does not exist in"
                msg += f" observation {ob.name}, skipping"
                log.verbose(msg)
                continue

            first_units = ob.detdata[self.first].units
            second_units = ob.detdata[self.second].units

            # Operate on the intersection of detectors
            dets = list(
                sorted(
                    set.intersection(
                        set(ob.detdata[self.first].detectors),
                        set(ob.detdata[self.second].detectors),
                    )
                )
            )

            if self.result == self.first:
                result_units = first_units
                scale_first = 1.0
                scale_second = unit_conversion(second_units, result_units)
            elif self.result == self.second:
                result_units = second_units
                scale_first = unit_conversion(first_units, result_units)
                scale_second = 1.0
            else:
                # We are creating a new field for the output.  Use units of first field.
                result_units = first_units
                scale_first = 1.0
                scale_second = unit_conversion(second_units, result_units)
                exists = ob.detdata.ensure(
                    self.result,
                    sample_shape=ob.detdata[self.first].detector_shape[1:],
                    dtype=ob.detdata[self.first].dtype,
                    detectors=ob.detdata[self.first].detectors,
                    create_units=result_units,
                )
            if self.op == "add":
                for d in dets:
                    ob.detdata[self.result][d, :] = (
                        scale_first * ob.detdata[self.first][d, :]
                    ) + (scale_second * ob.detdata[self.second][d, :])
            elif self.op == "subtract":
                for d in dets:
                    ob.detdata[self.result][d, :] = (
                        scale_first * ob.detdata[self.first][d, :]
                    ) - (scale_second * ob.detdata[self.second][d, :])
            elif self.op == "multiply":
                for d in dets:
                    ob.detdata[self.result][d, :] = (
                        scale_first * ob.detdata[self.first][d, :]
                    ) * (scale_second * ob.detdata[self.second][d, :])
            elif self.op == "divide":
                for d in dets:
                    ob.detdata[self.result][d, :] = (
                        scale_first * ob.detdata[self.first][d, :]
                    ) / (scale_second * ob.detdata[self.second][d, :])

    def _finalize(self, data, **kwargs):
        return None

    def _requires(self):
        req = {"detdata": [self.first, self.second]}
        if self.result is not None:
            if (self.result != self.first) and (self.result != self.second):
                req["detdata"].append(self.result)
        return req

    def _provides(self):
        prov = {"detdata": list()}
        if (self.result != self.first) and (self.result != self.second):
            prov["detdata"].append(self.result)
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

first = Unicode(None, allow_none=True, help='The first detdata object') class-attribute instance-attribute

op = Unicode(None, allow_none=True, help="Operation on the timestreams: 'subtract', 'add', 'multiply', or 'divide'") class-attribute instance-attribute

result = Unicode(None, allow_none=True, help='The resulting detdata object') class-attribute instance-attribute

second = Unicode(None, allow_none=True, help='The second detdata object') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/arithmetic.py
53
54
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_op(proposal)

Source code in toast/ops/arithmetic.py
45
46
47
48
49
50
51
@traitlets.validate("op")
def _check_op(self, proposal):
    val = proposal["value"]
    if val is not None:
        if val not in ["add", "subtract", "multiply", "divide"]:
            raise traitlets.TraitError("op must be one of the 4 allowed strings")
    return val

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/arithmetic.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for check_name, check_val in [
        ("first", self.first),
        ("second", self.second),
        ("result", self.result),
        ("op", self.op),
    ]:
        if check_val is None:
            msg = f"The {check_name} trait must be set before calling exec"
            log.error(msg)
            raise RuntimeError(msg)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        local_dets = ob.select_local_detectors(detectors)
        if len(local_dets) == 0:
            # Nothing to do for this observation
            continue
        if self.first not in ob.detdata:
            msg = f"The first detdata key '{self.first}' does not exist in"
            msg += f" observation {ob.name}, skipping"
            log.verbose(msg)
            continue
        if self.second not in ob.detdata:
            msg = f"The second detdata key '{self.first}' does not exist in"
            msg += f" observation {ob.name}, skipping"
            log.verbose(msg)
            continue

        first_units = ob.detdata[self.first].units
        second_units = ob.detdata[self.second].units

        # Operate on the intersection of detectors
        dets = list(
            sorted(
                set.intersection(
                    set(ob.detdata[self.first].detectors),
                    set(ob.detdata[self.second].detectors),
                )
            )
        )

        if self.result == self.first:
            result_units = first_units
            scale_first = 1.0
            scale_second = unit_conversion(second_units, result_units)
        elif self.result == self.second:
            result_units = second_units
            scale_first = unit_conversion(first_units, result_units)
            scale_second = 1.0
        else:
            # We are creating a new field for the output.  Use units of first field.
            result_units = first_units
            scale_first = 1.0
            scale_second = unit_conversion(second_units, result_units)
            exists = ob.detdata.ensure(
                self.result,
                sample_shape=ob.detdata[self.first].detector_shape[1:],
                dtype=ob.detdata[self.first].dtype,
                detectors=ob.detdata[self.first].detectors,
                create_units=result_units,
            )
        if self.op == "add":
            for d in dets:
                ob.detdata[self.result][d, :] = (
                    scale_first * ob.detdata[self.first][d, :]
                ) + (scale_second * ob.detdata[self.second][d, :])
        elif self.op == "subtract":
            for d in dets:
                ob.detdata[self.result][d, :] = (
                    scale_first * ob.detdata[self.first][d, :]
                ) - (scale_second * ob.detdata[self.second][d, :])
        elif self.op == "multiply":
            for d in dets:
                ob.detdata[self.result][d, :] = (
                    scale_first * ob.detdata[self.first][d, :]
                ) * (scale_second * ob.detdata[self.second][d, :])
        elif self.op == "divide":
            for d in dets:
                ob.detdata[self.result][d, :] = (
                    scale_first * ob.detdata[self.first][d, :]
                ) / (scale_second * ob.detdata[self.second][d, :])

_finalize(data, **kwargs)

Source code in toast/ops/arithmetic.py
142
143
def _finalize(self, data, **kwargs):
    return None

_provides()

Source code in toast/ops/arithmetic.py
152
153
154
155
156
def _provides(self):
    prov = {"detdata": list()}
    if (self.result != self.first) and (self.result != self.second):
        prov["detdata"].append(self.result)
    return prov

_requires()

Source code in toast/ops/arithmetic.py
145
146
147
148
149
150
def _requires(self):
    req = {"detdata": [self.first, self.second]}
    if self.result is not None:
        if (self.result != self.first) and (self.result != self.second):
            req["detdata"].append(self.result)
    return req

toast.ops.CalibrateDetectors

Bases: Operator

Multiply detector data by factors in the observation dictionary.

Given a dictionary in each observation, apply the per-detector scaling factors to the timestreams. Detectors that do not exist in the dictionary are flagged.

Source code in toast/ops/calibrate.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@trait_docs
class CalibrateDetectors(Operator):
    """Multiply detector data by factors in the observation dictionary.

    Given a dictionary in each observation, apply the per-detector scaling factors
    to the timestreams.  Detectors that do not exist in the dictionary are flagged.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    cal_name = Unicode(
        "calibration", help="The observation key containing the calibration dictionary"
    )

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key for data to calibrate",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    cal_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask to apply to detectors with no calibration information",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("cal_mask")
    def _check_cal_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Calibration mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        for ob in data.obs:
            if self.det_data not in ob.detdata:
                continue
            if self.cal_name not in ob:
                msg = f"{ob.name}: Calibration dictionary {self.cal_name} does "
                msg += f"not exist, skipping"
                if data.comm.group_rank == 0:
                    log.warning(msg)
                continue
            cal = ob[self.cal_name]

            dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
            if len(dets) == 0:
                continue

            # Process all detectors
            det_flags = dict(ob.local_detector_flags)
            for det in dets:
                if det not in cal:
                    # Flag this detector
                    det_flags[det] |= self.cal_mask
                    continue
                ob.detdata[self.det_data][det] *= cal[det]

            # Update flags
            ob.update_local_detector_flags(det_flags)

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "detdata": [self.det_data],
        }
        return req

    def _provides(self):
        return {"detdata": [self.det_data]}

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

cal_mask = Int(defaults.det_mask_invalid, help='Bit mask to apply to detectors with no calibration information') class-attribute instance-attribute

cal_name = Unicode('calibration', help='The observation key containing the calibration dictionary') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for data to calibrate') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/calibrate.py
61
62
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_cal_mask(proposal)

Source code in toast/ops/calibrate.py
54
55
56
57
58
59
@traitlets.validate("cal_mask")
def _check_cal_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Calibration mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/calibrate.py
47
48
49
50
51
52
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/calibrate.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    for ob in data.obs:
        if self.det_data not in ob.detdata:
            continue
        if self.cal_name not in ob:
            msg = f"{ob.name}: Calibration dictionary {self.cal_name} does "
            msg += f"not exist, skipping"
            if data.comm.group_rank == 0:
                log.warning(msg)
            continue
        cal = ob[self.cal_name]

        dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
        if len(dets) == 0:
            continue

        # Process all detectors
        det_flags = dict(ob.local_detector_flags)
        for det in dets:
            if det not in cal:
                # Flag this detector
                det_flags[det] |= self.cal_mask
                continue
            ob.detdata[self.det_data][det] *= cal[det]

        # Update flags
        ob.update_local_detector_flags(det_flags)

_finalize(data, **kwargs)

Source code in toast/ops/calibrate.py
95
96
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/calibrate.py
104
105
def _provides(self):
    return {"detdata": [self.det_data]}

_requires()

Source code in toast/ops/calibrate.py
 98
 99
100
101
102
def _requires(self):
    req = {
        "detdata": [self.det_data],
    }
    return req

toast.ops.MemoryCounter

Bases: Operator

Compute total memory used by Observations in a Data object.

Every process group iterates over their observations and sums the total memory used by detector and shared data. Metadata and interval lists are assumed to be negligible and are not counted.

Source code in toast/ops/memory_counter.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@trait_docs
class MemoryCounter(Operator):
    """Compute total memory used by Observations in a Data object.

    Every process group iterates over their observations and sums the total memory used
    by detector and shared data.  Metadata and interval lists are assumed to be
    negligible and are not counted.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    silent = Bool(
        False,
        help="If True, return the memory used but do not log the result",
    )

    prefix = Unicode("", help="Prefix for log messages")

    def __init__(self, **kwargs):
        self.total_bytes = 0
        self.sys_mem_str = ""
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        for ob in data.obs:
            self.total_bytes += ob.memory_use()
        self.sys_mem_str = memreport(
            msg="(whole node)", comm=data.comm.comm_world, silent=True
        )
        return

    def _finalize(self, data, **kwargs):
        log = Logger.get()
        if not self.silent:
            total_gb = self.total_bytes / 2**30
            if data.comm.comm_group_rank is not None:
                total_gb = data.comm.comm_group_rank.allreduce(total_gb)
            if data.comm.world_rank == 0:
                msg = f"Total timestream memory use = {total_gb:.3f} GB"
                log.info(f"{self.prefix}:  {msg}")
                log.info(f"{self.prefix}:  {self.sys_mem_str}")
        total = self.total_bytes
        self.total_bytes = 0
        return total

    def _requires(self):
        return dict()

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

prefix = Unicode('', help='Prefix for log messages') class-attribute instance-attribute

silent = Bool(False, help='If True, return the memory used but do not log the result') class-attribute instance-attribute

sys_mem_str = '' instance-attribute

total_bytes = 0 instance-attribute

__init__(**kwargs)

Source code in toast/ops/memory_counter.py
36
37
38
39
def __init__(self, **kwargs):
    self.total_bytes = 0
    self.sys_mem_str = ""
    super().__init__(**kwargs)

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/memory_counter.py
41
42
43
44
45
46
47
48
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    for ob in data.obs:
        self.total_bytes += ob.memory_use()
    self.sys_mem_str = memreport(
        msg="(whole node)", comm=data.comm.comm_world, silent=True
    )
    return

_finalize(data, **kwargs)

Source code in toast/ops/memory_counter.py
50
51
52
53
54
55
56
57
58
59
60
61
62
def _finalize(self, data, **kwargs):
    log = Logger.get()
    if not self.silent:
        total_gb = self.total_bytes / 2**30
        if data.comm.comm_group_rank is not None:
            total_gb = data.comm.comm_group_rank.allreduce(total_gb)
        if data.comm.world_rank == 0:
            msg = f"Total timestream memory use = {total_gb:.3f} GB"
            log.info(f"{self.prefix}:  {msg}")
            log.info(f"{self.prefix}:  {self.sys_mem_str}")
    total = self.total_bytes
    self.total_bytes = 0
    return total

_provides()

Source code in toast/ops/memory_counter.py
67
68
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/memory_counter.py
64
65
def _requires(self):
    return dict()

toast.ops.Statistics

Bases: Operator

Operator to measure and write out data statistics

Source code in toast/ops/statistics.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@trait_docs
class Statistics(Operator):
    """Operator to measure and write out data statistics"""

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(defaults.det_data, help="Observation detdata key to analyze")

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    output_dir = Unicode(
        None,
        allow_none=True,
        help="If specified, write output data products to this directory",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        """Measure the statistics

        Args:
            data (toast.Data): The distributed data.

        """
        log = Logger.get()
        nstat = 3  # Variance, Skewness, Kurtosis

        if self.output_dir is not None:
            if not os.path.isdir(self.output_dir):
                os.makedirs(self.output_dir, exist_ok=True)

        for obs in data.obs:
            # NOTE:  We could use the session name / uid in the filename
            # too for easy sorting.
            if obs.name is None:
                fname_out = f"{self.name}_{obs.uid}.h5"
            else:
                fname_out = f"{self.name}_{obs.name}.h5"
            if self.output_dir is not None:
                fname_out = os.path.join(self.output_dir, fname_out)

            # Get the list of all detectors that are not cut
            obs_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            if obs.comm.group_size == 1:
                all_dets = obs_dets
            else:
                proc_dets = obs.comm.comm_group.gather(obs_dets, root=0)
                all_dets = None
                if obs.comm.group_rank == 0:
                    all_set = set()
                    for pdets in proc_dets:
                        for d in pdets:
                            all_set.add(d)
                    all_dets = list(sorted(all_set))
                all_dets = obs.comm.comm_group.bcast(all_dets, root=0)

            ndet = len(all_dets)
            hits = np.zeros([ndet], dtype=int)
            means = np.zeros([ndet], dtype=float)
            stats = np.zeros([nstat, ndet], dtype=float)

            views = obs.view[self.view]

            # Measure the mean separately to simplify the math
            for iview, view in enumerate(views):
                if view.start is None:
                    # This is a view of the whole obs
                    nsample = obs.n_local_samples
                else:
                    nsample = view.stop - view.start
                if self.shared_flags is not None:
                    shared_flags = views.shared[self.shared_flags][iview]
                    shared_mask = (shared_flags & self.shared_flag_mask) == 0
                else:
                    shared_mask = np.ones(nsample, dtype=bool)

                for det in obs_dets:
                    if self.det_flags is not None:
                        det_flags = views.detdata[self.det_flags][iview][det]
                        det_mask = (det_flags & self.det_flag_mask) == 0
                        mask = np.logical_and(shared_mask, det_mask)
                    else:
                        mask = shared_mask
                    ngood = np.sum(mask)
                    if ngood == 0:
                        continue
                    signal = views.detdata[self.det_data][iview][det]
                    good_signal = signal[mask].copy()
                    idet = all_dets.index(det)
                    # Valid samples
                    hits[idet] += ngood
                    # Mean
                    means[idet] += np.sum(good_signal)

            if obs.comm.comm_group is not None:
                hits = obs.comm.comm_group.allreduce(hits, op=MPI.SUM)
                means = obs.comm.comm_group.allreduce(means, op=MPI.SUM)

            good = hits != 0
            means[good] /= hits[good]

            # Now evaluate the moments

            for iview, view in enumerate(views):
                if view.start is None:
                    # This is a view of the whole obs
                    nsample = obs.n_local_samples
                else:
                    nsample = view.stop - view.start
                if self.shared_flags is not None:
                    shared_flags = views.shared[self.shared_flags][iview]
                    shared_mask = (shared_flags & self.shared_flag_mask) == 0
                else:
                    shared_mask = np.ones(nsample, dtype=bool)

                for det in obs_dets:
                    if self.det_flags is not None:
                        det_flags = views.detdata[self.det_flags][iview][det]
                        det_mask = (det_flags & self.det_flag_mask) == 0
                        mask = np.logical_and(shared_mask, det_mask)
                    else:
                        mask = shared_mask
                    ngood = np.sum(mask)
                    if ngood == 0:
                        continue
                    idet = all_dets.index(det)
                    signal = views.detdata[self.det_data][iview][det]
                    good_signal = signal[mask].copy() - means[idet]
                    # Variance
                    stats[0, idet] += np.sum(good_signal**2)
                    # Skewness
                    stats[1, idet] += np.sum(good_signal**3)
                    # Kurtosis
                    stats[2, idet] += np.sum(good_signal**4)

            if obs.comm.comm_group is not None:
                stats = obs.comm.comm_group.reduce(stats, op=MPI.SUM)

            if obs.comm.group_rank == 0:
                # Central moments
                m2 = stats[0]
                m3 = stats[1]
                m4 = stats[2]
                for m in m2, m3, m4:
                    m[good] /= hits[good]
                # Variance
                var = m2.copy()
                # Skewness
                skew = m3.copy()
                skew[good] /= m2[good] ** 1.5
                # Kurtosis
                kurt = m4.copy()
                kurt[good] /= m2[good] ** 2
                # Write the results
                with h5py.File(fname_out, "w") as fout:
                    fout.attrs["UID"] = obs.uid
                    if obs.name is not None:
                        fout.attrs["name"] = obs.name
                    fout.attrs["nsample"] = obs.n_all_samples
                    fout.create_dataset(
                        "detectors", data=all_dets, dtype=h5py.string_dtype()
                    )
                    fout["ngood"] = hits
                    fout["mean"] = means
                    fout["variance"] = var
                    fout["skewness"] = skew
                    fout["kurtosis"] = kurt
                log.debug(f"Wrote data statistics to {fname_out}")

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key to analyze') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

output_dir = Unicode(None, allow_none=True, help='If specified, write output data products to this directory') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/statistics.py
90
91
92
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/statistics.py
83
84
85
86
87
88
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/statistics.py
69
70
71
72
73
74
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/statistics.py
76
77
78
79
80
81
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Measure the statistics

Parameters:

Name Type Description Default
data Data

The distributed data.

required
Source code in toast/ops/statistics.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    """Measure the statistics

    Args:
        data (toast.Data): The distributed data.

    """
    log = Logger.get()
    nstat = 3  # Variance, Skewness, Kurtosis

    if self.output_dir is not None:
        if not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)

    for obs in data.obs:
        # NOTE:  We could use the session name / uid in the filename
        # too for easy sorting.
        if obs.name is None:
            fname_out = f"{self.name}_{obs.uid}.h5"
        else:
            fname_out = f"{self.name}_{obs.name}.h5"
        if self.output_dir is not None:
            fname_out = os.path.join(self.output_dir, fname_out)

        # Get the list of all detectors that are not cut
        obs_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
        if obs.comm.group_size == 1:
            all_dets = obs_dets
        else:
            proc_dets = obs.comm.comm_group.gather(obs_dets, root=0)
            all_dets = None
            if obs.comm.group_rank == 0:
                all_set = set()
                for pdets in proc_dets:
                    for d in pdets:
                        all_set.add(d)
                all_dets = list(sorted(all_set))
            all_dets = obs.comm.comm_group.bcast(all_dets, root=0)

        ndet = len(all_dets)
        hits = np.zeros([ndet], dtype=int)
        means = np.zeros([ndet], dtype=float)
        stats = np.zeros([nstat, ndet], dtype=float)

        views = obs.view[self.view]

        # Measure the mean separately to simplify the math
        for iview, view in enumerate(views):
            if view.start is None:
                # This is a view of the whole obs
                nsample = obs.n_local_samples
            else:
                nsample = view.stop - view.start
            if self.shared_flags is not None:
                shared_flags = views.shared[self.shared_flags][iview]
                shared_mask = (shared_flags & self.shared_flag_mask) == 0
            else:
                shared_mask = np.ones(nsample, dtype=bool)

            for det in obs_dets:
                if self.det_flags is not None:
                    det_flags = views.detdata[self.det_flags][iview][det]
                    det_mask = (det_flags & self.det_flag_mask) == 0
                    mask = np.logical_and(shared_mask, det_mask)
                else:
                    mask = shared_mask
                ngood = np.sum(mask)
                if ngood == 0:
                    continue
                signal = views.detdata[self.det_data][iview][det]
                good_signal = signal[mask].copy()
                idet = all_dets.index(det)
                # Valid samples
                hits[idet] += ngood
                # Mean
                means[idet] += np.sum(good_signal)

        if obs.comm.comm_group is not None:
            hits = obs.comm.comm_group.allreduce(hits, op=MPI.SUM)
            means = obs.comm.comm_group.allreduce(means, op=MPI.SUM)

        good = hits != 0
        means[good] /= hits[good]

        # Now evaluate the moments

        for iview, view in enumerate(views):
            if view.start is None:
                # This is a view of the whole obs
                nsample = obs.n_local_samples
            else:
                nsample = view.stop - view.start
            if self.shared_flags is not None:
                shared_flags = views.shared[self.shared_flags][iview]
                shared_mask = (shared_flags & self.shared_flag_mask) == 0
            else:
                shared_mask = np.ones(nsample, dtype=bool)

            for det in obs_dets:
                if self.det_flags is not None:
                    det_flags = views.detdata[self.det_flags][iview][det]
                    det_mask = (det_flags & self.det_flag_mask) == 0
                    mask = np.logical_and(shared_mask, det_mask)
                else:
                    mask = shared_mask
                ngood = np.sum(mask)
                if ngood == 0:
                    continue
                idet = all_dets.index(det)
                signal = views.detdata[self.det_data][iview][det]
                good_signal = signal[mask].copy() - means[idet]
                # Variance
                stats[0, idet] += np.sum(good_signal**2)
                # Skewness
                stats[1, idet] += np.sum(good_signal**3)
                # Kurtosis
                stats[2, idet] += np.sum(good_signal**4)

        if obs.comm.comm_group is not None:
            stats = obs.comm.comm_group.reduce(stats, op=MPI.SUM)

        if obs.comm.group_rank == 0:
            # Central moments
            m2 = stats[0]
            m3 = stats[1]
            m4 = stats[2]
            for m in m2, m3, m4:
                m[good] /= hits[good]
            # Variance
            var = m2.copy()
            # Skewness
            skew = m3.copy()
            skew[good] /= m2[good] ** 1.5
            # Kurtosis
            kurt = m4.copy()
            kurt[good] /= m2[good] ** 2
            # Write the results
            with h5py.File(fname_out, "w") as fout:
                fout.attrs["UID"] = obs.uid
                if obs.name is not None:
                    fout.attrs["name"] = obs.name
                fout.attrs["nsample"] = obs.n_all_samples
                fout.create_dataset(
                    "detectors", data=all_dets, dtype=h5py.string_dtype()
                )
                fout["ngood"] = hits
                fout["mean"] = means
                fout["variance"] = var
                fout["skewness"] = skew
                fout["kurtosis"] = kurt
            log.debug(f"Wrote data statistics to {fname_out}")

    return

_finalize(data, **kwargs)

Source code in toast/ops/statistics.py
249
250
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/statistics.py
267
268
269
270
271
272
273
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
    }
    return prov

_requires()

Source code in toast/ops/statistics.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

toast.ops.Pipeline

Bases: Operator

Class representing a sequence of Operators.

This runs a list of other operators over sets of detectors (default is all detectors in one shot). By default all observations are passed to each operator, but the observation_key and observation_value traits can be used to run the operators on only observations which have a matching key / value pair.

Source code in toast/ops/pipeline.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@trait_docs
class Pipeline(Operator):
    """Class representing a sequence of Operators.

    This runs a list of other operators over sets of detectors (default is all
    detectors in one shot).  By default all observations are passed to each operator,
    but the `observation_key` and `observation_value` traits can be used to run the
    operators on only observations which have a matching key / value pair.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    operators = List([], help="List of Operator instances to run.")

    detector_sets = List(
        ["ALL"],
        help="List of detector sets.  ['ALL'] and ['SINGLE'] are also valid values.",
    )

    use_hybrid = Bool(
        True,
        help="Should the pipeline be allowed to use the GPU when it has some cpu-only operators.",
    )

    @traitlets.validate("detector_sets")
    def _check_detsets(self, proposal):
        detsets = proposal["value"]
        if len(detsets) == 0:
            msg = "detector_sets must be a list with at least one entry "
            msg += "('ALL' and 'SINGLE' are valid entries)"
            raise traitlets.TraitError(msg)
        for dset in detsets:
            if (dset != "ALL") and (dset != "SINGLE"):
                # Not a built-in name, must be an actual list of detectors
                if isinstance(dset, str) or len(dset) == 0:
                    raise traitlets.TraitError(
                        "A detector set must be a list of detectors or 'ALL' / 'SINGLE'"
                    )
                for d in dset:
                    if not isinstance(d, str):
                        raise traitlets.TraitError(
                            "Each element of a det set should be a detector name"
                        )
        return detsets

    @traitlets.validate("operators")
    def _check_operators(self, proposal):
        ops = proposal["value"]
        for op in ops:
            if not isinstance(op, Operator):
                raise traitlets.TraitError(
                    "operators must be a list of Operator instances or None"
                )
        return ops

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # keeps track of the data that is on device
        self._staged_data = None
        # keep track of the data that had to move back to host due to a cpu-only operator
        # (for display / debugging purposes)
        self._unstaged_data = None

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()
        pstr = f"Proc ({data.comm.world_rank}, {data.comm.group_rank})"

        if len(self.operators) == 0:
            log.debug_rank(
                "Pipeline has no operators, nothing to do", comm=data.comm.comm_world
            )
            return

        # If the calling code passed use_accel=True, we assume that it will move
        # the data for us.  Otherwise, if possible / allowed, use the accelerator and
        # deal with data movement ourselves.
        self._staged_data = None
        self._unstaged_data = None
        pipe_accel = self._pipe_accel(use_accel)

        if pipe_accel:
            # some of our operators support using the accelerator
            msg = f"{self} supports accelerators."
            log.verbose_rank(msg, comm=data.comm.comm_world)
            use_accel = True
            # keeps track of the data that is on device
            self._staged_data = SetDict(
                {
                    key: set()
                    for key in ["global", "meta", "detdata", "shared", "intervals"]
                }
            )
            # keep track of the data that had to move back from device
            # (for display / debugging purposes)
            self._unstaged_data = SetDict(
                {
                    key: set()
                    for key in ["global", "meta", "detdata", "shared", "intervals"]
                }
            )

        if len(data.obs) == 0:
            # No observations for this group
            msg = f"{self} data, group {data.comm.group} has no observations."
            log.verbose_rank(msg, comm=data.comm.comm_group)

        # Ensure that all operators with a detector mask are using the same
        # mask.
        det_mask = None
        for op in self.operators:
            if hasattr(op, "det_mask"):
                if det_mask is None:
                    det_mask = op.det_mask
                else:
                    if op.det_mask != det_mask:
                        msg = "All operators in a Pipeline which use a det_mask"
                        msg += " must have the same mask value"
                        log.error(msg)
                        raise RuntimeError(msg)
        if det_mask is None:
            det_mask = 0

        if len(self.detector_sets) == 1 and self.detector_sets[0] == "ALL":
            # Run the operators with all detectors at once
            for op in self.operators:
                self._exec_operator(
                    op,
                    data,
                    detectors=None,
                    pipe_accel=pipe_accel,
                )
        elif len(self.detector_sets) == 1 and self.detector_sets[0] == "SINGLE":
            # Get superset of detectors across all observations
            all_local_dets = data.all_local_detectors(
                selection=detectors, flagmask=det_mask
            )
            if len(all_local_dets) == 0:
                all_local_dets = [None]
            # Run operators one detector at a time
            for det in all_local_dets:
                msg = f"{pstr} {self} SINGLE detector {det}"
                log.verbose(msg)
                if det is None:
                    dets = []
                else:
                    dets = [det]
                for op in self.operators:
                    self._exec_operator(
                        op,
                        data,
                        detectors=dets,
                        pipe_accel=pipe_accel,
                    )
        else:
            # We have explicit detector sets
            det_check = set(detectors)
            for det_set in self.detector_sets:
                selected_set = det_set
                if detectors is not None:
                    selected_set = list()
                    for det in det_set:
                        if det in det_check:
                            selected_set.append(det)
                if len(selected_set) == 0:
                    # Nothing in this detector set is being used, skip it
                    continue
                msg = f"{pstr} {self} detector set {selected_set}"
                log.verbose(msg)
                for op in self.operators:
                    self._exec_operator(
                        op,
                        data,
                        detectors=selected_set,
                        pipe_accel=pipe_accel,
                    )

        # notify user of device->host data movements introduced by CPU operators
        if (self._unstaged_data is not None) and (not self._unstaged_data.is_empty()):
            cpu_ops = {
                op.__class__.__qualname__
                for op in self.operators
                if not op.supports_accel()
            }
            log.debug(
                f"{pstr} {self}, had to move {self._unstaged_data} back to host as {cpu_ops} do not support accel."
            )

    @function_timer
    def _exec_operator(self, op, data, detectors, pipe_accel):
        """Runs an operator, dealing with data movement to/from device if needed."""
        # For this operator, we run on the accelerator if the pipeline has some
        # operators enabled and if this operator supports it.
        run_accel = pipe_accel and op.supports_accel()

        log = Logger.get()
        msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
        msg += f"calling operator '{op.name}' exec(accelerator={run_accel})"
        if detectors is None:
            msg += " with ALL dets"
        log.verbose(msg)

        # Ensures data is where it should be for this operator
        if self._staged_data is not None:
            requires = SetDict(op.requires())
            if run_accel:
                # This operator will use the accelerator, stage data
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"BEFORE staged = {self._staged_data}, unstaged = {self._unstaged_data}"
                log.verbose(msg)
                requires -= self._staged_data
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"Staging objects {requires}"
                log.verbose(msg)
                data.accel_create(requires)
                data.accel_update_device(requires)
                # Update our record of data on device
                self._unstaged_data -= requires
                self._staged_data |= requires
                self._staged_data |= op.provides()
                self._unstaged_data -= op.provides()
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"AFTER staged = {self._staged_data}, unstaged = {self._unstaged_data}"
                log.verbose(msg)
            else:
                # This operator is running on the host, unstage data
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"BEFORE staged = {self._staged_data}, unstaged = {self._unstaged_data}"
                log.verbose(msg)
                requires &= self._staged_data  # intersection
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"Un-staging objects {requires}"
                log.verbose(msg)
                data.accel_update_host(requires)
                # Update our record of data on the device
                self._staged_data -= requires
                self._unstaged_data |= requires  # union
                self._unstaged_data |= op.provides()
                # lets operator decide if it wants to move data and operate on device by itself
                run_accel = None
                msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
                msg += f"AFTER staged = {self._staged_data}, unstaged = {self._unstaged_data}"
                log.verbose(msg)
        # runs operator
        op.exec(data, detectors=detectors, use_accel=run_accel)

    @function_timer
    def _finalize(self, data, use_accel=None, **kwargs):
        # FIXME:  We need to clarify in documentation that if using the
        # accelerator in _finalize() to produce output products, these
        # outputs should remain on the device so that they can be copied
        # out at the end automatically.

        log = Logger.get()
        pstr = f"Proc ({data.comm.world_rank}, {data.comm.group_rank})"
        msg = f"{pstr} {self} finalize"
        log.verbose(msg)

        # Are we running on the accelerator?
        pipe_accel = self._pipe_accel(use_accel)

        # run finalize on all the operators in the pipeline
        # NOTE: this might produce some output products
        result = list()
        if self.operators is not None:
            for op in self.operators:
                # Did we set use_accel to true when running with this operator
                use_accel_op = pipe_accel and op.supports_accel()
                result.append(op.finalize(data, use_accel=use_accel_op, **kwargs))

        # get outputs back and clean up data
        # if we are in charge of the data movement
        if self._staged_data is not None:
            # get outputs back from device
            provides = SetDict(self.provides())
            provides &= self._staged_data  # intersection
            log.verbose(f"{pstr} {self} copying out accel data outputs: {provides}")
            data.accel_update_host(provides)
            # deleting all data on device
            log.verbose(f"{pstr} {self} deleting accel data: {self._staged_data}")
            data.accel_delete(self._staged_data)
            self._staged_data = None
            self._unstaged_data = None

        return result

    def _pipe_accel(self, use_accel):
        if (use_accel is None) and accel_enabled():
            # Only allows hybrid pipelines if the environement variable and pipeline agree to it
            # (they both default to True)
            use_hybrid = self.use_hybrid and use_hybrid_pipelines
            # can we run this pipelines on accelerator
            supports_accel = (
                self._supports_accel_partial() if use_hybrid else self._supports_accel()
            )
            return supports_accel
        else:
            return use_accel

    def _requires(self):
        """
        Work through the operator list in reverse order and prune intermediate products
        (that will be provided by a previous operator).
        """
        # constructs the union of the requires minus the provides (in reverse order)
        req = SetDict(
            {key: set() for key in ["global", "meta", "detdata", "shared", "intervals"]}
        )
        for op in reversed(self.operators):
            # remove provides first as there can be an overlap between provides and requires
            req -= op.provides()
            req |= op.requires()
        # converts into a dictionary of lists
        req = {k: list(v) for (k, v) in req.items()}
        return req

    def _provides(self):
        """
        Work through the operator list and prune intermediate products
        (that are be provided to an intermediate operator).
        FIXME could a final result also be used by an intermediate operator?
        """
        # constructs the union of the provides minus the requires
        prov = SetDict(
            {key: set() for key in ["global", "meta", "detdata", "shared", "intervals"]}
        )
        for op in self.operators:
            # remove requires first as there can be an overlap between provides and requires
            prov -= op.requires()
            prov |= op.provides()
        # converts into a dictionary of lists
        prov = {k: list(v) for (k, v) in prov.items()}
        return prov

    def _implementations(self):
        """
        Find implementations supported by all the operators
        """
        implementations = {
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        }
        for op in self.operators:
            implementations.intersection_update(op.implementations())
        return list(implementations)

    def _supports_accel(self):
        """
        Returns True if *all* the operators are accelerator compatible.
        """
        for op in self.operators:
            if not op.supports_accel():
                return False
        return True

    def _supports_accel_partial(self):
        """
        Returns True if *at least one* of the operators is accelerator compatible.
        """
        for op in self.operators:
            if op.supports_accel():
                return True
        return False

    def __str__(self):
        """
        Converts the pipeline into a human-readable string.
        """
        return f"Pipeline{[op.__class__.__qualname__ for op in self.operators]}"

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

_staged_data = None instance-attribute

_unstaged_data = None instance-attribute

detector_sets = List(['ALL'], help="List of detector sets. ['ALL'] and ['SINGLE'] are also valid values.") class-attribute instance-attribute

operators = List([], help='List of Operator instances to run.') class-attribute instance-attribute

use_hybrid = Bool(True, help='Should the pipeline be allowed to use the GPU when it has some cpu-only operators.') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/pipeline.py
73
74
75
76
77
78
79
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    # keeps track of the data that is on device
    self._staged_data = None
    # keep track of the data that had to move back to host due to a cpu-only operator
    # (for display / debugging purposes)
    self._unstaged_data = None

__str__()

Converts the pipeline into a human-readable string.

Source code in toast/ops/pipeline.py
384
385
386
387
388
def __str__(self):
    """
    Converts the pipeline into a human-readable string.
    """
    return f"Pipeline{[op.__class__.__qualname__ for op in self.operators]}"

_check_detsets(proposal)

Source code in toast/ops/pipeline.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@traitlets.validate("detector_sets")
def _check_detsets(self, proposal):
    detsets = proposal["value"]
    if len(detsets) == 0:
        msg = "detector_sets must be a list with at least one entry "
        msg += "('ALL' and 'SINGLE' are valid entries)"
        raise traitlets.TraitError(msg)
    for dset in detsets:
        if (dset != "ALL") and (dset != "SINGLE"):
            # Not a built-in name, must be an actual list of detectors
            if isinstance(dset, str) or len(dset) == 0:
                raise traitlets.TraitError(
                    "A detector set must be a list of detectors or 'ALL' / 'SINGLE'"
                )
            for d in dset:
                if not isinstance(d, str):
                    raise traitlets.TraitError(
                        "Each element of a det set should be a detector name"
                    )
    return detsets

_check_operators(proposal)

Source code in toast/ops/pipeline.py
63
64
65
66
67
68
69
70
71
@traitlets.validate("operators")
def _check_operators(self, proposal):
    ops = proposal["value"]
    for op in ops:
        if not isinstance(op, Operator):
            raise traitlets.TraitError(
                "operators must be a list of Operator instances or None"
            )
    return ops

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/pipeline.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()
    pstr = f"Proc ({data.comm.world_rank}, {data.comm.group_rank})"

    if len(self.operators) == 0:
        log.debug_rank(
            "Pipeline has no operators, nothing to do", comm=data.comm.comm_world
        )
        return

    # If the calling code passed use_accel=True, we assume that it will move
    # the data for us.  Otherwise, if possible / allowed, use the accelerator and
    # deal with data movement ourselves.
    self._staged_data = None
    self._unstaged_data = None
    pipe_accel = self._pipe_accel(use_accel)

    if pipe_accel:
        # some of our operators support using the accelerator
        msg = f"{self} supports accelerators."
        log.verbose_rank(msg, comm=data.comm.comm_world)
        use_accel = True
        # keeps track of the data that is on device
        self._staged_data = SetDict(
            {
                key: set()
                for key in ["global", "meta", "detdata", "shared", "intervals"]
            }
        )
        # keep track of the data that had to move back from device
        # (for display / debugging purposes)
        self._unstaged_data = SetDict(
            {
                key: set()
                for key in ["global", "meta", "detdata", "shared", "intervals"]
            }
        )

    if len(data.obs) == 0:
        # No observations for this group
        msg = f"{self} data, group {data.comm.group} has no observations."
        log.verbose_rank(msg, comm=data.comm.comm_group)

    # Ensure that all operators with a detector mask are using the same
    # mask.
    det_mask = None
    for op in self.operators:
        if hasattr(op, "det_mask"):
            if det_mask is None:
                det_mask = op.det_mask
            else:
                if op.det_mask != det_mask:
                    msg = "All operators in a Pipeline which use a det_mask"
                    msg += " must have the same mask value"
                    log.error(msg)
                    raise RuntimeError(msg)
    if det_mask is None:
        det_mask = 0

    if len(self.detector_sets) == 1 and self.detector_sets[0] == "ALL":
        # Run the operators with all detectors at once
        for op in self.operators:
            self._exec_operator(
                op,
                data,
                detectors=None,
                pipe_accel=pipe_accel,
            )
    elif len(self.detector_sets) == 1 and self.detector_sets[0] == "SINGLE":
        # Get superset of detectors across all observations
        all_local_dets = data.all_local_detectors(
            selection=detectors, flagmask=det_mask
        )
        if len(all_local_dets) == 0:
            all_local_dets = [None]
        # Run operators one detector at a time
        for det in all_local_dets:
            msg = f"{pstr} {self} SINGLE detector {det}"
            log.verbose(msg)
            if det is None:
                dets = []
            else:
                dets = [det]
            for op in self.operators:
                self._exec_operator(
                    op,
                    data,
                    detectors=dets,
                    pipe_accel=pipe_accel,
                )
    else:
        # We have explicit detector sets
        det_check = set(detectors)
        for det_set in self.detector_sets:
            selected_set = det_set
            if detectors is not None:
                selected_set = list()
                for det in det_set:
                    if det in det_check:
                        selected_set.append(det)
            if len(selected_set) == 0:
                # Nothing in this detector set is being used, skip it
                continue
            msg = f"{pstr} {self} detector set {selected_set}"
            log.verbose(msg)
            for op in self.operators:
                self._exec_operator(
                    op,
                    data,
                    detectors=selected_set,
                    pipe_accel=pipe_accel,
                )

    # notify user of device->host data movements introduced by CPU operators
    if (self._unstaged_data is not None) and (not self._unstaged_data.is_empty()):
        cpu_ops = {
            op.__class__.__qualname__
            for op in self.operators
            if not op.supports_accel()
        }
        log.debug(
            f"{pstr} {self}, had to move {self._unstaged_data} back to host as {cpu_ops} do not support accel."
        )

_exec_operator(op, data, detectors, pipe_accel)

Runs an operator, dealing with data movement to/from device if needed.

Source code in toast/ops/pipeline.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@function_timer
def _exec_operator(self, op, data, detectors, pipe_accel):
    """Runs an operator, dealing with data movement to/from device if needed."""
    # For this operator, we run on the accelerator if the pipeline has some
    # operators enabled and if this operator supports it.
    run_accel = pipe_accel and op.supports_accel()

    log = Logger.get()
    msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
    msg += f"calling operator '{op.name}' exec(accelerator={run_accel})"
    if detectors is None:
        msg += " with ALL dets"
    log.verbose(msg)

    # Ensures data is where it should be for this operator
    if self._staged_data is not None:
        requires = SetDict(op.requires())
        if run_accel:
            # This operator will use the accelerator, stage data
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"BEFORE staged = {self._staged_data}, unstaged = {self._unstaged_data}"
            log.verbose(msg)
            requires -= self._staged_data
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"Staging objects {requires}"
            log.verbose(msg)
            data.accel_create(requires)
            data.accel_update_device(requires)
            # Update our record of data on device
            self._unstaged_data -= requires
            self._staged_data |= requires
            self._staged_data |= op.provides()
            self._unstaged_data -= op.provides()
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"AFTER staged = {self._staged_data}, unstaged = {self._unstaged_data}"
            log.verbose(msg)
        else:
            # This operator is running on the host, unstage data
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"BEFORE staged = {self._staged_data}, unstaged = {self._unstaged_data}"
            log.verbose(msg)
            requires &= self._staged_data  # intersection
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"Un-staging objects {requires}"
            log.verbose(msg)
            data.accel_update_host(requires)
            # Update our record of data on the device
            self._staged_data -= requires
            self._unstaged_data |= requires  # union
            self._unstaged_data |= op.provides()
            # lets operator decide if it wants to move data and operate on device by itself
            run_accel = None
            msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} "
            msg += f"AFTER staged = {self._staged_data}, unstaged = {self._unstaged_data}"
            log.verbose(msg)
    # runs operator
    op.exec(data, detectors=detectors, use_accel=run_accel)

_finalize(data, use_accel=None, **kwargs)

Source code in toast/ops/pipeline.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@function_timer
def _finalize(self, data, use_accel=None, **kwargs):
    # FIXME:  We need to clarify in documentation that if using the
    # accelerator in _finalize() to produce output products, these
    # outputs should remain on the device so that they can be copied
    # out at the end automatically.

    log = Logger.get()
    pstr = f"Proc ({data.comm.world_rank}, {data.comm.group_rank})"
    msg = f"{pstr} {self} finalize"
    log.verbose(msg)

    # Are we running on the accelerator?
    pipe_accel = self._pipe_accel(use_accel)

    # run finalize on all the operators in the pipeline
    # NOTE: this might produce some output products
    result = list()
    if self.operators is not None:
        for op in self.operators:
            # Did we set use_accel to true when running with this operator
            use_accel_op = pipe_accel and op.supports_accel()
            result.append(op.finalize(data, use_accel=use_accel_op, **kwargs))

    # get outputs back and clean up data
    # if we are in charge of the data movement
    if self._staged_data is not None:
        # get outputs back from device
        provides = SetDict(self.provides())
        provides &= self._staged_data  # intersection
        log.verbose(f"{pstr} {self} copying out accel data outputs: {provides}")
        data.accel_update_host(provides)
        # deleting all data on device
        log.verbose(f"{pstr} {self} deleting accel data: {self._staged_data}")
        data.accel_delete(self._staged_data)
        self._staged_data = None
        self._unstaged_data = None

    return result

_implementations()

Find implementations supported by all the operators

Source code in toast/ops/pipeline.py
352
353
354
355
356
357
358
359
360
361
362
363
364
def _implementations(self):
    """
    Find implementations supported by all the operators
    """
    implementations = {
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    }
    for op in self.operators:
        implementations.intersection_update(op.implementations())
    return list(implementations)

_pipe_accel(use_accel)

Source code in toast/ops/pipeline.py
304
305
306
307
308
309
310
311
312
313
314
315
def _pipe_accel(self, use_accel):
    if (use_accel is None) and accel_enabled():
        # Only allows hybrid pipelines if the environement variable and pipeline agree to it
        # (they both default to True)
        use_hybrid = self.use_hybrid and use_hybrid_pipelines
        # can we run this pipelines on accelerator
        supports_accel = (
            self._supports_accel_partial() if use_hybrid else self._supports_accel()
        )
        return supports_accel
    else:
        return use_accel

_provides()

Work through the operator list and prune intermediate products (that are be provided to an intermediate operator). FIXME could a final result also be used by an intermediate operator?

Source code in toast/ops/pipeline.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def _provides(self):
    """
    Work through the operator list and prune intermediate products
    (that are be provided to an intermediate operator).
    FIXME could a final result also be used by an intermediate operator?
    """
    # constructs the union of the provides minus the requires
    prov = SetDict(
        {key: set() for key in ["global", "meta", "detdata", "shared", "intervals"]}
    )
    for op in self.operators:
        # remove requires first as there can be an overlap between provides and requires
        prov -= op.requires()
        prov |= op.provides()
    # converts into a dictionary of lists
    prov = {k: list(v) for (k, v) in prov.items()}
    return prov

_requires()

Work through the operator list in reverse order and prune intermediate products (that will be provided by a previous operator).

Source code in toast/ops/pipeline.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def _requires(self):
    """
    Work through the operator list in reverse order and prune intermediate products
    (that will be provided by a previous operator).
    """
    # constructs the union of the requires minus the provides (in reverse order)
    req = SetDict(
        {key: set() for key in ["global", "meta", "detdata", "shared", "intervals"]}
    )
    for op in reversed(self.operators):
        # remove provides first as there can be an overlap between provides and requires
        req -= op.provides()
        req |= op.requires()
    # converts into a dictionary of lists
    req = {k: list(v) for (k, v) in req.items()}
    return req

_supports_accel()

Returns True if all the operators are accelerator compatible.

Source code in toast/ops/pipeline.py
366
367
368
369
370
371
372
373
def _supports_accel(self):
    """
    Returns True if *all* the operators are accelerator compatible.
    """
    for op in self.operators:
        if not op.supports_accel():
            return False
    return True

_supports_accel_partial()

Returns True if at least one of the operators is accelerator compatible.

Source code in toast/ops/pipeline.py
375
376
377
378
379
380
381
382
def _supports_accel_partial(self):
    """
    Returns True if *at least one* of the operators is accelerator compatible.
    """
    for op in self.operators:
        if op.supports_accel():
            return True
    return False

toast.ops.RunSpt3g

Bases: Operator

Operator which runs a G3Pipeline.

This operator converts each observation to a stream of frames on each process and then runs the specified G3 pipeline on the local frames. If the obs_import trait is specified, the resulting frames are re-imported to a toast observation at the end.

Source code in toast/ops/run_spt3g.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@trait_docs
class RunSpt3g(Operator):
    """Operator which runs a G3Pipeline.

    This operator converts each observation to a stream of frames on each process
    and then runs the specified G3 pipeline on the local frames.  If the `obs_import`
    trait is specified, the resulting frames are re-imported to a toast observation
    at the end.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    obs_export = Instance(
        klass=object,
        allow_none=True,
        help="Export class to create frames from an observation",
    )

    obs_import = Instance(
        klass=object,
        allow_none=True,
        help="Import class to create observations from frame files",
    )

    modules = List(
        [],
        help="List of tuples of (callable, **kwargs) that will passed to G3Pipeline.Add()",
    )

    @traitlets.validate("obs_export")
    def _check_export(self, proposal):
        ex = proposal["value"]
        if ex is not None:
            # Check that this class is callable.
            if not callable(ex):
                raise traitlets.TraitError("obs_export class must be callable")
        return ex

    @traitlets.validate("obs_import")
    def _check_import(self, proposal):
        im = proposal["value"]
        if im is not None:
            # Check that this class is callable.
            if not callable(im):
                raise traitlets.TraitError("obs_import class must be callable")
        return im

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if not available:
            raise RuntimeError("spt3g is not available")

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        # Check that the export class is set
        if self.obs_export is None:
            raise RuntimeError(
                "You must set the obs_export trait before calling exec()"
            )

        # Check that the import class is set
        if self.obs_import is None:
            raise RuntimeError(
                "You must set the obs_import trait before calling exec()"
            )

        if len(self.modules) == 0:
            log.debug_rank(
                "No modules specified, nothing to do.", comm=data.comm.comm_world
            )
            return

        n_obs = len(data.obs)

        for iobs in range(n_obs):
            ob = data.obs[iobs]

            # Export observation to frames on all processes
            frames = self.obs_export(ob)

            # Helper class that emits frames
            emitter = frame_emitter(frames=frames)

            # Optional frame collection afterwards
            collector = frame_collector()

            # Set up pipeline
            run_pipe = c3g.G3Pipeline()
            run_pipe.Add(emitter)
            for callable, args in self.modules:
                run_pipe.Add(callable, args)
            if self.obs_import is not None:
                run_pipe.Add(collector)

            # Run it
            run_pipe.Run()

            # Optionally convert back and replace the input observation
            if self.obs_import is not None:
                data.obs[iobs] = self.obs_import(collector.frames)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        return dict()

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

modules = List([], help='List of tuples of (callable, **kwargs) that will passed to G3Pipeline.Add()') class-attribute instance-attribute

obs_export = Instance(klass=object, allow_none=True, help='Export class to create frames from an observation') class-attribute instance-attribute

obs_import = Instance(klass=object, allow_none=True, help='Import class to create observations from frame files') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/run_spt3g.py
70
71
72
73
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    if not available:
        raise RuntimeError("spt3g is not available")

_check_export(proposal)

Source code in toast/ops/run_spt3g.py
52
53
54
55
56
57
58
59
@traitlets.validate("obs_export")
def _check_export(self, proposal):
    ex = proposal["value"]
    if ex is not None:
        # Check that this class is callable.
        if not callable(ex):
            raise traitlets.TraitError("obs_export class must be callable")
    return ex

_check_import(proposal)

Source code in toast/ops/run_spt3g.py
61
62
63
64
65
66
67
68
@traitlets.validate("obs_import")
def _check_import(self, proposal):
    im = proposal["value"]
    if im is not None:
        # Check that this class is callable.
        if not callable(im):
            raise traitlets.TraitError("obs_import class must be callable")
    return im

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/run_spt3g.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    # Check that the export class is set
    if self.obs_export is None:
        raise RuntimeError(
            "You must set the obs_export trait before calling exec()"
        )

    # Check that the import class is set
    if self.obs_import is None:
        raise RuntimeError(
            "You must set the obs_import trait before calling exec()"
        )

    if len(self.modules) == 0:
        log.debug_rank(
            "No modules specified, nothing to do.", comm=data.comm.comm_world
        )
        return

    n_obs = len(data.obs)

    for iobs in range(n_obs):
        ob = data.obs[iobs]

        # Export observation to frames on all processes
        frames = self.obs_export(ob)

        # Helper class that emits frames
        emitter = frame_emitter(frames=frames)

        # Optional frame collection afterwards
        collector = frame_collector()

        # Set up pipeline
        run_pipe = c3g.G3Pipeline()
        run_pipe.Add(emitter)
        for callable, args in self.modules:
            run_pipe.Add(callable, args)
        if self.obs_import is not None:
            run_pipe.Add(collector)

        # Run it
        run_pipe.Run()

        # Optionally convert back and replace the input observation
        if self.obs_import is not None:
            data.obs[iobs] = self.obs_import(collector.frames)

    return

_finalize(data, **kwargs)

Source code in toast/ops/run_spt3g.py
128
129
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/run_spt3g.py
134
135
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/run_spt3g.py
131
132
def _requires(self):
    return dict()

Flagging

toast.ops.AzimuthIntervals

Bases: Operator

Build intervals that describe the scanning motion in azimuth.

This operator passes through the azimuth angle and builds the list of intervals for standard types of scanning / turnaround motion. Note that it only makes sense to use this operator for ground-based telescopes that primarily scan in azimuth rather than more complicated (e.g. lissajous) patterns.

Source code in toast/ops/azimuth_intervals.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
@trait_docs
class AzimuthIntervals(Operator):
    """Build intervals that describe the scanning motion in azimuth.

    This operator passes through the azimuth angle and builds the list of
    intervals for standard types of scanning / turnaround motion.  Note
    that it only makes sense to use this operator for ground-based
    telescopes that primarily scan in azimuth rather than more complicated (e.g.
    lissajous) patterns.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    azimuth = Unicode(defaults.azimuth, help="Observation shared key for Azimuth")

    cut_short = Bool(True, help="If True, remove very short scanning intervals")

    cut_long = Bool(True, help="If True, remove very long scanning intervals")

    short_limit = Quantity(
        0.25 * u.dimensionless_unscaled,
        help="Minimum length of a scan.  Either the minimum length in time or a "
        "fraction of median scan length",
    )

    long_limit = Quantity(
        1.25 * u.dimensionless_unscaled,
        help="Maximum length of a scan.  Either the maximum length in time or a "
        "fraction of median scan length",
    )

    scanning_interval = Unicode(
        defaults.scanning_interval, help="Interval name for scanning"
    )

    turnaround_interval = Unicode(
        defaults.turnaround_interval, help="Interval name for turnarounds"
    )

    throw_leftright_interval = Unicode(
        defaults.throw_leftright_interval,
        help="Interval name for left to right scans + turnarounds",
    )

    throw_rightleft_interval = Unicode(
        defaults.throw_rightleft_interval,
        help="Interval name for right to left scans + turnarounds",
    )

    throw_interval = Unicode(
        defaults.throw_interval, help="Interval name for scan + turnaround intervals"
    )

    scan_leftright_interval = Unicode(
        defaults.scan_leftright_interval, help="Interval name for left to right scans"
    )

    turn_leftright_interval = Unicode(
        defaults.turn_leftright_interval,
        help="Interval name for turnarounds after left to right scans",
    )

    scan_rightleft_interval = Unicode(
        defaults.scan_rightleft_interval, help="Interval name for right to left scans"
    )

    turn_rightleft_interval = Unicode(
        defaults.turn_rightleft_interval,
        help="Interval name for turnarounds after right to left scans",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for bad azimuth pointing",
    )

    window_seconds = Float(0.5, help="Smoothing window in seconds")

    debug_root = Unicode(
        None,
        allow_none=True,
        help="If not None, dump debug plots to this root file name",
    )

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        for obs in data.obs:
            # For now, we just have the first process row do the calculation.  It
            # is relatively fast.

            throw_times = None
            throw_leftright_times = None
            throw_rightleft_times = None
            stable_times = None
            stable_leftright_times = None
            stable_rightleft_times = None
            have_scanning = True

            # Sample rate
            stamps = obs.shared[self.times].data
            (rate, dt, dt_min, dt_max, dt_std) = rate_from_times(stamps)

            # Smoothing window in samples
            window = int(rate * self.window_seconds)

            if obs.comm_col_rank == 0:
                # The azimuth angle
                azimuth = np.array(obs.shared[self.azimuth].data)

                # The azimuth flags
                flags = np.array(obs.shared[self.shared_flags].data)
                flags &= self.shared_flag_mask

                # Scan velocity
                scan_vel = self._gradient(azimuth, window, flags=flags)

                # The peak to peak range of the scan velocity
                vel_range = np.amax(scan_vel) - np.amin(scan_vel)

                # Scan acceleration
                scan_accel = self._gradient(scan_vel, window)

                # Peak to peak acceleration range
                accel_range = np.amax(scan_accel) - np.amin(scan_accel)

                # When the acceleration is zero to some tolerance, we are
                # scanning.  However, we also need to only consider times where
                # the velocity is non-zero.
                stable = (np.absolute(scan_accel) < 0.1 * accel_range) * np.ones(
                    len(scan_accel), dtype=np.int8
                )
                stable *= np.absolute(scan_vel) > 0.1 * vel_range

                # The first estimate of the samples where stable pointing
                # begins and ends.
                begin_stable = np.where(stable[1:] - stable[:-1] == 1)[0]
                end_stable = np.where(stable[:-1] - stable[1:] == 1)[0]

                if len(begin_stable) == 0 or len(end_stable) == 0:
                    msg = f"Observation {obs.name} has no stable scanning"
                    msg += f" periods.  You should cut this observation or"
                    msg += f" change the filter window.  Flagging all samples"
                    msg += f" as unstable pointing."
                    log.warning(msg)
                    have_scanning = False

                if have_scanning:
                    # Refine our list of stable periods
                    if begin_stable[0] > end_stable[0]:
                        # We start in the middle of a scan
                        begin_stable = np.concatenate(([0], begin_stable))
                    if begin_stable[-1] > end_stable[-1]:
                        # We end in the middle of a scan
                        end_stable = np.concatenate((end_stable, [obs.n_local_samples]))

                    # In some situations there are very short stable scans detected at
                    # the beginning and end of observations.  Here we cut any short
                    # throw and stable periods.
                    cut_threshold = 4
                    if (self.cut_short or self.cut_long) and (
                        len(begin_stable) >= cut_threshold
                    ):
                        if self.cut_short:
                            stable_timespans = np.array(
                                [
                                    stamps[y - 1] - stamps[x]
                                    for x, y in zip(begin_stable, end_stable)
                                ]
                            )
                            try:
                                # First try short limit as time
                                stable_bad = (
                                    stable_timespans < self.short_limit.to_value(u.s)
                                )
                            except:
                                # Try short limit as fraction
                                median_stable = np.median(stable_timespans)
                                stable_bad = (
                                    stable_timespans < self.short_limit * median_stable
                                )
                            begin_stable = np.array(
                                [x for (x, y) in zip(begin_stable, stable_bad) if not y]
                            )
                            end_stable = np.array(
                                [x for (x, y) in zip(end_stable, stable_bad) if not y]
                            )
                        if self.cut_long:
                            stable_timespans = np.array(
                                [
                                    stamps[y - 1] - stamps[x]
                                    for x, y in zip(begin_stable, end_stable)
                                ]
                            )
                            try:
                                # First try long limit as time
                                stable_bad = (
                                    stable_timespans > self.long_limit.to_value(u.s)
                                )
                            except:
                                # Try long limit as fraction
                                median_stable = np.median(stable_timespans)
                                stable_bad = (
                                    stable_timespans > self.long_limit * median_stable
                                )
                            begin_stable = np.array(
                                [x for (x, y) in zip(begin_stable, stable_bad) if not y]
                            )
                            end_stable = np.array(
                                [x for (x, y) in zip(end_stable, stable_bad) if not y]
                            )
                    if len(begin_stable) == 0:
                        have_scanning = False

                # The "throw" intervals extend from one turnaround to the next.
                # We start the first throw at the beginning of the first stable scan
                # and then find the sample between stable scans where the turnaround
                # happens.  This reduces false detections of turnarounds before or
                # after the stable scanning within the observation.
                #
                # If no turnaround is found between stable scans, we log a warning
                # and choose the sample midway between stable scans to be the throw
                # boundary.
                if have_scanning:
                    begin_throw = [begin_stable[0]]
                    end_throw = list()
                    vel_switch = list()
                    for start_turn, end_turn in zip(end_stable[:-1], begin_stable[1:]):
                        # Fit a quadratic polynomial and find the velocity change sample
                        vel_turn = self._find_turnaround(scan_vel[start_turn:end_turn])
                        if vel_turn is None:
                            msg = f"{obs.name}: Turnaround not found between"
                            msg += " end of stable scan at"
                            msg += f" sample {start_turn} and next start at"
                            msg += f" {end_turn}. Selecting midpoint as turnaround."
                            log.warning(msg)
                            half_gap = (end_turn - start_turn) // 2
                            end_throw.append(start_turn + half_gap)
                        else:
                            end_throw.append(start_turn + vel_turn)
                        vel_switch.append(end_throw[-1])
                        begin_throw.append(end_throw[-1] + 1)
                    end_throw.append(end_stable[-1])
                    begin_throw = np.array(begin_throw)
                    end_throw = np.array(end_throw)
                    vel_switch = np.array(vel_switch)

                    stable_times = [
                        (stamps[x[0]], stamps[x[1]])
                        for x in zip(begin_stable, end_stable)
                    ]
                    throw_times = [
                        (stamps[x[0]], stamps[x[1]])
                        for x in zip(begin_throw, end_throw)
                    ]

                    throw_leftright_times = list()
                    throw_rightleft_times = list()
                    stable_leftright_times = list()
                    stable_rightleft_times = list()

                    # Split scans into left and right-going intervals
                    for iscan, (first, last) in enumerate(
                        zip(begin_stable, end_stable)
                    ):
                        # Check the velocity at the middle of the scan
                        mid = first + (last - first) // 2
                        if scan_vel[mid] >= 0:
                            stable_leftright_times.append(stable_times[iscan])
                            throw_leftright_times.append(throw_times[iscan])
                        else:
                            stable_rightleft_times.append(stable_times[iscan])
                            throw_rightleft_times.append(throw_times[iscan])

                if self.debug_root is not None:
                    set_matplotlib_backend()

                    import matplotlib.pyplot as plt

                    # Dump some plots
                    out_file = f"{self.debug_root}_{obs.name}_{obs.comm_row_rank}.pdf"
                    if have_scanning:
                        if len(end_throw) >= 5:
                            # Plot a few scans
                            plot_start = 0
                            n_plot = end_throw[4]
                        else:
                            # Plot it all
                            plot_start = 0
                            n_plot = obs.n_local_samples
                        pslc = slice(plot_start, plot_start + n_plot, 1)
                        px = np.arange(plot_start, plot_start + n_plot, 1)

                        swplot = vel_switch[
                            np.logical_and(
                                vel_switch <= plot_start + n_plot,
                                vel_switch >= plot_start,
                            )
                        ]
                        bstable = begin_stable[
                            np.logical_and(
                                begin_stable <= plot_start + n_plot,
                                begin_stable >= plot_start,
                            )
                        ]
                        estable = end_stable[
                            np.logical_and(
                                end_stable <= plot_start + n_plot,
                                end_stable >= plot_start,
                            )
                        ]
                        bthrow = begin_throw[
                            np.logical_and(
                                begin_throw <= plot_start + n_plot,
                                begin_throw >= plot_start,
                            )
                        ]
                        ethrow = end_throw[
                            np.logical_and(
                                end_throw <= plot_start + n_plot,
                                end_throw >= plot_start,
                            )
                        ]

                        fig = plt.figure(dpi=100, figsize=(8, 16))

                        ax = fig.add_subplot(4, 1, 1)
                        ax.plot(px, azimuth[pslc], "-", label="Azimuth")
                        ax.legend(loc="best")
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Azimuth (Radians)")

                        ax = fig.add_subplot(4, 1, 2)
                        ax.plot(px, stable[pslc], "-", label="Stable Pointing")
                        ax.plot(px, flags[pslc], color="black", label="Flags")
                        ax.vlines(
                            bstable,
                            ymin=-1,
                            ymax=2,
                            color="green",
                            label="Begin Stable",
                        )
                        ax.vlines(
                            estable, ymin=-1, ymax=2, color="red", label="End Stable"
                        )
                        ax.vlines(
                            bthrow, ymin=-2, ymax=1, color="cyan", label="Begin Throw"
                        )
                        ax.vlines(
                            ethrow, ymin=-2, ymax=1, color="purple", label="End Throw"
                        )
                        ax.legend(loc="best")
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Stable Scan / Throw")

                        ax = fig.add_subplot(4, 1, 3)
                        ax.plot(px, scan_vel[pslc], "-", label="Velocity")
                        ax.vlines(
                            swplot,
                            ymin=np.amin(scan_vel),
                            ymax=np.amax(scan_vel),
                            color="red",
                            label="Velocity Switch",
                        )
                        ax.legend(loc="best")
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Scan Velocity (Radians / s)")

                        ax = fig.add_subplot(4, 1, 4)
                        ax.plot(px, scan_accel[pslc], "-", label="Acceleration")
                        ax.legend(loc="best")
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Scan Acceleration")
                    else:
                        n_plot = obs.n_local_samples
                        fig = plt.figure(dpi=100, figsize=(8, 12))

                        ax = fig.add_subplot(3, 1, 1)
                        ax.plot(
                            np.arange(n_plot),
                            azimuth[:n_plot],
                            "-",
                        )
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Azimuth")

                        ax = fig.add_subplot(3, 1, 2)
                        ax.plot(np.arange(n_plot), scan_vel[:n_plot], "-")
                        ax.vlines(
                            swplot,
                            ymin=np.amin(scan_vel),
                            ymax=np.amax(scan_vel),
                        )
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Scan Velocity")

                        ax = fig.add_subplot(3, 1, 3)
                        ax.plot(np.arange(n_plot), scan_accel[:n_plot], "-")
                        ax.set_xlabel("Samples")
                        ax.set_ylabel("Scan Acceleration")
                    plt.savefig(out_file)
                    plt.close()

            # Now create the intervals across each process column
            if obs.comm_col is not None:
                have_scanning = obs.comm_col.bcast(have_scanning, root=0)

            if have_scanning:
                # The throw intervals are between turnarounds
                obs.intervals.create_col(
                    self.throw_interval, throw_times, stamps, fromrank=0
                )
                obs.intervals.create_col(
                    self.throw_leftright_interval,
                    throw_leftright_times,
                    stamps,
                    fromrank=0,
                )
                obs.intervals.create_col(
                    self.throw_rightleft_interval,
                    throw_rightleft_times,
                    stamps,
                    fromrank=0,
                )

                # Stable scanning intervals
                obs.intervals.create_col(
                    self.scanning_interval, stable_times, stamps, fromrank=0
                )
                obs.intervals.create_col(
                    self.scan_leftright_interval,
                    stable_leftright_times,
                    stamps,
                    fromrank=0,
                )
                obs.intervals.create_col(
                    self.scan_rightleft_interval,
                    stable_rightleft_times,
                    stamps,
                    fromrank=0,
                )

                # Turnarounds are the inverse of stable scanning
                obs.intervals[self.turnaround_interval] = ~obs.intervals[
                    self.scanning_interval
                ]
            else:
                # Flag all samples as unstable
                if self.shared_flags not in obs.shared:
                    obs.shared.create_column(
                        self.shared_flags,
                        shape=(obs.n_local_samples,),
                        dtype=np.uint8,
                    )
                if obs.comm_col_rank == 0:
                    obs.shared[self.shared_flags].set(
                        np.zeros_like(obs.shared[self.shared_flags].data),
                        offset=(0,),
                        fromrank=0,
                    )
                else:
                    obs.shared[self.shared_flags].set(None, offset=(0,), fromrank=0)

        # Additionally flag turnarounds as unstable pointing
        flag_intervals = FlagIntervals(
            shared_flags=self.shared_flags,
            shared_flag_bytes=1,
            view_mask=[
                (self.turnaround_interval, defaults.shared_mask_unstable_scanrate),
            ],
        )
        flag_intervals.apply(data, detectors=None)

    def _find_turnaround(self, vel):
        """Fit a polynomial and find the turnaround sample."""
        x = np.arange(len(vel))
        fit_poly = np.polynomial.polynomial.Polynomial.fit(x, vel, 5)
        fit_vel = fit_poly(x)
        vel_switch = np.where(fit_vel[:-1] * fit_vel[1:] < 0)[0]
        if len(vel_switch) != 1:
            return None
        else:
            return vel_switch[0]

    def _gradient(self, data, window, flags=None):
        """Compute the numerical derivative with smoothing.

        Args:
            data (array):  The local data buffer to process.
            window (int):  The number of samples in the smoothing window.
            flags (array):  The optional array of sample flags.

        Returns:
            (array):  The result.

        """
        if flags is not None:
            # Fill flags with noise
            flagged_noise_fill(data, flags, window // 4, poly_order=5)
        # Smooth the data
        smoothed = uniform_filter1d(
            data,
            size=window,
            mode="nearest",
        )
        # Derivative
        result = np.gradient(smoothed)
        return result

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "shared": [self.times, self.azimuth],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        return req

    def _provides(self):
        return {
            "intervals": [
                self.scanning_interval,
                self.turnaround_interval,
                self.scan_leftright_interval,
                self.scan_rightleft_interval,
                self.turn_leftright_interval,
                self.turn_rightleft_interval,
                self.throw_interval,
                self.throw_leftright_interval,
                self.throw_rightleft_interval,
            ]
        }

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

azimuth = Unicode(defaults.azimuth, help='Observation shared key for Azimuth') class-attribute instance-attribute

cut_long = Bool(True, help='If True, remove very long scanning intervals') class-attribute instance-attribute

cut_short = Bool(True, help='If True, remove very short scanning intervals') class-attribute instance-attribute

debug_root = Unicode(None, allow_none=True, help='If not None, dump debug plots to this root file name') class-attribute instance-attribute

long_limit = Quantity(1.25 * u.dimensionless_unscaled, help='Maximum length of a scan. Either the maximum length in time or a fraction of median scan length') class-attribute instance-attribute

scan_leftright_interval = Unicode(defaults.scan_leftright_interval, help='Interval name for left to right scans') class-attribute instance-attribute

scan_rightleft_interval = Unicode(defaults.scan_rightleft_interval, help='Interval name for right to left scans') class-attribute instance-attribute

scanning_interval = Unicode(defaults.scanning_interval, help='Interval name for scanning') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for bad azimuth pointing') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

short_limit = Quantity(0.25 * u.dimensionless_unscaled, help='Minimum length of a scan. Either the minimum length in time or a fraction of median scan length') class-attribute instance-attribute

throw_interval = Unicode(defaults.throw_interval, help='Interval name for scan + turnaround intervals') class-attribute instance-attribute

throw_leftright_interval = Unicode(defaults.throw_leftright_interval, help='Interval name for left to right scans + turnarounds') class-attribute instance-attribute

throw_rightleft_interval = Unicode(defaults.throw_rightleft_interval, help='Interval name for right to left scans + turnarounds') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

turn_leftright_interval = Unicode(defaults.turn_leftright_interval, help='Interval name for turnarounds after left to right scans') class-attribute instance-attribute

turn_rightleft_interval = Unicode(defaults.turn_rightleft_interval, help='Interval name for turnarounds after right to left scans') class-attribute instance-attribute

turnaround_interval = Unicode(defaults.turnaround_interval, help='Interval name for turnarounds') class-attribute instance-attribute

window_seconds = Float(0.5, help='Smoothing window in seconds') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/azimuth_intervals.py
127
128
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_shared_flag_mask(proposal)

Source code in toast/ops/azimuth_intervals.py
120
121
122
123
124
125
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/azimuth_intervals.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    for obs in data.obs:
        # For now, we just have the first process row do the calculation.  It
        # is relatively fast.

        throw_times = None
        throw_leftright_times = None
        throw_rightleft_times = None
        stable_times = None
        stable_leftright_times = None
        stable_rightleft_times = None
        have_scanning = True

        # Sample rate
        stamps = obs.shared[self.times].data
        (rate, dt, dt_min, dt_max, dt_std) = rate_from_times(stamps)

        # Smoothing window in samples
        window = int(rate * self.window_seconds)

        if obs.comm_col_rank == 0:
            # The azimuth angle
            azimuth = np.array(obs.shared[self.azimuth].data)

            # The azimuth flags
            flags = np.array(obs.shared[self.shared_flags].data)
            flags &= self.shared_flag_mask

            # Scan velocity
            scan_vel = self._gradient(azimuth, window, flags=flags)

            # The peak to peak range of the scan velocity
            vel_range = np.amax(scan_vel) - np.amin(scan_vel)

            # Scan acceleration
            scan_accel = self._gradient(scan_vel, window)

            # Peak to peak acceleration range
            accel_range = np.amax(scan_accel) - np.amin(scan_accel)

            # When the acceleration is zero to some tolerance, we are
            # scanning.  However, we also need to only consider times where
            # the velocity is non-zero.
            stable = (np.absolute(scan_accel) < 0.1 * accel_range) * np.ones(
                len(scan_accel), dtype=np.int8
            )
            stable *= np.absolute(scan_vel) > 0.1 * vel_range

            # The first estimate of the samples where stable pointing
            # begins and ends.
            begin_stable = np.where(stable[1:] - stable[:-1] == 1)[0]
            end_stable = np.where(stable[:-1] - stable[1:] == 1)[0]

            if len(begin_stable) == 0 or len(end_stable) == 0:
                msg = f"Observation {obs.name} has no stable scanning"
                msg += f" periods.  You should cut this observation or"
                msg += f" change the filter window.  Flagging all samples"
                msg += f" as unstable pointing."
                log.warning(msg)
                have_scanning = False

            if have_scanning:
                # Refine our list of stable periods
                if begin_stable[0] > end_stable[0]:
                    # We start in the middle of a scan
                    begin_stable = np.concatenate(([0], begin_stable))
                if begin_stable[-1] > end_stable[-1]:
                    # We end in the middle of a scan
                    end_stable = np.concatenate((end_stable, [obs.n_local_samples]))

                # In some situations there are very short stable scans detected at
                # the beginning and end of observations.  Here we cut any short
                # throw and stable periods.
                cut_threshold = 4
                if (self.cut_short or self.cut_long) and (
                    len(begin_stable) >= cut_threshold
                ):
                    if self.cut_short:
                        stable_timespans = np.array(
                            [
                                stamps[y - 1] - stamps[x]
                                for x, y in zip(begin_stable, end_stable)
                            ]
                        )
                        try:
                            # First try short limit as time
                            stable_bad = (
                                stable_timespans < self.short_limit.to_value(u.s)
                            )
                        except:
                            # Try short limit as fraction
                            median_stable = np.median(stable_timespans)
                            stable_bad = (
                                stable_timespans < self.short_limit * median_stable
                            )
                        begin_stable = np.array(
                            [x for (x, y) in zip(begin_stable, stable_bad) if not y]
                        )
                        end_stable = np.array(
                            [x for (x, y) in zip(end_stable, stable_bad) if not y]
                        )
                    if self.cut_long:
                        stable_timespans = np.array(
                            [
                                stamps[y - 1] - stamps[x]
                                for x, y in zip(begin_stable, end_stable)
                            ]
                        )
                        try:
                            # First try long limit as time
                            stable_bad = (
                                stable_timespans > self.long_limit.to_value(u.s)
                            )
                        except:
                            # Try long limit as fraction
                            median_stable = np.median(stable_timespans)
                            stable_bad = (
                                stable_timespans > self.long_limit * median_stable
                            )
                        begin_stable = np.array(
                            [x for (x, y) in zip(begin_stable, stable_bad) if not y]
                        )
                        end_stable = np.array(
                            [x for (x, y) in zip(end_stable, stable_bad) if not y]
                        )
                if len(begin_stable) == 0:
                    have_scanning = False

            # The "throw" intervals extend from one turnaround to the next.
            # We start the first throw at the beginning of the first stable scan
            # and then find the sample between stable scans where the turnaround
            # happens.  This reduces false detections of turnarounds before or
            # after the stable scanning within the observation.
            #
            # If no turnaround is found between stable scans, we log a warning
            # and choose the sample midway between stable scans to be the throw
            # boundary.
            if have_scanning:
                begin_throw = [begin_stable[0]]
                end_throw = list()
                vel_switch = list()
                for start_turn, end_turn in zip(end_stable[:-1], begin_stable[1:]):
                    # Fit a quadratic polynomial and find the velocity change sample
                    vel_turn = self._find_turnaround(scan_vel[start_turn:end_turn])
                    if vel_turn is None:
                        msg = f"{obs.name}: Turnaround not found between"
                        msg += " end of stable scan at"
                        msg += f" sample {start_turn} and next start at"
                        msg += f" {end_turn}. Selecting midpoint as turnaround."
                        log.warning(msg)
                        half_gap = (end_turn - start_turn) // 2
                        end_throw.append(start_turn + half_gap)
                    else:
                        end_throw.append(start_turn + vel_turn)
                    vel_switch.append(end_throw[-1])
                    begin_throw.append(end_throw[-1] + 1)
                end_throw.append(end_stable[-1])
                begin_throw = np.array(begin_throw)
                end_throw = np.array(end_throw)
                vel_switch = np.array(vel_switch)

                stable_times = [
                    (stamps[x[0]], stamps[x[1]])
                    for x in zip(begin_stable, end_stable)
                ]
                throw_times = [
                    (stamps[x[0]], stamps[x[1]])
                    for x in zip(begin_throw, end_throw)
                ]

                throw_leftright_times = list()
                throw_rightleft_times = list()
                stable_leftright_times = list()
                stable_rightleft_times = list()

                # Split scans into left and right-going intervals
                for iscan, (first, last) in enumerate(
                    zip(begin_stable, end_stable)
                ):
                    # Check the velocity at the middle of the scan
                    mid = first + (last - first) // 2
                    if scan_vel[mid] >= 0:
                        stable_leftright_times.append(stable_times[iscan])
                        throw_leftright_times.append(throw_times[iscan])
                    else:
                        stable_rightleft_times.append(stable_times[iscan])
                        throw_rightleft_times.append(throw_times[iscan])

            if self.debug_root is not None:
                set_matplotlib_backend()

                import matplotlib.pyplot as plt

                # Dump some plots
                out_file = f"{self.debug_root}_{obs.name}_{obs.comm_row_rank}.pdf"
                if have_scanning:
                    if len(end_throw) >= 5:
                        # Plot a few scans
                        plot_start = 0
                        n_plot = end_throw[4]
                    else:
                        # Plot it all
                        plot_start = 0
                        n_plot = obs.n_local_samples
                    pslc = slice(plot_start, plot_start + n_plot, 1)
                    px = np.arange(plot_start, plot_start + n_plot, 1)

                    swplot = vel_switch[
                        np.logical_and(
                            vel_switch <= plot_start + n_plot,
                            vel_switch >= plot_start,
                        )
                    ]
                    bstable = begin_stable[
                        np.logical_and(
                            begin_stable <= plot_start + n_plot,
                            begin_stable >= plot_start,
                        )
                    ]
                    estable = end_stable[
                        np.logical_and(
                            end_stable <= plot_start + n_plot,
                            end_stable >= plot_start,
                        )
                    ]
                    bthrow = begin_throw[
                        np.logical_and(
                            begin_throw <= plot_start + n_plot,
                            begin_throw >= plot_start,
                        )
                    ]
                    ethrow = end_throw[
                        np.logical_and(
                            end_throw <= plot_start + n_plot,
                            end_throw >= plot_start,
                        )
                    ]

                    fig = plt.figure(dpi=100, figsize=(8, 16))

                    ax = fig.add_subplot(4, 1, 1)
                    ax.plot(px, azimuth[pslc], "-", label="Azimuth")
                    ax.legend(loc="best")
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Azimuth (Radians)")

                    ax = fig.add_subplot(4, 1, 2)
                    ax.plot(px, stable[pslc], "-", label="Stable Pointing")
                    ax.plot(px, flags[pslc], color="black", label="Flags")
                    ax.vlines(
                        bstable,
                        ymin=-1,
                        ymax=2,
                        color="green",
                        label="Begin Stable",
                    )
                    ax.vlines(
                        estable, ymin=-1, ymax=2, color="red", label="End Stable"
                    )
                    ax.vlines(
                        bthrow, ymin=-2, ymax=1, color="cyan", label="Begin Throw"
                    )
                    ax.vlines(
                        ethrow, ymin=-2, ymax=1, color="purple", label="End Throw"
                    )
                    ax.legend(loc="best")
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Stable Scan / Throw")

                    ax = fig.add_subplot(4, 1, 3)
                    ax.plot(px, scan_vel[pslc], "-", label="Velocity")
                    ax.vlines(
                        swplot,
                        ymin=np.amin(scan_vel),
                        ymax=np.amax(scan_vel),
                        color="red",
                        label="Velocity Switch",
                    )
                    ax.legend(loc="best")
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Scan Velocity (Radians / s)")

                    ax = fig.add_subplot(4, 1, 4)
                    ax.plot(px, scan_accel[pslc], "-", label="Acceleration")
                    ax.legend(loc="best")
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Scan Acceleration")
                else:
                    n_plot = obs.n_local_samples
                    fig = plt.figure(dpi=100, figsize=(8, 12))

                    ax = fig.add_subplot(3, 1, 1)
                    ax.plot(
                        np.arange(n_plot),
                        azimuth[:n_plot],
                        "-",
                    )
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Azimuth")

                    ax = fig.add_subplot(3, 1, 2)
                    ax.plot(np.arange(n_plot), scan_vel[:n_plot], "-")
                    ax.vlines(
                        swplot,
                        ymin=np.amin(scan_vel),
                        ymax=np.amax(scan_vel),
                    )
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Scan Velocity")

                    ax = fig.add_subplot(3, 1, 3)
                    ax.plot(np.arange(n_plot), scan_accel[:n_plot], "-")
                    ax.set_xlabel("Samples")
                    ax.set_ylabel("Scan Acceleration")
                plt.savefig(out_file)
                plt.close()

        # Now create the intervals across each process column
        if obs.comm_col is not None:
            have_scanning = obs.comm_col.bcast(have_scanning, root=0)

        if have_scanning:
            # The throw intervals are between turnarounds
            obs.intervals.create_col(
                self.throw_interval, throw_times, stamps, fromrank=0
            )
            obs.intervals.create_col(
                self.throw_leftright_interval,
                throw_leftright_times,
                stamps,
                fromrank=0,
            )
            obs.intervals.create_col(
                self.throw_rightleft_interval,
                throw_rightleft_times,
                stamps,
                fromrank=0,
            )

            # Stable scanning intervals
            obs.intervals.create_col(
                self.scanning_interval, stable_times, stamps, fromrank=0
            )
            obs.intervals.create_col(
                self.scan_leftright_interval,
                stable_leftright_times,
                stamps,
                fromrank=0,
            )
            obs.intervals.create_col(
                self.scan_rightleft_interval,
                stable_rightleft_times,
                stamps,
                fromrank=0,
            )

            # Turnarounds are the inverse of stable scanning
            obs.intervals[self.turnaround_interval] = ~obs.intervals[
                self.scanning_interval
            ]
        else:
            # Flag all samples as unstable
            if self.shared_flags not in obs.shared:
                obs.shared.create_column(
                    self.shared_flags,
                    shape=(obs.n_local_samples,),
                    dtype=np.uint8,
                )
            if obs.comm_col_rank == 0:
                obs.shared[self.shared_flags].set(
                    np.zeros_like(obs.shared[self.shared_flags].data),
                    offset=(0,),
                    fromrank=0,
                )
            else:
                obs.shared[self.shared_flags].set(None, offset=(0,), fromrank=0)

    # Additionally flag turnarounds as unstable pointing
    flag_intervals = FlagIntervals(
        shared_flags=self.shared_flags,
        shared_flag_bytes=1,
        view_mask=[
            (self.turnaround_interval, defaults.shared_mask_unstable_scanrate),
        ],
    )
    flag_intervals.apply(data, detectors=None)

_finalize(data, **kwargs)

Source code in toast/ops/azimuth_intervals.py
557
558
def _finalize(self, data, **kwargs):
    return

_find_turnaround(vel)

Fit a polynomial and find the turnaround sample.

Source code in toast/ops/azimuth_intervals.py
521
522
523
524
525
526
527
528
529
530
def _find_turnaround(self, vel):
    """Fit a polynomial and find the turnaround sample."""
    x = np.arange(len(vel))
    fit_poly = np.polynomial.polynomial.Polynomial.fit(x, vel, 5)
    fit_vel = fit_poly(x)
    vel_switch = np.where(fit_vel[:-1] * fit_vel[1:] < 0)[0]
    if len(vel_switch) != 1:
        return None
    else:
        return vel_switch[0]

_gradient(data, window, flags=None)

Compute the numerical derivative with smoothing.

Parameters:

Name Type Description Default
data array

The local data buffer to process.

required
window int

The number of samples in the smoothing window.

required
flags array

The optional array of sample flags.

None

Returns:

Type Description
array

The result.

Source code in toast/ops/azimuth_intervals.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def _gradient(self, data, window, flags=None):
    """Compute the numerical derivative with smoothing.

    Args:
        data (array):  The local data buffer to process.
        window (int):  The number of samples in the smoothing window.
        flags (array):  The optional array of sample flags.

    Returns:
        (array):  The result.

    """
    if flags is not None:
        # Fill flags with noise
        flagged_noise_fill(data, flags, window // 4, poly_order=5)
    # Smooth the data
    smoothed = uniform_filter1d(
        data,
        size=window,
        mode="nearest",
    )
    # Derivative
    result = np.gradient(smoothed)
    return result

_provides()

Source code in toast/ops/azimuth_intervals.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
def _provides(self):
    return {
        "intervals": [
            self.scanning_interval,
            self.turnaround_interval,
            self.scan_leftright_interval,
            self.scan_rightleft_interval,
            self.turn_leftright_interval,
            self.turn_rightleft_interval,
            self.throw_interval,
            self.throw_leftright_interval,
            self.throw_rightleft_interval,
        ]
    }

_requires()

Source code in toast/ops/azimuth_intervals.py
560
561
562
563
564
565
566
def _requires(self):
    req = {
        "shared": [self.times, self.azimuth],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    return req

toast.ops.FlagIntervals

Bases: Operator

Operator which updates shared flags from interval lists.

This operator can be used in cases where interval information needs to be combined with shared flags. The view_mask trait is a list of tuples. Each tuple contains the name of the view (i.e. interval) to apply and the bitmask to use for that view. For each interval view, flag values in the shared_flags object are bitwise- OR'd with the specified mask for samples in the view. If the name of the view is prefixed with '~' the bitmask is applied to all samples outside the view.

Source code in toast/ops/flag_intervals.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@trait_docs
class FlagIntervals(Operator):
    """Operator which updates shared flags from interval lists.

    This operator can be used in cases where interval information needs to be combined
    with shared flags.  The view_mask trait is a list of tuples.  Each tuple contains
    the name of the view (i.e. interval) to apply and the bitmask to use for that
    view.  For each interval view, flag values in the shared_flags object are bitwise-
    OR'd with the specified mask for samples in the view.  If the name of the view is
    prefixed with '~' the bitmask is applied to all samples outside the view.
    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    view_mask = List([], help="List of tuples of (view name, bit mask)")

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_bytes = Int(
        1, help="If creating shared key, use this many bytes per sample"
    )

    reset = Bool(
        False, help="If True, flag bits are first set to 0 for the entire observation"
    )

    @traitlets.validate("shared_flag_bytes")
    def _check_flag_bytes(self, proposal):
        check = proposal["value"]
        if check not in [1, 2, 4, 8]:
            raise traitlets.TraitError("shared flag byte width should be 1, 2, 4, or 8")
        return check

    @traitlets.validate("view_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        for vname, vmask in check:
            if vmask < 0:
                raise traitlets.TraitError("Flag masks should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if self.shared_flags is None:
            log.debug_rank(
                "shared_flags trait is None, nothing to do.", comm=data.comm.comm_world
            )
            return

        if self.view_mask is None or len(self.view_mask) == 0:
            log.debug_rank(
                "view_mask trait is empty or not set, nothing to do.",
                comm=data.comm.world_comm,
            )
            return

        fdtype = None
        if self.shared_flag_bytes == 8:
            fdtype = np.uint64
        elif self.shared_flag_bytes == 4:
            fdtype = np.uint32
        elif self.shared_flag_bytes == 2:
            fdtype = np.uint16
        else:
            fdtype = np.uint8

        for ob in data.obs:
            # If the shared flag object already exists, then use it with whatever
            # byte width is in place.  Otherwise create it.

            if self.shared_flags not in ob.shared:
                ob.shared.create_column(
                    self.shared_flags,
                    shape=(ob.n_local_samples,),
                    dtype=fdtype,
                )

            # The intervals / view is common between all processes in a column of the
            # process grid.  Only the rank zero process in each column builds the new
            # flags for the synchronous call to the set() method.  Note that views
            # of shared data are read-only, so we build the full flag vector and only
            # modify samples inside the view.

            new_flags = None
            if ob.comm_col_rank == 0:
                new_flags = np.array(ob.shared[self.shared_flags])
                if self.reset:
                    for vname, vmask in self.view_mask:
                        new_flags &= ~vmask
                for vname, vmask in self.view_mask:
                    try:
                        for vw in ob.view[vname]:
                            # Note that a View acts like a slice
                            new_flags[vw] |= vmask
                    except KeyError as e:
                        msg = f"{e}; Intervals '{vname}' does not exist in {ob.name}"
                        msg += " skipping flagging"
                        log.warning(msg)
            ob.shared[self.shared_flags].set(new_flags, offset=(0,), fromrank=0)

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
            "intervals": list(),
        }
        if self.view_mask is not None:
            req["intervals"] = [x[0] for x in self.view_mask]
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": [self.shared_flags],
            "detdata": list(),
            "intervals": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

reset = Bool(False, help='If True, flag bits are first set to 0 for the entire observation') class-attribute instance-attribute

shared_flag_bytes = Int(1, help='If creating shared key, use this many bytes per sample') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

view_mask = List([], help='List of tuples of (view name, bit mask)') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/flag_intervals.py
62
63
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_flag_bytes(proposal)

Source code in toast/ops/flag_intervals.py
47
48
49
50
51
52
@traitlets.validate("shared_flag_bytes")
def _check_flag_bytes(self, proposal):
    check = proposal["value"]
    if check not in [1, 2, 4, 8]:
        raise traitlets.TraitError("shared flag byte width should be 1, 2, 4, or 8")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/flag_intervals.py
54
55
56
57
58
59
60
@traitlets.validate("view_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    for vname, vmask in check:
        if vmask < 0:
            raise traitlets.TraitError("Flag masks should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/flag_intervals.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if self.shared_flags is None:
        log.debug_rank(
            "shared_flags trait is None, nothing to do.", comm=data.comm.comm_world
        )
        return

    if self.view_mask is None or len(self.view_mask) == 0:
        log.debug_rank(
            "view_mask trait is empty or not set, nothing to do.",
            comm=data.comm.world_comm,
        )
        return

    fdtype = None
    if self.shared_flag_bytes == 8:
        fdtype = np.uint64
    elif self.shared_flag_bytes == 4:
        fdtype = np.uint32
    elif self.shared_flag_bytes == 2:
        fdtype = np.uint16
    else:
        fdtype = np.uint8

    for ob in data.obs:
        # If the shared flag object already exists, then use it with whatever
        # byte width is in place.  Otherwise create it.

        if self.shared_flags not in ob.shared:
            ob.shared.create_column(
                self.shared_flags,
                shape=(ob.n_local_samples,),
                dtype=fdtype,
            )

        # The intervals / view is common between all processes in a column of the
        # process grid.  Only the rank zero process in each column builds the new
        # flags for the synchronous call to the set() method.  Note that views
        # of shared data are read-only, so we build the full flag vector and only
        # modify samples inside the view.

        new_flags = None
        if ob.comm_col_rank == 0:
            new_flags = np.array(ob.shared[self.shared_flags])
            if self.reset:
                for vname, vmask in self.view_mask:
                    new_flags &= ~vmask
            for vname, vmask in self.view_mask:
                try:
                    for vw in ob.view[vname]:
                        # Note that a View acts like a slice
                        new_flags[vw] |= vmask
                except KeyError as e:
                    msg = f"{e}; Intervals '{vname}' does not exist in {ob.name}"
                    msg += " skipping flagging"
                    log.warning(msg)
        ob.shared[self.shared_flags].set(new_flags, offset=(0,), fromrank=0)

_finalize(data, **kwargs)

Source code in toast/ops/flag_intervals.py
126
127
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/flag_intervals.py
140
141
142
143
144
145
146
147
def _provides(self):
    prov = {
        "meta": list(),
        "shared": [self.shared_flags],
        "detdata": list(),
        "intervals": list(),
    }
    return prov

_requires()

Source code in toast/ops/flag_intervals.py
129
130
131
132
133
134
135
136
137
138
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
        "intervals": list(),
    }
    if self.view_mask is not None:
        req["intervals"] = [x[0] for x in self.view_mask]
    return req

toast.ops.FlagSSO

Bases: Operator

Operator which flags detector data in the vicinity of solar system objects

Source code in toast/ops/flag_sso.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@trait_docs
class FlagSSO(Operator):
    """Operator which flags detector data in the vicinity of solar system objects"""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    detector_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="Operator that translates boresight Az/El pointing into detector frame",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(defaults.det_mask_sso, help="Bit mask to raise flags with")

    sso_names = List(
        [],
        help="Names of the SSOs, must be recognized by pyEphem",
    )

    sso_radii = List(
        [],
        help="Radii around the sources to flag",
    )

    @traitlets.validate("detector_pointing")
    def _check_detector_pointing(self, proposal):
        detpointing = proposal["value"]
        if detpointing is not None:
            if not isinstance(detpointing, Operator):
                raise traitlets.TraitError(
                    "detector_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in [
                "view",
                "boresight",
                "shared_flags",
                "shared_flag_mask",
                "quats",
                "coord_in",
                "coord_out",
            ]:
                if not detpointing.has_trait(trt):
                    msg = f"detector_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return detpointing

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if len(self.sso_names) != len(self.sso_radii):
            raise RuntimeError("Each SSO must have a radius")

        if len(self.sso_names) == 0:
            log.debug_rank(
                "Empty sso_names, nothing to flag", comm=data.comm.comm_world
            )
            return

        self.ssos = []
        for sso_name in self.sso_names:
            self.ssos.append(getattr(ephem, sso_name)())
        self.nsso = len(self.ssos)

        for obs in data.obs:
            dets = obs.select_local_detectors(detectors)
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            site = obs.telescope.site
            observer = ephem.Observer()
            observer.lon = site.earthloc.lon.to_value(u.radian)
            observer.lat = site.earthloc.lat.to_value(u.radian)
            observer.elevation = site.earthloc.height.to_value(u.meter)
            observer.epoch = ephem.J2000
            observer.temp = 0  # in Celcius
            observer.compute_pressure()

            # Get the observation time span and compute the horizontal
            # position of the SSO
            times = obs.shared[self.times].data
            sso_azs, sso_els = self._get_sso_positions(times, observer)

            self._flag_ssos(data, obs, dets, sso_azs, sso_els)

        return

    @function_timer
    def _get_sso_positions(self, times, observer):
        """
        Calculate the SSO horizontal position
        """
        sso_azs = np.zeros([self.nsso, times.size])
        sso_els = np.zeros([self.nsso, times.size])
        # Only evaluate the position every second and interpolate
        # in between
        n = min(int(times[-1] - times[0]), 2)
        tvec = np.linspace(times[0], times[-1], n)
        for isso, sso in enumerate(self.ssos):
            azvec = np.zeros(n)
            elvec = np.zeros(n)
            for i, t in enumerate(tvec):
                observer.date = to_DJD(t)
                sso.compute(observer)
                azvec[i] = sso.az
                elvec[i] = sso.alt
            azvec = np.unwrap(azvec)
            sso_azs[isso] = np.interp(times, tvec, azvec) % (2 * np.pi)
            sso_els[isso] = np.interp(times, tvec, elvec)
        return sso_azs, sso_els

    @function_timer
    def _flag_ssos(self, data, obs, dets, sso_azs, sso_els):
        """
        Flag the SSO for each detector in tod
        """
        log = Logger.get()

        exists_flags = obs.detdata.ensure(
            self.det_flags, dtype=np.uint8, detectors=dets
        )

        for det in dets:
            try:
                # Use cached detector quaternions
                quats = obs.detdata[self.detector_pointing.quats][det]
            except KeyError:
                # Compute the detector quaternions
                obs_data = data.select(obs_uid=obs.uid)
                self.detector_pointing.apply(obs_data, detectors=[det])
                quats = obs.detdata[self.detector_pointing.quats][det]

            det_vec = qa.rotate(quats, ZAXIS)

            flags = obs.detdata[self.det_flags][det]

            for sso_name, sso_az, sso_el, sso_radius in zip(
                self.sso_names, sso_azs, sso_els, self.sso_radii
            ):
                radius = sso_radius.to_value(u.radian)
                sso_vec = hp.dir2vec(np.pi / 2 - sso_el, -sso_az).T
                dp = np.sum(det_vec * sso_vec, 1)
                inside = dp > np.cos(radius)
                frac = np.sum(inside) / inside.size
                if frac > 0:
                    log.debug(
                        f"Flagged {frac * 100:.1f} % samples for "
                        f"{det} due to {sso_name} in {obs.name}"
                    )
                flags[inside] |= self.det_flag_mask

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_flags],
            "intervals": [self.view],
        }
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_flags],
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_sso, help='Bit mask to raise flags with') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

detector_pointing = Instance(klass=Operator, allow_none=True, help='Operator that translates boresight Az/El pointing into detector frame') class-attribute instance-attribute

sso_names = List([], help='Names of the SSOs, must be recognized by pyEphem') class-attribute instance-attribute

sso_radii = List([], help='Radii around the sources to flag') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/flag_sso.py
95
96
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/flag_sso.py
88
89
90
91
92
93
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_detector_pointing(proposal)

Source code in toast/ops/flag_sso.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
    detpointing = proposal["value"]
    if detpointing is not None:
        if not isinstance(detpointing, Operator):
            raise traitlets.TraitError(
                "detector_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in [
            "view",
            "boresight",
            "shared_flags",
            "shared_flag_mask",
            "quats",
            "coord_in",
            "coord_out",
        ]:
            if not detpointing.has_trait(trt):
                msg = f"detector_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return detpointing

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/flag_sso.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if len(self.sso_names) != len(self.sso_radii):
        raise RuntimeError("Each SSO must have a radius")

    if len(self.sso_names) == 0:
        log.debug_rank(
            "Empty sso_names, nothing to flag", comm=data.comm.comm_world
        )
        return

    self.ssos = []
    for sso_name in self.sso_names:
        self.ssos.append(getattr(ephem, sso_name)())
    self.nsso = len(self.ssos)

    for obs in data.obs:
        dets = obs.select_local_detectors(detectors)
        if len(dets) == 0:
            # Nothing to do for this observation
            continue
        site = obs.telescope.site
        observer = ephem.Observer()
        observer.lon = site.earthloc.lon.to_value(u.radian)
        observer.lat = site.earthloc.lat.to_value(u.radian)
        observer.elevation = site.earthloc.height.to_value(u.meter)
        observer.epoch = ephem.J2000
        observer.temp = 0  # in Celcius
        observer.compute_pressure()

        # Get the observation time span and compute the horizontal
        # position of the SSO
        times = obs.shared[self.times].data
        sso_azs, sso_els = self._get_sso_positions(times, observer)

        self._flag_ssos(data, obs, dets, sso_azs, sso_els)

    return

_finalize(data, **kwargs)

Source code in toast/ops/flag_sso.py
205
206
def _finalize(self, data, **kwargs):
    return

_flag_ssos(data, obs, dets, sso_azs, sso_els)

Flag the SSO for each detector in tod

Source code in toast/ops/flag_sso.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@function_timer
def _flag_ssos(self, data, obs, dets, sso_azs, sso_els):
    """
    Flag the SSO for each detector in tod
    """
    log = Logger.get()

    exists_flags = obs.detdata.ensure(
        self.det_flags, dtype=np.uint8, detectors=dets
    )

    for det in dets:
        try:
            # Use cached detector quaternions
            quats = obs.detdata[self.detector_pointing.quats][det]
        except KeyError:
            # Compute the detector quaternions
            obs_data = data.select(obs_uid=obs.uid)
            self.detector_pointing.apply(obs_data, detectors=[det])
            quats = obs.detdata[self.detector_pointing.quats][det]

        det_vec = qa.rotate(quats, ZAXIS)

        flags = obs.detdata[self.det_flags][det]

        for sso_name, sso_az, sso_el, sso_radius in zip(
            self.sso_names, sso_azs, sso_els, self.sso_radii
        ):
            radius = sso_radius.to_value(u.radian)
            sso_vec = hp.dir2vec(np.pi / 2 - sso_el, -sso_az).T
            dp = np.sum(det_vec * sso_vec, 1)
            inside = dp > np.cos(radius)
            frac = np.sum(inside) / inside.size
            if frac > 0:
                log.debug(
                    f"Flagged {frac * 100:.1f} % samples for "
                    f"{det} due to {sso_name} in {obs.name}"
                )
            flags[inside] |= self.det_flag_mask

    return

_get_sso_positions(times, observer)

Calculate the SSO horizontal position

Source code in toast/ops/flag_sso.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@function_timer
def _get_sso_positions(self, times, observer):
    """
    Calculate the SSO horizontal position
    """
    sso_azs = np.zeros([self.nsso, times.size])
    sso_els = np.zeros([self.nsso, times.size])
    # Only evaluate the position every second and interpolate
    # in between
    n = min(int(times[-1] - times[0]), 2)
    tvec = np.linspace(times[0], times[-1], n)
    for isso, sso in enumerate(self.ssos):
        azvec = np.zeros(n)
        elvec = np.zeros(n)
        for i, t in enumerate(tvec):
            observer.date = to_DJD(t)
            sso.compute(observer)
            azvec[i] = sso.az
            elvec[i] = sso.alt
        azvec = np.unwrap(azvec)
        sso_azs[isso] = np.interp(times, tvec, azvec) % (2 * np.pi)
        sso_els[isso] = np.interp(times, tvec, elvec)
    return sso_azs, sso_els

_provides()

Source code in toast/ops/flag_sso.py
217
218
219
220
221
222
223
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_flags],
    }
    return prov

_requires()

Source code in toast/ops/flag_sso.py
208
209
210
211
212
213
214
215
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_flags],
        "intervals": [self.view],
    }
    return req

toast.ops.SimpleDeglitch

Bases: Operator

An operator that flags extreme detector samples.

Source code in toast/ops/simple_deglitch.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@trait_docs
class SimpleDeglitch(Operator):
    """An operator that flags extreme detector samples."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(defaults.det_data, help="Observation detdata key to analyze")

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    reset_det_flags = Bool(
        False,
        help="Replace existing detector flags",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    view = Unicode(
        None,
        allow_none=True,
        help="Find glitches in this view",
    )

    glitch_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value to apply at glitch positions",
    )

    glitch_radius = Int(
        5,
        help="Number of additional samples to flag around a glitch",
    )

    glitch_limit = Float(
        5.0,
        help="Glitch detection threshold in units of RMS",
    )

    nsample_min = Int(
        100,
        help="Minimum number of good samples in an interval.",
    )

    medfilt_kernel_size = Int(
        101,
        help="Median filter kernel width.  Either 0 (full interval) "
        "or a positive odd number",
    )

    fill_gaps = Bool(
        True,
        help="Fill gaps with a trend line and white noise",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("medfilt_kernel_size")
    def _check_medfilt_kernel_size(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("medfilt_kernel_size cannot be negative")
        if check > 0 and check % 2 == 0:
            raise traitlets.TraitError("medfilt_kernel_size cannot be even")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.net_factors = []
        self.total_factors = []
        self.weights_in = []
        self.weights_out = []
        self.rates = []

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for ob in data.obs:
            if not ob.is_distributed_by_detector:
                msg = "Observation data must be distributed by detector, not samples"
                log.error(msg)
                raise RuntimeError(msg)
            views = ob.intervals[self.view]
            focalplane = ob.telescope.focalplane

            local_dets = ob.select_local_detectors(flagmask=self.det_mask)
            shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
            for name in local_dets:
                sig = ob.detdata[self.det_data][name]
                det_flags = ob.detdata[self.det_flags][name]
                if self.reset_det_flags:
                    det_flags[:] = 0
                bad = np.logical_or(
                    shared_flags != 0,
                    (det_flags & self.det_flag_mask) != 0,
                )
                for iview, view in enumerate(views):
                    nsample = view.last - view.first
                    ind = slice(view.first, view.last)
                    sig_view = sig[ind].copy()
                    w = self.medfilt_kernel_size
                    if w > 0 and nsample > 2 * w:
                        # Remove the running median
                        sig_view[w:-w] -= medfilt(sig_view, kernel_size=w)[w:-w]
                        # Special treatment for the ends
                        sig_view[:w] -= np.median(sig_view[:w])
                        sig_view[-w:] -= np.median(sig_view[-w:])
                    trend = sig[ind] - sig_view
                    sig_view[bad[ind]] = np.nan
                    offset = np.nanmedian(sig_view)
                    sig_view -= offset
                    trend += offset
                    rms = np.nanstd(sig_view)
                    nglitch = 0
                    while True:
                        if (
                            np.isnan(rms)
                            or np.sum(np.isfinite(sig_view)) < self.nsample_min
                        ):
                            # flag the entire view.  Not enough statistics
                            sig_view[:] = np.nan
                            break
                        # See if the brightest remaining sample still stands out
                        i = np.nanargmax(np.abs(sig_view))
                        sig_view_test = sig_view.copy()
                        istart = max(0, i - self.glitch_radius)
                        istop = min(nsample, i + self.glitch_radius + 1)
                        sig_view_test[istart:istop] = np.nan
                        rms_test = np.nanstd(sig_view_test)
                        if np.abs(sig_view[i]) < self.glitch_limit * rms_test:
                            # Not significant enough
                            break
                        nglitch += 1
                        sig_view = sig_view_test
                        rms = rms_test
                    if nglitch == 0:
                        continue
                    bad_view = np.isnan(sig_view)
                    det_flags[ind][bad_view] |= self.glitch_mask
                if self.fill_gaps:
                    # 1 second buffer
                    buffer = int(focalplane.sample_rate.to_value(u.Hz))
                    flagged_noise_fill(
                        sig,
                        det_flags,
                        buffer,
                        poly_order=1,
                    )

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
            "intervals": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key to analyze') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

fill_gaps = Bool(True, help='Fill gaps with a trend line and white noise') class-attribute instance-attribute

glitch_limit = Float(5.0, help='Glitch detection threshold in units of RMS') class-attribute instance-attribute

glitch_mask = Int(defaults.det_mask_invalid, help='Bit mask value to apply at glitch positions') class-attribute instance-attribute

glitch_radius = Int(5, help='Number of additional samples to flag around a glitch') class-attribute instance-attribute

medfilt_kernel_size = Int(101, help='Median filter kernel width. Either 0 (full interval) or a positive odd number') class-attribute instance-attribute

net_factors = [] instance-attribute

nsample_min = Int(100, help='Minimum number of good samples in an interval.') class-attribute instance-attribute

rates = [] instance-attribute

reset_det_flags = Bool(False, help='Replace existing detector flags') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

total_factors = [] instance-attribute

view = Unicode(None, allow_none=True, help='Find glitches in this view') class-attribute instance-attribute

weights_in = [] instance-attribute

weights_out = [] instance-attribute

__init__(**kwargs)

Source code in toast/ops/simple_deglitch.py
133
134
135
136
137
138
139
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.net_factors = []
    self.total_factors = []
    self.weights_in = []
    self.weights_out = []
    self.rates = []

_check_det_flag_mask(proposal)

Source code in toast/ops/simple_deglitch.py
117
118
119
120
121
122
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/simple_deglitch.py
103
104
105
106
107
108
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_medfilt_kernel_size(proposal)

Source code in toast/ops/simple_deglitch.py
124
125
126
127
128
129
130
131
@traitlets.validate("medfilt_kernel_size")
def _check_medfilt_kernel_size(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("medfilt_kernel_size cannot be negative")
    if check > 0 and check % 2 == 0:
        raise traitlets.TraitError("medfilt_kernel_size cannot be even")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/simple_deglitch.py
110
111
112
113
114
115
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/simple_deglitch.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for ob in data.obs:
        if not ob.is_distributed_by_detector:
            msg = "Observation data must be distributed by detector, not samples"
            log.error(msg)
            raise RuntimeError(msg)
        views = ob.intervals[self.view]
        focalplane = ob.telescope.focalplane

        local_dets = ob.select_local_detectors(flagmask=self.det_mask)
        shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
        for name in local_dets:
            sig = ob.detdata[self.det_data][name]
            det_flags = ob.detdata[self.det_flags][name]
            if self.reset_det_flags:
                det_flags[:] = 0
            bad = np.logical_or(
                shared_flags != 0,
                (det_flags & self.det_flag_mask) != 0,
            )
            for iview, view in enumerate(views):
                nsample = view.last - view.first
                ind = slice(view.first, view.last)
                sig_view = sig[ind].copy()
                w = self.medfilt_kernel_size
                if w > 0 and nsample > 2 * w:
                    # Remove the running median
                    sig_view[w:-w] -= medfilt(sig_view, kernel_size=w)[w:-w]
                    # Special treatment for the ends
                    sig_view[:w] -= np.median(sig_view[:w])
                    sig_view[-w:] -= np.median(sig_view[-w:])
                trend = sig[ind] - sig_view
                sig_view[bad[ind]] = np.nan
                offset = np.nanmedian(sig_view)
                sig_view -= offset
                trend += offset
                rms = np.nanstd(sig_view)
                nglitch = 0
                while True:
                    if (
                        np.isnan(rms)
                        or np.sum(np.isfinite(sig_view)) < self.nsample_min
                    ):
                        # flag the entire view.  Not enough statistics
                        sig_view[:] = np.nan
                        break
                    # See if the brightest remaining sample still stands out
                    i = np.nanargmax(np.abs(sig_view))
                    sig_view_test = sig_view.copy()
                    istart = max(0, i - self.glitch_radius)
                    istop = min(nsample, i + self.glitch_radius + 1)
                    sig_view_test[istart:istop] = np.nan
                    rms_test = np.nanstd(sig_view_test)
                    if np.abs(sig_view[i]) < self.glitch_limit * rms_test:
                        # Not significant enough
                        break
                    nglitch += 1
                    sig_view = sig_view_test
                    rms = rms_test
                if nglitch == 0:
                    continue
                bad_view = np.isnan(sig_view)
                det_flags[ind][bad_view] |= self.glitch_mask
            if self.fill_gaps:
                # 1 second buffer
                buffer = int(focalplane.sample_rate.to_value(u.Hz))
                flagged_noise_fill(
                    sig,
                    det_flags,
                    buffer,
                    poly_order=1,
                )

    return

_finalize(data, **kwargs)

Source code in toast/ops/simple_deglitch.py
219
220
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/simple_deglitch.py
237
238
239
240
241
242
243
244
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
        "intervals": list(),
    }
    return prov

_requires()

Source code in toast/ops/simple_deglitch.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

toast.ops.SimpleJumpCorrect

Bases: Operator

An operator that identifies and corrects jumps in the data

Source code in toast/ops/simple_jumpcorrect.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
@trait_docs
class SimpleJumpCorrect(Operator):
    """An operator that identifies and corrects jumps in the data"""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(defaults.det_data, help="Observation detdata key to analyze")

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    reset_det_flags = Bool(
        False,
        help="Replace existing detector flags",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    view = Unicode(
        None,
        allow_none=True,
        help="Find jumps in this view",
    )

    jump_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value to apply at glitch positions",
    )

    jump_radius = Int(
        5,
        help="Number of additional samples to flag around a jump",
    )

    jump_limit = Float(
        5.0,
        help="Jump detection threshold in units of RMS",
    )

    filterlen = Int(
        100,
        help="Matched filter length",
    )

    nsample_min = Int(
        100,
        help="Minimum number of good samples in an interval",
    )

    njump_limit = Int(
        10,
        help="If the detector has more than `njump_limit` jumps the detector "
        "the detector and time stream will be flagged as invalid.",
    )

    save_jumps = Unicode(
        None,
        allow_none=True,
        help="Save the jump corrections to a dictionary of values per observation",
    )

    apply_jumps = Unicode(
        None,
        allow_none=True,
        help="Do not compute jumps, instead apply the specified dictionary of values",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("njump_limit")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check <= 0:
            raise traitlets.TraitError("njump limit should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.net_factors = []
        self.total_factors = []
        self.weights_in = []
        self.weights_out = []
        self.rates = []

    def _get_stepfilter(self, m):
        """
        Return the time domain matched filter kernel of length m.
        """
        h = np.zeros(m)
        h[: m // 2] = 1
        h[m // 2 :] = -1
        # This turns the interpretation of the peak amplitude directly
        # into the step amplitude
        h /= m // 2
        return h

    def _find_peaks(self, toi, flag, flag_out, lim=3.0, tol=1e4, sigma_in=None):
        """
        Find the peaks and their amplitudes in the match-filtered TOI.
        Inputs:
        lim -- threshold for jump detection in units of filtered TOI RMS.
        tol -- radius of a region to mask from further peak finding upon
            detecting a peak.
        sigma_in -- an estimate of filtered TOI RMS that overrides the
             sample variance otherwise used.

        """
        peaks = []
        mytoi = np.ma.masked_array(toi)
        nsample = len(mytoi)
        # Do not accept jumps at the ends due to boundary effects
        lbound = tol
        rbound = tol
        mytoi[:lbound] = np.ma.masked
        mytoi[-rbound:] = np.ma.masked
        if sigma_in is None:
            sigma = self._get_sigma(mytoi, flag_out, tol)
        else:
            sigma = sigma_in

        if np.isnan(sigma) or sigma == 0:
            npeak = 0
        else:
            npeak = np.ma.sum(np.abs(mytoi) > sigma * lim)

        # Only one jump per iteration
        # And skip remaining if find more than `njump_limit` jumps
        while (npeak > 0) and (len(peaks) <= self.njump_limit):
            imax = np.argmax(np.abs(mytoi))
            amplitude = mytoi[imax]
            significance = np.abs(amplitude) / sigma

            # mask out the vicinity not to have false detections near the peak
            istart = max(0, imax - tol)
            istop = min(nsample, imax + tol)
            mytoi[istart:istop] = np.ma.masked
            flag_out[istart:istop] = True
            # Excessive flagging is a sign of false detection
            if significance > 5 or (
                float(np.sum(flag_out[istart:istop])) / (istop - istart) < 0.5
            ):
                peaks.append((imax, significance, amplitude))

            # Find additional peaks
            if sigma_in is None:
                sigma = self._get_sigma(mytoi, flag_out, tol)
            if np.isnan(sigma) or sigma == 0:
                npeak = 0
            else:
                npeak = np.ma.sum(np.abs(mytoi) > sigma * lim)

        return peaks

    def _get_sigma(self, toi, flag, tol):

        full_flag = np.logical_or(flag, toi == 0)

        sigmas = []
        nn = len(toi)
        # Ignore tol samples at the edge
        for start in range(tol, nn - 3 * tol + 1, 2 * tol):
            stop = start + 2 * tol
            ind = slice(start, stop)
            x = toi[ind][full_flag[ind] == 0]
            if len(x) != 0:
                rms = np.sqrt(np.mean(x.data**2))
                sigmas.append(rms)

        if len(sigmas) != 0:
            sigma = np.median(sigmas)
        else:
            sigma = 0.0
        return sigma

    def _remove_jumps(self, signal, flag, peaks, tol):
        """
        Removes the jumps described by peaks from x.
        Adds a buffer of flags with radius of tol.

        """
        corrected_signal = signal.copy()
        nsample = len(signal)
        flag_out = flag.copy()
        for peak, _, amplitude in peaks:
            corrected_signal[peak:] -= amplitude
            pstart = max(0, peak - tol)
            pstop = min(nsample, peak + tol)
            flag_out[pstart:pstop] = True
        return corrected_signal, flag_out

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if self.save_jumps is not None and self.apply_jumps is not None:
            msg = "Cannot both save to and apply pre-existing jumps"
            raise RuntimeError(msg)

        stepfilter = self._get_stepfilter(self.filterlen)

        for ob in data.obs:
            if not ob.is_distributed_by_detector:
                msg = "Observation data must be distributed by detector, not samples"
                log.error(msg)
                raise RuntimeError(msg)
            views = ob.intervals[self.view]
            focalplane = ob.telescope.focalplane

            local_dets = ob.select_local_detectors(flagmask=self.det_mask)
            shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
            if self.save_jumps is not None:
                jump_props = dict()
            for name in local_dets:
                if self.save_jumps is not None:
                    jump_dets = list()
                sig = ob.detdata[self.det_data][name]
                det_flags = ob.detdata[self.det_flags][name]
                if self.reset_det_flags:
                    det_flags[:] = 0
                bad = np.logical_or(
                    shared_flags != 0,
                    (det_flags & self.det_flag_mask) != 0,
                )
                if self.apply_jumps is not None:
                    corrected_signal, flag_out = self._remove_jumps(
                        sig, bad, ob[self.apply_jumps][name], self.jump_radius
                    )
                    sig[:] = corrected_signal
                    det_flags[flag_out] |= self.jump_mask
                else:
                    for iview, view in enumerate(views):
                        nsample = view.last - view.first
                        ind = slice(view.first, view.last)
                        sig_view = sig[ind].copy()
                        bad_view = bad[ind]
                        bad_view_out = bad_view.copy()
                        sig_filtered = convolve(sig_view, stepfilter, mode="same")
                        peaks = self._find_peaks(
                            sig_filtered,
                            bad_view,
                            bad_view_out,
                            lim=self.jump_limit,
                            tol=self.filterlen // 2,
                        )
                        if self.save_jumps is not None:
                            jump_dets.extend(
                                [(x + view.first, y, z) for x, y, z in peaks]
                            )
                        njump = len(peaks)
                        if njump == 0:
                            continue
                        if njump > self.njump_limit:
                            ob._detflags[name] |= self.det_mask
                            det_flags[ind] |= self.det_flag_mask
                            continue

                        corrected_signal, flag_out = self._remove_jumps(
                            sig_view, bad_view, peaks, self.jump_radius
                        )
                        sig[ind] = corrected_signal
                        det_flags[ind][flag_out] |= self.jump_mask
                    if self.save_jumps is not None:
                        jump_props[name] = jump_dets
            if self.save_jumps is not None:
                ob[self.save_jumps] = jump_props
        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
            "intervals": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

apply_jumps = Unicode(None, allow_none=True, help='Do not compute jumps, instead apply the specified dictionary of values') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key to analyze') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

filterlen = Int(100, help='Matched filter length') class-attribute instance-attribute

jump_limit = Float(5.0, help='Jump detection threshold in units of RMS') class-attribute instance-attribute

jump_mask = Int(defaults.det_mask_invalid, help='Bit mask value to apply at glitch positions') class-attribute instance-attribute

jump_radius = Int(5, help='Number of additional samples to flag around a jump') class-attribute instance-attribute

net_factors = [] instance-attribute

njump_limit = Int(10, help='If the detector has more than `njump_limit` jumps the detector the detector and time stream will be flagged as invalid.') class-attribute instance-attribute

nsample_min = Int(100, help='Minimum number of good samples in an interval') class-attribute instance-attribute

rates = [] instance-attribute

reset_det_flags = Bool(False, help='Replace existing detector flags') class-attribute instance-attribute

save_jumps = Unicode(None, allow_none=True, help='Save the jump corrections to a dictionary of values per observation') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

total_factors = [] instance-attribute

view = Unicode(None, allow_none=True, help='Find jumps in this view') class-attribute instance-attribute

weights_in = [] instance-attribute

weights_out = [] instance-attribute

__init__(**kwargs)

Source code in toast/ops/simple_jumpcorrect.py
143
144
145
146
147
148
149
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.net_factors = []
    self.total_factors = []
    self.weights_in = []
    self.weights_out = []
    self.rates = []

_check_det_flag_mask(proposal)

Source code in toast/ops/simple_jumpcorrect.py
136
137
138
139
140
141
@traitlets.validate("njump_limit")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check <= 0:
        raise traitlets.TraitError("njump limit should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/simple_jumpcorrect.py
115
116
117
118
119
120
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/simple_jumpcorrect.py
122
123
124
125
126
127
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/simple_jumpcorrect.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if self.save_jumps is not None and self.apply_jumps is not None:
        msg = "Cannot both save to and apply pre-existing jumps"
        raise RuntimeError(msg)

    stepfilter = self._get_stepfilter(self.filterlen)

    for ob in data.obs:
        if not ob.is_distributed_by_detector:
            msg = "Observation data must be distributed by detector, not samples"
            log.error(msg)
            raise RuntimeError(msg)
        views = ob.intervals[self.view]
        focalplane = ob.telescope.focalplane

        local_dets = ob.select_local_detectors(flagmask=self.det_mask)
        shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
        if self.save_jumps is not None:
            jump_props = dict()
        for name in local_dets:
            if self.save_jumps is not None:
                jump_dets = list()
            sig = ob.detdata[self.det_data][name]
            det_flags = ob.detdata[self.det_flags][name]
            if self.reset_det_flags:
                det_flags[:] = 0
            bad = np.logical_or(
                shared_flags != 0,
                (det_flags & self.det_flag_mask) != 0,
            )
            if self.apply_jumps is not None:
                corrected_signal, flag_out = self._remove_jumps(
                    sig, bad, ob[self.apply_jumps][name], self.jump_radius
                )
                sig[:] = corrected_signal
                det_flags[flag_out] |= self.jump_mask
            else:
                for iview, view in enumerate(views):
                    nsample = view.last - view.first
                    ind = slice(view.first, view.last)
                    sig_view = sig[ind].copy()
                    bad_view = bad[ind]
                    bad_view_out = bad_view.copy()
                    sig_filtered = convolve(sig_view, stepfilter, mode="same")
                    peaks = self._find_peaks(
                        sig_filtered,
                        bad_view,
                        bad_view_out,
                        lim=self.jump_limit,
                        tol=self.filterlen // 2,
                    )
                    if self.save_jumps is not None:
                        jump_dets.extend(
                            [(x + view.first, y, z) for x, y, z in peaks]
                        )
                    njump = len(peaks)
                    if njump == 0:
                        continue
                    if njump > self.njump_limit:
                        ob._detflags[name] |= self.det_mask
                        det_flags[ind] |= self.det_flag_mask
                        continue

                    corrected_signal, flag_out = self._remove_jumps(
                        sig_view, bad_view, peaks, self.jump_radius
                    )
                    sig[ind] = corrected_signal
                    det_flags[ind][flag_out] |= self.jump_mask
                if self.save_jumps is not None:
                    jump_props[name] = jump_dets
        if self.save_jumps is not None:
            ob[self.save_jumps] = jump_props
    return

_finalize(data, **kwargs)

Source code in toast/ops/simple_jumpcorrect.py
334
335
def _finalize(self, data, **kwargs):
    return

_find_peaks(toi, flag, flag_out, lim=3.0, tol=10000.0, sigma_in=None)

Find the peaks and their amplitudes in the match-filtered TOI. Inputs: lim -- threshold for jump detection in units of filtered TOI RMS. tol -- radius of a region to mask from further peak finding upon detecting a peak. sigma_in -- an estimate of filtered TOI RMS that overrides the sample variance otherwise used.

Source code in toast/ops/simple_jumpcorrect.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def _find_peaks(self, toi, flag, flag_out, lim=3.0, tol=1e4, sigma_in=None):
    """
    Find the peaks and their amplitudes in the match-filtered TOI.
    Inputs:
    lim -- threshold for jump detection in units of filtered TOI RMS.
    tol -- radius of a region to mask from further peak finding upon
        detecting a peak.
    sigma_in -- an estimate of filtered TOI RMS that overrides the
         sample variance otherwise used.

    """
    peaks = []
    mytoi = np.ma.masked_array(toi)
    nsample = len(mytoi)
    # Do not accept jumps at the ends due to boundary effects
    lbound = tol
    rbound = tol
    mytoi[:lbound] = np.ma.masked
    mytoi[-rbound:] = np.ma.masked
    if sigma_in is None:
        sigma = self._get_sigma(mytoi, flag_out, tol)
    else:
        sigma = sigma_in

    if np.isnan(sigma) or sigma == 0:
        npeak = 0
    else:
        npeak = np.ma.sum(np.abs(mytoi) > sigma * lim)

    # Only one jump per iteration
    # And skip remaining if find more than `njump_limit` jumps
    while (npeak > 0) and (len(peaks) <= self.njump_limit):
        imax = np.argmax(np.abs(mytoi))
        amplitude = mytoi[imax]
        significance = np.abs(amplitude) / sigma

        # mask out the vicinity not to have false detections near the peak
        istart = max(0, imax - tol)
        istop = min(nsample, imax + tol)
        mytoi[istart:istop] = np.ma.masked
        flag_out[istart:istop] = True
        # Excessive flagging is a sign of false detection
        if significance > 5 or (
            float(np.sum(flag_out[istart:istop])) / (istop - istart) < 0.5
        ):
            peaks.append((imax, significance, amplitude))

        # Find additional peaks
        if sigma_in is None:
            sigma = self._get_sigma(mytoi, flag_out, tol)
        if np.isnan(sigma) or sigma == 0:
            npeak = 0
        else:
            npeak = np.ma.sum(np.abs(mytoi) > sigma * lim)

    return peaks

_get_sigma(toi, flag, tol)

Source code in toast/ops/simple_jumpcorrect.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def _get_sigma(self, toi, flag, tol):

    full_flag = np.logical_or(flag, toi == 0)

    sigmas = []
    nn = len(toi)
    # Ignore tol samples at the edge
    for start in range(tol, nn - 3 * tol + 1, 2 * tol):
        stop = start + 2 * tol
        ind = slice(start, stop)
        x = toi[ind][full_flag[ind] == 0]
        if len(x) != 0:
            rms = np.sqrt(np.mean(x.data**2))
            sigmas.append(rms)

    if len(sigmas) != 0:
        sigma = np.median(sigmas)
    else:
        sigma = 0.0
    return sigma

_get_stepfilter(m)

Return the time domain matched filter kernel of length m.

Source code in toast/ops/simple_jumpcorrect.py
151
152
153
154
155
156
157
158
159
160
161
def _get_stepfilter(self, m):
    """
    Return the time domain matched filter kernel of length m.
    """
    h = np.zeros(m)
    h[: m // 2] = 1
    h[m // 2 :] = -1
    # This turns the interpretation of the peak amplitude directly
    # into the step amplitude
    h /= m // 2
    return h

_provides()

Source code in toast/ops/simple_jumpcorrect.py
352
353
354
355
356
357
358
359
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
        "intervals": list(),
    }
    return prov

_remove_jumps(signal, flag, peaks, tol)

Removes the jumps described by peaks from x. Adds a buffer of flags with radius of tol.

Source code in toast/ops/simple_jumpcorrect.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def _remove_jumps(self, signal, flag, peaks, tol):
    """
    Removes the jumps described by peaks from x.
    Adds a buffer of flags with radius of tol.

    """
    corrected_signal = signal.copy()
    nsample = len(signal)
    flag_out = flag.copy()
    for peak, _, amplitude in peaks:
        corrected_signal[peak:] -= amplitude
        pstart = max(0, peak - tol)
        pstop = min(nsample, peak + tol)
        flag_out[pstart:pstop] = True
    return corrected_signal, flag_out

_requires()

Source code in toast/ops/simple_jumpcorrect.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

toast.ops.FillGaps

Bases: Operator

Operator that fills flagged samples with noise.

Currently this operator just fills flagged samples with a simple polynomial plus white noise. It is mostly used for visualization. No attempt is made yet to fill the gaps with a constrained noise realization.

Source code in toast/ops/fill_gaps.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@trait_docs
class FillGaps(Operator):
    """Operator that fills flagged samples with noise.

    Currently this operator just fills flagged samples with a simple polynomial
    plus white noise.  It is mostly used for visualization.  No attempt is made
    yet to fill the gaps with a constrained noise realization.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    buffer = Quantity(
        1.0 * u.s,
        help="Buffer of time on either side of each gap",
    )

    poly_order = Int(
        1,
        help="Order of the polynomial to fit across each gap",
    )

    @traitlets.validate("poly_order")
    def _check_poly_order(self, proposal):
        check = proposal["value"]
        if check <= 0:
            raise traitlets.TraitError("poly_order should be >= 1")
        return check

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        for ob in data.obs:
            timer = Timer()
            timer.start()

            # Sample rate for this observation
            rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz)

            # The buffer size in samples
            buf_samp = int(self.buffer.to_value(u.second) * rate)

            # Check that parameters make sense
            if self.poly_order > buf_samp + 1:
                msg = f"Cannot fit an order {self.poly_order} polynomial "
                msg += f"to {buf_samp} samples"
                raise RuntimeError(msg)

            if buf_samp > ob.n_local_samples // 4:
                msg = f"Using {buf_samp} samples of buffer around gaps is"
                msg += f" not reasonable for an observation with {ob.n_local_samples}"
                msg += " local samples"
                raise RuntimeError(msg)

            # Local detectors we are considering
            local_dets = ob.select_local_detectors(flagmask=self.det_mask)
            n_dets = len(local_dets)

            # The shared flags
            if self.shared_flags is None:
                shared_flags = np.zeros(ob.n_local_samples, dtype=bool)
            else:
                shared_flags = (
                    ob.shared[self.shared_flags].data & self.shared_flag_mask
                ) != 0

            for idet, det in enumerate(local_dets):
                if self.det_flags is None:
                    flags = shared_flags
                else:
                    flags = np.logical_or(
                        shared_flags,
                        (ob.detdata[self.det_flags][det, :] & self.det_flag_mask) != 0,
                    )
                flagged_noise_fill(
                    ob.detdata[self.det_data][det],
                    flags,
                    buf_samp,
                    poly_order=self.poly_order,
                )
            msg = f"FillGaps {ob.name}: completed in"
            log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # Note that the hwp_angle is not strictly required- this
        # is just a no-op.
        req = {
            "shared": [self.times],
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        prov = {
            "meta": [],
            "detdata": [self.det_data],
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

buffer = Quantity(1.0 * u.s, help='Buffer of time on either side of each gap') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

poly_order = Int(1, help='Order of the polynomial to fit across each gap') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/fill_gaps.py
103
104
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/fill_gaps.py
89
90
91
92
93
94
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/fill_gaps.py
82
83
84
85
86
87
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_poly_order(proposal)

Source code in toast/ops/fill_gaps.py
75
76
77
78
79
80
@traitlets.validate("poly_order")
def _check_poly_order(self, proposal):
    check = proposal["value"]
    if check <= 0:
        raise traitlets.TraitError("poly_order should be >= 1")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/fill_gaps.py
 96
 97
 98
 99
100
101
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/fill_gaps.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    for ob in data.obs:
        timer = Timer()
        timer.start()

        # Sample rate for this observation
        rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz)

        # The buffer size in samples
        buf_samp = int(self.buffer.to_value(u.second) * rate)

        # Check that parameters make sense
        if self.poly_order > buf_samp + 1:
            msg = f"Cannot fit an order {self.poly_order} polynomial "
            msg += f"to {buf_samp} samples"
            raise RuntimeError(msg)

        if buf_samp > ob.n_local_samples // 4:
            msg = f"Using {buf_samp} samples of buffer around gaps is"
            msg += f" not reasonable for an observation with {ob.n_local_samples}"
            msg += " local samples"
            raise RuntimeError(msg)

        # Local detectors we are considering
        local_dets = ob.select_local_detectors(flagmask=self.det_mask)
        n_dets = len(local_dets)

        # The shared flags
        if self.shared_flags is None:
            shared_flags = np.zeros(ob.n_local_samples, dtype=bool)
        else:
            shared_flags = (
                ob.shared[self.shared_flags].data & self.shared_flag_mask
            ) != 0

        for idet, det in enumerate(local_dets):
            if self.det_flags is None:
                flags = shared_flags
            else:
                flags = np.logical_or(
                    shared_flags,
                    (ob.detdata[self.det_flags][det, :] & self.det_flag_mask) != 0,
                )
            flagged_noise_fill(
                ob.detdata[self.det_data][det],
                flags,
                buf_samp,
                poly_order=self.poly_order,
            )
        msg = f"FillGaps {ob.name}: completed in"
        log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

_finalize(data, **kwargs)

Source code in toast/ops/fill_gaps.py
162
163
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/fill_gaps.py
178
179
180
181
182
183
def _provides(self):
    prov = {
        "meta": [],
        "detdata": [self.det_data],
    }
    return prov

_requires()

Source code in toast/ops/fill_gaps.py
165
166
167
168
169
170
171
172
173
174
175
176
def _requires(self):
    # Note that the hwp_angle is not strictly required- this
    # is just a no-op.
    req = {
        "shared": [self.times],
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

Filtering

Polynomial Filters

toast.ops.CommonModeFilter

Bases: Operator

Operator to regress out common mode at each time stamp.

Source code in toast/ops/polyfilter/polyfilter.py
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
class CommonModeFilter(Operator):
    """Operator to regress out common mode at each time stamp."""

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(
        defaults.times,
        help="Observation shared key for timestamps",
    )

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key apply filtering to"
    )

    pattern = Unicode(
        f".*",
        allow_none=True,
        help="Regex pattern to match against detector names. Only detectors that "
        "match the pattern are filtered.",
    )

    det_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    focalplane_key = Unicode(
        None, allow_none=True, help="Which focalplane key to match"
    )

    redistribute = Bool(
        False,
        help="If True, redistribute data before and after filtering for "
        "optimal data locality.",
    )

    regress = Bool(
        False,
        help="If True, regress the common mode rather than subtract",
    )

    plot = Bool(
        False,
        help="If True, plot regression coefficients",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    @function_timer
    def _redistribute(self, data, obs, timer, log):
        if self.redistribute:
            # Redistribute the data so each process has all detectors for some sample
            # range.  Duplicate just the fields of the observation we will use.
            dup_shared = list()
            if self.shared_flags is not None:
                dup_shared.append(self.shared_flags)
            dup_detdata = [self.det_data]
            if self.det_flags is not None:
                dup_detdata.append(self.det_flags)
            dup_intervals = list()
            temp_ob = obs.duplicate(
                times=self.times,
                meta=list(),
                shared=dup_shared,
                detdata=dup_detdata,
                intervals=dup_intervals,
            )
            log.debug_rank(
                f"{data.comm.group:4} : Duplicated observation in",
                comm=temp_ob.comm.comm_group,
                timer=timer,
            )
            # Redistribute this temporary observation to be distributed by sample sets
            temp_ob.redistribute(1, times=self.times, override_sample_sets=None)
            log.debug_rank(
                f"{data.comm.group:4} : Redistributed observation in",
                comm=temp_ob.comm.comm_group,
                timer=timer,
            )
            comm = None
        else:
            comm = obs.comm_col
            temp_ob = obs

        return comm, temp_ob

    @function_timer
    def _re_redistribute(self, data, obs, timer, log, temp_ob):
        if self.redistribute:
            # Redistribute data back
            temp_ob.redistribute(
                obs.dist.process_rows,
                times=self.times,
                override_sample_sets=obs.dist.sample_sets,
            )
            log.debug_rank(
                f"{data.comm.group:4} : Re-redistributed observation in",
                comm=temp_ob.comm.comm_group,
                timer=timer,
            )
            # Copy data to original observation
            obs.detdata[self.det_data][:] = temp_ob.detdata[self.det_data][:]
            log.debug_rank(
                f"{data.comm.group:4} : Copied observation data in",
                comm=temp_ob.comm.comm_group,
                timer=timer,
            )
        return

    @function_timer
    def _plot_coeff(self, ob, coeffs, comm, value):
        # Make a plot of the coupling coefficients
        import matplotlib.pyplot as plt

        ndet = len(coeffs)
        lon = np.zeros(ndet)
        lat = np.zeros(ndet)
        yrot = qa.rotation([0, 1, 0], np.pi / 2)
        for idet in range(ndet):
            name = ob.local_detectors[idet]
            quat = ob.telescope.focalplane[name]["quat"]
            theta, phi, psi = qa.to_iso_angles(qa.mult(yrot, quat))
            lon[idet] = np.degrees(phi)
            lat[idet] = np.degrees(theta - np.pi / 2)
            top = (psi % np.pi) < (np.pi / 2)
            offset = 0.002  # Need a smarter offset...
            if top:
                lon[idet] += offset
                lat[idet] += offset
            else:
                lon[idet] -= offset
                lat[idet] -= offset
        if comm is not None:
            all_lon = comm.Gather(lon)
            all_lat = comm.Gather(lat)
            all_coeffs = comm.Gather(coeffs)
        else:
            all_lon = [lon]
            all_lat = [lat]
            all_coeffs = [coeffs]
        if comm is None or comm.rank == 0:
            lon = np.hstack(all_lon)
            lat = np.hstack(all_lat)
            coeffs = np.hstack(all_coeffs)
            fig = plt.figure(figsize=[12, 8])
            ax = fig.add_subplot(1, 1, 1)
            ax.set_title(f"obs = {ob.name}, key = {value}")
            amp = 0.15  # Need a smarter amplitude...
            p = ax.scatter(
                lon,
                lat,
                c=coeffs,
                vmin=1 - amp,
                vmax=1 + amp,
                edgecolors="k",
                cmap="bwr",
            )
            fig.colorbar(p)
            fig.savefig(f"coeffs_{ob.name}_{value}.png")
            plt.close()
        return

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        """Apply the common mode filter to the signal.

        Args:
            data (toast.Data): The distributed data.

        """
        if detectors is not None:
            raise RuntimeError("CommonModeFilter cannot be run in batch mode")

        log = Logger.get()
        timer = Timer()
        timer.start()
        pat = re.compile(self.pattern)

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        for obs in data.obs:
            comm, temp_ob = self._redistribute(data, obs, timer, log)

            focalplane = temp_ob.telescope.focalplane

            detectors = temp_ob.all_detectors
            if self.focalplane_key is None:
                values = [None]
            else:
                values = set()
                for det in detectors:
                    if pat.match(det) is None:
                        continue
                    values.add(focalplane[det][self.focalplane_key])
                values = sorted(values)

            nsample = temp_ob.n_local_samples
            ndet = len(temp_ob.local_detectors)

            # Loop over all values of the focalplane key
            for value in values:
                local_dets = []
                for det in temp_ob.local_detectors:
                    if temp_ob.local_detector_flags[det] & self.det_mask:
                        continue
                    if pat.match(det) is None:
                        continue
                    if (
                        value is not None
                        and focalplane[det][self.focalplane_key] != value
                    ):
                        continue
                    local_dets.append(det)

                # The indices into the detector data, which may be different than
                # the index into the full set of local detectors.
                data_indices = temp_ob.detdata[self.det_data].indices(local_dets)
                flag_indices = temp_ob.detdata[self.det_flags].indices(local_dets)

                # Average all detectors that match the key
                template = np.zeros(nsample)
                hits = np.zeros(nsample, dtype=np.int64)

                if self.shared_flags is not None:
                    shared_flags = temp_ob.shared[self.shared_flags].data
                else:
                    shared_flags = np.zeros(nsample, dtype=np.uint8)
                if self.det_flags is not None:
                    det_flags = temp_ob.detdata[self.det_flags].data
                else:
                    det_flags = np.zeros([ndet, nsample], dtype=np.uint8)

                sum_detectors(
                    data_indices,
                    flag_indices,
                    shared_flags,
                    self.shared_flag_mask,
                    temp_ob.detdata[self.det_data].data,
                    det_flags,
                    self.det_flag_mask,
                    template,
                    hits,
                )

                if comm is not None:
                    comm.Barrier()
                    comm.Allreduce(MPI.IN_PLACE, template, op=MPI.SUM)
                    comm.Allreduce(MPI.IN_PLACE, hits, op=MPI.SUM)

                if self.regress:
                    good = hits != 0
                    ngood = np.sum(good)
                    mean_template = template.copy()
                    mean_template[good] /= hits[good]
                    ndet, nsample = temp_ob.detdata[self.det_data].data.shape
                    coeffs = np.zeros(ndet)
                    templates = np.vstack([np.ones(ngood), mean_template[good]])
                    invcov = np.dot(templates, templates.T)
                    cov = np.linalg.inv(invcov)
                    for idet, iflag in zip(data_indices, flag_indices):
                        sig = temp_ob.detdata[self.det_data].data[idet]
                        sig_copy = sig[good].copy()
                        flg = det_flags[idet][good]
                        sig_copy[flg & self.det_flag_mask != 0] = 0
                        proj = np.dot(templates, sig_copy)
                        coeff = np.dot(cov, proj)
                        coeffs[idet] = coeff[1]
                        sig -= coeff[0] + coeff[1] * mean_template
                    if self.plot:
                        self._plot_coeff(temp_ob, coeffs, comm, value)
                else:
                    subtract_mean(
                        data_indices,
                        temp_ob.detdata[self.det_data].data,
                        template,
                        hits,
                    )

            log.debug_rank(
                f"{data.comm.group:4} : Commonfiltered observation in",
                comm=temp_ob.comm.comm_group,
                timer=timer,
            )

            self._re_redistribute(data, obs, timer, log, temp_ob)
            if self.redistribute:
                # In this case our temp_ob holds a copied subset of the
                # observation.  Clear it.
                temp_ob.clear()
                del temp_ob

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key apply filtering to') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

focalplane_key = Unicode(None, allow_none=True, help='Which focalplane key to match') class-attribute instance-attribute

pattern = Unicode(f'.*', allow_none=True, help='Regex pattern to match against detector names. Only detectors that match the pattern are filtered.') class-attribute instance-attribute

plot = Bool(False, help='If True, plot regression coefficients') class-attribute instance-attribute

redistribute = Bool(False, help='If True, redistribute data before and after filtering for optimal data locality.') class-attribute instance-attribute

regress = Bool(False, help='If True, regress the common mode rather than subtract') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
740
741
742
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
733
734
735
736
737
738
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
719
720
721
722
723
724
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
726
727
728
729
730
731
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Apply the common mode filter to the signal.

Parameters:

Name Type Description Default
data Data

The distributed data.

required
Source code in toast/ops/polyfilter/polyfilter.py
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    """Apply the common mode filter to the signal.

    Args:
        data (toast.Data): The distributed data.

    """
    if detectors is not None:
        raise RuntimeError("CommonModeFilter cannot be run in batch mode")

    log = Logger.get()
    timer = Timer()
    timer.start()
    pat = re.compile(self.pattern)

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    for obs in data.obs:
        comm, temp_ob = self._redistribute(data, obs, timer, log)

        focalplane = temp_ob.telescope.focalplane

        detectors = temp_ob.all_detectors
        if self.focalplane_key is None:
            values = [None]
        else:
            values = set()
            for det in detectors:
                if pat.match(det) is None:
                    continue
                values.add(focalplane[det][self.focalplane_key])
            values = sorted(values)

        nsample = temp_ob.n_local_samples
        ndet = len(temp_ob.local_detectors)

        # Loop over all values of the focalplane key
        for value in values:
            local_dets = []
            for det in temp_ob.local_detectors:
                if temp_ob.local_detector_flags[det] & self.det_mask:
                    continue
                if pat.match(det) is None:
                    continue
                if (
                    value is not None
                    and focalplane[det][self.focalplane_key] != value
                ):
                    continue
                local_dets.append(det)

            # The indices into the detector data, which may be different than
            # the index into the full set of local detectors.
            data_indices = temp_ob.detdata[self.det_data].indices(local_dets)
            flag_indices = temp_ob.detdata[self.det_flags].indices(local_dets)

            # Average all detectors that match the key
            template = np.zeros(nsample)
            hits = np.zeros(nsample, dtype=np.int64)

            if self.shared_flags is not None:
                shared_flags = temp_ob.shared[self.shared_flags].data
            else:
                shared_flags = np.zeros(nsample, dtype=np.uint8)
            if self.det_flags is not None:
                det_flags = temp_ob.detdata[self.det_flags].data
            else:
                det_flags = np.zeros([ndet, nsample], dtype=np.uint8)

            sum_detectors(
                data_indices,
                flag_indices,
                shared_flags,
                self.shared_flag_mask,
                temp_ob.detdata[self.det_data].data,
                det_flags,
                self.det_flag_mask,
                template,
                hits,
            )

            if comm is not None:
                comm.Barrier()
                comm.Allreduce(MPI.IN_PLACE, template, op=MPI.SUM)
                comm.Allreduce(MPI.IN_PLACE, hits, op=MPI.SUM)

            if self.regress:
                good = hits != 0
                ngood = np.sum(good)
                mean_template = template.copy()
                mean_template[good] /= hits[good]
                ndet, nsample = temp_ob.detdata[self.det_data].data.shape
                coeffs = np.zeros(ndet)
                templates = np.vstack([np.ones(ngood), mean_template[good]])
                invcov = np.dot(templates, templates.T)
                cov = np.linalg.inv(invcov)
                for idet, iflag in zip(data_indices, flag_indices):
                    sig = temp_ob.detdata[self.det_data].data[idet]
                    sig_copy = sig[good].copy()
                    flg = det_flags[idet][good]
                    sig_copy[flg & self.det_flag_mask != 0] = 0
                    proj = np.dot(templates, sig_copy)
                    coeff = np.dot(cov, proj)
                    coeffs[idet] = coeff[1]
                    sig -= coeff[0] + coeff[1] * mean_template
                if self.plot:
                    self._plot_coeff(temp_ob, coeffs, comm, value)
            else:
                subtract_mean(
                    data_indices,
                    temp_ob.detdata[self.det_data].data,
                    template,
                    hits,
                )

        log.debug_rank(
            f"{data.comm.group:4} : Commonfiltered observation in",
            comm=temp_ob.comm.comm_group,
            timer=timer,
        )

        self._re_redistribute(data, obs, timer, log, temp_ob)
        if self.redistribute:
            # In this case our temp_ob holds a copied subset of the
            # observation.  Clear it.
            temp_ob.clear()
            del temp_ob

    return

_finalize(data, **kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
990
991
def _finalize(self, data, **kwargs):
    return

_plot_coeff(ob, coeffs, comm, value)

Source code in toast/ops/polyfilter/polyfilter.py
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
@function_timer
def _plot_coeff(self, ob, coeffs, comm, value):
    # Make a plot of the coupling coefficients
    import matplotlib.pyplot as plt

    ndet = len(coeffs)
    lon = np.zeros(ndet)
    lat = np.zeros(ndet)
    yrot = qa.rotation([0, 1, 0], np.pi / 2)
    for idet in range(ndet):
        name = ob.local_detectors[idet]
        quat = ob.telescope.focalplane[name]["quat"]
        theta, phi, psi = qa.to_iso_angles(qa.mult(yrot, quat))
        lon[idet] = np.degrees(phi)
        lat[idet] = np.degrees(theta - np.pi / 2)
        top = (psi % np.pi) < (np.pi / 2)
        offset = 0.002  # Need a smarter offset...
        if top:
            lon[idet] += offset
            lat[idet] += offset
        else:
            lon[idet] -= offset
            lat[idet] -= offset
    if comm is not None:
        all_lon = comm.Gather(lon)
        all_lat = comm.Gather(lat)
        all_coeffs = comm.Gather(coeffs)
    else:
        all_lon = [lon]
        all_lat = [lat]
        all_coeffs = [coeffs]
    if comm is None or comm.rank == 0:
        lon = np.hstack(all_lon)
        lat = np.hstack(all_lat)
        coeffs = np.hstack(all_coeffs)
        fig = plt.figure(figsize=[12, 8])
        ax = fig.add_subplot(1, 1, 1)
        ax.set_title(f"obs = {ob.name}, key = {value}")
        amp = 0.15  # Need a smarter amplitude...
        p = ax.scatter(
            lon,
            lat,
            c=coeffs,
            vmin=1 - amp,
            vmax=1 + amp,
            edgecolors="k",
            cmap="bwr",
        )
        fig.colorbar(p)
        fig.savefig(f"coeffs_{ob.name}_{value}.png")
        plt.close()
    return

_provides()

Source code in toast/ops/polyfilter/polyfilter.py
1005
1006
1007
1008
1009
1010
1011
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
    }
    return prov

_re_redistribute(data, obs, timer, log, temp_ob)

Source code in toast/ops/polyfilter/polyfilter.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
@function_timer
def _re_redistribute(self, data, obs, timer, log, temp_ob):
    if self.redistribute:
        # Redistribute data back
        temp_ob.redistribute(
            obs.dist.process_rows,
            times=self.times,
            override_sample_sets=obs.dist.sample_sets,
        )
        log.debug_rank(
            f"{data.comm.group:4} : Re-redistributed observation in",
            comm=temp_ob.comm.comm_group,
            timer=timer,
        )
        # Copy data to original observation
        obs.detdata[self.det_data][:] = temp_ob.detdata[self.det_data][:]
        log.debug_rank(
            f"{data.comm.group:4} : Copied observation data in",
            comm=temp_ob.comm.comm_group,
            timer=timer,
        )
    return

_redistribute(data, obs, timer, log)

Source code in toast/ops/polyfilter/polyfilter.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
@function_timer
def _redistribute(self, data, obs, timer, log):
    if self.redistribute:
        # Redistribute the data so each process has all detectors for some sample
        # range.  Duplicate just the fields of the observation we will use.
        dup_shared = list()
        if self.shared_flags is not None:
            dup_shared.append(self.shared_flags)
        dup_detdata = [self.det_data]
        if self.det_flags is not None:
            dup_detdata.append(self.det_flags)
        dup_intervals = list()
        temp_ob = obs.duplicate(
            times=self.times,
            meta=list(),
            shared=dup_shared,
            detdata=dup_detdata,
            intervals=dup_intervals,
        )
        log.debug_rank(
            f"{data.comm.group:4} : Duplicated observation in",
            comm=temp_ob.comm.comm_group,
            timer=timer,
        )
        # Redistribute this temporary observation to be distributed by sample sets
        temp_ob.redistribute(1, times=self.times, override_sample_sets=None)
        log.debug_rank(
            f"{data.comm.group:4} : Redistributed observation in",
            comm=temp_ob.comm.comm_group,
            timer=timer,
        )
        comm = None
    else:
        comm = obs.comm_col
        temp_ob = obs

    return comm, temp_ob

_requires()

Source code in toast/ops/polyfilter/polyfilter.py
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

toast.ops.PolyFilter

Bases: Operator

Operator which applies polynomial filtering to the TOD.

Source code in toast/ops/polyfilter/polyfilter.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
@trait_docs
class PolyFilter(Operator):
    """Operator which applies polynomial filtering to the TOD."""

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key apply filtering to"
    )

    pattern = Unicode(
        f".*",
        allow_none=True,
        help="Regex pattern to match against detector names. Only detectors that "
        "match the pattern are filtered.",
    )

    order = Int(1, allow_none=False, help="Polynomial order")

    det_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for detector sample flagging",
    )

    poly_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Shared flag bit mask for samples outside of filtering view",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    view = Unicode(
        "throw", allow_none=True, help="Use this view of the data in all observations"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.pattern is None:
            pat = None
        else:
            pat = re.compile(self.pattern)

        for obs in data.obs:
            # Get the detectors we are using for this observation
            dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)

            if self.view is not None:
                if self.view not in obs.intervals:
                    msg = (
                        f"PolyFilter is configured to apply in the '{self.view}' view "
                        f"but it is not defined for observation '{obs.name}'"
                    )
                    raise RuntimeError(msg)
                local_starts = []
                local_stops = []
                for interval in obs.intervals[self.view]:
                    local_starts.append(interval.first)
                    local_stops.append(interval.last)
            else:
                local_starts = [0]
                local_stops = [obs.n_local_samples]

            local_starts = np.array(local_starts)
            local_stops = np.array(local_stops)

            if self.shared_flags is not None:
                shared_flags = (
                    obs.shared[self.shared_flags].data & self.shared_flag_mask
                )
            else:
                shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

            signals = []
            filter_dets = []
            last_flags = None
            in_place = True
            for det in dets:
                # Test the detector pattern
                if pat.match(det) is None:
                    continue

                ref = obs.detdata[self.det_data][det]
                if isinstance(ref[0], np.float64):
                    signal = ref
                else:
                    in_place = False
                    signal = np.array(ref, dtype=np.float64)
                if self.det_flags is not None:
                    det_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                    flags = shared_flags | det_flags
                else:
                    flags = shared_flags

                if last_flags is None or np.all(last_flags == flags):
                    filter_dets.append(det)
                    signals.append(signal)
                else:
                    filter_polynomial(
                        self.order,
                        last_flags,
                        signals,
                        local_starts,
                        local_stops,
                        impl=implementation,
                        use_accel=use_accel,
                    )
                    if not in_place:
                        for fdet, x in zip(filter_dets, signals):
                            obs.detdata[self.det_data][fdet] = x
                    signals = [signal]
                    filter_dets = [det]
                last_flags = flags.copy()

            if len(signals) > 0:
                filter_polynomial(
                    self.order,
                    last_flags,
                    signals,
                    local_starts,
                    local_stops,
                    impl=implementation,
                    use_accel=use_accel,
                )
                if not in_place:
                    for fdet, x in zip(filter_dets, signals):
                        obs.detdata[self.det_data][fdet] = x

            # Optionally flag unfiltered data
            if self.shared_flags is not None and self.poly_flag_mask is not None:
                if obs.comm_col_rank != 0:
                    shared_flags = None
                else:
                    shared_flags = np.array(obs.shared[self.shared_flags])
                    not_filtered = np.ones(shared_flags.size, dtype=bool)
                    for start, stop in zip(local_starts, local_stops):
                        not_filtered[start : stop] = False
                    shared_flags[not_filtered] |= self.poly_flag_mask
                obs.shared[self.shared_flags].set(shared_flags, fromrank=0)
            if obs.comm.comm_group is not None:
                obs.comm.comm_group.barrier()

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": [self.view],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key apply filtering to') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

order = Int(1, allow_none=False, help='Polynomial order') class-attribute instance-attribute

pattern = Unicode(f'.*', allow_none=True, help='Regex pattern to match against detector names. Only detectors that match the pattern are filtered.') class-attribute instance-attribute

poly_flag_mask = Int(defaults.shared_mask_invalid, help='Shared flag bit mask for samples outside of filtering view') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

view = Unicode('throw', allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
509
510
511
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
502
503
504
505
506
507
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
488
489
490
491
492
493
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
495
496
497
498
499
500
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.pattern is None:
        pat = None
    else:
        pat = re.compile(self.pattern)

    for obs in data.obs:
        # Get the detectors we are using for this observation
        dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)

        if self.view is not None:
            if self.view not in obs.intervals:
                msg = (
                    f"PolyFilter is configured to apply in the '{self.view}' view "
                    f"but it is not defined for observation '{obs.name}'"
                )
                raise RuntimeError(msg)
            local_starts = []
            local_stops = []
            for interval in obs.intervals[self.view]:
                local_starts.append(interval.first)
                local_stops.append(interval.last)
        else:
            local_starts = [0]
            local_stops = [obs.n_local_samples]

        local_starts = np.array(local_starts)
        local_stops = np.array(local_stops)

        if self.shared_flags is not None:
            shared_flags = (
                obs.shared[self.shared_flags].data & self.shared_flag_mask
            )
        else:
            shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

        signals = []
        filter_dets = []
        last_flags = None
        in_place = True
        for det in dets:
            # Test the detector pattern
            if pat.match(det) is None:
                continue

            ref = obs.detdata[self.det_data][det]
            if isinstance(ref[0], np.float64):
                signal = ref
            else:
                in_place = False
                signal = np.array(ref, dtype=np.float64)
            if self.det_flags is not None:
                det_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                flags = shared_flags | det_flags
            else:
                flags = shared_flags

            if last_flags is None or np.all(last_flags == flags):
                filter_dets.append(det)
                signals.append(signal)
            else:
                filter_polynomial(
                    self.order,
                    last_flags,
                    signals,
                    local_starts,
                    local_stops,
                    impl=implementation,
                    use_accel=use_accel,
                )
                if not in_place:
                    for fdet, x in zip(filter_dets, signals):
                        obs.detdata[self.det_data][fdet] = x
                signals = [signal]
                filter_dets = [det]
            last_flags = flags.copy()

        if len(signals) > 0:
            filter_polynomial(
                self.order,
                last_flags,
                signals,
                local_starts,
                local_stops,
                impl=implementation,
                use_accel=use_accel,
            )
            if not in_place:
                for fdet, x in zip(filter_dets, signals):
                    obs.detdata[self.det_data][fdet] = x

        # Optionally flag unfiltered data
        if self.shared_flags is not None and self.poly_flag_mask is not None:
            if obs.comm_col_rank != 0:
                shared_flags = None
            else:
                shared_flags = np.array(obs.shared[self.shared_flags])
                not_filtered = np.ones(shared_flags.size, dtype=bool)
                for start, stop in zip(local_starts, local_stops):
                    not_filtered[start : stop] = False
                shared_flags[not_filtered] |= self.poly_flag_mask
            obs.shared[self.shared_flags].set(shared_flags, fromrank=0)
        if obs.comm.comm_group is not None:
            obs.comm.comm_group.barrier()

    return

_finalize(data, **kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
626
627
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/polyfilter/polyfilter.py
642
643
644
645
646
647
648
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
    }
    return prov

_requires()

Source code in toast/ops/polyfilter/polyfilter.py
629
630
631
632
633
634
635
636
637
638
639
640
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": [self.view],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

toast.ops.PolyFilter2D

Bases: Operator

Operator to regress out 2D polynomials across the focal plane.

Source code in toast/ops/polyfilter/polyfilter.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
@trait_docs
class PolyFilter2D(Operator):
    """Operator to regress out 2D polynomials across the focal plane."""

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(
        defaults.times,
        help="Observation shared key for timestamps",
    )

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key apply filtering to",
    )

    pattern = Unicode(
        f".*",
        allow_none=True,
        help="Regex pattern to match against detector names. Only detectors that "
        "match the pattern are filtered.",
    )

    order = Int(1, allow_none=False, help="Polynomial order")

    det_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid | defaults.det_mask_processing,
        help="Bit mask value for detector sample flagging",
    )

    poly_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for samples that fail to filter",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    focalplane_key = Unicode(
        None, allow_none=True, help="Which focalplane key to match"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        gt = GlobalTimers.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if detectors is not None:
            raise RuntimeError("PolyFilter2D cannot be run on subsets of detectors")
        norder = self.order + 1
        nmode = norder**2
        pat = re.compile(self.pattern)

        for obs in data.obs:
            # Get the original number of process rows in the observation
            proc_rows = obs.dist.process_rows

            # Duplicate just the fields of the observation we will use
            gt.start("Poly2D:  Duplicate obs")
            dup_shared = list()
            if self.shared_flags is not None:
                dup_shared.append(self.shared_flags)
            dup_detdata = [self.det_data]
            if self.det_flags is not None:
                dup_detdata.append(self.det_flags)
            dup_intervals = list()
            if self.view is not None:
                dup_intervals.append(self.view)
            temp_ob = obs.duplicate(
                times=self.times,
                meta=list(),
                shared=dup_shared,
                detdata=dup_detdata,
                intervals=dup_intervals,
            )
            gt.stop("Poly2D:  Duplicate obs")

            # Redistribute this temporary observation to be distributed by samples
            gt.start("Poly2D:  Forward redistribution")
            temp_ob.redistribute(1, times=self.times, override_sample_sets=None)
            gt.stop("Poly2D:  Forward redistribution")

            gt.start("Poly2D:  Detector setup")

            # Detectors to process.  We apply the detector flag mask to this
            # selection, in order to avoid bad detectors in the fit.
            detectors = []
            for det in temp_ob.select_local_detectors(
                selection=None, flagmask=self.det_mask
            ):
                if pat.match(det) is None:
                    continue
                detectors.append(det)
            ndet = len(detectors)
            if ndet == 0:
                continue

            # Detector positions

            detector_position = {}
            for det in detectors:
                x, y, z = qa.rotate(temp_ob.telescope.focalplane[det]["quat"], ZAXIS)
                theta, phi = np.arcsin([x, y])
                detector_position[det] = [theta, phi]

            # Enumerate detector groups (e.g. wafers) to filter

            # The integer group ID for a given detector
            group_index = {}

            # The list of detectors for each group key
            groups = {}

            # The integer group ID for each group key
            group_ids = {}

            if self.focalplane_key is None:
                # We have just one group of all detectors
                groups[None] = []
                group_ids[None] = 0
                ngroup = 1
                for det in detectors:
                    group_index[det] = 0
                    groups[None].append(det)
            else:
                focalplane = temp_ob.telescope.focalplane
                if self.focalplane_key not in focalplane.detector_data.colnames:
                    msg = (
                        f"Cannot divide detectors by {self.focalplane_key} because "
                        "it is not defined in the focalplane detector data."
                    )
                    raise RuntimeError(msg)
                for det in detectors:
                    value = focalplane[det][self.focalplane_key]
                    if value not in groups:
                        groups[value] = []
                    groups[value].append(det)
                ngroup = len(groups)
                for igroup, group in enumerate(sorted(groups)):
                    group_ids[group] = igroup
                    for det in groups[group]:
                        group_index[det] = igroup

            # Enumerate detectors to process

            # Mapping from good detector name to index
            detector_index = {y: x for x, y in enumerate(detectors)}

            # Measure offset for each group, translate and scale
            # detector positions to [-1, 1]

            all_positions = []
            for group, detectors_group in groups.items():
                ndet_group = len(detectors_group)
                theta_offset, phi_offset = 0, 0
                for det in detectors_group:
                    theta, phi = detector_position[det]
                    theta_offset += theta
                    phi_offset += phi
                theta_offset /= ndet_group
                phi_offset /= ndet_group
                for det in detectors_group:
                    theta, phi = detector_position[det]
                    detector_position[det] = [theta - theta_offset, phi - phi_offset]
                    all_positions.append(detector_position[det])

            thetavec, phivec = np.vstack(all_positions).T
            thetamax = np.amax(np.abs(thetavec))
            phimax = np.amax(np.abs(phivec))
            scale = 0.999 / max(thetamax, phimax)

            for det in detectors:
                theta, phi = detector_position[det]
                detector_position[det] = [theta * scale, phi * scale]

            # Now evaluate the polynomial templates at the sites of
            # each detector

            orders = np.arange(norder)
            xorders, yorders = np.meshgrid(orders, orders, indexing="ij")
            xorders = xorders.ravel()
            yorders = yorders.ravel()

            detector_templates = np.zeros([ndet, nmode])
            for det in detectors:
                idet = detector_index[det]
                theta, phi = detector_position[det]
                detector_templates[idet] = theta**xorders * phi**yorders

            gt.stop("Poly2D:  Detector setup")

            # Iterate over each interval

            # Aligned memory objects using C-allocated memory so that we
            # can explicitly free it after processing.
            template_mem = AlignedF64()
            mask_mem = AlignedU8()
            signal_mem = AlignedF64()
            coeff_mem = AlignedF64()

            views = temp_ob.intervals[self.view]
            for iview, view in enumerate(views):
                nsample = view.last - view.first
                vslice = slice(view.first, view.last)

                # Accumulate the linear regression templates

                gt.start("Poly2D:  Accumulate templates")

                template_mem.resize(ndet * nmode)
                template_mem[:] = 0
                templates = template_mem.array().reshape((ndet, nmode))

                mask_mem.resize(nsample * ndet)
                mask_mem[:] = 0
                masks = mask_mem.array().reshape((nsample, ndet))

                signal_mem.resize(nsample * ndet)
                signal_mem[:] = 0
                signals = signal_mem.array().reshape((nsample, ndet))

                coeff_mem.resize(nsample * ngroup * nmode)
                coeff_mem[:] = 0
                coeff = coeff_mem.array().reshape((nsample, ngroup, nmode))

                det_groups = -1 * np.ones(ndet, dtype=np.int32)

                if self.shared_flags is not None:
                    shared_flags = temp_ob.shared[self.shared_flags][vslice]
                    shared_mask = (shared_flags & self.shared_flag_mask) == 0
                else:
                    shared_mask = np.ones(nsample, dtype=bool)

                for det in detectors:
                    ind_det = detector_index[det]
                    ind_group = group_index[det]
                    det_groups[ind_det] = ind_group

                    signal = temp_ob.detdata[self.det_data][det, vslice]
                    if self.det_flags is not None:
                        det_flags = temp_ob.detdata[self.det_flags][det, vslice]
                        det_mask = (det_flags & self.det_flag_mask) == 0
                        mask = np.logical_and(shared_mask, det_mask)
                    else:
                        mask = shared_mask

                    templates[ind_det, :] = detector_templates[ind_det]
                    masks[:, ind_det] = mask
                    signals[:, ind_det] = signal * mask

                gt.stop("Poly2D:  Accumulate templates")

                gt.start("Poly2D:  Solve templates")
                filter_poly2D(
                    det_groups,
                    templates,
                    signals,
                    masks,
                    coeff,
                    impl=implementation,
                    use_accel=use_accel,
                )
                gt.stop("Poly2D:  Solve templates")

                gt.start("Poly2D:  Update detector flags")

                for igroup in range(ngroup):
                    dets_in_group = np.zeros(ndet, dtype=bool)
                    for idet, det in enumerate(detectors):
                        if group_index[det] == igroup:
                            dets_in_group[idet] = True
                    if not np.any(dets_in_group):
                        continue
                    if self.det_flags is not None:
                        sample_flags = np.ones(
                            ndet,
                            dtype=temp_ob.detdata[self.det_flags].dtype,
                        )
                        sample_flags *= self.poly_flag_mask
                        sample_flags *= dets_in_group
                        for isample in range(nsample):
                            if np.all(coeff[isample, igroup] == 0):
                                for idet, det in enumerate(detectors):
                                    temp_ob.detdata[self.det_flags][
                                        det, view.first + isample
                                    ] |= sample_flags[idet]

                gt.stop("Poly2D:  Update detector flags")

                gt.start("Poly2D:  Clean timestreams")

                trcoeff = np.transpose(
                    np.array(coeff), [1, 0, 2]
                )  # ngroup x nsample x nmode
                trmasks = np.array(masks).T  # ndet x nsample
                for idet, det in enumerate(detectors):
                    igroup = group_index[det]
                    ind = detector_index[det]
                    signal = temp_ob.detdata[self.det_data][det, vslice]
                    mask = trmasks[idet]
                    signal -= np.sum(trcoeff[igroup] * templates[ind], 1) * mask

                gt.stop("Poly2D:  Clean timestreams")

            # Redistribute back
            gt.start("Poly2D:  Backward redistribution")
            temp_ob.redistribute(
                proc_rows, times=self.times, override_sample_sets=obs.dist.sample_sets
            )
            gt.stop("Poly2D:  Backward redistribution")

            # Copy data to original observation
            gt.start("Poly2D:  Copy output")
            for det in obs.select_local_detectors(
                selection=None, flagmask=self.det_mask
            ):
                obs.detdata[self.det_data][det] = temp_ob.detdata[self.det_data][det]
                if self.det_flags is not None:
                    obs.detdata[self.det_flags][det] = temp_ob.detdata[self.det_flags][
                        det
                    ]
            gt.stop("Poly2D:  Copy output")

            # Free data copy
            temp_ob.clear()
            del temp_ob

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key apply filtering to') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid | defaults.det_mask_processing, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

focalplane_key = Unicode(None, allow_none=True, help='Which focalplane key to match') class-attribute instance-attribute

order = Int(1, allow_none=False, help='Polynomial order') class-attribute instance-attribute

pattern = Unicode(f'.*', allow_none=True, help='Regex pattern to match against detector names. Only detectors that match the pattern are filtered.') class-attribute instance-attribute

poly_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for samples that fail to filter') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
115
116
117
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
108
109
110
111
112
113
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
94
95
96
97
98
99
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/polyfilter/polyfilter.py
101
102
103
104
105
106
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    gt = GlobalTimers.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if detectors is not None:
        raise RuntimeError("PolyFilter2D cannot be run on subsets of detectors")
    norder = self.order + 1
    nmode = norder**2
    pat = re.compile(self.pattern)

    for obs in data.obs:
        # Get the original number of process rows in the observation
        proc_rows = obs.dist.process_rows

        # Duplicate just the fields of the observation we will use
        gt.start("Poly2D:  Duplicate obs")
        dup_shared = list()
        if self.shared_flags is not None:
            dup_shared.append(self.shared_flags)
        dup_detdata = [self.det_data]
        if self.det_flags is not None:
            dup_detdata.append(self.det_flags)
        dup_intervals = list()
        if self.view is not None:
            dup_intervals.append(self.view)
        temp_ob = obs.duplicate(
            times=self.times,
            meta=list(),
            shared=dup_shared,
            detdata=dup_detdata,
            intervals=dup_intervals,
        )
        gt.stop("Poly2D:  Duplicate obs")

        # Redistribute this temporary observation to be distributed by samples
        gt.start("Poly2D:  Forward redistribution")
        temp_ob.redistribute(1, times=self.times, override_sample_sets=None)
        gt.stop("Poly2D:  Forward redistribution")

        gt.start("Poly2D:  Detector setup")

        # Detectors to process.  We apply the detector flag mask to this
        # selection, in order to avoid bad detectors in the fit.
        detectors = []
        for det in temp_ob.select_local_detectors(
            selection=None, flagmask=self.det_mask
        ):
            if pat.match(det) is None:
                continue
            detectors.append(det)
        ndet = len(detectors)
        if ndet == 0:
            continue

        # Detector positions

        detector_position = {}
        for det in detectors:
            x, y, z = qa.rotate(temp_ob.telescope.focalplane[det]["quat"], ZAXIS)
            theta, phi = np.arcsin([x, y])
            detector_position[det] = [theta, phi]

        # Enumerate detector groups (e.g. wafers) to filter

        # The integer group ID for a given detector
        group_index = {}

        # The list of detectors for each group key
        groups = {}

        # The integer group ID for each group key
        group_ids = {}

        if self.focalplane_key is None:
            # We have just one group of all detectors
            groups[None] = []
            group_ids[None] = 0
            ngroup = 1
            for det in detectors:
                group_index[det] = 0
                groups[None].append(det)
        else:
            focalplane = temp_ob.telescope.focalplane
            if self.focalplane_key not in focalplane.detector_data.colnames:
                msg = (
                    f"Cannot divide detectors by {self.focalplane_key} because "
                    "it is not defined in the focalplane detector data."
                )
                raise RuntimeError(msg)
            for det in detectors:
                value = focalplane[det][self.focalplane_key]
                if value not in groups:
                    groups[value] = []
                groups[value].append(det)
            ngroup = len(groups)
            for igroup, group in enumerate(sorted(groups)):
                group_ids[group] = igroup
                for det in groups[group]:
                    group_index[det] = igroup

        # Enumerate detectors to process

        # Mapping from good detector name to index
        detector_index = {y: x for x, y in enumerate(detectors)}

        # Measure offset for each group, translate and scale
        # detector positions to [-1, 1]

        all_positions = []
        for group, detectors_group in groups.items():
            ndet_group = len(detectors_group)
            theta_offset, phi_offset = 0, 0
            for det in detectors_group:
                theta, phi = detector_position[det]
                theta_offset += theta
                phi_offset += phi
            theta_offset /= ndet_group
            phi_offset /= ndet_group
            for det in detectors_group:
                theta, phi = detector_position[det]
                detector_position[det] = [theta - theta_offset, phi - phi_offset]
                all_positions.append(detector_position[det])

        thetavec, phivec = np.vstack(all_positions).T
        thetamax = np.amax(np.abs(thetavec))
        phimax = np.amax(np.abs(phivec))
        scale = 0.999 / max(thetamax, phimax)

        for det in detectors:
            theta, phi = detector_position[det]
            detector_position[det] = [theta * scale, phi * scale]

        # Now evaluate the polynomial templates at the sites of
        # each detector

        orders = np.arange(norder)
        xorders, yorders = np.meshgrid(orders, orders, indexing="ij")
        xorders = xorders.ravel()
        yorders = yorders.ravel()

        detector_templates = np.zeros([ndet, nmode])
        for det in detectors:
            idet = detector_index[det]
            theta, phi = detector_position[det]
            detector_templates[idet] = theta**xorders * phi**yorders

        gt.stop("Poly2D:  Detector setup")

        # Iterate over each interval

        # Aligned memory objects using C-allocated memory so that we
        # can explicitly free it after processing.
        template_mem = AlignedF64()
        mask_mem = AlignedU8()
        signal_mem = AlignedF64()
        coeff_mem = AlignedF64()

        views = temp_ob.intervals[self.view]
        for iview, view in enumerate(views):
            nsample = view.last - view.first
            vslice = slice(view.first, view.last)

            # Accumulate the linear regression templates

            gt.start("Poly2D:  Accumulate templates")

            template_mem.resize(ndet * nmode)
            template_mem[:] = 0
            templates = template_mem.array().reshape((ndet, nmode))

            mask_mem.resize(nsample * ndet)
            mask_mem[:] = 0
            masks = mask_mem.array().reshape((nsample, ndet))

            signal_mem.resize(nsample * ndet)
            signal_mem[:] = 0
            signals = signal_mem.array().reshape((nsample, ndet))

            coeff_mem.resize(nsample * ngroup * nmode)
            coeff_mem[:] = 0
            coeff = coeff_mem.array().reshape((nsample, ngroup, nmode))

            det_groups = -1 * np.ones(ndet, dtype=np.int32)

            if self.shared_flags is not None:
                shared_flags = temp_ob.shared[self.shared_flags][vslice]
                shared_mask = (shared_flags & self.shared_flag_mask) == 0
            else:
                shared_mask = np.ones(nsample, dtype=bool)

            for det in detectors:
                ind_det = detector_index[det]
                ind_group = group_index[det]
                det_groups[ind_det] = ind_group

                signal = temp_ob.detdata[self.det_data][det, vslice]
                if self.det_flags is not None:
                    det_flags = temp_ob.detdata[self.det_flags][det, vslice]
                    det_mask = (det_flags & self.det_flag_mask) == 0
                    mask = np.logical_and(shared_mask, det_mask)
                else:
                    mask = shared_mask

                templates[ind_det, :] = detector_templates[ind_det]
                masks[:, ind_det] = mask
                signals[:, ind_det] = signal * mask

            gt.stop("Poly2D:  Accumulate templates")

            gt.start("Poly2D:  Solve templates")
            filter_poly2D(
                det_groups,
                templates,
                signals,
                masks,
                coeff,
                impl=implementation,
                use_accel=use_accel,
            )
            gt.stop("Poly2D:  Solve templates")

            gt.start("Poly2D:  Update detector flags")

            for igroup in range(ngroup):
                dets_in_group = np.zeros(ndet, dtype=bool)
                for idet, det in enumerate(detectors):
                    if group_index[det] == igroup:
                        dets_in_group[idet] = True
                if not np.any(dets_in_group):
                    continue
                if self.det_flags is not None:
                    sample_flags = np.ones(
                        ndet,
                        dtype=temp_ob.detdata[self.det_flags].dtype,
                    )
                    sample_flags *= self.poly_flag_mask
                    sample_flags *= dets_in_group
                    for isample in range(nsample):
                        if np.all(coeff[isample, igroup] == 0):
                            for idet, det in enumerate(detectors):
                                temp_ob.detdata[self.det_flags][
                                    det, view.first + isample
                                ] |= sample_flags[idet]

            gt.stop("Poly2D:  Update detector flags")

            gt.start("Poly2D:  Clean timestreams")

            trcoeff = np.transpose(
                np.array(coeff), [1, 0, 2]
            )  # ngroup x nsample x nmode
            trmasks = np.array(masks).T  # ndet x nsample
            for idet, det in enumerate(detectors):
                igroup = group_index[det]
                ind = detector_index[det]
                signal = temp_ob.detdata[self.det_data][det, vslice]
                mask = trmasks[idet]
                signal -= np.sum(trcoeff[igroup] * templates[ind], 1) * mask

            gt.stop("Poly2D:  Clean timestreams")

        # Redistribute back
        gt.start("Poly2D:  Backward redistribution")
        temp_ob.redistribute(
            proc_rows, times=self.times, override_sample_sets=obs.dist.sample_sets
        )
        gt.stop("Poly2D:  Backward redistribution")

        # Copy data to original observation
        gt.start("Poly2D:  Copy output")
        for det in obs.select_local_detectors(
            selection=None, flagmask=self.det_mask
        ):
            obs.detdata[self.det_data][det] = temp_ob.detdata[self.det_data][det]
            if self.det_flags is not None:
                obs.detdata[self.det_flags][det] = temp_ob.detdata[self.det_flags][
                    det
                ]
        gt.stop("Poly2D:  Copy output")

        # Free data copy
        temp_ob.clear()
        del temp_ob

_finalize(data, **kwargs)

Source code in toast/ops/polyfilter/polyfilter.py
406
407
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/polyfilter/polyfilter.py
424
425
426
427
428
429
430
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
    }
    return prov

_requires()

Source code in toast/ops/polyfilter/polyfilter.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

Specialized Filters

toast.ops.GroundFilter

Bases: Operator

Operator that applies ground template filtering to azimuthal scans.

Source code in toast/ops/groundfilter.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
@trait_docs
class GroundFilter(Operator):
    """Operator that applies ground template filtering to azimuthal scans."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    ground_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask to use when adding flags based on ground filter failures.",
    )

    azimuth = Unicode(
        defaults.azimuth, allow_none=True, help="Observation shared key for Azimuth"
    )

    boresight_azel = Unicode(
        defaults.boresight_azel,
        allow_none=True,
        help="Observation shared key for boresight Az/El",
    )

    trend_order = Int(
        5, help="Order of a Legendre polynomial to fit along with the ground template."
    )

    filter_order = Int(
        5, help="Order of a Legendre polynomial to fit as a function of azimuth."
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    detrend = Bool(
        False, help="Subtract the fitted trend along with the ground template"
    )

    split_template = Bool(
        False, help="Apply a different template for left and right scans"
    )

    leftright_interval = Unicode(
        defaults.throw_leftright_interval,
        help="Intervals for left-to-right scans",
    )

    rightleft_interval = Unicode(
        defaults.throw_rightleft_interval,
        help="Intervals for right-to-left scans",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("trend_order")
    def _check_trend_order(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Trend order should be a non-negative integer")
        return check

    @traitlets.validate("filter_order")
    def _check_filter_order(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Filter order should be a non-negative integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def build_templates(self, obs):
        """Construct the local ground template hierarchy"""

        views = obs.view[self.view]

        # Construct trend templates.  Full domain for x is [-1, 1]

        my_offset = obs.local_index_offset
        my_nsamp = obs.n_local_samples
        nsamp_tot = obs.n_all_samples
        x = np.arange(my_offset, my_offset + my_nsamp) / nsamp_tot * 2 - 1

        # Do not include the offset in the trend.  It will be part of
        # of the ground template
        legendre_trend = np.zeros([self.trend_order, x.size])
        legendre(x, legendre_trend, 1, self.trend_order + 1)

        try:
            if self.azimuth is not None:
                az = obs.shared[self.azimuth]
            else:
                quats = obs.shared[self.boresight_azel]
                theta, phi, _ = qa.to_iso_angles(quats)
                az = 2 * np.pi - phi
            if "scan_min_az" in obs:
                azmin = obs["scan_min_az"].to_value(u.radian)
                azmax = obs["scan_max_az"].to_value(u.radian)
            else:
                azmin = np.amin(az)
                azmax = np.amax(az)
                comm = obs.comm.comm_group
                if comm is not None:
                    azmin = comm.allreduce(azmin, op=MPI.MIN)
                    azmax = comm.allreduce(azmax, op=MPI.MAX)
                obs["scan_min_az"] = azmin * u.radian
                obs["scan_max_az"] = azmax * u.radian
        except Exception as e:
            msg = (
                f"Failed to get boresight azimuth from TOD.  "
                f"Perhaps it is not ground TOD? '{e}'"
            )
            raise RuntimeError(msg)

        # The azimuth vector is assumed to be arranged so that the
        # azimuth increases monotonously even across the zero meridian.

        phase = (np.unwrap(az) - azmin) / (azmax - azmin) * 2 - 1
        nfilter = self.filter_order + 1
        legendre_templates = np.zeros([nfilter, phase.size])
        legendre(phase, legendre_templates, 0, nfilter)
        if not self.split_template:
            legendre_filter = legendre_templates
        else:
            # Create separate templates for alternating scans
            common_flags = obs.shared[self.shared_flags].data
            legendre_filter = []
            masks = []
            for name in self.leftright_interval, self.rightleft_interval:
                mask = np.zeros(phase.size, dtype=bool)
                for ival in obs.intervals[name]:
                    mask[ival.first : ival.last] = True
                masks.append(mask)
            for template in legendre_templates:
                for mask in masks:
                    temp = template.copy()
                    temp[mask] = 0
                    legendre_filter.append(temp)
            legendre_filter = np.vstack(legendre_filter)

        templates = np.vstack([legendre_trend, legendre_filter])

        return templates, legendre_trend, legendre_filter

    @function_timer
    def fit_templates(
        self,
        obs,
        templates,
        ref,
        good,
        last_good,
        last_invcov,
        last_cov,
        last_rcond,
    ):
        log = Logger.get()
        # communicator for processes with the same detectors
        comm = obs.comm_row
        ngood = np.sum(good)
        ntask = 1
        if comm is not None:
            ngood = comm.allreduce(ngood)
            ntask = comm.size
        if ngood == 0:
            return None, None, None, None

        ntemplate = len(templates)
        invcov = np.zeros([ntemplate, ntemplate])
        proj = np.zeros(ntemplate)

        bin_proj_fast(ref, templates, good.astype(np.uint8), proj)
        if last_good is not None and np.all(good == last_good) and ntask == 1:
            # Flags have not changed, we can re-use the last inverse covariance
            invcov = last_invcov
            cov = last_cov
            rcond = last_rcond
        else:
            bin_invcov_fast(templates, good.astype(np.uint8), invcov)
            if comm is not None:
                # Reduce the binned data.  The detector signal is
                # distributed across the group communicator.
                comm.Allreduce(MPI.IN_PLACE, invcov, op=MPI.SUM)
                comm.Allreduce(MPI.IN_PLACE, proj, op=MPI.SUM)
            rcond = get_rcond(invcov)
            cov = None

        self.rcondsum += rcond
        if rcond > 1e-6:
            self.ngood += 1
            if cov is None:
                cov = get_inverse(invcov)
        else:
            self.nsingular += 1
            log.debug(
                f"Ground template matrix is poorly conditioned, "
                f"rcond = {rcond}, using pseudoinverse."
            )
            if cov is None:
                cov = get_pseudoinverse(invcov)
        coeff = np.dot(cov, proj)

        return coeff, invcov, cov, rcond

    @function_timer
    def subtract_templates(self, ref, good, coeff, legendre_trend, legendre_filter):
        # Trend
        if self.detrend:
            trend = np.zeros_like(ref)
            add_templates(trend, legendre_trend, coeff[: self.trend_order])
            ref -= trend
        # Ground template
        grtemplate = np.zeros(ref.size, dtype=np.float64)
        add_templates(grtemplate, legendre_filter, coeff[self.trend_order :])
        ref -= grtemplate
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        t0 = time()
        env = Environment.get()
        log = Logger.get()

        wcomm = data.comm.comm_world
        gcomm = data.comm.comm_group

        self.nsingular = 0
        self.ngood = 0
        self.rcondsum = 0

        # Each group loops over its own CES:es
        nobs = len(data.obs)
        for iobs, obs in enumerate(data.obs):
            # Prefix for logging
            log_prefix = f"{data.comm.group} : {obs.name} :"

            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} OpGroundFilter: "
                    f"Processing observation {iobs + 1} / {nobs}"
                )
                log.debug(msg)

            # Cache the output common flags
            if self.shared_flags is not None:
                common_flags = (
                    obs.shared[self.shared_flags].data & self.shared_flag_mask
                )
            else:
                common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

            t1 = time()
            templates, legendre_trend, legendre_filter = self.build_templates(obs)
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} OpGroundFilter: "
                    f"Built templates in {time() - t1:.1f}s"
                )
                log.debug(msg)

            last_good = None
            last_invcov = None
            last_cov = None
            last_rcond = None

            for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
                if data.comm.group_rank == 0:
                    msg = f"{log_prefix} OpGroundFilter: " f"Processing detector {det}"
                    log.verbose(msg)

                ref = obs.detdata[self.det_data][det]
                if self.det_flags is not None:
                    test_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                    good = np.logical_and(common_flags == 0, test_flags == 0)
                else:
                    good = common_flags == 0

                t1 = time()
                coeff, last_invcov, last_cov, last_rcond = self.fit_templates(
                    obs,
                    templates,
                    ref,
                    good,
                    last_good,
                    last_invcov,
                    last_cov,
                    last_rcond,
                )
                last_good = good
                if data.comm.group_rank == 0:
                    msg = (
                        f"{log_prefix} OpGroundFilter: "
                        f"Fit templates in {time() - t1:.1f}s"
                    )
                    log.verbose(msg)

                if coeff is None:
                    # All samples flagged or template fit failed.
                    curflag = obs.local_detector_flags[det]
                    obs.update_local_detector_flags(
                        {det: curflag | self.ground_flag_mask}
                    )
                    continue

                t1 = time()
                self.subtract_templates(
                    ref, good, coeff, legendre_trend, legendre_filter
                )
                if data.comm.group_rank == 0:
                    msg = (
                        f"{log_prefix} OpGroundFilter: "
                        f"Subtract templates in {time() - t1:.1f}s"
                    )
                    log.verbose(msg)
            del last_good
            del last_invcov
            del last_cov
            del last_rcond

        if wcomm is not None:
            self.nsingular = wcomm.allreduce(self.nsingular)
            self.ngood = wcomm.allreduce(self.ngood)
            self.rcondsum = wcomm.allreduce(self.rcondsum)

        if wcomm is None or wcomm.rank == 0:
            rcond_mean = self.rcondsum / (self.nsingular + self.ngood)
            msg = (
                f"Applied ground filter in {time() - t0:.1f} s.  "
                f"Average rcond of template matrix was {rcond_mean}"
            )
            log.debug(msg)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "shared": list(),
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.azimuth is not None:
            req["shared"].append(self.azimuth)
        if self.boresight_azel is not None:
            req["shared"].append(self.boresight_azel)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

azimuth = Unicode(defaults.azimuth, allow_none=True, help='Observation shared key for Azimuth') class-attribute instance-attribute

boresight_azel = Unicode(defaults.boresight_azel, allow_none=True, help='Observation shared key for boresight Az/El') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

detrend = Bool(False, help='Subtract the fitted trend along with the ground template') class-attribute instance-attribute

filter_order = Int(5, help='Order of a Legendre polynomial to fit as a function of azimuth.') class-attribute instance-attribute

ground_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask to use when adding flags based on ground filter failures.') class-attribute instance-attribute

leftright_interval = Unicode(defaults.throw_leftright_interval, help='Intervals for left-to-right scans') class-attribute instance-attribute

rightleft_interval = Unicode(defaults.throw_rightleft_interval, help='Intervals for right-to-left scans') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

split_template = Bool(False, help='Apply a different template for left and right scans') class-attribute instance-attribute

trend_order = Int(5, help='Order of a Legendre polynomial to fit along with the ground template.') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/groundfilter.py
179
180
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/groundfilter.py
151
152
153
154
155
156
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/groundfilter.py
144
145
146
147
148
149
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_filter_order(proposal)

Source code in toast/ops/groundfilter.py
172
173
174
175
176
177
@traitlets.validate("filter_order")
def _check_filter_order(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Filter order should be a non-negative integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/groundfilter.py
158
159
160
161
162
163
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_trend_order(proposal)

Source code in toast/ops/groundfilter.py
165
166
167
168
169
170
@traitlets.validate("trend_order")
def _check_trend_order(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Trend order should be a non-negative integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/groundfilter.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    t0 = time()
    env = Environment.get()
    log = Logger.get()

    wcomm = data.comm.comm_world
    gcomm = data.comm.comm_group

    self.nsingular = 0
    self.ngood = 0
    self.rcondsum = 0

    # Each group loops over its own CES:es
    nobs = len(data.obs)
    for iobs, obs in enumerate(data.obs):
        # Prefix for logging
        log_prefix = f"{data.comm.group} : {obs.name} :"

        if data.comm.group_rank == 0:
            msg = (
                f"{log_prefix} OpGroundFilter: "
                f"Processing observation {iobs + 1} / {nobs}"
            )
            log.debug(msg)

        # Cache the output common flags
        if self.shared_flags is not None:
            common_flags = (
                obs.shared[self.shared_flags].data & self.shared_flag_mask
            )
        else:
            common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

        t1 = time()
        templates, legendre_trend, legendre_filter = self.build_templates(obs)
        if data.comm.group_rank == 0:
            msg = (
                f"{log_prefix} OpGroundFilter: "
                f"Built templates in {time() - t1:.1f}s"
            )
            log.debug(msg)

        last_good = None
        last_invcov = None
        last_cov = None
        last_rcond = None

        for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
            if data.comm.group_rank == 0:
                msg = f"{log_prefix} OpGroundFilter: " f"Processing detector {det}"
                log.verbose(msg)

            ref = obs.detdata[self.det_data][det]
            if self.det_flags is not None:
                test_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                good = np.logical_and(common_flags == 0, test_flags == 0)
            else:
                good = common_flags == 0

            t1 = time()
            coeff, last_invcov, last_cov, last_rcond = self.fit_templates(
                obs,
                templates,
                ref,
                good,
                last_good,
                last_invcov,
                last_cov,
                last_rcond,
            )
            last_good = good
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} OpGroundFilter: "
                    f"Fit templates in {time() - t1:.1f}s"
                )
                log.verbose(msg)

            if coeff is None:
                # All samples flagged or template fit failed.
                curflag = obs.local_detector_flags[det]
                obs.update_local_detector_flags(
                    {det: curflag | self.ground_flag_mask}
                )
                continue

            t1 = time()
            self.subtract_templates(
                ref, good, coeff, legendre_trend, legendre_filter
            )
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} OpGroundFilter: "
                    f"Subtract templates in {time() - t1:.1f}s"
                )
                log.verbose(msg)
        del last_good
        del last_invcov
        del last_cov
        del last_rcond

    if wcomm is not None:
        self.nsingular = wcomm.allreduce(self.nsingular)
        self.ngood = wcomm.allreduce(self.ngood)
        self.rcondsum = wcomm.allreduce(self.rcondsum)

    if wcomm is None or wcomm.rank == 0:
        rcond_mean = self.rcondsum / (self.nsingular + self.ngood)
        msg = (
            f"Applied ground filter in {time() - t0:.1f} s.  "
            f"Average rcond of template matrix was {rcond_mean}"
        )
        log.debug(msg)

    return

_finalize(data, **kwargs)

Source code in toast/ops/groundfilter.py
446
447
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/groundfilter.py
464
465
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/groundfilter.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def _requires(self):
    req = {
        "shared": list(),
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.azimuth is not None:
        req["shared"].append(self.azimuth)
    if self.boresight_azel is not None:
        req["shared"].append(self.boresight_azel)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

build_templates(obs)

Construct the local ground template hierarchy

Source code in toast/ops/groundfilter.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
@function_timer
def build_templates(self, obs):
    """Construct the local ground template hierarchy"""

    views = obs.view[self.view]

    # Construct trend templates.  Full domain for x is [-1, 1]

    my_offset = obs.local_index_offset
    my_nsamp = obs.n_local_samples
    nsamp_tot = obs.n_all_samples
    x = np.arange(my_offset, my_offset + my_nsamp) / nsamp_tot * 2 - 1

    # Do not include the offset in the trend.  It will be part of
    # of the ground template
    legendre_trend = np.zeros([self.trend_order, x.size])
    legendre(x, legendre_trend, 1, self.trend_order + 1)

    try:
        if self.azimuth is not None:
            az = obs.shared[self.azimuth]
        else:
            quats = obs.shared[self.boresight_azel]
            theta, phi, _ = qa.to_iso_angles(quats)
            az = 2 * np.pi - phi
        if "scan_min_az" in obs:
            azmin = obs["scan_min_az"].to_value(u.radian)
            azmax = obs["scan_max_az"].to_value(u.radian)
        else:
            azmin = np.amin(az)
            azmax = np.amax(az)
            comm = obs.comm.comm_group
            if comm is not None:
                azmin = comm.allreduce(azmin, op=MPI.MIN)
                azmax = comm.allreduce(azmax, op=MPI.MAX)
            obs["scan_min_az"] = azmin * u.radian
            obs["scan_max_az"] = azmax * u.radian
    except Exception as e:
        msg = (
            f"Failed to get boresight azimuth from TOD.  "
            f"Perhaps it is not ground TOD? '{e}'"
        )
        raise RuntimeError(msg)

    # The azimuth vector is assumed to be arranged so that the
    # azimuth increases monotonously even across the zero meridian.

    phase = (np.unwrap(az) - azmin) / (azmax - azmin) * 2 - 1
    nfilter = self.filter_order + 1
    legendre_templates = np.zeros([nfilter, phase.size])
    legendre(phase, legendre_templates, 0, nfilter)
    if not self.split_template:
        legendre_filter = legendre_templates
    else:
        # Create separate templates for alternating scans
        common_flags = obs.shared[self.shared_flags].data
        legendre_filter = []
        masks = []
        for name in self.leftright_interval, self.rightleft_interval:
            mask = np.zeros(phase.size, dtype=bool)
            for ival in obs.intervals[name]:
                mask[ival.first : ival.last] = True
            masks.append(mask)
        for template in legendre_templates:
            for mask in masks:
                temp = template.copy()
                temp[mask] = 0
                legendre_filter.append(temp)
        legendre_filter = np.vstack(legendre_filter)

    templates = np.vstack([legendre_trend, legendre_filter])

    return templates, legendre_trend, legendre_filter

fit_templates(obs, templates, ref, good, last_good, last_invcov, last_cov, last_rcond)

Source code in toast/ops/groundfilter.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@function_timer
def fit_templates(
    self,
    obs,
    templates,
    ref,
    good,
    last_good,
    last_invcov,
    last_cov,
    last_rcond,
):
    log = Logger.get()
    # communicator for processes with the same detectors
    comm = obs.comm_row
    ngood = np.sum(good)
    ntask = 1
    if comm is not None:
        ngood = comm.allreduce(ngood)
        ntask = comm.size
    if ngood == 0:
        return None, None, None, None

    ntemplate = len(templates)
    invcov = np.zeros([ntemplate, ntemplate])
    proj = np.zeros(ntemplate)

    bin_proj_fast(ref, templates, good.astype(np.uint8), proj)
    if last_good is not None and np.all(good == last_good) and ntask == 1:
        # Flags have not changed, we can re-use the last inverse covariance
        invcov = last_invcov
        cov = last_cov
        rcond = last_rcond
    else:
        bin_invcov_fast(templates, good.astype(np.uint8), invcov)
        if comm is not None:
            # Reduce the binned data.  The detector signal is
            # distributed across the group communicator.
            comm.Allreduce(MPI.IN_PLACE, invcov, op=MPI.SUM)
            comm.Allreduce(MPI.IN_PLACE, proj, op=MPI.SUM)
        rcond = get_rcond(invcov)
        cov = None

    self.rcondsum += rcond
    if rcond > 1e-6:
        self.ngood += 1
        if cov is None:
            cov = get_inverse(invcov)
    else:
        self.nsingular += 1
        log.debug(
            f"Ground template matrix is poorly conditioned, "
            f"rcond = {rcond}, using pseudoinverse."
        )
        if cov is None:
            cov = get_pseudoinverse(invcov)
    coeff = np.dot(cov, proj)

    return coeff, invcov, cov, rcond

subtract_templates(ref, good, coeff, legendre_trend, legendre_filter)

Source code in toast/ops/groundfilter.py
316
317
318
319
320
321
322
323
324
325
326
327
@function_timer
def subtract_templates(self, ref, good, coeff, legendre_trend, legendre_filter):
    # Trend
    if self.detrend:
        trend = np.zeros_like(ref)
        add_templates(trend, legendre_trend, coeff[: self.trend_order])
        ref -= trend
    # Ground template
    grtemplate = np.zeros(ref.size, dtype=np.float64)
    add_templates(grtemplate, legendre_filter, coeff[self.trend_order :])
    ref -= grtemplate
    return

toast.ops.MitigateCrossTalk

Bases: Operator

  1. The cross talk matrix can just be a dictionary of dictionaries of values (i.e. a sparse matrix) on every process. It does not need to be a dense matrix loaded from an HDF5 file. The calling code can create this however it likes.

  2. Each process has a DetectorData object representing the local data for some detectors and some timespan (e.g. obs.detdata["signal"]). It can make a copy of this and pass it to the next rank in the grid column. Each process receives a copy from the previous process in the column, accumulates to its local detectors, and passes it along. This continues until every process has accumulated the data from the other processes in the column.

Source code in toast/ops/sim_crosstalk.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
@trait_docs
class MitigateCrossTalk(Operator):
    """
    1.  The cross talk matrix can just be a dictionary of
    dictionaries of values (i.e. a sparse matrix) on every process.
    It does not need to be a dense matrix loaded from an HDF5 file.
    The calling code can create this however it likes.

    2. Each process has a DetectorData object representing the local data for some
    detectors and some timespan (e.g. obs.detdata["signal"]).
    It can make a copy of this and pass it to the next rank in the grid column.
    Each process receives a copy from the previous process in the column,
    accumulates to its local detectors, and passes it along.
    This continues until every process has accumulated the data
    from the other processes in the column.
    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_data = Unicode(
        None, allow_none=True, help="Observation detdata key for the timestream data"
    )

    xtalk_mat_file = Unicode(
        None, allow_none=True, help="CrossTalk matrix dictionary of dictionaries"
    )

    realization = Int(0, help="integer to set a different random seed ")
    error_coefficients = Float(
        0, help="relative amplitude to simulate crosstalk errors on the inverse matrix "
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, **kwargs):
        env = Environment.get()
        log = Logger.get()

        ## Read the XTalk matrix from file or initialize it randomly
        if self.xtalk_mat_file is None:
            self.xtalk_mat = init_xtalk_matrix(data, realization=self.realization)
        else:
            self.xtalk_mat = read_xtalk_matrix(self.xtalk_mat_file, data)

        ## Inject an error to the matrix coefficients
        if self.error_coefficients:
            self.xtalk_mat = inject_error_in_xtalk_matrix(
                self.xtalk_mat, self.error_coefficients, realization=self.realization
            )
        # invert the Xtalk matrix (encoding the error )
        self.inv_xtalk_mat = invert_xtalk_mat(self.xtalk_mat)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            comm = ob.comm.comm_group
            rank = ob.comm.group_rank
            # Redistribute data as in CrossTalk operator
            if ob.comm.group_size > 1:
                old_data_shape = ob.detdata[self.det_data].data.shape
                ob.redistribute(1, times=ob.shared["times"])
                new_data_shape = ob.detdata[self.det_data].data.shape
                assert new_data_shape[0] == len(ob.all_detectors)

            # we store the crosstalked data into a temporary array
            tmp = np.zeros_like(ob.detdata[self.det_data].data)

            for idet, det in enumerate(ob.all_detectors):
                # for a given detector only a subset
                # of detectors can be crosstalking

                xtalklist = list(self.xtalk_mat[det].keys())
                intersect_local = np.intersect1d(ob.all_detectors, xtalklist)
                ind1 = [xtalklist.index(k) for k in intersect_local]
                ind2 = [
                    ob.detdata[self.det_data].detectors.index(k)
                    for k in intersect_local
                ]

                xtalk_weights = np.array(
                    [self.inv_xtalk_mat[det][kk] for kk in np.array(xtalklist)[ind1]]
                )
                tmp[idet] += np.dot(
                    xtalk_weights, ob.detdata[self.det_data].data[ind2, :]
                )

            for idet, det in enumerate(ob.all_detectors):
                ob.detdata[self.det_data][det] = tmp[idet]
            # We distribute the data back to the previous distribution
            if ob.comm.group_size > 1:
                ob.redistribute(ob.comm.group_size, times=ob.shared["times"])

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": [
                self.boresight,
            ],
            "detdata": list(),
            "intervals": list(),
        }
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": [
                self.det_data,
            ],
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(None, allow_none=True, help='Observation detdata key for the timestream data') class-attribute instance-attribute

error_coefficients = Float(0, help='relative amplitude to simulate crosstalk errors on the inverse matrix ') class-attribute instance-attribute

realization = Int(0, help='integer to set a different random seed ') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

xtalk_mat_file = Unicode(None, allow_none=True, help='CrossTalk matrix dictionary of dictionaries') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/sim_crosstalk.py
388
389
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_exec(data, **kwargs)

Source code in toast/ops/sim_crosstalk.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
@function_timer
def _exec(self, data, **kwargs):
    env = Environment.get()
    log = Logger.get()

    ## Read the XTalk matrix from file or initialize it randomly
    if self.xtalk_mat_file is None:
        self.xtalk_mat = init_xtalk_matrix(data, realization=self.realization)
    else:
        self.xtalk_mat = read_xtalk_matrix(self.xtalk_mat_file, data)

    ## Inject an error to the matrix coefficients
    if self.error_coefficients:
        self.xtalk_mat = inject_error_in_xtalk_matrix(
            self.xtalk_mat, self.error_coefficients, realization=self.realization
        )
    # invert the Xtalk matrix (encoding the error )
    self.inv_xtalk_mat = invert_xtalk_mat(self.xtalk_mat)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        comm = ob.comm.comm_group
        rank = ob.comm.group_rank
        # Redistribute data as in CrossTalk operator
        if ob.comm.group_size > 1:
            old_data_shape = ob.detdata[self.det_data].data.shape
            ob.redistribute(1, times=ob.shared["times"])
            new_data_shape = ob.detdata[self.det_data].data.shape
            assert new_data_shape[0] == len(ob.all_detectors)

        # we store the crosstalked data into a temporary array
        tmp = np.zeros_like(ob.detdata[self.det_data].data)

        for idet, det in enumerate(ob.all_detectors):
            # for a given detector only a subset
            # of detectors can be crosstalking

            xtalklist = list(self.xtalk_mat[det].keys())
            intersect_local = np.intersect1d(ob.all_detectors, xtalklist)
            ind1 = [xtalklist.index(k) for k in intersect_local]
            ind2 = [
                ob.detdata[self.det_data].detectors.index(k)
                for k in intersect_local
            ]

            xtalk_weights = np.array(
                [self.inv_xtalk_mat[det][kk] for kk in np.array(xtalklist)[ind1]]
            )
            tmp[idet] += np.dot(
                xtalk_weights, ob.detdata[self.det_data].data[ind2, :]
            )

        for idet, det in enumerate(ob.all_detectors):
            ob.detdata[self.det_data][det] = tmp[idet]
        # We distribute the data back to the previous distribution
        if ob.comm.group_size > 1:
            ob.redistribute(ob.comm.group_size, times=ob.shared["times"])

    return

_finalize(data, **kwargs)

Source code in toast/ops/sim_crosstalk.py
451
452
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/sim_crosstalk.py
467
468
469
470
471
472
473
474
475
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": [
            self.det_data,
        ],
    }
    return prov

_requires()

Source code in toast/ops/sim_crosstalk.py
454
455
456
457
458
459
460
461
462
463
464
465
def _requires(self):
    req = {
        "meta": list(),
        "shared": [
            self.boresight,
        ],
        "detdata": list(),
        "intervals": list(),
    }
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

Half Wave Plate Tools

toast.ops.HWPSynchronousModel

Bases: Operator

Operator that models and removes HWP synchronous signal.

This fits and optionally subtracts a Maxipol / EBEX style model for the HWPSS. The time dependent drift term is optional. See the details in toast.hwp_utils.hwpss_compute_coeff_covariance().

The 2f component of the model is optionally used to build a relative calibration between detectors, either as a fixed table per observation or as continuously varying factors.

The HWPSS model can be constructed either with one set of template coefficients for the entire observation, or one set per time interval smoothly interpolated across the observation.

Source code in toast/ops/hwpss_model.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
@trait_docs
class HWPSynchronousModel(Operator):
    """Operator that models and removes HWP synchronous signal.

    This fits and optionally subtracts a Maxipol / EBEX style model for the HWPSS.
    The time dependent drift term is optional.  See the details in
    `toast.hwp_utils.hwpss_compute_coeff_covariance()`.

    The 2f component of the model is optionally used to build a relative calibration
    between detectors, either as a fixed table per observation or as continuously
    varying factors.

    The HWPSS model can be constructed either with one set of template coefficients
    for the entire observation, or one set per time interval smoothly interpolated
    across the observation.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    hwp_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask to use when adding flags based on HWP filter failures.",
    )

    hwp_angle = Unicode(
        defaults.hwp_angle, allow_none=True, help="Observation shared key for HWP angle"
    )

    harmonics = Int(9, help="Number of harmonics to consider in the expansion")

    subtract_model = Bool(False, help="Subtract the model from the input data")

    save_model = Unicode(
        None, allow_none=True, help="Save the model to this observation key"
    )

    chunk_view = Unicode(
        None,
        allow_none=True,
        help="The intervals over which to independently compute the HWPSS template",
    )

    chunk_time = Quantity(
        None,
        allow_none=True,
        help="The overlapping time chunks over which to compute the HWPSS template",
    )

    relcal_fixed = Unicode(
        None,
        allow_none=True,
        help="Build a relative calibration dictionary in this observation key",
    )

    relcal_continuous = Unicode(
        None,
        allow_none=True,
        help="Build interpolated relative calibration timestreams",
    )

    relcal_cut_sigma = Float(
        5.0, help="Sigma cut for outlier rejection based on relative calibration"
    )

    time_drift = Bool(False, help="If True, include time drift terms in the model")

    fill_gaps = Bool(False, help="If True, fill gaps with a simple noise model")

    debug = Unicode(
        None,
        allow_none=True,
        help="Path to directory for generating debug plots",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("harmonics")
    def _check_harmonics(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Harmonics should be a non-negative integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        if self.relcal_continuous is not None and self.relcal_fixed is not None:
            msg = "Only one of continuous and fixed relative calibration can be used"
            raise RuntimeError(msg)

        if self.chunk_view is not None and self.chunk_time is not None:
            msg = "Only one of chunk_view and chunk_time can be used"
            raise RuntimeError(msg)

        do_cal = self.relcal_continuous or self.relcal_fixed
        if not self.subtract_model and (self.save_model is None) and not do_cal:
            msg = "Nothing to do.  You should enable at least one of the options"
            msg += " to subtract or save the model or to generate calibrations."
            raise RuntimeError(msg)

        if self.debug is not None:
            if data.comm.world_rank == 0:
                os.makedirs(self.debug)
            if data.comm.comm_world is not None:
                data.comm.comm_world.barrier()

        for ob in data.obs:
            timer = Timer()
            timer.start()

            if not ob.is_distributed_by_detector:
                msg = f"{ob.name} is not distributed by detector"
                raise RuntimeError(msg)

            if self.hwp_angle not in ob.shared:
                # Nothing to do, but if a relative calibration
                # was requested, make a fake one.
                if self.relcal_fixed is not None:
                    ob[self.relcal_fixed] = {x: 1.0 for x in ob.local_detectors}
                if self.relcal_continuous is not None:
                    ob.detdata.ensure(
                        self.relcal_continuous,
                        dtype=np.float32,
                        create_units=ob.detdata[self.det_data].units,
                    )
                    ob.detdata[self.relcal_continuous][:, :] = 1.0
                msg = f"{ob.name} has no '{self.hwp_angle}' field, skipping"
                log.warning_rank(msg, comm=data.comm.comm_group)
                continue

            # Compute quantities we need for all detectors and which we
            # might re-use for overlapping chunks.

            # Local detectors we are considering
            local_dets = ob.select_local_detectors(flagmask=self.det_mask)
            n_dets = len(local_dets)

            # Get the timestamps relative to the observation start
            reltime = np.array(ob.shared[self.times].data, copy=True)
            time_offset = reltime[0]
            reltime -= time_offset

            # Compute the properties of the chunks we are using
            chunks = self._compute_chunking(ob, reltime)
            n_chunk = len(chunks)

            # Compute shared and per-detector flags.  These already have
            # masks applied and have values of either zero or one.
            sh_flags, det_flags = self._compute_flags(ob, local_dets)

            msg = f"HWPSS Model {ob.name}: compute flags and chunking in"
            log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

            # Trig quantities of the HWP angle
            sincos = hwpss_sincos_buffer(
                ob.shared[self.hwp_angle].data,
                sh_flags,
                self.harmonics,
                comm=ob.comm.comm_group,
            )
            msg = f"HWPSS Model {ob.name}: built sincos buffer in"
            log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

            # The coefficients for all detectors and chunks
            if self.time_drift:
                n_coeff = 4 * self.harmonics
            else:
                n_coeff = 2 * self.harmonics
            coeff = np.zeros((n_dets, n_coeff, n_chunk), dtype=np.float64)
            coeff_flags = np.zeros(n_chunk, dtype=np.uint8)

            for ichunk, chunk in enumerate(chunks):
                self._fit_chunk(
                    ob,
                    local_dets,
                    ichunk,
                    chunk["start"],
                    chunk["end"],
                    sincos,
                    sh_flags,
                    det_flags,
                    reltime,
                    coeff,
                    coeff_flags,
                )

            msg = f"HWPSS Model {ob.name}: fit model to all chunks in"
            log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

            if self.save_model is not None:
                self._store_model(ob, local_dets, chunks, coeff, coeff_flags)

            # Even if we are not saving a fixed relative calibration table, compute
            # the mean 2f magnitude in order to cut outlier detectors.  The
            # calibration factors are relative to the mean of the distribution
            # of good detectors values.
            mag_table = self._average_magnitude(local_dets, coeff, coeff_flags)
            good_dets, cal_center = self._cut_outliers(ob, mag_table)
            relcal_table = dict()
            for det in good_dets:
                relcal_table[det] = cal_center / mag_table[det]
            if self.relcal_fixed is not None:
                ob[self.relcal_fixed] = relcal_table

            # If we are generating relative calibration timestreams create that now.
            if self.relcal_continuous is not None:
                ob.detdata.ensure(
                    self.relcal_continuous,
                    dtype=np.float32,
                    create_units=ob.detdata[self.det_data].units,
                )
                ob.detdata[self.relcal_continuous][:, :] = 1.0

            # For each detector, compute the model and subtract from the data.  Also
            # compute the interpolated calibration timestream if requested.  We
            # assume that the model coefficients are slowly varying and just do a
            # linear interpolation.
            if not self.subtract_model and not self.relcal_continuous:
                # No need to compute the full time-domain templates
                continue

            good_check = set(good_dets)
            for idet, det in enumerate(local_dets):
                if det not in good_check:
                    continue
                model, det_mag = self._build_model(
                    ob,
                    reltime,
                    sincos,
                    sh_flags,
                    det_flags,
                    det,
                    mag_table[det],
                    chunks,
                    coeff[idet],
                    coeff_flags,
                )
                # Update flags
                ob.detdata[self.det_flags][det] |= det_flags[det] * self.hwp_flag_mask
                if model is None:
                    # The model construction failed due to flagged samples.  Nothing to
                    # subtract, since the detector has been flagged.
                    continue
                # Subtract model from good samples
                if self.subtract_model:
                    good = det_flags[det] == 0
                    ob.detdata[self.det_data][det][good] -= model[good]
                    dc = np.mean(ob.detdata[self.det_data][det][good])
                    ob.detdata[self.det_data][det][good] -= dc
                if self.fill_gaps:
                    rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz)
                    # 1 second buffer
                    buffer = int(rate)
                    flagged_noise_fill(
                        ob.detdata[self.det_data][det],
                        det_flags[det],
                        buffer,
                        poly_order=1,
                    )
                if self.relcal_continuous is not None:
                    ob.detdata[self.relcal_continuous][det, :] = cal_center / det_mag

    def _plot_model(
        self,
        obs,
        det_name,
        reltime,
        sincos,
        sh_flags,
        det_flags,
        model,
        chunks,
        chunk_coeff,
        first,
        last,
    ):
        if self.debug is None:
            return
        import matplotlib.pyplot as plt

        slc = slice(first, last, 1)
        # If we are plotting per-chunk quantities, find the overlap of every chunk
        # with our plot range
        chunk_slc = None
        if len(chunks) > 1:
            chunk_slc = list()
            for ich, chk in enumerate(chunks):
                ch_start = chk["start"]
                ch_end = chk["end"]
                ch_size = ch_end - ch_start
                prp = dict()
                prp["abs_slc"] = slice(ch_start, ch_end, 1)
                if ch_start < last and ch_end > first:
                    # some overlap
                    if ch_start < first:
                        ch_first = first - ch_start
                        plt_first = first
                    else:
                        ch_first = 0
                        plt_first = ch_start
                    if ch_end > last:
                        ch_last = ch_size - (ch_end - last)
                        plt_last = last
                    else:
                        ch_last = ch_size
                        plt_last = ch_end

                    prp["ch_overlap"] = slice(int(ch_first), int(ch_last), 1)
                    prp["plt_overlap"] = slice(int(plt_first), int(plt_last), 1)
                else:
                    prp["ch_overlap"] = None
                    prp["plt_overlap"] = None
                chunk_slc.append(prp)
        cmap = plt.get_cmap("tab10")
        plt_file = os.path.join(
            self.debug,
            f"{obs.name}_model_{det_name}_{first}-{last}.png",
        )
        fig = plt.figure(figsize=(12, 12), dpi=100)
        ax = fig.add_subplot(2, 1, 1, aspect="auto")
        # Plot original signal
        ax.plot(
            reltime[slc],
            obs.detdata[self.det_data][det_name, slc],
            color="black",
            label=f"Signal {det_name}",
        )
        # Plot per chunk models
        if len(chunks) > 1:
            for ich, chk in enumerate(chunks):
                if chunk_slc[ich]["ch_overlap"] is None:
                    # No overlap
                    continue
                ch_coeff = chunk_coeff[:, ich]
                if np.count_nonzero(ch_coeff) == 0:
                    continue
                ch_model = hwpss_build_model(
                    sincos[chunk_slc[ich]["abs_slc"]],
                    sh_flags[chunk_slc[ich]["abs_slc"]],
                    ch_coeff,
                    times=reltime[chunk_slc[ich]["abs_slc"]],
                    time_drift=self.time_drift,
                )
                ax.plot(
                    reltime[chunk_slc[ich]["plt_overlap"]],
                    ch_model[chunk_slc[ich]["ch_overlap"]],
                    color=cmap(ich),
                    label=f"Model {det_name}",
                )
        # Plot full model
        ax.plot(
            reltime[slc],
            model[slc],
            color="red",
            label=f"Model {det_name}",
        )
        ax.legend(loc="best")

        cmap = plt.get_cmap("tab10")
        ax = fig.add_subplot(2, 1, 2, aspect="auto")
        # Plot flags
        ax.plot(
            reltime[slc],
            det_flags[det_name][slc],
            color="black",
            label=f"Flags {det_name}",
        )
        # Plot chunk boundaries
        if len(chunks) > 1:
            incr = 1 / (len(chunks) + 1)
            for ich, chk in enumerate(chunks):
                if chunk_slc[ich]["ch_overlap"] is None:
                    # No overlap
                    continue
                ax.plot(
                    reltime[chunk_slc[ich]["plt_overlap"]],
                    incr * ich * np.ones_like(reltime[chunk_slc[ich]["plt_overlap"]]),
                    color=cmap(ich),
                    linewidth=3,
                    label=f"Chunk {ich}",
                )
        ax.legend(loc="best")
        fig.suptitle(f"Obs {obs.name} Samples {first} - {last}")
        fig.savefig(plt_file)
        plt.close(fig)

    def _build_model(
        self,
        obs,
        reltime,
        sincos,
        sh_flags,
        det_flags,
        det_name,
        det_mag,
        chunks,
        ch_coeff,
        coeff_flags,
        min_smooth=4,
    ):
        log = Logger.get()
        nsamp = len(reltime)
        if len(chunks) == 1:
            if coeff_flags[0] != 0:
                msg = f"{obs.name}[{det_name}]: only one chunk, which is flagged"
                log.warning(msg)
                # Flag this detector
                current = obs.local_detector_flags[det_name]
                obs.update_local_detector_flags(
                    {det_name: current | self.hwp_flag_mask}
                )
                return (None, None)
            det_coeff = ch_coeff[:, 0]
            model = hwpss_build_model(
                sincos,
                sh_flags,
                det_coeff,
                times=reltime,
                time_drift=self.time_drift,
            )
            self._plot_model(
                obs,
                det_name,
                reltime,
                sincos,
                sh_flags,
                det_flags,
                model,
                chunks,
                ch_coeff,
                0,
                nsamp,
            )
            self._plot_model(
                obs,
                det_name,
                reltime,
                sincos,
                sh_flags,
                det_flags,
                model,
                chunks,
                ch_coeff,
                nsamp // 2 - 500,
                nsamp // 2 + 500,
            )
        else:
            n_coeff = ch_coeff.shape[0]
            n_chunk = ch_coeff.shape[1]
            good_chunk = [
                np.count_nonzero(ch_coeff[:, x]) > 0 and coeff_flags[x] == 0
                for x in range(n_chunk)
            ]
            if np.count_nonzero(good_chunk) == 0:
                msg = f"{obs.name}[{det_name}]: All {len(good_chunk)} chunks"
                msg += f" are flagged."
                log.warning(msg)
                # Flag this detector
                current = obs.local_detector_flags[det_name]
                obs.update_local_detector_flags(
                    {det_name: current | self.hwp_flag_mask}
                )
                return (None, None)
            ch_times = np.array(
                [x["time"] for y, x in enumerate(chunks) if good_chunk[y]]
            )
            smoothing = max(n_chunk // 16, min_smooth)
            if smoothing >= n_chunk:
                msg = f"Only {n_chunk} chunks for interpolation. "
                msg += f"Reduce the split time or use different intervals"
                raise RuntimeError(msg)
            det_coeff = np.zeros((len(reltime), n_coeff), dtype=np.float64)
            for icoeff in range(n_coeff):
                coeff_spl = scipy.interpolate.splrep(
                    ch_times, ch_coeff[icoeff, good_chunk], s=smoothing
                )
                det_coeff[:, icoeff] = scipy.interpolate.splev(
                    reltime, coeff_spl, ext=0
                )
            model = hwpss_build_model(
                sincos,
                sh_flags,
                det_coeff,
                times=reltime,
                time_drift=self.time_drift,
            )
            if self.relcal_continuous is not None:
                if self.time_drift:
                    det_mag = np.sqrt(det_coeff[:, 4] ** 2 + det_coeff[:, 6] ** 2)
                else:
                    det_mag = np.sqrt(det_coeff[:, 2] ** 2 + det_coeff[:, 3] ** 2)
                det_mag[det_flags[det_name] != 0] = 1.0
                if self.debug is not None:
                    import matplotlib.pyplot as plt

                    def plot_2f(first, last):
                        slc = slice(first, last, 1)
                        plt_file = os.path.join(
                            self.debug,
                            f"{obs.name}_model_{det_name}_2f_{first}-{last}.png",
                        )
                        fig = plt.figure(figsize=(12, 12), dpi=100)
                        ax = fig.add_subplot(2, 1, 1, aspect="auto")
                        ax.plot(
                            reltime[slc],
                            det_mag[slc],
                            color="red",
                            label=f"Interpolated 2f Magnitude {det_name}",
                        )
                        if self.time_drift:
                            ch_mag = np.sqrt(
                                ch_coeff[4, good_chunk] ** 2
                                + ch_coeff[6, good_chunk] ** 2
                            )
                        else:
                            ch_mag = np.sqrt(
                                ch_coeff[2, good_chunk] ** 2
                                + ch_coeff[3, good_chunk] ** 2
                            )
                        ax.scatter(
                            ch_times,
                            ch_mag,
                            marker="*",
                            color="blue",
                            label="Estimated Chunk 2f Magnitude",
                        )
                        ax.legend(loc="best")
                        ax.set_xlim(left=reltime[first], right=reltime[last - 1])
                        ax = fig.add_subplot(2, 1, 2, aspect="auto")
                        ax.plot(
                            reltime[slc],
                            det_flags[det_name][slc],
                            color="black",
                            label=f"Flags {det_name}",
                        )
                        fig.suptitle(f"Obs {obs.name} Samples {first} - {last}")
                        fig.savefig(plt_file)
                        plt.close(fig)

                    plot_2f(0, nsamp)
                    plot_2f(nsamp // 2 - 500, nsamp // 2 + 500)
            self._plot_model(
                obs,
                det_name,
                reltime,
                sincos,
                sh_flags,
                det_flags,
                model,
                chunks,
                ch_coeff,
                0,
                nsamp,
            )
            self._plot_model(
                obs,
                det_name,
                reltime,
                sincos,
                sh_flags,
                det_flags,
                model,
                chunks,
                ch_coeff,
                nsamp // 2 - 500,
                nsamp // 2 + 500,
            )
        return model, det_mag

    def _store_model(self, obs, dets, chunks, coeff, coeff_flags):
        log = Logger.get()
        if self.save_model in obs:
            msg = "observation {obs.name} already has something at "
            msg += "key {self.save_model}.  Overwriting."
            log.warning(msg)
        # Repackage the coefficients and chunk information
        ob_start = obs.shared[self.times].data[0]
        model = list()
        for ichk, chk in enumerate(chunks):
            props = {
                "start": chk["start"],
                "end": chk["end"],
                "time": ob_start + chk["time"],
                "flag": coeff_flags[ichk],
            }
            props["dets"] = dict()
            for idet, det in enumerate(dets):
                props["dets"][det] = np.array(coeff[idet, :, ichk])
            model.append(props)
        obs[self.save_model] = model

    def _cut_outliers(self, obs, det_mag):
        log = Logger.get()
        cut_timer = Timer()
        cut_timer.start()

        dets = list(det_mag.keys())
        mag = np.array([det_mag[x] for x in dets])

        # Communicate magnitudes
        all_dets = None
        all_mag = None
        if obs.comm_col is None:
            all_dets = dets
            all_mag = mag
        else:
            all_dets = flatten(obs.comm_col.gather(dets, root=0))
            all_mag = np.array(flatten(obs.comm_col.gather(mag, root=0)))

        # One process does the trivial calculation
        all_flags = None
        central_mag = None
        if obs.comm_col_rank == 0:
            all_good = [True for x in all_dets]
            n_cut = 1
            while n_cut > 0:
                n_cut = 0
                mn = np.mean(all_mag[all_good])
                std = np.std(all_mag[all_good])
                for idet, det in enumerate(all_dets):
                    if not all_good[idet]:
                        continue
                    if np.absolute(all_mag[idet] - mn) > self.relcal_cut_sigma * std:
                        all_good[idet] = False
                        n_cut += 1
            central_mag = np.mean(all_mag[all_good])
            all_flags = {
                x: self.hwp_flag_mask for i, x in enumerate(all_dets) if not all_good[i]
            }
        if obs.comm_col is not None:
            all_flags = obs.comm_col.bcast(all_flags, root=0)
            central_mag = obs.comm_col.bcast(central_mag, root=0)

        # Every process flags its local detectors
        det_check = set(dets)
        local_flags = dict(obs.local_detector_flags)
        for det, val in all_flags.items():
            if det in det_check:
                local_flags[det] |= val
        obs.update_local_detector_flags(local_flags)
        local_good = [x for x in dets if x not in all_flags]

        return local_good, central_mag

    def _average_magnitude(self, dets, coeff, coeff_flags):
        mag = dict()
        if self.time_drift:
            # 4 values per harmonic, 2f is index 1
            re_comp = 4 * 1 + 0
            im_comp = 4 * 1 + 2
        else:
            # 2 values per harmonic, 2f is index 1
            re_comp = 2 * 1 + 0
            im_comp = 2 * 1 + 1
        n_chunk = coeff.shape[2]
        for idet, det in enumerate(dets):
            ch_mag = list()
            for ch in range(n_chunk):
                if coeff_flags[ch] != 0:
                    # All detectors in this chunk were flagged
                    continue
                if coeff[idet, re_comp, ch] == 0 and coeff[idet, im_comp, ch] == 0:
                    # This detector data was flagged
                    continue
                ch_mag.append(
                    np.sqrt(
                        coeff[idet, re_comp, ch] ** 2 + coeff[idet, im_comp, ch] ** 2
                    )
                )
            mag[det] = np.mean(ch_mag)
        return mag

    def _fit_chunk(
        self,
        obs,
        dets,
        indx,
        start,
        end,
        sincos,
        sh_flags,
        det_flags,
        reltime,
        coeff,
        coeff_flags,
    ):
        log = Logger.get()
        ch_timer = Timer()
        ch_timer.start()

        # The sample slice
        slc = slice(start, end, 1)
        slc_samps = end - start

        if reltime is None:
            ch_reltime = None
        else:
            ch_reltime = reltime[slc]

        obs_cov = hwpss_compute_coeff_covariance(
            sincos[slc],
            sh_flags[slc],
            comm=obs.comm.comm_group,
            times=ch_reltime,
            time_drift=self.time_drift,
        )
        if obs_cov is None:
            msg = f"HWPSS Model {obs.name}[{indx}] ({slc_samps} samples)"
            msg += " failed to compute coefficient"
            msg += " covariance.  Flagging this chunk when building model."
            log.verbose_rank(msg, comm=obs.comm.comm_group)
            coeff_flags[indx] = 1
            return

        msg = f"HWPSS Model {obs.name}[{indx}]: built coefficient covariance in"
        log.verbose_rank(msg, comm=obs.comm.comm_group, timer=ch_timer)

        for idet, det in enumerate(dets):
            good_samp = det_flags[det][slc] == 0
            if np.count_nonzero(good_samp) < coeff.shape[1]:
                # Not very many good samples, set coefficients to zero
                msg = f"HWPSS Model {obs.name}[{indx}] {det}: insufficient good "
                msg += "samples, setting coefficients to zero"
                log.verbose(msg)
                coeff[idet, :, indx] = 0
                continue
            sig = np.array(obs.detdata[self.det_data][det, slc])
            dc = np.mean(sig[good_samp])
            sig -= dc

            cf = hwpss_compute_coeff(
                sincos[slc],
                sig,
                det_flags[det][slc],
                obs_cov[0],
                obs_cov[1],
                times=ch_reltime,
                time_drift=self.time_drift,
            )
            if idet == 0:
                cfstr = ""
                for ic in cf:
                    cfstr += f"{ic} "
            coeff[idet, :, indx] = cf

        msg = f"HWPSS Model {obs.name}[{indx}]: compute detector coefficients in"
        log.verbose_rank(msg, comm=obs.comm.comm_group, timer=ch_timer)

    def _compute_chunking(self, obs, reltime):
        chunks = list()
        if self.chunk_view is None:
            if self.chunk_time is None:
                # One chunk for the whole observation
                chunks.append(
                    {
                        "start": 0,
                        "end": obs.n_local_samples,
                        "time": reltime[obs.n_local_samples // 2],
                    }
                )
            else:
                # Overlapping chunks
                duration = reltime[-1] - reltime[0]
                non_overlap = int(duration / self.chunk_time.to_value(u.second))
                # Adjust the chunk size to evenly divide into the obs range
                adjusted_time = duration / non_overlap
                # Convert to samples.  Round up so that the final chunk has
                # a few less samples rather than having a short chunk at the
                # end.
                rate = obs.telescope.focalplane.sample_rate.to_value(u.Hz)
                chunk_samples = int(adjusted_time * rate + 0.5)
                half_chunk = chunk_samples // 2
                ch_start = 0
                for ch in range(non_overlap):
                    ch_mid = ch_start + half_chunk
                    chunks.append(
                        {
                            "start": ch_start,
                            "end": ch_start + chunk_samples,
                            "time": reltime[ch_mid],
                        }
                    )
                    if ch != non_overlap - 1:
                        # Add the overlapping chunk
                        chunks.append(
                            {
                                "start": ch_start + half_chunk,
                                "end": ch_start + half_chunk + chunk_samples,
                                "time": reltime[ch_start + chunk_samples],
                            }
                        )
                    ch_start += chunk_samples
        else:
            # Use the specified interval list for the chunks.  Cut any
            # chunks that are tiny.
            for intr in obs.intervals[self.chunk_view]:
                ch_size = intr.last - intr.first
                ch_mid = intr.first + ch_size // 2
                if ch_size > 10 * self.harmonics:
                    chunks.append(
                        {
                            "start": intr.first,
                            "end": intr.last,
                            "time": reltime[ch_mid],
                        }
                    )
        return chunks

    def _compute_flags(self, obs, dets):
        # The shared flags
        if self.shared_flags is None:
            shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)
        else:
            shared_flags = np.array(obs.shared[self.shared_flags].data)
            shared_flags &= self.shared_flag_mask

        # Compute flags for samples where the hwp is stopped
        stopped = self._stopped_flags(obs)
        shared_flags |= stopped

        # If we are chunking based on intervals, flag the regions between valid
        # intervals.
        if self.chunk_view is not None:
            not_modelled = np.ones_like(shared_flags)
            for intr in obs.intervals[self.chunk_view]:
                not_modelled[intr.first : intr.last] = 0
            shared_flags |= not_modelled

        # Per-detector flags.  We merge in the shared flags to these since the
        # detector flags will be written out at the end if the model is subtracted.
        det_flags = dict()
        for idet, det in enumerate(dets):
            if self.det_flags is None:
                det_flags[det] = shared_flags
            else:
                det_flags[det] = np.copy(obs.detdata[self.det_flags][det])
                det_flags[det] &= self.det_flag_mask
                det_flags[det] |= shared_flags
        return (shared_flags, det_flags)

    def _stopped_flags(self, obs):
        hdata = np.unwrap(obs.shared[self.hwp_angle].data, period=2 * np.pi)
        hvel = np.gradient(hdata)
        moving = np.absolute(hvel) > 1.0e-6
        nominal = np.median(hvel[moving])
        unstable = np.absolute(hvel - nominal) > 1.0e-3 * nominal
        stopped = np.array(unstable, dtype=np.uint8)
        return stopped

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # Note that the hwp_angle is not strictly required- this
        # is just a no-op.
        req = {
            "shared": [self.times],
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        prov = {
            "meta": [],
            "detdata": [self.det_data],
        }
        if self.relcal_continuous is not None:
            prov["detdata"].append(self.relcal_continuous)
        if self.save_model is not None:
            prov["meta"].append(self.save_model)
        if self.relcal_fixed is not None:
            prov["meta"].append(self.relcal_fixed)
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

chunk_time = Quantity(None, allow_none=True, help='The overlapping time chunks over which to compute the HWPSS template') class-attribute instance-attribute

chunk_view = Unicode(None, allow_none=True, help='The intervals over which to independently compute the HWPSS template') class-attribute instance-attribute

debug = Unicode(None, allow_none=True, help='Path to directory for generating debug plots') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

fill_gaps = Bool(False, help='If True, fill gaps with a simple noise model') class-attribute instance-attribute

harmonics = Int(9, help='Number of harmonics to consider in the expansion') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

hwp_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask to use when adding flags based on HWP filter failures.') class-attribute instance-attribute

relcal_continuous = Unicode(None, allow_none=True, help='Build interpolated relative calibration timestreams') class-attribute instance-attribute

relcal_cut_sigma = Float(5.0, help='Sigma cut for outlier rejection based on relative calibration') class-attribute instance-attribute

relcal_fixed = Unicode(None, allow_none=True, help='Build a relative calibration dictionary in this observation key') class-attribute instance-attribute

save_model = Unicode(None, allow_none=True, help='Save the model to this observation key') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

subtract_model = Bool(False, help='Subtract the model from the input data') class-attribute instance-attribute

time_drift = Bool(False, help='If True, include time drift terms in the model') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/hwpss_model.py
165
166
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_average_magnitude(dets, coeff, coeff_flags)

Source code in toast/ops/hwpss_model.py
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
def _average_magnitude(self, dets, coeff, coeff_flags):
    mag = dict()
    if self.time_drift:
        # 4 values per harmonic, 2f is index 1
        re_comp = 4 * 1 + 0
        im_comp = 4 * 1 + 2
    else:
        # 2 values per harmonic, 2f is index 1
        re_comp = 2 * 1 + 0
        im_comp = 2 * 1 + 1
    n_chunk = coeff.shape[2]
    for idet, det in enumerate(dets):
        ch_mag = list()
        for ch in range(n_chunk):
            if coeff_flags[ch] != 0:
                # All detectors in this chunk were flagged
                continue
            if coeff[idet, re_comp, ch] == 0 and coeff[idet, im_comp, ch] == 0:
                # This detector data was flagged
                continue
            ch_mag.append(
                np.sqrt(
                    coeff[idet, re_comp, ch] ** 2 + coeff[idet, im_comp, ch] ** 2
                )
            )
        mag[det] = np.mean(ch_mag)
    return mag

_build_model(obs, reltime, sincos, sh_flags, det_flags, det_name, det_mag, chunks, ch_coeff, coeff_flags, min_smooth=4)

Source code in toast/ops/hwpss_model.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def _build_model(
    self,
    obs,
    reltime,
    sincos,
    sh_flags,
    det_flags,
    det_name,
    det_mag,
    chunks,
    ch_coeff,
    coeff_flags,
    min_smooth=4,
):
    log = Logger.get()
    nsamp = len(reltime)
    if len(chunks) == 1:
        if coeff_flags[0] != 0:
            msg = f"{obs.name}[{det_name}]: only one chunk, which is flagged"
            log.warning(msg)
            # Flag this detector
            current = obs.local_detector_flags[det_name]
            obs.update_local_detector_flags(
                {det_name: current | self.hwp_flag_mask}
            )
            return (None, None)
        det_coeff = ch_coeff[:, 0]
        model = hwpss_build_model(
            sincos,
            sh_flags,
            det_coeff,
            times=reltime,
            time_drift=self.time_drift,
        )
        self._plot_model(
            obs,
            det_name,
            reltime,
            sincos,
            sh_flags,
            det_flags,
            model,
            chunks,
            ch_coeff,
            0,
            nsamp,
        )
        self._plot_model(
            obs,
            det_name,
            reltime,
            sincos,
            sh_flags,
            det_flags,
            model,
            chunks,
            ch_coeff,
            nsamp // 2 - 500,
            nsamp // 2 + 500,
        )
    else:
        n_coeff = ch_coeff.shape[0]
        n_chunk = ch_coeff.shape[1]
        good_chunk = [
            np.count_nonzero(ch_coeff[:, x]) > 0 and coeff_flags[x] == 0
            for x in range(n_chunk)
        ]
        if np.count_nonzero(good_chunk) == 0:
            msg = f"{obs.name}[{det_name}]: All {len(good_chunk)} chunks"
            msg += f" are flagged."
            log.warning(msg)
            # Flag this detector
            current = obs.local_detector_flags[det_name]
            obs.update_local_detector_flags(
                {det_name: current | self.hwp_flag_mask}
            )
            return (None, None)
        ch_times = np.array(
            [x["time"] for y, x in enumerate(chunks) if good_chunk[y]]
        )
        smoothing = max(n_chunk // 16, min_smooth)
        if smoothing >= n_chunk:
            msg = f"Only {n_chunk} chunks for interpolation. "
            msg += f"Reduce the split time or use different intervals"
            raise RuntimeError(msg)
        det_coeff = np.zeros((len(reltime), n_coeff), dtype=np.float64)
        for icoeff in range(n_coeff):
            coeff_spl = scipy.interpolate.splrep(
                ch_times, ch_coeff[icoeff, good_chunk], s=smoothing
            )
            det_coeff[:, icoeff] = scipy.interpolate.splev(
                reltime, coeff_spl, ext=0
            )
        model = hwpss_build_model(
            sincos,
            sh_flags,
            det_coeff,
            times=reltime,
            time_drift=self.time_drift,
        )
        if self.relcal_continuous is not None:
            if self.time_drift:
                det_mag = np.sqrt(det_coeff[:, 4] ** 2 + det_coeff[:, 6] ** 2)
            else:
                det_mag = np.sqrt(det_coeff[:, 2] ** 2 + det_coeff[:, 3] ** 2)
            det_mag[det_flags[det_name] != 0] = 1.0
            if self.debug is not None:
                import matplotlib.pyplot as plt

                def plot_2f(first, last):
                    slc = slice(first, last, 1)
                    plt_file = os.path.join(
                        self.debug,
                        f"{obs.name}_model_{det_name}_2f_{first}-{last}.png",
                    )
                    fig = plt.figure(figsize=(12, 12), dpi=100)
                    ax = fig.add_subplot(2, 1, 1, aspect="auto")
                    ax.plot(
                        reltime[slc],
                        det_mag[slc],
                        color="red",
                        label=f"Interpolated 2f Magnitude {det_name}",
                    )
                    if self.time_drift:
                        ch_mag = np.sqrt(
                            ch_coeff[4, good_chunk] ** 2
                            + ch_coeff[6, good_chunk] ** 2
                        )
                    else:
                        ch_mag = np.sqrt(
                            ch_coeff[2, good_chunk] ** 2
                            + ch_coeff[3, good_chunk] ** 2
                        )
                    ax.scatter(
                        ch_times,
                        ch_mag,
                        marker="*",
                        color="blue",
                        label="Estimated Chunk 2f Magnitude",
                    )
                    ax.legend(loc="best")
                    ax.set_xlim(left=reltime[first], right=reltime[last - 1])
                    ax = fig.add_subplot(2, 1, 2, aspect="auto")
                    ax.plot(
                        reltime[slc],
                        det_flags[det_name][slc],
                        color="black",
                        label=f"Flags {det_name}",
                    )
                    fig.suptitle(f"Obs {obs.name} Samples {first} - {last}")
                    fig.savefig(plt_file)
                    plt.close(fig)

                plot_2f(0, nsamp)
                plot_2f(nsamp // 2 - 500, nsamp // 2 + 500)
        self._plot_model(
            obs,
            det_name,
            reltime,
            sincos,
            sh_flags,
            det_flags,
            model,
            chunks,
            ch_coeff,
            0,
            nsamp,
        )
        self._plot_model(
            obs,
            det_name,
            reltime,
            sincos,
            sh_flags,
            det_flags,
            model,
            chunks,
            ch_coeff,
            nsamp // 2 - 500,
            nsamp // 2 + 500,
        )
    return model, det_mag

_check_det_flag_mask(proposal)

Source code in toast/ops/hwpss_model.py
144
145
146
147
148
149
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/hwpss_model.py
137
138
139
140
141
142
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_harmonics(proposal)

Source code in toast/ops/hwpss_model.py
158
159
160
161
162
163
@traitlets.validate("harmonics")
def _check_harmonics(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Harmonics should be a non-negative integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/hwpss_model.py
151
152
153
154
155
156
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_compute_chunking(obs, reltime)

Source code in toast/ops/hwpss_model.py
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
def _compute_chunking(self, obs, reltime):
    chunks = list()
    if self.chunk_view is None:
        if self.chunk_time is None:
            # One chunk for the whole observation
            chunks.append(
                {
                    "start": 0,
                    "end": obs.n_local_samples,
                    "time": reltime[obs.n_local_samples // 2],
                }
            )
        else:
            # Overlapping chunks
            duration = reltime[-1] - reltime[0]
            non_overlap = int(duration / self.chunk_time.to_value(u.second))
            # Adjust the chunk size to evenly divide into the obs range
            adjusted_time = duration / non_overlap
            # Convert to samples.  Round up so that the final chunk has
            # a few less samples rather than having a short chunk at the
            # end.
            rate = obs.telescope.focalplane.sample_rate.to_value(u.Hz)
            chunk_samples = int(adjusted_time * rate + 0.5)
            half_chunk = chunk_samples // 2
            ch_start = 0
            for ch in range(non_overlap):
                ch_mid = ch_start + half_chunk
                chunks.append(
                    {
                        "start": ch_start,
                        "end": ch_start + chunk_samples,
                        "time": reltime[ch_mid],
                    }
                )
                if ch != non_overlap - 1:
                    # Add the overlapping chunk
                    chunks.append(
                        {
                            "start": ch_start + half_chunk,
                            "end": ch_start + half_chunk + chunk_samples,
                            "time": reltime[ch_start + chunk_samples],
                        }
                    )
                ch_start += chunk_samples
    else:
        # Use the specified interval list for the chunks.  Cut any
        # chunks that are tiny.
        for intr in obs.intervals[self.chunk_view]:
            ch_size = intr.last - intr.first
            ch_mid = intr.first + ch_size // 2
            if ch_size > 10 * self.harmonics:
                chunks.append(
                    {
                        "start": intr.first,
                        "end": intr.last,
                        "time": reltime[ch_mid],
                    }
                )
    return chunks

_compute_flags(obs, dets)

Source code in toast/ops/hwpss_model.py
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
def _compute_flags(self, obs, dets):
    # The shared flags
    if self.shared_flags is None:
        shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)
    else:
        shared_flags = np.array(obs.shared[self.shared_flags].data)
        shared_flags &= self.shared_flag_mask

    # Compute flags for samples where the hwp is stopped
    stopped = self._stopped_flags(obs)
    shared_flags |= stopped

    # If we are chunking based on intervals, flag the regions between valid
    # intervals.
    if self.chunk_view is not None:
        not_modelled = np.ones_like(shared_flags)
        for intr in obs.intervals[self.chunk_view]:
            not_modelled[intr.first : intr.last] = 0
        shared_flags |= not_modelled

    # Per-detector flags.  We merge in the shared flags to these since the
    # detector flags will be written out at the end if the model is subtracted.
    det_flags = dict()
    for idet, det in enumerate(dets):
        if self.det_flags is None:
            det_flags[det] = shared_flags
        else:
            det_flags[det] = np.copy(obs.detdata[self.det_flags][det])
            det_flags[det] &= self.det_flag_mask
            det_flags[det] |= shared_flags
    return (shared_flags, det_flags)

_cut_outliers(obs, det_mag)

Source code in toast/ops/hwpss_model.py
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def _cut_outliers(self, obs, det_mag):
    log = Logger.get()
    cut_timer = Timer()
    cut_timer.start()

    dets = list(det_mag.keys())
    mag = np.array([det_mag[x] for x in dets])

    # Communicate magnitudes
    all_dets = None
    all_mag = None
    if obs.comm_col is None:
        all_dets = dets
        all_mag = mag
    else:
        all_dets = flatten(obs.comm_col.gather(dets, root=0))
        all_mag = np.array(flatten(obs.comm_col.gather(mag, root=0)))

    # One process does the trivial calculation
    all_flags = None
    central_mag = None
    if obs.comm_col_rank == 0:
        all_good = [True for x in all_dets]
        n_cut = 1
        while n_cut > 0:
            n_cut = 0
            mn = np.mean(all_mag[all_good])
            std = np.std(all_mag[all_good])
            for idet, det in enumerate(all_dets):
                if not all_good[idet]:
                    continue
                if np.absolute(all_mag[idet] - mn) > self.relcal_cut_sigma * std:
                    all_good[idet] = False
                    n_cut += 1
        central_mag = np.mean(all_mag[all_good])
        all_flags = {
            x: self.hwp_flag_mask for i, x in enumerate(all_dets) if not all_good[i]
        }
    if obs.comm_col is not None:
        all_flags = obs.comm_col.bcast(all_flags, root=0)
        central_mag = obs.comm_col.bcast(central_mag, root=0)

    # Every process flags its local detectors
    det_check = set(dets)
    local_flags = dict(obs.local_detector_flags)
    for det, val in all_flags.items():
        if det in det_check:
            local_flags[det] |= val
    obs.update_local_detector_flags(local_flags)
    local_good = [x for x in dets if x not in all_flags]

    return local_good, central_mag

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/hwpss_model.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    if self.relcal_continuous is not None and self.relcal_fixed is not None:
        msg = "Only one of continuous and fixed relative calibration can be used"
        raise RuntimeError(msg)

    if self.chunk_view is not None and self.chunk_time is not None:
        msg = "Only one of chunk_view and chunk_time can be used"
        raise RuntimeError(msg)

    do_cal = self.relcal_continuous or self.relcal_fixed
    if not self.subtract_model and (self.save_model is None) and not do_cal:
        msg = "Nothing to do.  You should enable at least one of the options"
        msg += " to subtract or save the model or to generate calibrations."
        raise RuntimeError(msg)

    if self.debug is not None:
        if data.comm.world_rank == 0:
            os.makedirs(self.debug)
        if data.comm.comm_world is not None:
            data.comm.comm_world.barrier()

    for ob in data.obs:
        timer = Timer()
        timer.start()

        if not ob.is_distributed_by_detector:
            msg = f"{ob.name} is not distributed by detector"
            raise RuntimeError(msg)

        if self.hwp_angle not in ob.shared:
            # Nothing to do, but if a relative calibration
            # was requested, make a fake one.
            if self.relcal_fixed is not None:
                ob[self.relcal_fixed] = {x: 1.0 for x in ob.local_detectors}
            if self.relcal_continuous is not None:
                ob.detdata.ensure(
                    self.relcal_continuous,
                    dtype=np.float32,
                    create_units=ob.detdata[self.det_data].units,
                )
                ob.detdata[self.relcal_continuous][:, :] = 1.0
            msg = f"{ob.name} has no '{self.hwp_angle}' field, skipping"
            log.warning_rank(msg, comm=data.comm.comm_group)
            continue

        # Compute quantities we need for all detectors and which we
        # might re-use for overlapping chunks.

        # Local detectors we are considering
        local_dets = ob.select_local_detectors(flagmask=self.det_mask)
        n_dets = len(local_dets)

        # Get the timestamps relative to the observation start
        reltime = np.array(ob.shared[self.times].data, copy=True)
        time_offset = reltime[0]
        reltime -= time_offset

        # Compute the properties of the chunks we are using
        chunks = self._compute_chunking(ob, reltime)
        n_chunk = len(chunks)

        # Compute shared and per-detector flags.  These already have
        # masks applied and have values of either zero or one.
        sh_flags, det_flags = self._compute_flags(ob, local_dets)

        msg = f"HWPSS Model {ob.name}: compute flags and chunking in"
        log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

        # Trig quantities of the HWP angle
        sincos = hwpss_sincos_buffer(
            ob.shared[self.hwp_angle].data,
            sh_flags,
            self.harmonics,
            comm=ob.comm.comm_group,
        )
        msg = f"HWPSS Model {ob.name}: built sincos buffer in"
        log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

        # The coefficients for all detectors and chunks
        if self.time_drift:
            n_coeff = 4 * self.harmonics
        else:
            n_coeff = 2 * self.harmonics
        coeff = np.zeros((n_dets, n_coeff, n_chunk), dtype=np.float64)
        coeff_flags = np.zeros(n_chunk, dtype=np.uint8)

        for ichunk, chunk in enumerate(chunks):
            self._fit_chunk(
                ob,
                local_dets,
                ichunk,
                chunk["start"],
                chunk["end"],
                sincos,
                sh_flags,
                det_flags,
                reltime,
                coeff,
                coeff_flags,
            )

        msg = f"HWPSS Model {ob.name}: fit model to all chunks in"
        log.debug_rank(msg, comm=data.comm.comm_group, timer=timer)

        if self.save_model is not None:
            self._store_model(ob, local_dets, chunks, coeff, coeff_flags)

        # Even if we are not saving a fixed relative calibration table, compute
        # the mean 2f magnitude in order to cut outlier detectors.  The
        # calibration factors are relative to the mean of the distribution
        # of good detectors values.
        mag_table = self._average_magnitude(local_dets, coeff, coeff_flags)
        good_dets, cal_center = self._cut_outliers(ob, mag_table)
        relcal_table = dict()
        for det in good_dets:
            relcal_table[det] = cal_center / mag_table[det]
        if self.relcal_fixed is not None:
            ob[self.relcal_fixed] = relcal_table

        # If we are generating relative calibration timestreams create that now.
        if self.relcal_continuous is not None:
            ob.detdata.ensure(
                self.relcal_continuous,
                dtype=np.float32,
                create_units=ob.detdata[self.det_data].units,
            )
            ob.detdata[self.relcal_continuous][:, :] = 1.0

        # For each detector, compute the model and subtract from the data.  Also
        # compute the interpolated calibration timestream if requested.  We
        # assume that the model coefficients are slowly varying and just do a
        # linear interpolation.
        if not self.subtract_model and not self.relcal_continuous:
            # No need to compute the full time-domain templates
            continue

        good_check = set(good_dets)
        for idet, det in enumerate(local_dets):
            if det not in good_check:
                continue
            model, det_mag = self._build_model(
                ob,
                reltime,
                sincos,
                sh_flags,
                det_flags,
                det,
                mag_table[det],
                chunks,
                coeff[idet],
                coeff_flags,
            )
            # Update flags
            ob.detdata[self.det_flags][det] |= det_flags[det] * self.hwp_flag_mask
            if model is None:
                # The model construction failed due to flagged samples.  Nothing to
                # subtract, since the detector has been flagged.
                continue
            # Subtract model from good samples
            if self.subtract_model:
                good = det_flags[det] == 0
                ob.detdata[self.det_data][det][good] -= model[good]
                dc = np.mean(ob.detdata[self.det_data][det][good])
                ob.detdata[self.det_data][det][good] -= dc
            if self.fill_gaps:
                rate = ob.telescope.focalplane.sample_rate.to_value(u.Hz)
                # 1 second buffer
                buffer = int(rate)
                flagged_noise_fill(
                    ob.detdata[self.det_data][det],
                    det_flags[det],
                    buffer,
                    poly_order=1,
                )
            if self.relcal_continuous is not None:
                ob.detdata[self.relcal_continuous][det, :] = cal_center / det_mag

_finalize(data, **kwargs)

Source code in toast/ops/hwpss_model.py
936
937
def _finalize(self, data, **kwargs):
    return

_fit_chunk(obs, dets, indx, start, end, sincos, sh_flags, det_flags, reltime, coeff, coeff_flags)

Source code in toast/ops/hwpss_model.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
def _fit_chunk(
    self,
    obs,
    dets,
    indx,
    start,
    end,
    sincos,
    sh_flags,
    det_flags,
    reltime,
    coeff,
    coeff_flags,
):
    log = Logger.get()
    ch_timer = Timer()
    ch_timer.start()

    # The sample slice
    slc = slice(start, end, 1)
    slc_samps = end - start

    if reltime is None:
        ch_reltime = None
    else:
        ch_reltime = reltime[slc]

    obs_cov = hwpss_compute_coeff_covariance(
        sincos[slc],
        sh_flags[slc],
        comm=obs.comm.comm_group,
        times=ch_reltime,
        time_drift=self.time_drift,
    )
    if obs_cov is None:
        msg = f"HWPSS Model {obs.name}[{indx}] ({slc_samps} samples)"
        msg += " failed to compute coefficient"
        msg += " covariance.  Flagging this chunk when building model."
        log.verbose_rank(msg, comm=obs.comm.comm_group)
        coeff_flags[indx] = 1
        return

    msg = f"HWPSS Model {obs.name}[{indx}]: built coefficient covariance in"
    log.verbose_rank(msg, comm=obs.comm.comm_group, timer=ch_timer)

    for idet, det in enumerate(dets):
        good_samp = det_flags[det][slc] == 0
        if np.count_nonzero(good_samp) < coeff.shape[1]:
            # Not very many good samples, set coefficients to zero
            msg = f"HWPSS Model {obs.name}[{indx}] {det}: insufficient good "
            msg += "samples, setting coefficients to zero"
            log.verbose(msg)
            coeff[idet, :, indx] = 0
            continue
        sig = np.array(obs.detdata[self.det_data][det, slc])
        dc = np.mean(sig[good_samp])
        sig -= dc

        cf = hwpss_compute_coeff(
            sincos[slc],
            sig,
            det_flags[det][slc],
            obs_cov[0],
            obs_cov[1],
            times=ch_reltime,
            time_drift=self.time_drift,
        )
        if idet == 0:
            cfstr = ""
            for ic in cf:
                cfstr += f"{ic} "
        coeff[idet, :, indx] = cf

    msg = f"HWPSS Model {obs.name}[{indx}]: compute detector coefficients in"
    log.verbose_rank(msg, comm=obs.comm.comm_group, timer=ch_timer)

_plot_model(obs, det_name, reltime, sincos, sh_flags, det_flags, model, chunks, chunk_coeff, first, last)

Source code in toast/ops/hwpss_model.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def _plot_model(
    self,
    obs,
    det_name,
    reltime,
    sincos,
    sh_flags,
    det_flags,
    model,
    chunks,
    chunk_coeff,
    first,
    last,
):
    if self.debug is None:
        return
    import matplotlib.pyplot as plt

    slc = slice(first, last, 1)
    # If we are plotting per-chunk quantities, find the overlap of every chunk
    # with our plot range
    chunk_slc = None
    if len(chunks) > 1:
        chunk_slc = list()
        for ich, chk in enumerate(chunks):
            ch_start = chk["start"]
            ch_end = chk["end"]
            ch_size = ch_end - ch_start
            prp = dict()
            prp["abs_slc"] = slice(ch_start, ch_end, 1)
            if ch_start < last and ch_end > first:
                # some overlap
                if ch_start < first:
                    ch_first = first - ch_start
                    plt_first = first
                else:
                    ch_first = 0
                    plt_first = ch_start
                if ch_end > last:
                    ch_last = ch_size - (ch_end - last)
                    plt_last = last
                else:
                    ch_last = ch_size
                    plt_last = ch_end

                prp["ch_overlap"] = slice(int(ch_first), int(ch_last), 1)
                prp["plt_overlap"] = slice(int(plt_first), int(plt_last), 1)
            else:
                prp["ch_overlap"] = None
                prp["plt_overlap"] = None
            chunk_slc.append(prp)
    cmap = plt.get_cmap("tab10")
    plt_file = os.path.join(
        self.debug,
        f"{obs.name}_model_{det_name}_{first}-{last}.png",
    )
    fig = plt.figure(figsize=(12, 12), dpi=100)
    ax = fig.add_subplot(2, 1, 1, aspect="auto")
    # Plot original signal
    ax.plot(
        reltime[slc],
        obs.detdata[self.det_data][det_name, slc],
        color="black",
        label=f"Signal {det_name}",
    )
    # Plot per chunk models
    if len(chunks) > 1:
        for ich, chk in enumerate(chunks):
            if chunk_slc[ich]["ch_overlap"] is None:
                # No overlap
                continue
            ch_coeff = chunk_coeff[:, ich]
            if np.count_nonzero(ch_coeff) == 0:
                continue
            ch_model = hwpss_build_model(
                sincos[chunk_slc[ich]["abs_slc"]],
                sh_flags[chunk_slc[ich]["abs_slc"]],
                ch_coeff,
                times=reltime[chunk_slc[ich]["abs_slc"]],
                time_drift=self.time_drift,
            )
            ax.plot(
                reltime[chunk_slc[ich]["plt_overlap"]],
                ch_model[chunk_slc[ich]["ch_overlap"]],
                color=cmap(ich),
                label=f"Model {det_name}",
            )
    # Plot full model
    ax.plot(
        reltime[slc],
        model[slc],
        color="red",
        label=f"Model {det_name}",
    )
    ax.legend(loc="best")

    cmap = plt.get_cmap("tab10")
    ax = fig.add_subplot(2, 1, 2, aspect="auto")
    # Plot flags
    ax.plot(
        reltime[slc],
        det_flags[det_name][slc],
        color="black",
        label=f"Flags {det_name}",
    )
    # Plot chunk boundaries
    if len(chunks) > 1:
        incr = 1 / (len(chunks) + 1)
        for ich, chk in enumerate(chunks):
            if chunk_slc[ich]["ch_overlap"] is None:
                # No overlap
                continue
            ax.plot(
                reltime[chunk_slc[ich]["plt_overlap"]],
                incr * ich * np.ones_like(reltime[chunk_slc[ich]["plt_overlap"]]),
                color=cmap(ich),
                linewidth=3,
                label=f"Chunk {ich}",
            )
    ax.legend(loc="best")
    fig.suptitle(f"Obs {obs.name} Samples {first} - {last}")
    fig.savefig(plt_file)
    plt.close(fig)

_provides()

Source code in toast/ops/hwpss_model.py
952
953
954
955
956
957
958
959
960
961
962
963
def _provides(self):
    prov = {
        "meta": [],
        "detdata": [self.det_data],
    }
    if self.relcal_continuous is not None:
        prov["detdata"].append(self.relcal_continuous)
    if self.save_model is not None:
        prov["meta"].append(self.save_model)
    if self.relcal_fixed is not None:
        prov["meta"].append(self.relcal_fixed)
    return prov

_requires()

Source code in toast/ops/hwpss_model.py
939
940
941
942
943
944
945
946
947
948
949
950
def _requires(self):
    # Note that the hwp_angle is not strictly required- this
    # is just a no-op.
    req = {
        "shared": [self.times],
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

_stopped_flags(obs)

Source code in toast/ops/hwpss_model.py
927
928
929
930
931
932
933
934
def _stopped_flags(self, obs):
    hdata = np.unwrap(obs.shared[self.hwp_angle].data, period=2 * np.pi)
    hvel = np.gradient(hdata)
    moving = np.absolute(hvel) > 1.0e-6
    nominal = np.median(hvel[moving])
    unstable = np.absolute(hvel - nominal) > 1.0e-3 * nominal
    stopped = np.array(unstable, dtype=np.uint8)
    return stopped

_store_model(obs, dets, chunks, coeff, coeff_flags)

Source code in toast/ops/hwpss_model.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def _store_model(self, obs, dets, chunks, coeff, coeff_flags):
    log = Logger.get()
    if self.save_model in obs:
        msg = "observation {obs.name} already has something at "
        msg += "key {self.save_model}.  Overwriting."
        log.warning(msg)
    # Repackage the coefficients and chunk information
    ob_start = obs.shared[self.times].data[0]
    model = list()
    for ichk, chk in enumerate(chunks):
        props = {
            "start": chk["start"],
            "end": chk["end"],
            "time": ob_start + chk["time"],
            "flag": coeff_flags[ichk],
        }
        props["dets"] = dict()
        for idet, det in enumerate(dets):
            props["dets"][det] = np.array(coeff[idet, :, ichk])
        model.append(props)
    obs[self.save_model] = model

toast.ops.HWPFilter

Bases: Operator

Operator that applies HWP-synchronous signal filtering.

Source code in toast/ops/hwpfilter.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
@trait_docs
class HWPFilter(Operator):
    """Operator that applies HWP-synchronous signal filtering."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for optional shared flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    hwp_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask to use when adding flags based on HWP filter failures.",
    )

    hwp_angle = Unicode(
        defaults.hwp_angle, allow_none=True, help="Observation shared key for HWP angle"
    )

    trend_order = Int(
        5, help="Order of a Legendre polynomial to fit along with the HWPSS template."
    )

    filter_order = Int(
        5, help="Order of a Fourier expansion to fit as a function of HWP angle."
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    detrend = Bool(
        False, help="Subtract the fitted trend along with the ground template"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("trend_order")
    def _check_trend_order(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Trend order should be a non-negative integer")
        return check

    @traitlets.validate("filter_order")
    def _check_filter_order(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Filter order should be a non-negative integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def build_templates(self, obs):
        """Construct the local HWPSS template hierarchy"""

        # Construct trend templates.  Full domain for x is [-1, 1]

        my_offset = obs.local_index_offset
        my_nsamp = obs.n_local_samples
        nsamp_tot = obs.n_all_samples
        x = np.arange(my_offset, my_offset + my_nsamp) / nsamp_tot * 2 - 1

        legendre_trend = np.zeros([self.trend_order + 1, my_nsamp])
        legendre(x, legendre_trend, 0, self.trend_order + 1)

        # Fourier templates

        hwp_angle = obs.shared[self.hwp_angle].data
        nfilter = 2 * self.filter_order
        fourier_templates = np.zeros([nfilter, my_nsamp])
        fourier(hwp_angle, fourier_templates, 1, self.filter_order + 1)

        templates = np.vstack([legendre_trend, fourier_templates])

        return templates, legendre_trend, fourier_templates

    @function_timer
    def fit_templates(
        self,
        obs,
        templates,
        ref,
        good,
        last_good,
        last_invcov,
        last_cov,
        last_rcond,
    ):
        log = Logger.get()
        # communicator for processes with the same detectors
        comm = obs.comm_row
        ngood = np.sum(good)
        ntask = 1
        if comm is not None:
            ngood = comm.allreduce(ngood)
            ntask = comm.size
        if ngood == 0:
            return None, None, None, None

        ntemplate = len(templates)
        invcov = np.zeros([ntemplate, ntemplate])
        proj = np.zeros(ntemplate)

        bin_proj_fast(ref, templates, good.astype(np.uint8), proj)
        if last_good is not None and np.all(good == last_good) and ntask == 1:
            # Flags have not changed, we can re-use the last inverse covariance
            invcov = last_invcov
            cov = last_cov
            rcond = last_rcond
        else:
            bin_invcov_fast(templates, good.astype(np.uint8), invcov)
            if comm is not None:
                # Reduce the binned data.  The detector signal is
                # distributed across the group communicator.
                comm.Allreduce(MPI.IN_PLACE, invcov, op=MPI.SUM)
                comm.Allreduce(MPI.IN_PLACE, proj, op=MPI.SUM)
            rcond = get_rcond(invcov)
            cov = None

        self.rcondsum += rcond
        if rcond > 1e-6:
            self.ngood += 1
            if cov is None:
                cov = get_inverse(invcov)
        else:
            self.nsingular += 1
            log.debug(
                f"HWP template matrix is poorly conditioned, "
                f"rcond = {rcond}, using pseudoinverse."
            )
            if cov is None:
                cov = get_pseudoinverse(invcov)
        coeff = np.dot(cov, proj)
        return coeff, invcov, cov, rcond

    @function_timer
    def subtract_templates(self, ref, good, coeff, legendre_trend, fourier_filter):
        # Trend
        if self.detrend:
            trend = np.zeros_like(ref)
            add_templates(trend, legendre_trend, coeff[: self.trend_order + 1])
            ref[:] -= trend
        # HWP template
        hwptemplate = np.zeros_like(ref)
        add_templates(hwptemplate, fourier_filter, coeff[self.trend_order + 1 :])
        ref[:] -= hwptemplate
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        t0 = time()
        env = Environment.get()
        log = Logger.get()

        wcomm = data.comm.comm_world
        gcomm = data.comm.comm_group

        self.nsingular = 0
        self.ngood = 0
        self.rcondsum = 0

        # Each group loops over its own CES:es
        nobs = len(data.obs)
        for iobs, obs in enumerate(data.obs):
            # Prefix for logging
            log_prefix = f"{data.comm.group} : {obs.name} :"

            if self.hwp_angle in obs.shared:
                if data.comm.group_rank == 0:
                    msg = f"{log_prefix} HWPSS Filter: "
                    msg += f"Processing observation {iobs + 1} / {nobs}"
                    msg += f" ({obs.name})"
                    log.debug(msg)
            else:
                # This observation has no HWP
                if data.comm.group_rank == 0:
                    msg = (
                        f"{log_prefix} HWPSS Filter:  skipping observation {obs.name},"
                    )
                    msg += f" which has no HWP"
                    log.debug(msg)
                continue

            # Cache the output common flags
            if self.shared_flags is not None:
                common_flags = (
                    obs.shared[self.shared_flags].data & self.shared_flag_mask
                )
            else:
                common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

            t1 = time()
            templates, legendre_trend, fourier_filter = self.build_templates(obs)
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} HWPSS Filter: "
                    f"Built templates in {time() - t1:.1f}s"
                )
                log.debug(msg)

            last_good = None
            last_invcov = None
            last_cov = None
            last_rcond = None
            for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
                if data.comm.group_rank == 0:
                    msg = f"{log_prefix} HWPSS Filter: " f"Processing detector {det}"
                    log.verbose(msg)

                ref = obs.detdata[self.det_data][det]
                if self.det_flags is not None:
                    test_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                    good = np.logical_and(common_flags == 0, test_flags == 0)
                else:
                    good = common_flags == 0

                t1 = time()
                coeff, last_invcov, last_cov, last_rcond = self.fit_templates(
                    obs,
                    templates,
                    ref,
                    good,
                    last_good,
                    last_invcov,
                    last_cov,
                    last_rcond,
                )
                last_good = good
                if data.comm.group_rank == 0:
                    msg = (
                        f"{log_prefix} HWPSS Filter: "
                        f"Fit templates in {time() - t1:.1f}s"
                    )
                    log.verbose(msg)

                if coeff is None:
                    # All samples flagged or template fit failed.
                    curflag = obs.local_detector_flags[det]
                    obs.update_local_detector_flags({det: curflag | self.hwp_flag_mask})
                    continue

                t1 = time()
                self.subtract_templates(
                    ref, good, coeff, legendre_trend, fourier_filter
                )
                if data.comm.group_rank == 0:
                    msg = (
                        f"{log_prefix} HWPSS Filter: "
                        f"Subtract templates in {time() - t1:.1f}s"
                    )
                    log.verbose(msg)
            del last_good
            del last_invcov
            del last_cov
            del last_rcond

        if wcomm is not None:
            self.nsingular = wcomm.allreduce(self.nsingular)
            self.ngood = wcomm.allreduce(self.ngood)
            self.rcondsum = wcomm.allreduce(self.rcondsum)

        if wcomm is None or wcomm.rank == 0:
            denominator = self.nsingular + self.ngood
            if denominator == 0:
                msg = f"HWPSS filter had no observations with a HWP"
                log.debug(msg)
            else:
                rcond_mean = self.rcondsum / denominator
                msg = (
                    f"Applied HWPSS filter in {time() - t0:.1f} s.  "
                    f"Average rcond of template matrix was {rcond_mean}"
                )
                log.debug(msg)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "shared": list(),
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.azimuth is not None:
            req["shared"].append(self.azimuth)
        if self.boresight_azel is not None:
            req["shared"].append(self.boresight_azel)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

detrend = Bool(False, help='Subtract the fitted trend along with the ground template') class-attribute instance-attribute

filter_order = Int(5, help='Order of a Fourier expansion to fit as a function of HWP angle.') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

hwp_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask to use when adding flags based on HWP filter failures.') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

trend_order = Int(5, help='Order of a Legendre polynomial to fit along with the HWPSS template.') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/hwpfilter.py
154
155
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/hwpfilter.py
126
127
128
129
130
131
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/hwpfilter.py
119
120
121
122
123
124
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_filter_order(proposal)

Source code in toast/ops/hwpfilter.py
147
148
149
150
151
152
@traitlets.validate("filter_order")
def _check_filter_order(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Filter order should be a non-negative integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/hwpfilter.py
133
134
135
136
137
138
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_trend_order(proposal)

Source code in toast/ops/hwpfilter.py
140
141
142
143
144
145
@traitlets.validate("trend_order")
def _check_trend_order(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Trend order should be a non-negative integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/hwpfilter.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    t0 = time()
    env = Environment.get()
    log = Logger.get()

    wcomm = data.comm.comm_world
    gcomm = data.comm.comm_group

    self.nsingular = 0
    self.ngood = 0
    self.rcondsum = 0

    # Each group loops over its own CES:es
    nobs = len(data.obs)
    for iobs, obs in enumerate(data.obs):
        # Prefix for logging
        log_prefix = f"{data.comm.group} : {obs.name} :"

        if self.hwp_angle in obs.shared:
            if data.comm.group_rank == 0:
                msg = f"{log_prefix} HWPSS Filter: "
                msg += f"Processing observation {iobs + 1} / {nobs}"
                msg += f" ({obs.name})"
                log.debug(msg)
        else:
            # This observation has no HWP
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} HWPSS Filter:  skipping observation {obs.name},"
                )
                msg += f" which has no HWP"
                log.debug(msg)
            continue

        # Cache the output common flags
        if self.shared_flags is not None:
            common_flags = (
                obs.shared[self.shared_flags].data & self.shared_flag_mask
            )
        else:
            common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

        t1 = time()
        templates, legendre_trend, fourier_filter = self.build_templates(obs)
        if data.comm.group_rank == 0:
            msg = (
                f"{log_prefix} HWPSS Filter: "
                f"Built templates in {time() - t1:.1f}s"
            )
            log.debug(msg)

        last_good = None
        last_invcov = None
        last_cov = None
        last_rcond = None
        for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
            if data.comm.group_rank == 0:
                msg = f"{log_prefix} HWPSS Filter: " f"Processing detector {det}"
                log.verbose(msg)

            ref = obs.detdata[self.det_data][det]
            if self.det_flags is not None:
                test_flags = obs.detdata[self.det_flags][det] & self.det_flag_mask
                good = np.logical_and(common_flags == 0, test_flags == 0)
            else:
                good = common_flags == 0

            t1 = time()
            coeff, last_invcov, last_cov, last_rcond = self.fit_templates(
                obs,
                templates,
                ref,
                good,
                last_good,
                last_invcov,
                last_cov,
                last_rcond,
            )
            last_good = good
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} HWPSS Filter: "
                    f"Fit templates in {time() - t1:.1f}s"
                )
                log.verbose(msg)

            if coeff is None:
                # All samples flagged or template fit failed.
                curflag = obs.local_detector_flags[det]
                obs.update_local_detector_flags({det: curflag | self.hwp_flag_mask})
                continue

            t1 = time()
            self.subtract_templates(
                ref, good, coeff, legendre_trend, fourier_filter
            )
            if data.comm.group_rank == 0:
                msg = (
                    f"{log_prefix} HWPSS Filter: "
                    f"Subtract templates in {time() - t1:.1f}s"
                )
                log.verbose(msg)
        del last_good
        del last_invcov
        del last_cov
        del last_rcond

    if wcomm is not None:
        self.nsingular = wcomm.allreduce(self.nsingular)
        self.ngood = wcomm.allreduce(self.ngood)
        self.rcondsum = wcomm.allreduce(self.rcondsum)

    if wcomm is None or wcomm.rank == 0:
        denominator = self.nsingular + self.ngood
        if denominator == 0:
            msg = f"HWPSS filter had no observations with a HWP"
            log.debug(msg)
        else:
            rcond_mean = self.rcondsum / denominator
            msg = (
                f"Applied HWPSS filter in {time() - t0:.1f} s.  "
                f"Average rcond of template matrix was {rcond_mean}"
            )
            log.debug(msg)

    return

_finalize(data, **kwargs)

Source code in toast/ops/hwpfilter.py
382
383
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/hwpfilter.py
400
401
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/hwpfilter.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def _requires(self):
    req = {
        "shared": list(),
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.azimuth is not None:
        req["shared"].append(self.azimuth)
    if self.boresight_azel is not None:
        req["shared"].append(self.boresight_azel)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

build_templates(obs)

Construct the local HWPSS template hierarchy

Source code in toast/ops/hwpfilter.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@function_timer
def build_templates(self, obs):
    """Construct the local HWPSS template hierarchy"""

    # Construct trend templates.  Full domain for x is [-1, 1]

    my_offset = obs.local_index_offset
    my_nsamp = obs.n_local_samples
    nsamp_tot = obs.n_all_samples
    x = np.arange(my_offset, my_offset + my_nsamp) / nsamp_tot * 2 - 1

    legendre_trend = np.zeros([self.trend_order + 1, my_nsamp])
    legendre(x, legendre_trend, 0, self.trend_order + 1)

    # Fourier templates

    hwp_angle = obs.shared[self.hwp_angle].data
    nfilter = 2 * self.filter_order
    fourier_templates = np.zeros([nfilter, my_nsamp])
    fourier(hwp_angle, fourier_templates, 1, self.filter_order + 1)

    templates = np.vstack([legendre_trend, fourier_templates])

    return templates, legendre_trend, fourier_templates

fit_templates(obs, templates, ref, good, last_good, last_invcov, last_cov, last_rcond)

Source code in toast/ops/hwpfilter.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
@function_timer
def fit_templates(
    self,
    obs,
    templates,
    ref,
    good,
    last_good,
    last_invcov,
    last_cov,
    last_rcond,
):
    log = Logger.get()
    # communicator for processes with the same detectors
    comm = obs.comm_row
    ngood = np.sum(good)
    ntask = 1
    if comm is not None:
        ngood = comm.allreduce(ngood)
        ntask = comm.size
    if ngood == 0:
        return None, None, None, None

    ntemplate = len(templates)
    invcov = np.zeros([ntemplate, ntemplate])
    proj = np.zeros(ntemplate)

    bin_proj_fast(ref, templates, good.astype(np.uint8), proj)
    if last_good is not None and np.all(good == last_good) and ntask == 1:
        # Flags have not changed, we can re-use the last inverse covariance
        invcov = last_invcov
        cov = last_cov
        rcond = last_rcond
    else:
        bin_invcov_fast(templates, good.astype(np.uint8), invcov)
        if comm is not None:
            # Reduce the binned data.  The detector signal is
            # distributed across the group communicator.
            comm.Allreduce(MPI.IN_PLACE, invcov, op=MPI.SUM)
            comm.Allreduce(MPI.IN_PLACE, proj, op=MPI.SUM)
        rcond = get_rcond(invcov)
        cov = None

    self.rcondsum += rcond
    if rcond > 1e-6:
        self.ngood += 1
        if cov is None:
            cov = get_inverse(invcov)
    else:
        self.nsingular += 1
        log.debug(
            f"HWP template matrix is poorly conditioned, "
            f"rcond = {rcond}, using pseudoinverse."
        )
        if cov is None:
            cov = get_pseudoinverse(invcov)
    coeff = np.dot(cov, proj)
    return coeff, invcov, cov, rcond

subtract_templates(ref, good, coeff, legendre_trend, fourier_filter)

Source code in toast/ops/hwpfilter.py
241
242
243
244
245
246
247
248
249
250
251
252
@function_timer
def subtract_templates(self, ref, good, coeff, legendre_trend, fourier_filter):
    # Trend
    if self.detrend:
        trend = np.zeros_like(ref)
        add_templates(trend, legendre_trend, coeff[: self.trend_order + 1])
        ref[:] -= trend
    # HWP template
    hwptemplate = np.zeros_like(ref)
    add_templates(hwptemplate, fourier_filter, coeff[self.trend_order + 1 :])
    ref[:] -= hwptemplate
    return

toast.ops.Demodulate

Bases: Operator

Demodulate and downsample HWP-modulated data

Source code in toast/ops/demodulation.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
@trait_docs
class Demodulate(Operator):
    """Demodulate and downsample HWP-modulated data"""

    API = Int(0, help="Internal interface version for this operator")

    stokes_weights = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a Stokes weights operator",
    )

    times = Unicode(
        defaults.times,
        help="Observation shared key for timestamps",
    )

    hwp_angle = Unicode(defaults.hwp_angle, help="Observation shared key for HWP angle")

    azimuth = Unicode(defaults.azimuth, help="Observation shared key for Azimuth")

    elevation = Unicode(defaults.elevation, help="Observation shared key for Elevation")

    boresight = Unicode(
        defaults.boresight_radec, help="Observation shared key for boresight"
    )

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key apply filtering to.  Use ';' if multiple "
        "signal flavors should be demodulated.",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid, help="Bit mask value for detector sample flagging"
    )

    demod_flag_mask = Int(
        defaults.det_mask_invalid, help="Bit mask value for demod & downsample flagging"
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid, help="Bit mask value for optional shared flagging"
    )

    noise_model = Unicode(
        "noise_model",
        allow_none=True,
        help="Observation key containing the noise model",
    )

    wkernel = Int(None, allow_none=True, help="Override automatic filter kernel size")

    fmax = Quantity(
        None, allow_none=True, help="Override automatic lowpass cut-off frequency"
    )

    nskip = Int(3, help="Downsampling factor")

    window = Unicode(
        "hamming", help="Window function name recognized by scipy.signal.firwin"
    )

    purge = Bool(False, help="Remove inputs after demodulation")

    do_2f = Bool(False, help="also cache the 2f-demodulated signal")

    # Intervals?

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("stokes_weights")
    def _check_stokes_weights(self, proposal):
        weights = proposal["value"]
        if weights is not None:
            if not isinstance(weights, Operator):
                raise traitlets.TraitError(
                    "stokes_weights should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["weights", "view"]:
                if not weights.has_trait(trt):
                    msg = f"stokes_weights operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return weights

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.demod_data = Data()
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for trait in "noise_model", "stokes_weights":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        # Demodulation only applies to observations with HWP.  Verify
        # that there are such observations in `data`

        demodulate_obs = []
        for obs in data.obs:
            if self.hwp_angle not in obs.shared:
                continue
            hwp_angle = obs.shared[self.hwp_angle]
            if np.abs(np.median(np.diff(hwp_angle))) < 1e-6:
                # Stepped or stationary HWP
                continue
            demodulate_obs.append(obs)
        n_obs = len(demodulate_obs)
        if data.comm.comm_world is not None:
            n_obs = data.comm.comm_world.allreduce(n_obs)
        if n_obs == 0:
            raise RuntimeError(
                "None of the observations have a spinning HWP.  Nothing to demodulate."
            )

        # Each modulated detector demodulates into 3 or 5 pseudo detectors

        self.prefixes = ["demod0", "demod4r", "demod4i"]
        if self.do_2f:
            self.prefixes.extend(["demod2r", "demod2i"])

        timer = Timer()
        timer.start()
        for obs in demodulate_obs:
            # Get the detectors which are not cut with per-detector flags
            local_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            if obs.comm.comm_group is None:
                all_dets = local_dets
            else:
                proc_dets = obs.comm.comm_group.gather(local_dets, root=0)
                all_dets = None
                if obs.comm.comm_group.rank == 0:
                    all_dets = set()
                    for pdets in proc_dets:
                        for d in pdets:
                            all_dets.add(d)
                    all_dets = list(sorted(all_dets))
                all_dets = obs.comm.comm_group.bcast(all_dets, root=0)

            offset = obs.local_index_offset
            nsample = obs.n_local_samples

            fsample = obs.telescope.focalplane.sample_rate
            fmax, hwp_rate = self._get_fmax(obs)

            wkernel = self._get_wkernel(fmax, fsample)
            lowpass = Lowpass(wkernel, fmax, fsample, offset, self.nskip, self.window)

            # Create a new observation to hold the demodulated and downsampled data

            demod_telescope = self._demodulate_telescope(obs, all_dets)
            demod_times = self._demodulate_times(obs)
            demod_detsets = self._demodulate_detsets(obs, all_dets)
            demod_sample_sets = self._demodulate_sample_sets(obs)
            demod_process_rows = obs.dist.process_rows

            demod_name = f"demod_{obs.name}"
            demod_obs = Observation(
                obs.comm,
                demod_telescope,
                demod_times.size,
                name=demod_name,
                uid=name_UID(demod_name),
                session=obs.session,
                detector_sets=demod_detsets,
                process_rows=demod_process_rows,
                sample_sets=demod_sample_sets,
            )
            for key, value in obs.items():
                if key == self.noise_model:
                    # Will be generated later
                    continue
                demod_obs[key] = value

            # Allocate storage

            demod_dets = []
            for det in local_dets:
                for prefix in self.prefixes:
                    demod_dets.append(f"{prefix}_{det}")
            n_local = demod_obs.n_local_samples

            demod_obs.shared.create_column(self.times, (n_local,))
            demod_obs.shared[self.times].set(demod_times, offset=(0,), fromrank=0)
            demod_obs.shared.create_column(self.boresight, (n_local, 4))
            demod_obs.shared.create_column(
                self.shared_flags, (n_local,), dtype=np.uint8
            )

            self._demodulate_shared_data(obs, demod_obs)

            for det_data in self.det_data.split(";"):
                exists_data = demod_obs.detdata.ensure(
                    det_data,
                    detectors=demod_dets,
                    dtype=np.float64,
                    create_units=obs.detdata[det_data].units,
                )
            exists_flags = demod_obs.detdata.ensure(
                self.det_flags, detectors=demod_dets, dtype=np.uint8
            )

            self._demodulate_flags(obs, demod_obs, local_dets, wkernel, offset)
            self._demodulate_signal(data, obs, demod_obs, local_dets, lowpass)
            self._demodulate_pointing(data, obs, demod_obs, local_dets, lowpass, offset)
            self._demodulate_noise(
                obs, demod_obs, local_dets, fsample, hwp_rate, lowpass
            )

            self._demodulate_intervals(obs, demod_obs)

            self.demod_data.obs.append(demod_obs)

            if self.purge:
                obs.clear()

            log.debug_rank(
                "Demodulated observation in", comm=data.comm.comm_group, timer=timer
            )
        if self.purge:
            data.clear()

        return

    @function_timer
    def _get_fmax(self, obs):
        times = obs.shared[self.times].data
        hwp_angle = np.unwrap(obs.shared[self.hwp_angle].data)
        hwp_rate = np.absolute(
            np.mean(np.diff(hwp_angle) / np.diff(times)) / (2 * np.pi) * u.Hz
        )
        if self.fmax is not None:
            fmax = self.fmax
        else:
            # set low-pass filter cut-off frequency as same as HWP 1f
            fmax = hwp_rate
        return fmax, hwp_rate

    @function_timer
    def _get_wkernel(self, fmax, fsample):
        if self.wkernel is not None:
            wkernel = self.wkernel
        else:
            # set kernel size longer than low-pass filter time scale
            wkernel = (1 << int(np.ceil(np.log(fsample / fmax * 10) / np.log(2)))) - 1
        return wkernel

    @function_timer
    def _demodulate_telescope(self, obs, all_dets):
        focalplane = obs.telescope.focalplane
        det_data = focalplane.detector_data
        field_names = det_data.colnames
        # Initialize fields to empty lists
        fields = {name: list() for name in field_names}
        all_set = set(all_dets)
        for row, det in enumerate(det_data["name"]):
            if det not in all_set:
                continue
            for field_name in field_names:
                # Each detector translates into 3 or 5 new entries
                for prefix in self.prefixes:
                    if field_name == "name":
                        fields[field_name].append(f"{prefix}_{det}")
                    else:
                        fields[field_name].append(det_data[field_name][row])
        demod_det_data = QTable(
            [fields[field_name] for field_name in field_names], names=field_names
        )
        my_all = list()
        for name in demod_det_data["name"]:
            my_all.append(name)

        demod_focalplane = Focalplane(
            detector_data=demod_det_data,
            field_of_view=focalplane.field_of_view,
            sample_rate=focalplane.sample_rate / self.nskip,
        )
        demod_name = f"demod_{obs.telescope.name}"
        demod_telescope = Telescope(
            name=demod_name,
            uid=name_UID(demod_name),
            focalplane=demod_focalplane,
            site=obs.telescope.site,
        )
        return demod_telescope

    @function_timer
    def _demodulate_times(self, obs):
        """Downsample timestamps"""
        times = obs.shared[self.times].data.copy()
        if self.nskip != 1:
            offset = obs.local_index_offset
            times = np.array(times[offset % self.nskip :: self.nskip])
        return times

    @function_timer
    def _demodulate_shared_data(self, obs, demod_obs):
        """Downsample shared data"""
        n_local = demod_obs.n_local_samples
        for key in self.azimuth, self.elevation:
            if key is None:
                continue
            values = obs.shared[key].data.copy()
            if self.nskip != 1:
                offset = obs.local_index_offset
                values = np.array(values[offset % self.nskip :: self.nskip])
            demod_obs.shared.create_column(key, (n_local,))
            demod_obs.shared[key].set(
                values,
                offset=(0,),
                fromrank=0,
            )
        return

    @function_timer
    def _demodulate_detsets(self, obs, all_dets):
        """In order to force local detectors to remain on their original
        process, we create a detector set for each row of the process
        grid.
        """
        log = Logger.get()
        if obs.comm_col_size == 1:
            # One process row
            detsets = [all_dets]
        else:
            local_proc_dets = obs.comm_col.gather(obs.local_detectors, root=0)
            detsets = None
            if obs.comm_col_rank == 0:
                all_set = set(all_dets)
                detsets = list()
                for iprow, pdets in enumerate(local_proc_dets):
                    plocal = list()
                    for d in pdets:
                        if d in all_set:
                            plocal.append(d)
                    if len(plocal) == 0:
                        msg = f"obs {obs.name}, process row {iprow} has no"
                        msg += " unflagged detectors.  This may cause an error."
                        log.warning(msg)
                    detsets.append(plocal)
            detsets = obs.comm_col.bcast(detsets, root=0)

        demod_detsets = list()
        for dset in detsets:
            demod_detset = list()
            for det in dset:
                for prefix in self.prefixes:
                    demod_detset.append(f"{prefix}_{det}")
            demod_detsets.append(demod_detset)
        return demod_detsets

    @function_timer
    def _demodulate_sample_sets(self, obs):
        sample_sets = obs.all_sample_sets
        if sample_sets is None:
            return None
        demod_sample_sets = []
        offset = 0
        for sample_set in sample_sets:
            demod_sample_set = []
            for chunksize in sample_set:
                first_sample = offset
                last_sample = offset + chunksize
                demod_first_sample = int(np.ceil(first_sample / self.nskip))
                demod_last_sample = int(np.ceil(last_sample / self.nskip))
                demod_chunksize = demod_last_sample - demod_first_sample
                demod_sample_set.append(demod_chunksize)
                offset += chunksize
            demod_sample_sets.append(demod_sample_set)
        return demod_sample_sets

    @function_timer
    def _demodulate_intervals(self, obs, demod_obs):
        if self.nskip == 1:
            demod_obs.intervals = obs.intervals
            return
        times = demod_obs.shared[self.times]
        for name, ivals in obs.intervals.items():
            timespans = [[ival.start, ival.stop] for ival in ivals]
            demod_obs.intervals[name] = IntervalList(times, timespans=timespans)
        # Force the creation of new "all" interval
        del demod_obs.intervals[None]
        return

    @function_timer
    def _demodulate_flag(self, flags, wkernel, offset):
        """Collapse flags inside the filter window and downsample"""
        """
        # FIXME: this is horribly inefficient but optimization may require
        # FIXME: a compiled kernel
        n = flags.size
        new_flags = []
        width = wkernel // 2 + 1
        for i in range(0, n, self.nskip):
            ind = slice(max(0, i - width), min(n, i + width + 1))
            buf = flags[ind]
            flag = buf[0]
            for flag2 in buf[1:]:
                flag |= flag2
            new_flags.append(flag)
        new_flags = np.array(new_flags)
        """
        # FIXME: for now, just downsample the flags.  Real data will require
        # FIXME:    measuring the total flag within the filter window
        flags = flags.copy()
        # flag invalid samples in both ends
        flags[: wkernel // 2] |= self.demod_flag_mask
        flags[-(wkernel // 2) :] |= self.demod_flag_mask
        new_flags = np.array(flags[offset % self.nskip :: self.nskip])
        return new_flags

    @function_timer
    def _demodulate_signal(self, data, obs, demod_obs, dets, lowpass):
        """demodulate signal TOD"""

        for det in dets:
            # Get weights
            obs_data = data.select(obs_uid=obs.uid)
            self.stokes_weights.apply(obs_data, dets=[det])
            weights = obs.detdata[self.stokes_weights.weights][det]
            # iweights = 1
            # qweights = eta * cos(2 * psi_det + 4 * psi_hwp)
            # uweights = eta * sin(2 * psi_det + 4 * psi_hwp)
            iweights, qweights, uweights = weights.T
            etainv = 1 / np.sqrt(qweights**2 + uweights**2)

            for flavor in self.det_data.split(";"):
                signal = obs.detdata[flavor][det]
                det_data = demod_obs.detdata[flavor]
                det_data[f"demod0_{det}"] = lowpass(signal)
                det_data[f"demod4r_{det}"] = lowpass(signal * 2 * qweights * etainv)
                det_data[f"demod4i_{det}"] = lowpass(signal * 2 * uweights * etainv)

                if self.do_2f:
                    # Start by evaluating the 2f demodulation factors from the
                    # pointing matrix.  We use the half-angle formulas and some
                    # extra logic to identify the right branch
                    #
                    # |cos(psi/2)| and |sin(psi/2)|:
                    signal_demod2r = np.sqrt(0.5 * (1 + qweights * etainv))
                    signal_demod2i = np.sqrt(0.5 * (1 - qweights * etainv))
                    # inverse the sign for every second mode
                    for sig in signal_demod2r, signal_demod2i:
                        dsig = np.diff(sig)
                        dsig[sig[1:] > 0.5] = 0
                        starts = np.where(dsig[:-1] * dsig[1:] < 0)[0]
                        for start, stop in zip(starts[::2], starts[1::2]):
                            sig[start + 1 : stop + 2] *= -1
                        # handle some corner cases
                        dsig = np.diff(sig)
                        dstep = np.median(np.abs(dsig[sig[1:] < 0.5]))
                        bad = np.abs(dsig) > 2 * dstep
                        bad = np.hstack([bad, False])
                        sig[bad] *= -1
                    # Demodulate and lowpass for 2f
                    det_data[f"demod2r_{det}"] = lowpass(signal * signal_demod2r)
                    det_data[f"demod2i_{det}"] = lowpass(signal * signal_demod2i)

        return

    @function_timer
    def _demodulate_flags(self, obs, demod_obs, dets, wkernel, offset):
        """Demodulate and downsample flags"""

        shared_flags = obs.shared[self.shared_flags].data
        demod_shared_flags = self._demodulate_flag(shared_flags, wkernel, offset)
        demod_obs.shared[self.shared_flags].set(
            demod_shared_flags, offset=(0,), fromrank=0
        )

        input_det_flags = obs.local_detector_flags
        output_det_flags = dict()

        for det in dets:
            flags = obs.detdata[self.det_flags][det]
            # Downsample flags
            demod_flags = self._demodulate_flag(flags, wkernel, offset)
            for prefix in self.prefixes:
                demod_det = f"{prefix}_{det}"
                demod_obs.detdata[self.det_flags][demod_det] = demod_flags
                output_det_flags[demod_det] = input_det_flags[det]
        demod_obs.update_local_detector_flags(output_det_flags)
        return

    @function_timer
    def _demodulate_pointing(self, data, obs, demod_obs, dets, lowpass, offset):
        """demodulate pointing matrix"""

        # Pointing matrix is now computed on the fly.  We only need to
        # demodulate the boresight quaternions

        quats = obs.shared[self.boresight].data
        demod_obs.shared[self.boresight].set(
            np.array(quats[offset % self.nskip :: self.nskip]),
            offset=(0, 0),
            fromrank=0,
        )

        return

    @function_timer
    def _demodulate_noise(
        self,
        obs,
        demod_obs,
        dets,
        fsample,
        hwp_rate,
        lowpass,
    ):
        """Add Noise objects for the new detectors"""
        noise = obs[self.noise_model]

        demod_detectors = []
        demod_freqs = {}
        demod_psds = {}
        demod_indices = {}
        demod_weights = {}

        lpf = lowpass.lpf
        lpf_freq = np.fft.rfftfreq(lpf.size, 1 / fsample.to_value(u.Hz))
        lpf_value = np.abs(np.fft.rfft(lpf)) ** 2
        for det in dets:
            # weight -- ignored
            # index  - ignored
            # rate
            rate_in = noise.rate(det)
            # freq
            freq_in = noise.freq(det)
            # Lowpass transfer function
            tf = np.interp(freq_in.to_value(u.Hz), lpf_freq, lpf_value)
            # Find the highest frequency without significant suppression
            # to measure noise weights at
            iweight = tf.size - 1
            while iweight > 0 and tf[iweight] < 0.99:
                iweight -= 1
            # psd
            psd_in = noise.psd(det)
            n_mode = len(self.prefixes)
            for indexoff, prefix in enumerate(self.prefixes):
                demod_det = f"{prefix}_{det}"
                # Get the demodulated PSD
                if prefix == "demod0":
                    # this PSD does not change
                    psd_out = psd_in.copy()
                elif prefix.startswith("demod2"):
                    # get noise at 2f
                    psd_out = np.zeros_like(psd_in)
                    psd_out[:] = np.interp(2 * hwp_rate, freq_in, psd_in)
                else:
                    # get noise at 4f
                    psd_out = np.zeros_like(psd_in)
                    psd_out[:] = np.interp(4 * hwp_rate, freq_in, psd_in)
                # Lowpass
                psd_out *= tf
                # Downsample
                rate_out = rate_in / self.nskip
                ind = freq_in <= rate_out / 2
                freq_out = freq_in[ind]
                # Last bin must equal the new Nyquist frequency
                freq_out[-1] = rate_out / 2
                psd_out = psd_out[ind] / self.nskip
                # Calculate noise weight
                noisevar = psd_out[iweight].to_value(u.K**2 * u.second)
                invvar = 1.0 / noisevar / rate_out.to_value(u.Hz)
                # Insert
                demod_detectors.append(demod_det)
                demod_freqs[demod_det] = freq_out
                demod_psds[demod_det] = psd_out
                demod_indices[demod_det] = noise.index(det) * n_mode + indexoff
                demod_weights[demod_det] = invvar / u.K**2
        demod_obs[self.noise_model] = Noise(
            detectors=demod_detectors,
            freqs=demod_freqs,
            psds=demod_psds,
            indices=demod_indices,
            detweights=demod_weights,
        )
        return

    def _finalize(self, data, **kwargs):
        return self.demod_data

    def _requires(self):
        req = {
            "shared": [self.times, self.boresight],
            "detdata": [self.det_data],
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.boresight_azel is not None:
            req["shared"].append(self.boresight_azel)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

azimuth = Unicode(defaults.azimuth, help='Observation shared key for Azimuth') class-attribute instance-attribute

boresight = Unicode(defaults.boresight_radec, help='Observation shared key for boresight') class-attribute instance-attribute

demod_data = Data() instance-attribute

demod_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for demod & downsample flagging') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help="Observation detdata key apply filtering to. Use ';' if multiple signal flavors should be demodulated.") class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

do_2f = Bool(False, help='also cache the 2f-demodulated signal') class-attribute instance-attribute

elevation = Unicode(defaults.elevation, help='Observation shared key for Elevation') class-attribute instance-attribute

fmax = Quantity(None, allow_none=True, help='Override automatic lowpass cut-off frequency') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, help='Observation shared key for HWP angle') class-attribute instance-attribute

noise_model = Unicode('noise_model', allow_none=True, help='Observation key containing the noise model') class-attribute instance-attribute

nskip = Int(3, help='Downsampling factor') class-attribute instance-attribute

purge = Bool(False, help='Remove inputs after demodulation') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

stokes_weights = Instance(klass=Operator, allow_none=True, help='This must be an instance of a Stokes weights operator') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

window = Unicode('hamming', help='Window function name recognized by scipy.signal.firwin') class-attribute instance-attribute

wkernel = Int(None, allow_none=True, help='Override automatic filter kernel size') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/demodulation.py
179
180
181
182
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.demod_data = Data()
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/demodulation.py
150
151
152
153
154
155
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/demodulation.py
143
144
145
146
147
148
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_mask(proposal)

Source code in toast/ops/demodulation.py
157
158
159
160
161
162
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_check_stokes_weights(proposal)

Source code in toast/ops/demodulation.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@traitlets.validate("stokes_weights")
def _check_stokes_weights(self, proposal):
    weights = proposal["value"]
    if weights is not None:
        if not isinstance(weights, Operator):
            raise traitlets.TraitError(
                "stokes_weights should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["weights", "view"]:
            if not weights.has_trait(trt):
                msg = f"stokes_weights operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return weights

_demodulate_detsets(obs, all_dets)

In order to force local detectors to remain on their original process, we create a detector set for each row of the process grid.

Source code in toast/ops/demodulation.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
@function_timer
def _demodulate_detsets(self, obs, all_dets):
    """In order to force local detectors to remain on their original
    process, we create a detector set for each row of the process
    grid.
    """
    log = Logger.get()
    if obs.comm_col_size == 1:
        # One process row
        detsets = [all_dets]
    else:
        local_proc_dets = obs.comm_col.gather(obs.local_detectors, root=0)
        detsets = None
        if obs.comm_col_rank == 0:
            all_set = set(all_dets)
            detsets = list()
            for iprow, pdets in enumerate(local_proc_dets):
                plocal = list()
                for d in pdets:
                    if d in all_set:
                        plocal.append(d)
                if len(plocal) == 0:
                    msg = f"obs {obs.name}, process row {iprow} has no"
                    msg += " unflagged detectors.  This may cause an error."
                    log.warning(msg)
                detsets.append(plocal)
        detsets = obs.comm_col.bcast(detsets, root=0)

    demod_detsets = list()
    for dset in detsets:
        demod_detset = list()
        for det in dset:
            for prefix in self.prefixes:
                demod_detset.append(f"{prefix}_{det}")
        demod_detsets.append(demod_detset)
    return demod_detsets

_demodulate_flag(flags, wkernel, offset)

Collapse flags inside the filter window and downsample

Source code in toast/ops/demodulation.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
@function_timer
def _demodulate_flag(self, flags, wkernel, offset):
    """Collapse flags inside the filter window and downsample"""
    """
    # FIXME: this is horribly inefficient but optimization may require
    # FIXME: a compiled kernel
    n = flags.size
    new_flags = []
    width = wkernel // 2 + 1
    for i in range(0, n, self.nskip):
        ind = slice(max(0, i - width), min(n, i + width + 1))
        buf = flags[ind]
        flag = buf[0]
        for flag2 in buf[1:]:
            flag |= flag2
        new_flags.append(flag)
    new_flags = np.array(new_flags)
    """
    # FIXME: for now, just downsample the flags.  Real data will require
    # FIXME:    measuring the total flag within the filter window
    flags = flags.copy()
    # flag invalid samples in both ends
    flags[: wkernel // 2] |= self.demod_flag_mask
    flags[-(wkernel // 2) :] |= self.demod_flag_mask
    new_flags = np.array(flags[offset % self.nskip :: self.nskip])
    return new_flags

_demodulate_flags(obs, demod_obs, dets, wkernel, offset)

Demodulate and downsample flags

Source code in toast/ops/demodulation.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
@function_timer
def _demodulate_flags(self, obs, demod_obs, dets, wkernel, offset):
    """Demodulate and downsample flags"""

    shared_flags = obs.shared[self.shared_flags].data
    demod_shared_flags = self._demodulate_flag(shared_flags, wkernel, offset)
    demod_obs.shared[self.shared_flags].set(
        demod_shared_flags, offset=(0,), fromrank=0
    )

    input_det_flags = obs.local_detector_flags
    output_det_flags = dict()

    for det in dets:
        flags = obs.detdata[self.det_flags][det]
        # Downsample flags
        demod_flags = self._demodulate_flag(flags, wkernel, offset)
        for prefix in self.prefixes:
            demod_det = f"{prefix}_{det}"
            demod_obs.detdata[self.det_flags][demod_det] = demod_flags
            output_det_flags[demod_det] = input_det_flags[det]
    demod_obs.update_local_detector_flags(output_det_flags)
    return

_demodulate_intervals(obs, demod_obs)

Source code in toast/ops/demodulation.py
469
470
471
472
473
474
475
476
477
478
479
480
@function_timer
def _demodulate_intervals(self, obs, demod_obs):
    if self.nskip == 1:
        demod_obs.intervals = obs.intervals
        return
    times = demod_obs.shared[self.times]
    for name, ivals in obs.intervals.items():
        timespans = [[ival.start, ival.stop] for ival in ivals]
        demod_obs.intervals[name] = IntervalList(times, timespans=timespans)
    # Force the creation of new "all" interval
    del demod_obs.intervals[None]
    return

_demodulate_noise(obs, demod_obs, dets, fsample, hwp_rate, lowpass)

Add Noise objects for the new detectors

Source code in toast/ops/demodulation.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
@function_timer
def _demodulate_noise(
    self,
    obs,
    demod_obs,
    dets,
    fsample,
    hwp_rate,
    lowpass,
):
    """Add Noise objects for the new detectors"""
    noise = obs[self.noise_model]

    demod_detectors = []
    demod_freqs = {}
    demod_psds = {}
    demod_indices = {}
    demod_weights = {}

    lpf = lowpass.lpf
    lpf_freq = np.fft.rfftfreq(lpf.size, 1 / fsample.to_value(u.Hz))
    lpf_value = np.abs(np.fft.rfft(lpf)) ** 2
    for det in dets:
        # weight -- ignored
        # index  - ignored
        # rate
        rate_in = noise.rate(det)
        # freq
        freq_in = noise.freq(det)
        # Lowpass transfer function
        tf = np.interp(freq_in.to_value(u.Hz), lpf_freq, lpf_value)
        # Find the highest frequency without significant suppression
        # to measure noise weights at
        iweight = tf.size - 1
        while iweight > 0 and tf[iweight] < 0.99:
            iweight -= 1
        # psd
        psd_in = noise.psd(det)
        n_mode = len(self.prefixes)
        for indexoff, prefix in enumerate(self.prefixes):
            demod_det = f"{prefix}_{det}"
            # Get the demodulated PSD
            if prefix == "demod0":
                # this PSD does not change
                psd_out = psd_in.copy()
            elif prefix.startswith("demod2"):
                # get noise at 2f
                psd_out = np.zeros_like(psd_in)
                psd_out[:] = np.interp(2 * hwp_rate, freq_in, psd_in)
            else:
                # get noise at 4f
                psd_out = np.zeros_like(psd_in)
                psd_out[:] = np.interp(4 * hwp_rate, freq_in, psd_in)
            # Lowpass
            psd_out *= tf
            # Downsample
            rate_out = rate_in / self.nskip
            ind = freq_in <= rate_out / 2
            freq_out = freq_in[ind]
            # Last bin must equal the new Nyquist frequency
            freq_out[-1] = rate_out / 2
            psd_out = psd_out[ind] / self.nskip
            # Calculate noise weight
            noisevar = psd_out[iweight].to_value(u.K**2 * u.second)
            invvar = 1.0 / noisevar / rate_out.to_value(u.Hz)
            # Insert
            demod_detectors.append(demod_det)
            demod_freqs[demod_det] = freq_out
            demod_psds[demod_det] = psd_out
            demod_indices[demod_det] = noise.index(det) * n_mode + indexoff
            demod_weights[demod_det] = invvar / u.K**2
    demod_obs[self.noise_model] = Noise(
        detectors=demod_detectors,
        freqs=demod_freqs,
        psds=demod_psds,
        indices=demod_indices,
        detweights=demod_weights,
    )
    return

_demodulate_pointing(data, obs, demod_obs, dets, lowpass, offset)

demodulate pointing matrix

Source code in toast/ops/demodulation.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
@function_timer
def _demodulate_pointing(self, data, obs, demod_obs, dets, lowpass, offset):
    """demodulate pointing matrix"""

    # Pointing matrix is now computed on the fly.  We only need to
    # demodulate the boresight quaternions

    quats = obs.shared[self.boresight].data
    demod_obs.shared[self.boresight].set(
        np.array(quats[offset % self.nskip :: self.nskip]),
        offset=(0, 0),
        fromrank=0,
    )

    return

_demodulate_sample_sets(obs)

Source code in toast/ops/demodulation.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@function_timer
def _demodulate_sample_sets(self, obs):
    sample_sets = obs.all_sample_sets
    if sample_sets is None:
        return None
    demod_sample_sets = []
    offset = 0
    for sample_set in sample_sets:
        demod_sample_set = []
        for chunksize in sample_set:
            first_sample = offset
            last_sample = offset + chunksize
            demod_first_sample = int(np.ceil(first_sample / self.nskip))
            demod_last_sample = int(np.ceil(last_sample / self.nskip))
            demod_chunksize = demod_last_sample - demod_first_sample
            demod_sample_set.append(demod_chunksize)
            offset += chunksize
        demod_sample_sets.append(demod_sample_set)
    return demod_sample_sets

_demodulate_shared_data(obs, demod_obs)

Downsample shared data

Source code in toast/ops/demodulation.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
@function_timer
def _demodulate_shared_data(self, obs, demod_obs):
    """Downsample shared data"""
    n_local = demod_obs.n_local_samples
    for key in self.azimuth, self.elevation:
        if key is None:
            continue
        values = obs.shared[key].data.copy()
        if self.nskip != 1:
            offset = obs.local_index_offset
            values = np.array(values[offset % self.nskip :: self.nskip])
        demod_obs.shared.create_column(key, (n_local,))
        demod_obs.shared[key].set(
            values,
            offset=(0,),
            fromrank=0,
        )
    return

_demodulate_signal(data, obs, demod_obs, dets, lowpass)

demodulate signal TOD

Source code in toast/ops/demodulation.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
@function_timer
def _demodulate_signal(self, data, obs, demod_obs, dets, lowpass):
    """demodulate signal TOD"""

    for det in dets:
        # Get weights
        obs_data = data.select(obs_uid=obs.uid)
        self.stokes_weights.apply(obs_data, dets=[det])
        weights = obs.detdata[self.stokes_weights.weights][det]
        # iweights = 1
        # qweights = eta * cos(2 * psi_det + 4 * psi_hwp)
        # uweights = eta * sin(2 * psi_det + 4 * psi_hwp)
        iweights, qweights, uweights = weights.T
        etainv = 1 / np.sqrt(qweights**2 + uweights**2)

        for flavor in self.det_data.split(";"):
            signal = obs.detdata[flavor][det]
            det_data = demod_obs.detdata[flavor]
            det_data[f"demod0_{det}"] = lowpass(signal)
            det_data[f"demod4r_{det}"] = lowpass(signal * 2 * qweights * etainv)
            det_data[f"demod4i_{det}"] = lowpass(signal * 2 * uweights * etainv)

            if self.do_2f:
                # Start by evaluating the 2f demodulation factors from the
                # pointing matrix.  We use the half-angle formulas and some
                # extra logic to identify the right branch
                #
                # |cos(psi/2)| and |sin(psi/2)|:
                signal_demod2r = np.sqrt(0.5 * (1 + qweights * etainv))
                signal_demod2i = np.sqrt(0.5 * (1 - qweights * etainv))
                # inverse the sign for every second mode
                for sig in signal_demod2r, signal_demod2i:
                    dsig = np.diff(sig)
                    dsig[sig[1:] > 0.5] = 0
                    starts = np.where(dsig[:-1] * dsig[1:] < 0)[0]
                    for start, stop in zip(starts[::2], starts[1::2]):
                        sig[start + 1 : stop + 2] *= -1
                    # handle some corner cases
                    dsig = np.diff(sig)
                    dstep = np.median(np.abs(dsig[sig[1:] < 0.5]))
                    bad = np.abs(dsig) > 2 * dstep
                    bad = np.hstack([bad, False])
                    sig[bad] *= -1
                # Demodulate and lowpass for 2f
                det_data[f"demod2r_{det}"] = lowpass(signal * signal_demod2r)
                det_data[f"demod2i_{det}"] = lowpass(signal * signal_demod2i)

    return

_demodulate_telescope(obs, all_dets)

Source code in toast/ops/demodulation.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@function_timer
def _demodulate_telescope(self, obs, all_dets):
    focalplane = obs.telescope.focalplane
    det_data = focalplane.detector_data
    field_names = det_data.colnames
    # Initialize fields to empty lists
    fields = {name: list() for name in field_names}
    all_set = set(all_dets)
    for row, det in enumerate(det_data["name"]):
        if det not in all_set:
            continue
        for field_name in field_names:
            # Each detector translates into 3 or 5 new entries
            for prefix in self.prefixes:
                if field_name == "name":
                    fields[field_name].append(f"{prefix}_{det}")
                else:
                    fields[field_name].append(det_data[field_name][row])
    demod_det_data = QTable(
        [fields[field_name] for field_name in field_names], names=field_names
    )
    my_all = list()
    for name in demod_det_data["name"]:
        my_all.append(name)

    demod_focalplane = Focalplane(
        detector_data=demod_det_data,
        field_of_view=focalplane.field_of_view,
        sample_rate=focalplane.sample_rate / self.nskip,
    )
    demod_name = f"demod_{obs.telescope.name}"
    demod_telescope = Telescope(
        name=demod_name,
        uid=name_UID(demod_name),
        focalplane=demod_focalplane,
        site=obs.telescope.site,
    )
    return demod_telescope

_demodulate_times(obs)

Downsample timestamps

Source code in toast/ops/demodulation.py
384
385
386
387
388
389
390
391
@function_timer
def _demodulate_times(self, obs):
    """Downsample timestamps"""
    times = obs.shared[self.times].data.copy()
    if self.nskip != 1:
        offset = obs.local_index_offset
        times = np.array(times[offset % self.nskip :: self.nskip])
    return times

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/demodulation.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for trait in "noise_model", "stokes_weights":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    # Demodulation only applies to observations with HWP.  Verify
    # that there are such observations in `data`

    demodulate_obs = []
    for obs in data.obs:
        if self.hwp_angle not in obs.shared:
            continue
        hwp_angle = obs.shared[self.hwp_angle]
        if np.abs(np.median(np.diff(hwp_angle))) < 1e-6:
            # Stepped or stationary HWP
            continue
        demodulate_obs.append(obs)
    n_obs = len(demodulate_obs)
    if data.comm.comm_world is not None:
        n_obs = data.comm.comm_world.allreduce(n_obs)
    if n_obs == 0:
        raise RuntimeError(
            "None of the observations have a spinning HWP.  Nothing to demodulate."
        )

    # Each modulated detector demodulates into 3 or 5 pseudo detectors

    self.prefixes = ["demod0", "demod4r", "demod4i"]
    if self.do_2f:
        self.prefixes.extend(["demod2r", "demod2i"])

    timer = Timer()
    timer.start()
    for obs in demodulate_obs:
        # Get the detectors which are not cut with per-detector flags
        local_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
        if obs.comm.comm_group is None:
            all_dets = local_dets
        else:
            proc_dets = obs.comm.comm_group.gather(local_dets, root=0)
            all_dets = None
            if obs.comm.comm_group.rank == 0:
                all_dets = set()
                for pdets in proc_dets:
                    for d in pdets:
                        all_dets.add(d)
                all_dets = list(sorted(all_dets))
            all_dets = obs.comm.comm_group.bcast(all_dets, root=0)

        offset = obs.local_index_offset
        nsample = obs.n_local_samples

        fsample = obs.telescope.focalplane.sample_rate
        fmax, hwp_rate = self._get_fmax(obs)

        wkernel = self._get_wkernel(fmax, fsample)
        lowpass = Lowpass(wkernel, fmax, fsample, offset, self.nskip, self.window)

        # Create a new observation to hold the demodulated and downsampled data

        demod_telescope = self._demodulate_telescope(obs, all_dets)
        demod_times = self._demodulate_times(obs)
        demod_detsets = self._demodulate_detsets(obs, all_dets)
        demod_sample_sets = self._demodulate_sample_sets(obs)
        demod_process_rows = obs.dist.process_rows

        demod_name = f"demod_{obs.name}"
        demod_obs = Observation(
            obs.comm,
            demod_telescope,
            demod_times.size,
            name=demod_name,
            uid=name_UID(demod_name),
            session=obs.session,
            detector_sets=demod_detsets,
            process_rows=demod_process_rows,
            sample_sets=demod_sample_sets,
        )
        for key, value in obs.items():
            if key == self.noise_model:
                # Will be generated later
                continue
            demod_obs[key] = value

        # Allocate storage

        demod_dets = []
        for det in local_dets:
            for prefix in self.prefixes:
                demod_dets.append(f"{prefix}_{det}")
        n_local = demod_obs.n_local_samples

        demod_obs.shared.create_column(self.times, (n_local,))
        demod_obs.shared[self.times].set(demod_times, offset=(0,), fromrank=0)
        demod_obs.shared.create_column(self.boresight, (n_local, 4))
        demod_obs.shared.create_column(
            self.shared_flags, (n_local,), dtype=np.uint8
        )

        self._demodulate_shared_data(obs, demod_obs)

        for det_data in self.det_data.split(";"):
            exists_data = demod_obs.detdata.ensure(
                det_data,
                detectors=demod_dets,
                dtype=np.float64,
                create_units=obs.detdata[det_data].units,
            )
        exists_flags = demod_obs.detdata.ensure(
            self.det_flags, detectors=demod_dets, dtype=np.uint8
        )

        self._demodulate_flags(obs, demod_obs, local_dets, wkernel, offset)
        self._demodulate_signal(data, obs, demod_obs, local_dets, lowpass)
        self._demodulate_pointing(data, obs, demod_obs, local_dets, lowpass, offset)
        self._demodulate_noise(
            obs, demod_obs, local_dets, fsample, hwp_rate, lowpass
        )

        self._demodulate_intervals(obs, demod_obs)

        self.demod_data.obs.append(demod_obs)

        if self.purge:
            obs.clear()

        log.debug_rank(
            "Demodulated observation in", comm=data.comm.comm_group, timer=timer
        )
    if self.purge:
        data.clear()

    return

_finalize(data, **kwargs)

Source code in toast/ops/demodulation.py
678
679
def _finalize(self, data, **kwargs):
    return self.demod_data

_get_fmax(obs)

Source code in toast/ops/demodulation.py
322
323
324
325
326
327
328
329
330
331
332
333
334
@function_timer
def _get_fmax(self, obs):
    times = obs.shared[self.times].data
    hwp_angle = np.unwrap(obs.shared[self.hwp_angle].data)
    hwp_rate = np.absolute(
        np.mean(np.diff(hwp_angle) / np.diff(times)) / (2 * np.pi) * u.Hz
    )
    if self.fmax is not None:
        fmax = self.fmax
    else:
        # set low-pass filter cut-off frequency as same as HWP 1f
        fmax = hwp_rate
    return fmax, hwp_rate

_get_wkernel(fmax, fsample)

Source code in toast/ops/demodulation.py
336
337
338
339
340
341
342
343
@function_timer
def _get_wkernel(self, fmax, fsample):
    if self.wkernel is not None:
        wkernel = self.wkernel
    else:
        # set kernel size longer than low-pass filter time scale
        wkernel = (1 << int(np.ceil(np.log(fsample / fmax * 10) / np.log(2)))) - 1
    return wkernel

_provides()

Source code in toast/ops/demodulation.py
694
695
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/demodulation.py
681
682
683
684
685
686
687
688
689
690
691
692
def _requires(self):
    req = {
        "shared": [self.times, self.boresight],
        "detdata": [self.det_data],
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.boresight_azel is not None:
        req["shared"].append(self.boresight_azel)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

toast.ops.StokesWeightsDemod

Bases: Operator

Compute the Stokes pointing weights for demodulated data

Source code in toast/ops/demodulation.py
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
@trait_docs
class StokesWeightsDemod(Operator):
    """Compute the Stokes pointing weights for demodulated data"""

    API = Int(0, help="Internal interface version for this operator")

    mode = Unicode("IQU", help="The Stokes weights to generate")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    weights = Unicode(
        defaults.weights, help="Observation detdata key for output weights"
    )

    single_precision = Bool(False, help="If True, use 32bit float in output")

    @traitlets.validate("mode")
    def _check_mode(self, proposal):
        mode = proposal["value"]
        if mode not in ["IQU"]:
            raise traitlets.TraitError("Invalid mode (must be 'IQU')")
        return mode

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        nnz = len(self.mode)

        if self.single_precision:
            dtype = np.float32
        else:
            dtype = np.float64

        for obs in data.obs:
            dets = obs.select_local_detectors(detectors)
            if len(dets) == 0:
                continue

            exists_weights = obs.detdata.ensure(
                self.weights,
                sample_shape=(nnz,),
                dtype=dtype,
                detectors=dets,
            )
            nsample = obs.n_local_samples
            ones = np.ones(nsample, dtype=dtype)
            zeros = np.zeros(nsample, dtype=dtype)
            weights = obs.detdata[self.weights]
            for det in dets:
                props = obs.telescope.focalplane[det]
                if "pol_efficiency" in props.colnames:
                    eta = props["pol_efficiency"]
                else:
                    eta = 1.0
                if det.startswith("demod0"):
                    # Stokes I only
                    weights[det] = np.column_stack([ones, zeros, zeros])
                elif det.startswith("demod4r"):
                    # Stokes Q only
                    weights[det] = np.column_stack([zeros, eta * ones, zeros])
                elif det.startswith("demod4i"):
                    # Stokes U only
                    weights[det] = np.column_stack([zeros, zeros, eta * ones])
                else:
                    # 2f, systematics only
                    weights[det] = np.column_stack([zeros, zeros, zeros])
        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "shared": list(),
            "detdata": list(),
        }
        return req

    def _provides(self):
        return {"detdata": self.weights}

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

mode = Unicode('IQU', help='The Stokes weights to generate') class-attribute instance-attribute

single_precision = Bool(False, help='If True, use 32bit float in output') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

weights = Unicode(defaults.weights, help='Observation detdata key for output weights') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/demodulation.py
723
724
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_mode(proposal)

Source code in toast/ops/demodulation.py
716
717
718
719
720
721
@traitlets.validate("mode")
def _check_mode(self, proposal):
    mode = proposal["value"]
    if mode not in ["IQU"]:
        raise traitlets.TraitError("Invalid mode (must be 'IQU')")
    return mode

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/demodulation.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    nnz = len(self.mode)

    if self.single_precision:
        dtype = np.float32
    else:
        dtype = np.float64

    for obs in data.obs:
        dets = obs.select_local_detectors(detectors)
        if len(dets) == 0:
            continue

        exists_weights = obs.detdata.ensure(
            self.weights,
            sample_shape=(nnz,),
            dtype=dtype,
            detectors=dets,
        )
        nsample = obs.n_local_samples
        ones = np.ones(nsample, dtype=dtype)
        zeros = np.zeros(nsample, dtype=dtype)
        weights = obs.detdata[self.weights]
        for det in dets:
            props = obs.telescope.focalplane[det]
            if "pol_efficiency" in props.colnames:
                eta = props["pol_efficiency"]
            else:
                eta = 1.0
            if det.startswith("demod0"):
                # Stokes I only
                weights[det] = np.column_stack([ones, zeros, zeros])
            elif det.startswith("demod4r"):
                # Stokes Q only
                weights[det] = np.column_stack([zeros, eta * ones, zeros])
            elif det.startswith("demod4i"):
                # Stokes U only
                weights[det] = np.column_stack([zeros, zeros, eta * ones])
            else:
                # 2f, systematics only
                weights[det] = np.column_stack([zeros, zeros, zeros])
    return

_finalize(data, **kwargs)

Source code in toast/ops/demodulation.py
772
773
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/demodulation.py
782
783
def _provides(self):
    return {"detdata": self.weights}

_requires()

Source code in toast/ops/demodulation.py
775
776
777
778
779
780
def _requires(self):
    req = {
        "shared": list(),
        "detdata": list(),
    }
    return req

Pointing Matrix

toast.ops.PointingDetectorSimple

Bases: Operator

Operator which translates boresight pointing into detector frame

Source code in toast/ops/pointing_detector/pointing_detector.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@trait_docs
class PointingDetectorSimple(Operator):
    """Operator which translates boresight pointing into detector frame"""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_invalid, help="Bit mask value for optional flagging"
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    boresight = Unicode(
        defaults.boresight_radec, help="Observation shared key for boresight"
    )

    hwp_angle = Unicode(
        defaults.hwp_angle, allow_none=True, help="Observation shared key for HWP angle"
    )

    hwp_angle_offset = Quantity(
        0 * u.deg, help="HWP angle offset to apply when constructing deflection"
    )

    hwp_deflection_radius = Quantity(
        None,
        allow_none=True,
        help="If non-zero, nominal detector pointing will be deflected in a circular "
        "pattern according to HWP phase.",
    )

    quats = Unicode(
        defaults.quats,
        allow_none=True,
        help="Observation detdata key for output quaternions",
    )

    coord_in = Unicode(
        None,
        allow_none=True,
        help="The input boresight coordinate system ('C', 'E', 'G')",
    )

    coord_out = Unicode(
        None,
        allow_none=True,
        help="The output coordinate system ('C', 'E', 'G')",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("coord_in")
    def _check_coord_in(self, proposal):
        check = proposal["value"]
        if check is not None:
            if check not in ["E", "C", "G"]:
                raise traitlets.TraitError("coordinate system must be 'E', 'C', or 'G'")
        return check

    @traitlets.validate("coord_out")
    def _check_coord_out(self, proposal):
        check = proposal["value"]
        if check is not None:
            if check not in ["E", "C", "G"]:
                raise traitlets.TraitError("coordinate system must be 'E', 'C', or 'G'")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        coord_rot = None
        bore_suffix = ""
        if self.coord_in is None:
            if self.coord_out is not None:
                msg = "Input and output coordinate systems should both be None or valid"
                raise RuntimeError(msg)
        else:
            if self.coord_out is None:
                msg = "Input and output coordinate systems should both be None or valid"
                raise RuntimeError(msg)
            if self.coord_in == "C":
                if self.coord_out == "E":
                    coord_rot = qa.equ2ecl()
                    bore_suffix = "_C2E"
                elif self.coord_out == "G":
                    coord_rot = qa.equ2gal()
                    bore_suffix = "_C2G"
            elif self.coord_in == "E":
                if self.coord_out == "G":
                    coord_rot = qa.ecl2gal()
                    bore_suffix = "_E2G"
                elif self.coord_out == "C":
                    coord_rot = qa.inv(qa.equ2ecl())
                    bore_suffix = "_E2C"
            elif self.coord_in == "G":
                if self.coord_out == "C":
                    coord_rot = qa.inv(qa.equ2gal())
                    bore_suffix = "_G2C"
                if self.coord_out == "E":
                    coord_rot = qa.inv(qa.ecl2gal())
                    bore_suffix = "_G2E"

        # Ensure that we have boresight pointing in the required coordinate
        # frame.  We will potentially re-use this boresight pointing for every
        # iteration of the amplitude solver, so it makes sense to compute and
        # store this.
        bore_name = self.boresight
        if bore_suffix != "":
            bore_name = f"{self.boresight}{bore_suffix}"
            for ob in data.obs:
                if bore_name not in ob.shared:
                    # Does not yet exist, create it
                    ob.shared.create_column(
                        bore_name,
                        ob.shared[self.boresight].shape,
                        ob.shared[self.boresight].dtype,
                    )
                    # First process in each column computes the quaternions
                    bore_rot = None
                    if ob.comm_col_rank == 0:
                        bore_rot = qa.mult(coord_rot, ob.shared[self.boresight].data)
                    ob.shared[bore_name].set(bore_rot, fromrank=0)

        # Ensure that our boresight data is on the right device.  In the case of
        # no coordinate rotation, this would already be done by the outer pipeline.
        for ob in data.obs:
            if use_accel:
                if not ob.shared.accel_in_use(bore_name):
                    # Not currently on the device
                    if not ob.shared.accel_exists(bore_name):
                        # Does not even exist yet on the device
                        ob.shared.accel_create(bore_name)
                    ob.shared.accel_update_device(bore_name)
            else:
                if ob.shared.accel_in_use(bore_name):
                    # Back to host
                    ob.shared.accel_update_host(bore_name)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            exists = ob.detdata.ensure(
                self.quats,
                sample_shape=(4,),
                dtype=np.float64,
                detectors=dets,
                accel=use_accel,
            )

            if exists:
                if data.comm.group_rank == 0:
                    msg = (
                        f"Group {data.comm.group}, ob {ob.name}, det quats "
                        f"already computed for {dets}"
                    )
                    log.verbose(msg)
                continue

            # FIXME:  temporary hack until instrument classes are also pre-staged to GPU
            focalplane = ob.telescope.focalplane
            fp_quats = np.zeros((len(dets), 4), dtype=np.float64)
            for idet, d in enumerate(dets):
                fp_quats[idet, :] = focalplane[d]["quat"]

            quat_indx = ob.detdata[self.quats].indices(dets)

            if self.shared_flags is None:
                flags = np.zeros(1, dtype=np.uint8)
            else:
                flags = ob.shared[self.shared_flags].data

            log.verbose_rank(
                f"Operator {self.name}, observation {ob.name}, use_accel = {use_accel}",
                comm=data.comm.comm_group,
            )

            # Optionally apply HWP deflection.  This is effectively a deflection
            # of the boresight prior to the rotation by the detector quaternion.
            if (
                self.hwp_deflection_radius is not None
                and self.hwp_deflection_radius.value != 0
            ):
                if use_accel:
                    # The data objects are on an accelerator.  Raise an exception
                    # until we can move this code into the kernel.
                    raise NotImplementedError("HWP deflection only works on CPU")
                # Copy node-shared object so that we can modify it.  Starting point
                # is the HWP fast axis.
                deflection_orientation = np.array(ob.shared[self.hwp_angle].data)

                # Apply any phase offset from the fast axis.
                deflection_orientation += self.hwp_angle_offset.to_value(u.rad)

                # The orientation of the deflection is 90 degrees from
                # the axis of rotation going from the boresight to the deflected
                # boresight.
                deflection_orientation += np.pi / 2

                # The rotation axis of the deflection
                deflection_axis = np.zeros(
                    3 * len(deflection_orientation),
                    dtype=np.float64,
                ).reshape((len(deflection_orientation), 3))
                deflection_axis[:, 0] = np.cos(deflection_orientation)
                deflection_axis[:, 1] = np.sin(deflection_orientation)

                # Angle of deflection
                deflection_angle = self.hwp_deflection_radius.to_value(u.radian)

                # Deflection quaternion
                deflection = qa.rotation(
                    deflection_axis,
                    deflection_angle,
                )

                # Apply deflection to the boresight
                boresight = qa.mult(ob.shared[bore_name].data, deflection)
            else:
                boresight = ob.shared[bore_name].data

            pointing_detector(
                fp_quats,
                boresight,
                quat_indx,
                ob.detdata[self.quats].data,
                ob.intervals[self.view].data,
                flags,
                self.shared_flag_mask,
                impl=implementation,
                use_accel=use_accel,
            )

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": [self.boresight],
            "detdata": [self.quats],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.quats],
        }
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        return True

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

boresight = Unicode(defaults.boresight_radec, help='Observation shared key for boresight') class-attribute instance-attribute

coord_in = Unicode(None, allow_none=True, help="The input boresight coordinate system ('C', 'E', 'G')") class-attribute instance-attribute

coord_out = Unicode(None, allow_none=True, help="The output coordinate system ('C', 'E', 'G')") class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

hwp_angle_offset = Quantity(0 * u.deg, help='HWP angle offset to apply when constructing deflection') class-attribute instance-attribute

hwp_deflection_radius = Quantity(None, allow_none=True, help='If non-zero, nominal detector pointing will be deflected in a circular pattern according to HWP phase.') class-attribute instance-attribute

quats = Unicode(defaults.quats, allow_none=True, help='Observation detdata key for output quaternions') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for optional flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/pointing_detector/pointing_detector.py
113
114
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_coord_in(proposal)

Source code in toast/ops/pointing_detector/pointing_detector.py
90
91
92
93
94
95
96
@traitlets.validate("coord_in")
def _check_coord_in(self, proposal):
    check = proposal["value"]
    if check is not None:
        if check not in ["E", "C", "G"]:
            raise traitlets.TraitError("coordinate system must be 'E', 'C', or 'G'")
    return check

_check_coord_out(proposal)

Source code in toast/ops/pointing_detector/pointing_detector.py
 98
 99
100
101
102
103
104
@traitlets.validate("coord_out")
def _check_coord_out(self, proposal):
    check = proposal["value"]
    if check is not None:
        if check not in ["E", "C", "G"]:
            raise traitlets.TraitError("coordinate system must be 'E', 'C', or 'G'")
    return check

_check_det_mask(proposal)

Source code in toast/ops/pointing_detector/pointing_detector.py
83
84
85
86
87
88
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/pointing_detector/pointing_detector.py
106
107
108
109
110
111
@traitlets.validate("shared_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/pointing_detector/pointing_detector.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    coord_rot = None
    bore_suffix = ""
    if self.coord_in is None:
        if self.coord_out is not None:
            msg = "Input and output coordinate systems should both be None or valid"
            raise RuntimeError(msg)
    else:
        if self.coord_out is None:
            msg = "Input and output coordinate systems should both be None or valid"
            raise RuntimeError(msg)
        if self.coord_in == "C":
            if self.coord_out == "E":
                coord_rot = qa.equ2ecl()
                bore_suffix = "_C2E"
            elif self.coord_out == "G":
                coord_rot = qa.equ2gal()
                bore_suffix = "_C2G"
        elif self.coord_in == "E":
            if self.coord_out == "G":
                coord_rot = qa.ecl2gal()
                bore_suffix = "_E2G"
            elif self.coord_out == "C":
                coord_rot = qa.inv(qa.equ2ecl())
                bore_suffix = "_E2C"
        elif self.coord_in == "G":
            if self.coord_out == "C":
                coord_rot = qa.inv(qa.equ2gal())
                bore_suffix = "_G2C"
            if self.coord_out == "E":
                coord_rot = qa.inv(qa.ecl2gal())
                bore_suffix = "_G2E"

    # Ensure that we have boresight pointing in the required coordinate
    # frame.  We will potentially re-use this boresight pointing for every
    # iteration of the amplitude solver, so it makes sense to compute and
    # store this.
    bore_name = self.boresight
    if bore_suffix != "":
        bore_name = f"{self.boresight}{bore_suffix}"
        for ob in data.obs:
            if bore_name not in ob.shared:
                # Does not yet exist, create it
                ob.shared.create_column(
                    bore_name,
                    ob.shared[self.boresight].shape,
                    ob.shared[self.boresight].dtype,
                )
                # First process in each column computes the quaternions
                bore_rot = None
                if ob.comm_col_rank == 0:
                    bore_rot = qa.mult(coord_rot, ob.shared[self.boresight].data)
                ob.shared[bore_name].set(bore_rot, fromrank=0)

    # Ensure that our boresight data is on the right device.  In the case of
    # no coordinate rotation, this would already be done by the outer pipeline.
    for ob in data.obs:
        if use_accel:
            if not ob.shared.accel_in_use(bore_name):
                # Not currently on the device
                if not ob.shared.accel_exists(bore_name):
                    # Does not even exist yet on the device
                    ob.shared.accel_create(bore_name)
                ob.shared.accel_update_device(bore_name)
        else:
            if ob.shared.accel_in_use(bore_name):
                # Back to host
                ob.shared.accel_update_host(bore_name)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        exists = ob.detdata.ensure(
            self.quats,
            sample_shape=(4,),
            dtype=np.float64,
            detectors=dets,
            accel=use_accel,
        )

        if exists:
            if data.comm.group_rank == 0:
                msg = (
                    f"Group {data.comm.group}, ob {ob.name}, det quats "
                    f"already computed for {dets}"
                )
                log.verbose(msg)
            continue

        # FIXME:  temporary hack until instrument classes are also pre-staged to GPU
        focalplane = ob.telescope.focalplane
        fp_quats = np.zeros((len(dets), 4), dtype=np.float64)
        for idet, d in enumerate(dets):
            fp_quats[idet, :] = focalplane[d]["quat"]

        quat_indx = ob.detdata[self.quats].indices(dets)

        if self.shared_flags is None:
            flags = np.zeros(1, dtype=np.uint8)
        else:
            flags = ob.shared[self.shared_flags].data

        log.verbose_rank(
            f"Operator {self.name}, observation {ob.name}, use_accel = {use_accel}",
            comm=data.comm.comm_group,
        )

        # Optionally apply HWP deflection.  This is effectively a deflection
        # of the boresight prior to the rotation by the detector quaternion.
        if (
            self.hwp_deflection_radius is not None
            and self.hwp_deflection_radius.value != 0
        ):
            if use_accel:
                # The data objects are on an accelerator.  Raise an exception
                # until we can move this code into the kernel.
                raise NotImplementedError("HWP deflection only works on CPU")
            # Copy node-shared object so that we can modify it.  Starting point
            # is the HWP fast axis.
            deflection_orientation = np.array(ob.shared[self.hwp_angle].data)

            # Apply any phase offset from the fast axis.
            deflection_orientation += self.hwp_angle_offset.to_value(u.rad)

            # The orientation of the deflection is 90 degrees from
            # the axis of rotation going from the boresight to the deflected
            # boresight.
            deflection_orientation += np.pi / 2

            # The rotation axis of the deflection
            deflection_axis = np.zeros(
                3 * len(deflection_orientation),
                dtype=np.float64,
            ).reshape((len(deflection_orientation), 3))
            deflection_axis[:, 0] = np.cos(deflection_orientation)
            deflection_axis[:, 1] = np.sin(deflection_orientation)

            # Angle of deflection
            deflection_angle = self.hwp_deflection_radius.to_value(u.radian)

            # Deflection quaternion
            deflection = qa.rotation(
                deflection_axis,
                deflection_angle,
            )

            # Apply deflection to the boresight
            boresight = qa.mult(ob.shared[bore_name].data, deflection)
        else:
            boresight = ob.shared[bore_name].data

        pointing_detector(
            fp_quats,
            boresight,
            quat_indx,
            ob.detdata[self.quats].data,
            ob.intervals[self.view].data,
            flags,
            self.shared_flag_mask,
            impl=implementation,
            use_accel=use_accel,
        )

    return

_finalize(data, **kwargs)

Source code in toast/ops/pointing_detector/pointing_detector.py
291
292
def _finalize(self, data, **kwargs):
    return

_implementations()

Source code in toast/ops/pointing_detector/pointing_detector.py
315
316
317
318
319
320
321
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/pointing_detector/pointing_detector.py
307
308
309
310
311
312
313
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.quats],
    }
    return prov

_requires()

Source code in toast/ops/pointing_detector/pointing_detector.py
294
295
296
297
298
299
300
301
302
303
304
305
def _requires(self):
    req = {
        "meta": list(),
        "shared": [self.boresight],
        "detdata": [self.quats],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/pointing_detector/pointing_detector.py
323
324
def _supports_accel(self):
    return True

toast.ops.PixelsHealpix

Bases: Operator

Operator which generates healpix pixel numbers.

If the view trait is not specified, then this operator will use the same data view as the detector pointing operator when computing the pointing matrix pixels.

Any samples with "bad" pointing should have already been set to a "safe" quaternion value by the detector pointing operator. We use the same shared flags as the detector pointing operator.

Source code in toast/ops/pixels_healpix/pixels_healpix.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@trait_docs
class PixelsHealpix(Operator):
    """Operator which generates healpix pixel numbers.

    If the view trait is not specified, then this operator will use the same data
    view as the detector pointing operator when computing the pointing matrix pixels.

    Any samples with "bad" pointing should have already been set to a "safe" quaternion
    value by the detector pointing operator.  We use the same shared flags as the
    detector pointing operator.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    detector_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="Operator that translates boresight pointing into detector frame",
    )

    nside = Int(64, help="The NSIDE resolution")

    nside_submap = Int(16, help="The NSIDE of the submap resolution")

    nest = Bool(True, help="If True, use NESTED ordering instead of RING")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    pixels = Unicode(
        defaults.pixels, help="Observation detdata key for output pixel indices"
    )

    create_dist = Unicode(
        None,
        allow_none=True,
        help="Create the submap distribution for all detectors and store in the "
        "Data key specified",
    )

    single_precision = Bool(False, help="If True, use 32bit int in output")

    @traitlets.validate("detector_pointing")
    def _check_detector_pointing(self, proposal):
        detpointing = proposal["value"]
        if detpointing is not None:
            if not isinstance(detpointing, Operator):
                raise traitlets.TraitError(
                    "detector_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in [
                "view",
                "boresight",
                "shared_flags",
                "shared_flag_mask",
                "det_mask",
                "quats",
                "coord_in",
                "coord_out",
            ]:
                if not detpointing.has_trait(trt):
                    msg = f"detector_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return detpointing

    @traitlets.validate("nside")
    def _check_nside(self, proposal):
        check = proposal["value"]
        if ~check & (check - 1) != check - 1:
            raise traitlets.TraitError("Invalid NSIDE value")
        if check < self.nside_submap:
            raise traitlets.TraitError("NSIDE value is less than nside_submap")
        return check

    @traitlets.validate("nside_submap")
    def _check_nside_submap(self, proposal):
        check = proposal["value"]
        if ~check & (check - 1) != check - 1:
            raise traitlets.TraitError("Invalid NSIDE submap value")
        if check > self.nside:
            newval = 16
            if newval > self.nside:
                newval = 1
            log = Logger.get()
            log.warning(
                "nside_submap greater than NSIDE.  Setting to {} instead".format(newval)
            )
            check = newval
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Check that healpix pixels are set up.  If the nside is left as
        # default, then the 'observe' function will not have baen called yet.
        if not hasattr(self, "_local_submaps"):
            self._set_hpix(self.nside, self.nside_submap)

    @traitlets.observe("nside", "nside_submap")
    def _reset_hpix(self, change):
        # (Re-)initialize the healpix pixels object when one of these traits change.
        # Current values:
        nside = self.nside
        nside_submap = self.nside_submap
        # Update to the trait that changed
        if change["name"] == "nside":
            nside = change["new"]
        if change["name"] == "nside_submap":
            nside_submap = change["new"]
        self._set_hpix(nside, nside_submap)

    def _set_hpix(self, nside, nside_submap):
        self._n_pix = 12 * nside**2
        self._n_pix_submap = 12 * nside_submap**2
        self._n_submap = (nside // nside_submap) ** 2
        self._local_submaps = None

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.detector_pointing is None:
            raise RuntimeError("The detector_pointing trait must be set")

        if self._local_submaps is None and self.create_dist is not None:
            self._local_submaps = np.zeros(self._n_submap, dtype=np.uint8)

        # Expand detector pointing
        quats_name = self.detector_pointing.quats

        view = self.view
        if view is None:
            # Use the same data view as detector pointing
            view = self.detector_pointing.view

        # Expand detector pointing
        self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                detectors, flagmask=self.detector_pointing.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # Check that our view is fully covered by detector pointing.  If the
            # detector_pointing view is None, then it has all samples.  If our own
            # view was None, then it would have been set to the detector_pointing
            # view above.
            if (view is not None) and (self.detector_pointing.view is not None):
                if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                    # We need to check intersection
                    intervals = ob.intervals[self.view]
                    detector_intervals = ob.intervals[self.detector_pointing.view]
                    intersection = detector_intervals & intervals
                    if intersection != intervals:
                        msg = (
                            f"view {self.view} is not fully covered by valid "
                            "detector pointing"
                        )
                        raise RuntimeError(msg)

            # Create (or re-use) output data for the pixels.
            if self.single_precision:
                exists = ob.detdata.ensure(
                    self.pixels,
                    sample_shape=(),
                    dtype=np.int32,
                    detectors=dets,
                    accel=use_accel,
                )
            else:
                exists = ob.detdata.ensure(
                    self.pixels,
                    sample_shape=(),
                    dtype=np.int64,
                    detectors=dets,
                    accel=use_accel,
                )

            hit_submaps = self._local_submaps
            if hit_submaps is None:
                hit_submaps = np.zeros(self._n_submap, dtype=np.uint8)

            quat_indx = ob.detdata[quats_name].indices(dets)
            pix_indx = ob.detdata[self.pixels].indices(dets)

            view_slices = [slice(x.first, x.last, 1) for x in ob.intervals[view]]

            # Do we already have pointing for all requested detectors?
            if exists:
                # Yes...
                if self.create_dist is not None:
                    # but the caller wants the pixel distribution
                    restore_dev = False
                    if ob.detdata[self.pixels].accel_in_use():
                        # The data is on the accelerator- copy back to host for
                        # this calculation.  This could eventually be a kernel.
                        ob.detdata[self.pixels].accel_update_host()
                        restore_dev = True
                    for det in ob.select_local_detectors(
                        detectors, flagmask=self.detector_pointing.det_mask
                    ):
                        for vslice in view_slices:
                            good = ob.detdata[self.pixels][det, vslice] >= 0
                            self._local_submaps[
                                ob.detdata[self.pixels][det, vslice][good]
                                // self._n_pix_submap
                            ] = 1
                    if restore_dev:
                        ob.detdata[self.pixels].accel_update_device()

                if data.comm.group_rank == 0:
                    msg = (
                        f"Group {data.comm.group}, ob {ob.name}, healpix pixels "
                        f"already computed for {dets}"
                    )
                    log.verbose(msg)
                continue

            # Get the flags if needed.  Use the same flags as
            # detector pointing.
            if self.detector_pointing.shared_flags is None:
                flags = np.zeros(1, dtype=np.uint8)
            else:
                flags = ob.shared[self.detector_pointing.shared_flags].data

            pixels_healpix(
                quat_indx,
                ob.detdata[quats_name].data,
                flags,
                self.detector_pointing.shared_flag_mask,
                pix_indx,
                ob.detdata[self.pixels].data,
                ob.intervals[self.view].data,
                hit_submaps,
                self._n_pix_submap,
                self.nside,
                self.nest,
                impl=implementation,
                use_accel=use_accel,
            )

            if self._local_submaps is not None:
                self._local_submaps[:] |= hit_submaps

        return

    def _finalize(self, data, use_accel=None, **kwargs):
        if self.create_dist is not None:
            submaps = None
            if self.single_precision:
                submaps = np.arange(self._n_submap, dtype=np.int32)[
                    self._local_submaps == 1
                ]
            else:
                submaps = np.arange(self._n_submap, dtype=np.int64)[
                    self._local_submaps == 1
                ]

            data[self.create_dist] = PixelDistribution(
                n_pix=self._n_pix,
                n_submap=self._n_submap,
                local_submaps=submaps,
                comm=data.comm.comm_world,
            )
        return

    def _requires(self):
        req = self.detector_pointing.requires()
        if "detdata" not in req:
            req["detdata"] = list()
        req["detdata"].append(self.pixels)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = self.detector_pointing.provides()
        prov["detdata"].append(self.pixels)
        if self.create_dist is not None:
            prov["global"].append(self.create_dist)
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        if (self.detector_pointing is not None) and (
            self.detector_pointing.supports_accel()
        ):
            return True
        else:
            return False

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

create_dist = Unicode(None, allow_none=True, help='Create the submap distribution for all detectors and store in the Data key specified') class-attribute instance-attribute

detector_pointing = Instance(klass=Operator, allow_none=True, help='Operator that translates boresight pointing into detector frame') class-attribute instance-attribute

nest = Bool(True, help='If True, use NESTED ordering instead of RING') class-attribute instance-attribute

nside = Int(64, help='The NSIDE resolution') class-attribute instance-attribute

nside_submap = Int(16, help='The NSIDE of the submap resolution') class-attribute instance-attribute

pixels = Unicode(defaults.pixels, help='Observation detdata key for output pixel indices') class-attribute instance-attribute

single_precision = Bool(False, help='If True, use 32bit int in output') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
113
114
115
116
117
118
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    # Check that healpix pixels are set up.  If the nside is left as
    # default, then the 'observe' function will not have baen called yet.
    if not hasattr(self, "_local_submaps"):
        self._set_hpix(self.nside, self.nside_submap)

_check_detector_pointing(proposal)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
    detpointing = proposal["value"]
    if detpointing is not None:
        if not isinstance(detpointing, Operator):
            raise traitlets.TraitError(
                "detector_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in [
            "view",
            "boresight",
            "shared_flags",
            "shared_flag_mask",
            "det_mask",
            "quats",
            "coord_in",
            "coord_out",
        ]:
            if not detpointing.has_trait(trt):
                msg = f"detector_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return detpointing

_check_nside(proposal)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
88
89
90
91
92
93
94
95
@traitlets.validate("nside")
def _check_nside(self, proposal):
    check = proposal["value"]
    if ~check & (check - 1) != check - 1:
        raise traitlets.TraitError("Invalid NSIDE value")
    if check < self.nside_submap:
        raise traitlets.TraitError("NSIDE value is less than nside_submap")
    return check

_check_nside_submap(proposal)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@traitlets.validate("nside_submap")
def _check_nside_submap(self, proposal):
    check = proposal["value"]
    if ~check & (check - 1) != check - 1:
        raise traitlets.TraitError("Invalid NSIDE submap value")
    if check > self.nside:
        newval = 16
        if newval > self.nside:
            newval = 1
        log = Logger.get()
        log.warning(
            "nside_submap greater than NSIDE.  Setting to {} instead".format(newval)
        )
        check = newval
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.detector_pointing is None:
        raise RuntimeError("The detector_pointing trait must be set")

    if self._local_submaps is None and self.create_dist is not None:
        self._local_submaps = np.zeros(self._n_submap, dtype=np.uint8)

    # Expand detector pointing
    quats_name = self.detector_pointing.quats

    view = self.view
    if view is None:
        # Use the same data view as detector pointing
        view = self.detector_pointing.view

    # Expand detector pointing
    self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            detectors, flagmask=self.detector_pointing.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # Check that our view is fully covered by detector pointing.  If the
        # detector_pointing view is None, then it has all samples.  If our own
        # view was None, then it would have been set to the detector_pointing
        # view above.
        if (view is not None) and (self.detector_pointing.view is not None):
            if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                # We need to check intersection
                intervals = ob.intervals[self.view]
                detector_intervals = ob.intervals[self.detector_pointing.view]
                intersection = detector_intervals & intervals
                if intersection != intervals:
                    msg = (
                        f"view {self.view} is not fully covered by valid "
                        "detector pointing"
                    )
                    raise RuntimeError(msg)

        # Create (or re-use) output data for the pixels.
        if self.single_precision:
            exists = ob.detdata.ensure(
                self.pixels,
                sample_shape=(),
                dtype=np.int32,
                detectors=dets,
                accel=use_accel,
            )
        else:
            exists = ob.detdata.ensure(
                self.pixels,
                sample_shape=(),
                dtype=np.int64,
                detectors=dets,
                accel=use_accel,
            )

        hit_submaps = self._local_submaps
        if hit_submaps is None:
            hit_submaps = np.zeros(self._n_submap, dtype=np.uint8)

        quat_indx = ob.detdata[quats_name].indices(dets)
        pix_indx = ob.detdata[self.pixels].indices(dets)

        view_slices = [slice(x.first, x.last, 1) for x in ob.intervals[view]]

        # Do we already have pointing for all requested detectors?
        if exists:
            # Yes...
            if self.create_dist is not None:
                # but the caller wants the pixel distribution
                restore_dev = False
                if ob.detdata[self.pixels].accel_in_use():
                    # The data is on the accelerator- copy back to host for
                    # this calculation.  This could eventually be a kernel.
                    ob.detdata[self.pixels].accel_update_host()
                    restore_dev = True
                for det in ob.select_local_detectors(
                    detectors, flagmask=self.detector_pointing.det_mask
                ):
                    for vslice in view_slices:
                        good = ob.detdata[self.pixels][det, vslice] >= 0
                        self._local_submaps[
                            ob.detdata[self.pixels][det, vslice][good]
                            // self._n_pix_submap
                        ] = 1
                if restore_dev:
                    ob.detdata[self.pixels].accel_update_device()

            if data.comm.group_rank == 0:
                msg = (
                    f"Group {data.comm.group}, ob {ob.name}, healpix pixels "
                    f"already computed for {dets}"
                )
                log.verbose(msg)
            continue

        # Get the flags if needed.  Use the same flags as
        # detector pointing.
        if self.detector_pointing.shared_flags is None:
            flags = np.zeros(1, dtype=np.uint8)
        else:
            flags = ob.shared[self.detector_pointing.shared_flags].data

        pixels_healpix(
            quat_indx,
            ob.detdata[quats_name].data,
            flags,
            self.detector_pointing.shared_flag_mask,
            pix_indx,
            ob.detdata[self.pixels].data,
            ob.intervals[self.view].data,
            hit_submaps,
            self._n_pix_submap,
            self.nside,
            self.nest,
            impl=implementation,
            use_accel=use_accel,
        )

        if self._local_submaps is not None:
            self._local_submaps[:] |= hit_submaps

    return

_finalize(data, use_accel=None, **kwargs)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def _finalize(self, data, use_accel=None, **kwargs):
    if self.create_dist is not None:
        submaps = None
        if self.single_precision:
            submaps = np.arange(self._n_submap, dtype=np.int32)[
                self._local_submaps == 1
            ]
        else:
            submaps = np.arange(self._n_submap, dtype=np.int64)[
                self._local_submaps == 1
            ]

        data[self.create_dist] = PixelDistribution(
            n_pix=self._n_pix,
            n_submap=self._n_submap,
            local_submaps=submaps,
            comm=data.comm.comm_world,
        )
    return

_implementations()

Source code in toast/ops/pixels_healpix/pixels_healpix.py
312
313
314
315
316
317
318
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/pixels_healpix/pixels_healpix.py
305
306
307
308
309
310
def _provides(self):
    prov = self.detector_pointing.provides()
    prov["detdata"].append(self.pixels)
    if self.create_dist is not None:
        prov["global"].append(self.create_dist)
    return prov

_requires()

Source code in toast/ops/pixels_healpix/pixels_healpix.py
296
297
298
299
300
301
302
303
def _requires(self):
    req = self.detector_pointing.requires()
    if "detdata" not in req:
        req["detdata"] = list()
    req["detdata"].append(self.pixels)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_reset_hpix(change)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
120
121
122
123
124
125
126
127
128
129
130
131
@traitlets.observe("nside", "nside_submap")
def _reset_hpix(self, change):
    # (Re-)initialize the healpix pixels object when one of these traits change.
    # Current values:
    nside = self.nside
    nside_submap = self.nside_submap
    # Update to the trait that changed
    if change["name"] == "nside":
        nside = change["new"]
    if change["name"] == "nside_submap":
        nside_submap = change["new"]
    self._set_hpix(nside, nside_submap)

_set_hpix(nside, nside_submap)

Source code in toast/ops/pixels_healpix/pixels_healpix.py
133
134
135
136
137
def _set_hpix(self, nside, nside_submap):
    self._n_pix = 12 * nside**2
    self._n_pix_submap = 12 * nside_submap**2
    self._n_submap = (nside // nside_submap) ** 2
    self._local_submaps = None

_supports_accel()

Source code in toast/ops/pixels_healpix/pixels_healpix.py
320
321
322
323
324
325
326
def _supports_accel(self):
    if (self.detector_pointing is not None) and (
        self.detector_pointing.supports_accel()
    ):
        return True
    else:
        return False

toast.ops.PixelsWCS

Bases: Operator

Operator which generates detector pixel indices defined on a flat projection.

When placing the projection on the sky, either the center or bounds traits must be specified, but not both.

When determining the pixel density in the projection, exactly two traits from the set of bounds, resolution and dimensions must be specified.

If the view trait is not specified, then this operator will use the same data view as the detector pointing operator when computing the pointing matrix pixels.

This uses the astropy wcs utilities to build the projection parameters. Eventually this operator will use internal kernels for the projection unless use_astropy is set to True.

Source code in toast/ops/pixels_wcs.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
@trait_docs
class PixelsWCS(Operator):
    """Operator which generates detector pixel indices defined on a flat projection.

    When placing the projection on the sky, either the `center` or `bounds`
    traits must be specified, but not both.

    When determining the pixel density in the projection, exactly two traits from the
    set of `bounds`, `resolution` and `dimensions` must be specified.

    If the view trait is not specified, then this operator will use the same data
    view as the detector pointing operator when computing the pointing matrix pixels.

    This uses the astropy wcs utilities to build the projection parameters.  Eventually
    this operator will use internal kernels for the projection unless `use_astropy`
    is set to True.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    detector_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="Operator that translates boresight pointing into detector frame",
    )

    fits_header = Unicode(
        None,
        allow_none=True,
        help="FITS file containing header to use with pre-existing WCS parameters",
    )

    coord_frame = Unicode("EQU", help="Supported values are AZEL, EQU, GAL, ECL")

    projection = Unicode(
        "CAR", help="Supported values are CAR, CEA, MER, ZEA, TAN, SFL"
    )

    center = Tuple(
        tuple(),
        help="The center Lon/Lat coordinates (Quantities) of the projection",
    )

    center_offset = Unicode(
        None,
        allow_none=True,
        help="Optional name of shared field with lon, lat offset in degrees",
    )

    bounds = Tuple(
        tuple(),
        help="The (lon_min, lon_max, lat_min, lat_max) values (Quantities)",
    )

    auto_bounds = Bool(
        True,
        help="If True, set the bounding box based on boresight and field of view",
    )

    dimensions = Tuple(
        (1000, 1000),
        help="The Lon/Lat pixel dimensions of the projection",
    )

    resolution = Tuple(
        tuple(),
        help="The Lon/Lat projection resolution (Quantities) along the 2 axes",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    pixels = Unicode("pixels", help="Observation detdata key for output pixel indices")

    submaps = Int(1, help="Number of submaps to use")

    create_dist = Unicode(
        None,
        allow_none=True,
        help="Create the submap distribution for all detectors and store in the Data key specified",
    )

    single_precision = Bool(False, help="If True, use 32bit int in output")

    use_astropy = Bool(True, help="If True, use astropy for world to pix conversion")

    @traitlets.validate("detector_pointing")
    def _check_detector_pointing(self, proposal):
        detpointing = proposal["value"]
        if detpointing is not None:
            if not isinstance(detpointing, Operator):
                raise traitlets.TraitError(
                    "detector_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in [
                "view",
                "boresight",
                "shared_flags",
                "shared_flag_mask",
                "quats",
                "coord_in",
                "coord_out",
            ]:
                if not detpointing.has_trait(trt):
                    msg = f"detector_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return detpointing

    @traitlets.validate("wcs_projection")
    def _check_wcs_projection(self, proposal):
        check = proposal["value"]
        if check not in ["CAR", "CEA", "MER", "ZEA", "TAN", "SFL"]:
            raise traitlets.TraitError("Invalid WCS projection name")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Track whether we need to recompute autobounds
        self._done_auto = False
        # Track whether we need to recompute the WCS projection
        self._done_wcs = False

    @traitlets.observe("auto_bounds")
    def _reset_auto_bounds(self, change):
        # Track whether we need to recompute the bounds.
        old_val = change["old"]
        new_val = change["new"]
        if new_val != old_val:
            self._done_auto = False
            self._done_wcs = False

    @traitlets.observe("center_offset")
    def _reset_auto_center(self, change):
        old_val = change["old"]
        new_val = change["new"]
        # Track whether we need to recompute the projection
        if new_val != old_val:
            self._done_wcs = False
            self._done_auto = False

    @traitlets.observe("projection", "center", "bounds", "dimensions", "resolution")
    def _reset_wcs(self, change):
        # (Re-)initialize the WCS projection when one of these traits change.
        old_val = change["old"]
        new_val = change["new"]
        if old_val != new_val:
            self._done_wcs = False
            self._done_auto = False

    @classmethod
    def create_wcs(
        cls,
        coord="EQU",
        proj="CAR",
        center_deg=None,
        bounds_deg=None,
        res_deg=None,
        dims=None,
    ):
        """Create a WCS object given projection parameters.

        Either the `center_deg` or `bounds_deg` parameters must be specified,
        but not both.

        When determining the pixel density in the projection, exactly two
        parameters from the set of `bounds_deg`, `res_deg` and `dims` must be
        specified.

        Args:
            coord (str):  The coordinate frame name.
            proj (str):  The projection type.
            center_deg (tuple):  The (lon, lat) projection center in degrees.
            bounds_deg (tuple):  The (lon_min, lon_max, lat_min, lat_max)
                values in degrees.
            res_deg (tuple):  The (lon, lat) resolution in degrees.
            dims (tuple):  The (lon, lat) projection size in pixels.

        Returns:
            (WCS, shape): The instantiated WCS object and final shape.

        """
        log = Logger.get()

        # Compute projection center
        if center_deg is not None:
            # We are specifying the center.  Bounds should not be set and we should
            # have both resolution and dimensions
            if bounds_deg is not None:
                msg = f"PixelsWCS: only one of center and bounds should be set."
                log.error(msg)
                raise RuntimeError(msg)
            if res_deg is None or dims is None:
                msg = f"PixelsWCS: when center is set, both resolution and dimensions"
                msg += f" are required."
                log.error(msg)
                raise RuntimeError(msg)
            crval = np.array(center_deg, dtype=np.float64)
        else:
            # Not using center, bounds is required
            if bounds_deg is None:
                msg = f"PixelsWCS: when center is not specified, bounds required."
                log.error(msg)
                raise RuntimeError(msg)
            mid_lon = 0.5 * (bounds_deg[1] + bounds_deg[0])
            mid_lat = 0.5 * (bounds_deg[3] + bounds_deg[2])
            crval = np.array([mid_lon, mid_lat], dtype=np.float64)
            # Either resolution or dimensions should be specified
            if res_deg is not None:
                # Using resolution
                if dims is not None:
                    msg = f"PixelsWCS: when using bounds, only one of resolution or"
                    msg += f" dimensions must be specified."
                    log.error(msg)
                    raise RuntimeError(msg)
            else:
                # Using dimensions
                if res_deg is not None:
                    msg = f"PixelsWCS: when using bounds, only one of resolution or"
                    msg += f" dimensions must be specified."
                    log.error(msg)
                    raise RuntimeError(msg)

        # Create the WCS object.
        # CTYPE1 = Longitude
        # CTYPE2 = Latitude
        wcs = WCS(naxis=2)

        if coord == "AZEL":
            # For local Azimuth and Elevation coordinate frame, we
            # use the generic longitude and latitude string.
            coordstr = ("TLON", "TLAT")
        elif coord == "EQU":
            coordstr = ("RA--", "DEC-")
        elif coord == "GAL":
            coordstr = ("GLON", "GLAT")
        elif coord == "ECL":
            coordstr = ("ELON", "ELAT")
        else:
            msg = f"Unsupported coordinate frame '{coord}'"
            raise RuntimeError(msg)

        if proj == "CAR":
            wcs.wcs.ctype = [f"{coordstr[0]}-CAR", f"{coordstr[1]}-CAR"]
            wcs.wcs.crval = crval
        elif proj == "CEA":
            wcs.wcs.ctype = [f"{coordstr[0]}-CEA", f"{coordstr[1]}-CEA"]
            wcs.wcs.crval = crval
            lam = np.cos(np.deg2rad(crval[1])) ** 2
            wcs.wcs.set_pv([(2, 1, lam)])
        elif proj == "MER":
            wcs.wcs.ctype = [f"{coordstr[0]}-MER", f"{coordstr[1]}-MER"]
            wcs.wcs.crval = crval
        elif proj == "ZEA":
            wcs.wcs.ctype = [f"{coordstr[0]}-ZEA", f"{coordstr[1]}-ZEA"]
            wcs.wcs.crval = crval
        elif proj == "TAN":
            wcs.wcs.ctype = [f"{coordstr[0]}-TAN", f"{coordstr[1]}-TAN"]
            wcs.wcs.crval = crval
        elif proj == "SFL":
            wcs.wcs.ctype = [f"{coordstr[0]}-SFL", f"{coordstr[1]}-SFL"]
            wcs.wcs.crval = crval
        else:
            msg = f"Invalid WCS projection name '{proj}'"
            raise ValueError(msg)

        # Compute resolution.  Note that we negate the longitudinal
        # coordinate so that the resulting projections match expectations
        # for plotting, etc.
        if center_deg is not None:
            wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]])
        else:
            if res_deg is not None:
                wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]])
            else:
                # Compute CDELT from the bounding box and image size.
                wcs.wcs.cdelt = np.array(
                    [
                        -(bounds_deg[1] - bounds_deg[0]) / dims[0],
                        (bounds_deg[3] - bounds_deg[2]) / dims[1],
                    ]
                )

        # Compute shape of the projection
        if dims is not None:
            wcs_shape = tuple(dims)
        else:
            # Compute from the bounding box corners
            lower_left = wcs.wcs_world2pix(
                np.array([[bounds_deg[0], bounds_deg[2]]]), 0
            )[0]
            upper_right = wcs.wcs_world2pix(
                np.array([[bounds_deg[1], bounds_deg[3]]]), 0
            )[0]
            wcs_shape = tuple(np.round(np.abs(upper_right - lower_left)).astype(int))

        # Set the reference pixel to the center of the projection
        off = wcs.wcs_world2pix(crval.reshape((1, 2)), 0)[0]
        wcs.wcs.crpix = 0.5 * np.array(wcs_shape, dtype=np.float64) + 0.5 + off

        return wcs, wcs_shape

    def set_wcs(self):
        if self._done_wcs:
            return

        log = Logger.get()
        msg = f"PixelsWCS: set_wcs coord={self.coord_frame}, "
        msg += f"proj={self.projection}, center={self.center}, bounds={self.bounds}"
        msg += f", dims={self.dimensions}, res={self.resolution}"
        log.verbose(msg)

        center_deg = None
        if len(self.center) > 0:
            if self.center_offset is None:
                center_deg = (
                    self.center[0].to_value(u.degree),
                    self.center[1].to_value(u.degree),
                )
            else:
                center_deg = (0.0, 0.0)
        bounds_deg = None
        if len(self.bounds) > 0:
            bounds_deg = tuple([x.to_value(u.degree) for x in self.bounds])
        res_deg = None
        if len(self.resolution) > 0:
            res_deg = tuple([x.to_value(u.degree) for x in self.resolution])
        if len(self.dimensions) > 0:
            dims = tuple(self.dimensions)
        else:
            dims = None

        self.wcs, self.wcs_shape = self.create_wcs(
            coord=self.coord_frame,
            proj=self.projection,
            center_deg=center_deg,
            bounds_deg=bounds_deg,
            res_deg=res_deg,
            dims=dims,
        )

        self.pix_lon = self.wcs_shape[0]
        self.pix_lat = self.wcs_shape[1]
        self._n_pix = self.pix_lon * self.pix_lat
        self._n_pix_submap = self._n_pix // self.submaps
        if self._n_pix_submap * self.submaps < self._n_pix:
            self._n_pix_submap += 1
        self._local_submaps = np.zeros(self.submaps, dtype=np.uint8)
        self._done_wcs = True
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        if self.detector_pointing is None:
            raise RuntimeError("The detector_pointing trait must be set")

        if not self.use_astropy:
            raise NotImplementedError("Only astropy conversion is currently supported")

        if self.fits_header is not None:
            # with open(self.fits_header, "rb") as f:
            #     header = af.Header.fromfile(f)
            raise NotImplementedError(
                "Initialization from a FITS header not yet finished"
            )

        if self.coord_frame == "AZEL":
            is_azimuth = True
        else:
            is_azimuth = False

        if self.auto_bounds and not self._done_auto:
            # Pass through the boresight pointing for every observation and build
            # the maximum extent of the detector field of view.
            lonmax = -2 * np.pi * u.radian
            lonmin = 2 * np.pi * u.radian
            latmax = (-np.pi / 2) * u.radian
            latmin = (np.pi / 2) * u.radian
            for ob in data.obs:
                # The scan range is computed collectively among the group.
                lnmin, lnmax, ltmin, ltmax = scan_range_lonlat(
                    ob,
                    self.detector_pointing.boresight,
                    flags=self.detector_pointing.shared_flags,
                    flag_mask=self.detector_pointing.shared_flag_mask,
                    field_of_view=None,
                    is_azimuth=is_azimuth,
                    center_offset=self.center_offset,
                )
                lonmin = min(lonmin, lnmin)
                lonmax = max(lonmax, lnmax)
                latmin = min(latmin, ltmin)
                latmax = max(latmax, ltmax)
            if data.comm.comm_world is not None:
                lonlatmin = np.zeros(2, dtype=np.float64)
                lonlatmax = np.zeros(2, dtype=np.float64)
                lonlatmin[0] = lonmin.to_value(u.radian)
                lonlatmin[1] = latmin.to_value(u.radian)
                lonlatmax[0] = lonmax.to_value(u.radian)
                lonlatmax[1] = latmax.to_value(u.radian)
                all_lonlatmin = np.zeros(2, dtype=np.float64)
                all_lonlatmax = np.zeros(2, dtype=np.float64)
                data.comm.comm_world.Allreduce(lonlatmin, all_lonlatmin, op=MPI.MIN)
                data.comm.comm_world.Allreduce(lonlatmax, all_lonlatmax, op=MPI.MAX)
                lonmin = all_lonlatmin[0] * u.radian
                latmin = all_lonlatmin[1] * u.radian
                lonmax = all_lonlatmax[0] * u.radian
                latmax = all_lonlatmax[1] * u.radian
            self.bounds = (
                lonmin.to(u.degree),
                lonmax.to(u.degree),
                latmin.to(u.degree),
                latmax.to(u.degree),
            )
            log.verbose(f"PixelsWCS: auto_bounds set to {self.bounds}")
            self._done_auto = True

        # Compute the projection if needed
        self.set_wcs()

        # Expand detector pointing
        quats_name = self.detector_pointing.quats

        view = self.view
        if view is None:
            # Use the same data view as detector pointing
            view = self.detector_pointing.view

        # Once this supports accelerator, pass that instead of False
        self.detector_pointing.apply(data, detectors=detectors, use_accel=False)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                detectors, flagmask=self.detector_pointing.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # Check that our view is fully covered by detector pointing.  If the
            # detector_pointing view is None, then it has all samples.  If our own
            # view was None, then it would have been set to the detector_pointing
            # view above.
            if (view is not None) and (self.detector_pointing.view is not None):
                if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                    # We need to check intersection
                    intervals = ob.intervals[self.view]
                    detector_intervals = ob.intervals[self.detector_pointing.view]
                    intersection = detector_intervals & intervals
                    if intersection != intervals:
                        msg = (
                            f"view {self.view} is not fully covered by valid "
                            "detector pointing"
                        )
                        raise RuntimeError(msg)

            # Create (or re-use) output data for the pixels, weights and optionally the
            # detector quaternions.

            if self.single_precision:
                exists = ob.detdata.ensure(
                    self.pixels, sample_shape=(), dtype=np.int32, detectors=dets
                )
            else:
                exists = ob.detdata.ensure(
                    self.pixels, sample_shape=(), dtype=np.int64, detectors=dets
                )

            view_slices = [slice(x.first, x.last, 1) for x in ob.intervals[view]]

            # Do we already have pointing for all requested detectors?
            if exists:
                # Yes...
                if self.create_dist is not None:
                    # but the caller wants the pixel distribution
                    restore_dev = False
                    if ob.detdata[self.pixels].accel_in_use():
                        # The data is on the accelerator- copy back to host for
                        # this calculation.  This could eventually be a kernel.
                        ob.detdata[self.pixels].accel_update_host()
                        restore_dev = True
                    for det in dets:
                        for vslice in view_slices:
                            good = ob.detdata[self.pixels][det, vslice] >= 0
                            self._local_submaps[
                                ob.detdata[self.pixels][det, vslice][good]
                                // self._n_pix_submap
                            ] = 1
                    if restore_dev:
                        ob.detdata[self.pixels].accel_update_device()

                if data.comm.group_rank == 0:
                    msg = (
                        f"Group {data.comm.group}, ob {ob.name}, WCS pixels "
                        f"already computed for {dets}"
                    )
                    log.verbose(msg)
                continue

            # Focalplane for this observation
            focalplane = ob.telescope.focalplane

            # Get the flags if needed.  Use the same flags as
            # detector pointing.
            flags = None
            if self.detector_pointing.shared_flags is not None:
                flags = ob.shared[self.detector_pointing.shared_flags].data
                flags &= self.detector_pointing.shared_flag_mask

            center_lonlat = None
            if self.center_offset is not None:
                center_lonlat = np.radians(ob.shared[self.center_offset].data)

            # Process all detectors
            for det in dets:
                for vslice in view_slices:
                    # Timestream of detector quaternions
                    quats = ob.detdata[quats_name][det][vslice]
                    view_samples = len(quats)

                    if center_lonlat is None:
                        center_offset = None
                    else:
                        center_offset = center_lonlat[vslice]

                    rel_lon, rel_lat = center_offset_lonlat(
                        quats,
                        center_offset=center_offset,
                        degrees=True,
                        is_azimuth=is_azimuth,
                    )

                    world_in = np.column_stack([rel_lon, rel_lat])

                    rdpix = self.wcs.wcs_world2pix(world_in, 0)
                    rdpix = np.array(np.around(rdpix), dtype=np.int64)

                    ob.detdata[self.pixels][det, vslice] = (
                        rdpix[:, 0] * self.pix_lat + rdpix[:, 1]
                    )
                    bad_pointing = ob.detdata[self.pixels][det, vslice] >= self._n_pix
                    if flags is not None:
                        bad_pointing = np.logical_or(bad_pointing, flags[vslice] != 0)
                    (ob.detdata[self.pixels][det, vslice])[bad_pointing] = -1

                    if self.create_dist is not None:
                        good = ob.detdata[self.pixels][det][vslice] >= 0
                        self._local_submaps[
                            (ob.detdata[self.pixels][det, vslice])[good]
                            // self._n_pix_submap
                        ] = 1

    def _finalize(self, data, **kwargs):
        if self.create_dist is not None:
            if self.single_precision:
                submaps = np.arange(self.submaps, dtype=np.int32)[
                    self._local_submaps == 1
                ]
            else:
                submaps = np.arange(self.submaps, dtype=np.int64)[
                    self._local_submaps == 1
                ]

            data[self.create_dist] = PixelDistribution(
                n_pix=self._n_pix,
                n_submap=self.submaps,
                local_submaps=submaps,
                comm=data.comm.comm_world,
            )
            # Store a copy of the WCS information in the distribution object
            data[self.create_dist].wcs = self.wcs.deepcopy()
            data[self.create_dist].wcs_shape = tuple(self.wcs_shape)
            # Reset the local submaps
            self._local_submaps[:] = 0
        return

    def _requires(self):
        req = self.detector_pointing.requires()
        if "detdata" not in req:
            req["detdata"] = list()
        req["detdata"].append(self.pixels)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = self.detector_pointing.provides()
        prov["detdata"].append(self.pixels)
        if self.create_dist is not None:
            prov["global"].append(self.create_dist)
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

_done_auto = False instance-attribute

_done_wcs = False instance-attribute

auto_bounds = Bool(True, help='If True, set the bounding box based on boresight and field of view') class-attribute instance-attribute

bounds = Tuple(tuple(), help='The (lon_min, lon_max, lat_min, lat_max) values (Quantities)') class-attribute instance-attribute

center = Tuple(tuple(), help='The center Lon/Lat coordinates (Quantities) of the projection') class-attribute instance-attribute

center_offset = Unicode(None, allow_none=True, help='Optional name of shared field with lon, lat offset in degrees') class-attribute instance-attribute

coord_frame = Unicode('EQU', help='Supported values are AZEL, EQU, GAL, ECL') class-attribute instance-attribute

create_dist = Unicode(None, allow_none=True, help='Create the submap distribution for all detectors and store in the Data key specified') class-attribute instance-attribute

detector_pointing = Instance(klass=Operator, allow_none=True, help='Operator that translates boresight pointing into detector frame') class-attribute instance-attribute

dimensions = Tuple((1000, 1000), help='The Lon/Lat pixel dimensions of the projection') class-attribute instance-attribute

fits_header = Unicode(None, allow_none=True, help='FITS file containing header to use with pre-existing WCS parameters') class-attribute instance-attribute

pixels = Unicode('pixels', help='Observation detdata key for output pixel indices') class-attribute instance-attribute

projection = Unicode('CAR', help='Supported values are CAR, CEA, MER, ZEA, TAN, SFL') class-attribute instance-attribute

resolution = Tuple(tuple(), help='The Lon/Lat projection resolution (Quantities) along the 2 axes') class-attribute instance-attribute

single_precision = Bool(False, help='If True, use 32bit int in output') class-attribute instance-attribute

submaps = Int(1, help='Number of submaps to use') class-attribute instance-attribute

use_astropy = Bool(True, help='If True, use astropy for world to pix conversion') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/pixels_wcs.py
146
147
148
149
150
151
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    # Track whether we need to recompute autobounds
    self._done_auto = False
    # Track whether we need to recompute the WCS projection
    self._done_wcs = False

_check_detector_pointing(proposal)

Source code in toast/ops/pixels_wcs.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
    detpointing = proposal["value"]
    if detpointing is not None:
        if not isinstance(detpointing, Operator):
            raise traitlets.TraitError(
                "detector_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in [
            "view",
            "boresight",
            "shared_flags",
            "shared_flag_mask",
            "quats",
            "coord_in",
            "coord_out",
        ]:
            if not detpointing.has_trait(trt):
                msg = f"detector_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return detpointing

_check_wcs_projection(proposal)

Source code in toast/ops/pixels_wcs.py
139
140
141
142
143
144
@traitlets.validate("wcs_projection")
def _check_wcs_projection(self, proposal):
    check = proposal["value"]
    if check not in ["CAR", "CEA", "MER", "ZEA", "TAN", "SFL"]:
        raise traitlets.TraitError("Invalid WCS projection name")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/pixels_wcs.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    if self.detector_pointing is None:
        raise RuntimeError("The detector_pointing trait must be set")

    if not self.use_astropy:
        raise NotImplementedError("Only astropy conversion is currently supported")

    if self.fits_header is not None:
        # with open(self.fits_header, "rb") as f:
        #     header = af.Header.fromfile(f)
        raise NotImplementedError(
            "Initialization from a FITS header not yet finished"
        )

    if self.coord_frame == "AZEL":
        is_azimuth = True
    else:
        is_azimuth = False

    if self.auto_bounds and not self._done_auto:
        # Pass through the boresight pointing for every observation and build
        # the maximum extent of the detector field of view.
        lonmax = -2 * np.pi * u.radian
        lonmin = 2 * np.pi * u.radian
        latmax = (-np.pi / 2) * u.radian
        latmin = (np.pi / 2) * u.radian
        for ob in data.obs:
            # The scan range is computed collectively among the group.
            lnmin, lnmax, ltmin, ltmax = scan_range_lonlat(
                ob,
                self.detector_pointing.boresight,
                flags=self.detector_pointing.shared_flags,
                flag_mask=self.detector_pointing.shared_flag_mask,
                field_of_view=None,
                is_azimuth=is_azimuth,
                center_offset=self.center_offset,
            )
            lonmin = min(lonmin, lnmin)
            lonmax = max(lonmax, lnmax)
            latmin = min(latmin, ltmin)
            latmax = max(latmax, ltmax)
        if data.comm.comm_world is not None:
            lonlatmin = np.zeros(2, dtype=np.float64)
            lonlatmax = np.zeros(2, dtype=np.float64)
            lonlatmin[0] = lonmin.to_value(u.radian)
            lonlatmin[1] = latmin.to_value(u.radian)
            lonlatmax[0] = lonmax.to_value(u.radian)
            lonlatmax[1] = latmax.to_value(u.radian)
            all_lonlatmin = np.zeros(2, dtype=np.float64)
            all_lonlatmax = np.zeros(2, dtype=np.float64)
            data.comm.comm_world.Allreduce(lonlatmin, all_lonlatmin, op=MPI.MIN)
            data.comm.comm_world.Allreduce(lonlatmax, all_lonlatmax, op=MPI.MAX)
            lonmin = all_lonlatmin[0] * u.radian
            latmin = all_lonlatmin[1] * u.radian
            lonmax = all_lonlatmax[0] * u.radian
            latmax = all_lonlatmax[1] * u.radian
        self.bounds = (
            lonmin.to(u.degree),
            lonmax.to(u.degree),
            latmin.to(u.degree),
            latmax.to(u.degree),
        )
        log.verbose(f"PixelsWCS: auto_bounds set to {self.bounds}")
        self._done_auto = True

    # Compute the projection if needed
    self.set_wcs()

    # Expand detector pointing
    quats_name = self.detector_pointing.quats

    view = self.view
    if view is None:
        # Use the same data view as detector pointing
        view = self.detector_pointing.view

    # Once this supports accelerator, pass that instead of False
    self.detector_pointing.apply(data, detectors=detectors, use_accel=False)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            detectors, flagmask=self.detector_pointing.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # Check that our view is fully covered by detector pointing.  If the
        # detector_pointing view is None, then it has all samples.  If our own
        # view was None, then it would have been set to the detector_pointing
        # view above.
        if (view is not None) and (self.detector_pointing.view is not None):
            if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                # We need to check intersection
                intervals = ob.intervals[self.view]
                detector_intervals = ob.intervals[self.detector_pointing.view]
                intersection = detector_intervals & intervals
                if intersection != intervals:
                    msg = (
                        f"view {self.view} is not fully covered by valid "
                        "detector pointing"
                    )
                    raise RuntimeError(msg)

        # Create (or re-use) output data for the pixels, weights and optionally the
        # detector quaternions.

        if self.single_precision:
            exists = ob.detdata.ensure(
                self.pixels, sample_shape=(), dtype=np.int32, detectors=dets
            )
        else:
            exists = ob.detdata.ensure(
                self.pixels, sample_shape=(), dtype=np.int64, detectors=dets
            )

        view_slices = [slice(x.first, x.last, 1) for x in ob.intervals[view]]

        # Do we already have pointing for all requested detectors?
        if exists:
            # Yes...
            if self.create_dist is not None:
                # but the caller wants the pixel distribution
                restore_dev = False
                if ob.detdata[self.pixels].accel_in_use():
                    # The data is on the accelerator- copy back to host for
                    # this calculation.  This could eventually be a kernel.
                    ob.detdata[self.pixels].accel_update_host()
                    restore_dev = True
                for det in dets:
                    for vslice in view_slices:
                        good = ob.detdata[self.pixels][det, vslice] >= 0
                        self._local_submaps[
                            ob.detdata[self.pixels][det, vslice][good]
                            // self._n_pix_submap
                        ] = 1
                if restore_dev:
                    ob.detdata[self.pixels].accel_update_device()

            if data.comm.group_rank == 0:
                msg = (
                    f"Group {data.comm.group}, ob {ob.name}, WCS pixels "
                    f"already computed for {dets}"
                )
                log.verbose(msg)
            continue

        # Focalplane for this observation
        focalplane = ob.telescope.focalplane

        # Get the flags if needed.  Use the same flags as
        # detector pointing.
        flags = None
        if self.detector_pointing.shared_flags is not None:
            flags = ob.shared[self.detector_pointing.shared_flags].data
            flags &= self.detector_pointing.shared_flag_mask

        center_lonlat = None
        if self.center_offset is not None:
            center_lonlat = np.radians(ob.shared[self.center_offset].data)

        # Process all detectors
        for det in dets:
            for vslice in view_slices:
                # Timestream of detector quaternions
                quats = ob.detdata[quats_name][det][vslice]
                view_samples = len(quats)

                if center_lonlat is None:
                    center_offset = None
                else:
                    center_offset = center_lonlat[vslice]

                rel_lon, rel_lat = center_offset_lonlat(
                    quats,
                    center_offset=center_offset,
                    degrees=True,
                    is_azimuth=is_azimuth,
                )

                world_in = np.column_stack([rel_lon, rel_lat])

                rdpix = self.wcs.wcs_world2pix(world_in, 0)
                rdpix = np.array(np.around(rdpix), dtype=np.int64)

                ob.detdata[self.pixels][det, vslice] = (
                    rdpix[:, 0] * self.pix_lat + rdpix[:, 1]
                )
                bad_pointing = ob.detdata[self.pixels][det, vslice] >= self._n_pix
                if flags is not None:
                    bad_pointing = np.logical_or(bad_pointing, flags[vslice] != 0)
                (ob.detdata[self.pixels][det, vslice])[bad_pointing] = -1

                if self.create_dist is not None:
                    good = ob.detdata[self.pixels][det][vslice] >= 0
                    self._local_submaps[
                        (ob.detdata[self.pixels][det, vslice])[good]
                        // self._n_pix_submap
                    ] = 1

_finalize(data, **kwargs)

Source code in toast/ops/pixels_wcs.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def _finalize(self, data, **kwargs):
    if self.create_dist is not None:
        if self.single_precision:
            submaps = np.arange(self.submaps, dtype=np.int32)[
                self._local_submaps == 1
            ]
        else:
            submaps = np.arange(self.submaps, dtype=np.int64)[
                self._local_submaps == 1
            ]

        data[self.create_dist] = PixelDistribution(
            n_pix=self._n_pix,
            n_submap=self.submaps,
            local_submaps=submaps,
            comm=data.comm.comm_world,
        )
        # Store a copy of the WCS information in the distribution object
        data[self.create_dist].wcs = self.wcs.deepcopy()
        data[self.create_dist].wcs_shape = tuple(self.wcs_shape)
        # Reset the local submaps
        self._local_submaps[:] = 0
    return

_provides()

Source code in toast/ops/pixels_wcs.py
619
620
621
622
623
624
def _provides(self):
    prov = self.detector_pointing.provides()
    prov["detdata"].append(self.pixels)
    if self.create_dist is not None:
        prov["global"].append(self.create_dist)
    return prov

_requires()

Source code in toast/ops/pixels_wcs.py
610
611
612
613
614
615
616
617
def _requires(self):
    req = self.detector_pointing.requires()
    if "detdata" not in req:
        req["detdata"] = list()
    req["detdata"].append(self.pixels)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_reset_auto_bounds(change)

Source code in toast/ops/pixels_wcs.py
153
154
155
156
157
158
159
160
@traitlets.observe("auto_bounds")
def _reset_auto_bounds(self, change):
    # Track whether we need to recompute the bounds.
    old_val = change["old"]
    new_val = change["new"]
    if new_val != old_val:
        self._done_auto = False
        self._done_wcs = False

_reset_auto_center(change)

Source code in toast/ops/pixels_wcs.py
162
163
164
165
166
167
168
169
@traitlets.observe("center_offset")
def _reset_auto_center(self, change):
    old_val = change["old"]
    new_val = change["new"]
    # Track whether we need to recompute the projection
    if new_val != old_val:
        self._done_wcs = False
        self._done_auto = False

_reset_wcs(change)

Source code in toast/ops/pixels_wcs.py
171
172
173
174
175
176
177
178
@traitlets.observe("projection", "center", "bounds", "dimensions", "resolution")
def _reset_wcs(self, change):
    # (Re-)initialize the WCS projection when one of these traits change.
    old_val = change["old"]
    new_val = change["new"]
    if old_val != new_val:
        self._done_wcs = False
        self._done_auto = False

create_wcs(coord='EQU', proj='CAR', center_deg=None, bounds_deg=None, res_deg=None, dims=None) classmethod

Create a WCS object given projection parameters.

Either the center_deg or bounds_deg parameters must be specified, but not both.

When determining the pixel density in the projection, exactly two parameters from the set of bounds_deg, res_deg and dims must be specified.

Parameters:

Name Type Description Default
coord str

The coordinate frame name.

'EQU'
proj str

The projection type.

'CAR'
center_deg tuple

The (lon, lat) projection center in degrees.

None
bounds_deg tuple

The (lon_min, lon_max, lat_min, lat_max) values in degrees.

None
res_deg tuple

The (lon, lat) resolution in degrees.

None
dims tuple

The (lon, lat) projection size in pixels.

None

Returns:

Type Description
(WCS, shape)

The instantiated WCS object and final shape.

Source code in toast/ops/pixels_wcs.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
@classmethod
def create_wcs(
    cls,
    coord="EQU",
    proj="CAR",
    center_deg=None,
    bounds_deg=None,
    res_deg=None,
    dims=None,
):
    """Create a WCS object given projection parameters.

    Either the `center_deg` or `bounds_deg` parameters must be specified,
    but not both.

    When determining the pixel density in the projection, exactly two
    parameters from the set of `bounds_deg`, `res_deg` and `dims` must be
    specified.

    Args:
        coord (str):  The coordinate frame name.
        proj (str):  The projection type.
        center_deg (tuple):  The (lon, lat) projection center in degrees.
        bounds_deg (tuple):  The (lon_min, lon_max, lat_min, lat_max)
            values in degrees.
        res_deg (tuple):  The (lon, lat) resolution in degrees.
        dims (tuple):  The (lon, lat) projection size in pixels.

    Returns:
        (WCS, shape): The instantiated WCS object and final shape.

    """
    log = Logger.get()

    # Compute projection center
    if center_deg is not None:
        # We are specifying the center.  Bounds should not be set and we should
        # have both resolution and dimensions
        if bounds_deg is not None:
            msg = f"PixelsWCS: only one of center and bounds should be set."
            log.error(msg)
            raise RuntimeError(msg)
        if res_deg is None or dims is None:
            msg = f"PixelsWCS: when center is set, both resolution and dimensions"
            msg += f" are required."
            log.error(msg)
            raise RuntimeError(msg)
        crval = np.array(center_deg, dtype=np.float64)
    else:
        # Not using center, bounds is required
        if bounds_deg is None:
            msg = f"PixelsWCS: when center is not specified, bounds required."
            log.error(msg)
            raise RuntimeError(msg)
        mid_lon = 0.5 * (bounds_deg[1] + bounds_deg[0])
        mid_lat = 0.5 * (bounds_deg[3] + bounds_deg[2])
        crval = np.array([mid_lon, mid_lat], dtype=np.float64)
        # Either resolution or dimensions should be specified
        if res_deg is not None:
            # Using resolution
            if dims is not None:
                msg = f"PixelsWCS: when using bounds, only one of resolution or"
                msg += f" dimensions must be specified."
                log.error(msg)
                raise RuntimeError(msg)
        else:
            # Using dimensions
            if res_deg is not None:
                msg = f"PixelsWCS: when using bounds, only one of resolution or"
                msg += f" dimensions must be specified."
                log.error(msg)
                raise RuntimeError(msg)

    # Create the WCS object.
    # CTYPE1 = Longitude
    # CTYPE2 = Latitude
    wcs = WCS(naxis=2)

    if coord == "AZEL":
        # For local Azimuth and Elevation coordinate frame, we
        # use the generic longitude and latitude string.
        coordstr = ("TLON", "TLAT")
    elif coord == "EQU":
        coordstr = ("RA--", "DEC-")
    elif coord == "GAL":
        coordstr = ("GLON", "GLAT")
    elif coord == "ECL":
        coordstr = ("ELON", "ELAT")
    else:
        msg = f"Unsupported coordinate frame '{coord}'"
        raise RuntimeError(msg)

    if proj == "CAR":
        wcs.wcs.ctype = [f"{coordstr[0]}-CAR", f"{coordstr[1]}-CAR"]
        wcs.wcs.crval = crval
    elif proj == "CEA":
        wcs.wcs.ctype = [f"{coordstr[0]}-CEA", f"{coordstr[1]}-CEA"]
        wcs.wcs.crval = crval
        lam = np.cos(np.deg2rad(crval[1])) ** 2
        wcs.wcs.set_pv([(2, 1, lam)])
    elif proj == "MER":
        wcs.wcs.ctype = [f"{coordstr[0]}-MER", f"{coordstr[1]}-MER"]
        wcs.wcs.crval = crval
    elif proj == "ZEA":
        wcs.wcs.ctype = [f"{coordstr[0]}-ZEA", f"{coordstr[1]}-ZEA"]
        wcs.wcs.crval = crval
    elif proj == "TAN":
        wcs.wcs.ctype = [f"{coordstr[0]}-TAN", f"{coordstr[1]}-TAN"]
        wcs.wcs.crval = crval
    elif proj == "SFL":
        wcs.wcs.ctype = [f"{coordstr[0]}-SFL", f"{coordstr[1]}-SFL"]
        wcs.wcs.crval = crval
    else:
        msg = f"Invalid WCS projection name '{proj}'"
        raise ValueError(msg)

    # Compute resolution.  Note that we negate the longitudinal
    # coordinate so that the resulting projections match expectations
    # for plotting, etc.
    if center_deg is not None:
        wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]])
    else:
        if res_deg is not None:
            wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]])
        else:
            # Compute CDELT from the bounding box and image size.
            wcs.wcs.cdelt = np.array(
                [
                    -(bounds_deg[1] - bounds_deg[0]) / dims[0],
                    (bounds_deg[3] - bounds_deg[2]) / dims[1],
                ]
            )

    # Compute shape of the projection
    if dims is not None:
        wcs_shape = tuple(dims)
    else:
        # Compute from the bounding box corners
        lower_left = wcs.wcs_world2pix(
            np.array([[bounds_deg[0], bounds_deg[2]]]), 0
        )[0]
        upper_right = wcs.wcs_world2pix(
            np.array([[bounds_deg[1], bounds_deg[3]]]), 0
        )[0]
        wcs_shape = tuple(np.round(np.abs(upper_right - lower_left)).astype(int))

    # Set the reference pixel to the center of the projection
    off = wcs.wcs_world2pix(crval.reshape((1, 2)), 0)[0]
    wcs.wcs.crpix = 0.5 * np.array(wcs_shape, dtype=np.float64) + 0.5 + off

    return wcs, wcs_shape

set_wcs()

Source code in toast/ops/pixels_wcs.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def set_wcs(self):
    if self._done_wcs:
        return

    log = Logger.get()
    msg = f"PixelsWCS: set_wcs coord={self.coord_frame}, "
    msg += f"proj={self.projection}, center={self.center}, bounds={self.bounds}"
    msg += f", dims={self.dimensions}, res={self.resolution}"
    log.verbose(msg)

    center_deg = None
    if len(self.center) > 0:
        if self.center_offset is None:
            center_deg = (
                self.center[0].to_value(u.degree),
                self.center[1].to_value(u.degree),
            )
        else:
            center_deg = (0.0, 0.0)
    bounds_deg = None
    if len(self.bounds) > 0:
        bounds_deg = tuple([x.to_value(u.degree) for x in self.bounds])
    res_deg = None
    if len(self.resolution) > 0:
        res_deg = tuple([x.to_value(u.degree) for x in self.resolution])
    if len(self.dimensions) > 0:
        dims = tuple(self.dimensions)
    else:
        dims = None

    self.wcs, self.wcs_shape = self.create_wcs(
        coord=self.coord_frame,
        proj=self.projection,
        center_deg=center_deg,
        bounds_deg=bounds_deg,
        res_deg=res_deg,
        dims=dims,
    )

    self.pix_lon = self.wcs_shape[0]
    self.pix_lat = self.wcs_shape[1]
    self._n_pix = self.pix_lon * self.pix_lat
    self._n_pix_submap = self._n_pix // self.submaps
    if self._n_pix_submap * self.submaps < self._n_pix:
        self._n_pix_submap += 1
    self._local_submaps = np.zeros(self.submaps, dtype=np.uint8)
    self._done_wcs = True
    return

toast.ops.StokesWeights

Bases: Operator

Operator which generates I/Q/U pointing weights.

Given the individual detector pointing, this computes the pointing weights assuming that the detector is a linear polarizer followed by a total power measurement. By definition, the detector coordinate frame has the X-axis aligned with the polarization sensitive direction. An optional dictionary of pointing weight calibration factors may be specified for each observation.

If the hwp_angle field is specified, then an ideal HWP Mueller matrix is inserted in the optics chain before the linear polarizer. In this case, the fp_gamma key name must be specified and each detector must have a value in the focalplane table.

The timestream model without a HWP in COSMO convention is:

.. math:: d = cal \left[I + \frac{1 - \epsilon}{1 + \epsilon} \left[Q \cos\left(2\alpha\right) + U \sin\left(2\alpha\right) \right] \right]

When a HWP is present, we have:

.. math:: d = cal \left[I + \frac{1 - \epsilon}{1 + \epsilon} \left[Q \cos\left(2(\alpha - 2\omega) \right) - U \sin\left(2(\alpha - 2\omega) \right) \right] \right]

The detector orientation angle "alpha" in COSMO convention is measured in a right-handed sense from the local meridian and the HWP angle "omega" is also measured from the local meridian. The omega value can be described in terms of alpha, a fixed per-detector offset gamma, and the time varying HWP angle measured from the focalplane coordinate frame X-axis:

.. math:: \omega = \alpha + {\gamma}{HWP}(t) - {\gamma}{DET}

See documentation for a full treatment of this math.

By default, this operator uses the "COSMO" convention for Q/U. If the "IAU" trait is set to True, then resulting weights will differ by the sign of the U Stokes weight.

If the view trait is not specified, then this operator will use the same data view as the detector pointing operator when computing the pointing matrix pixels and weights.

Source code in toast/ops/stokes_weights/stokes_weights.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
@trait_docs
class StokesWeights(Operator):
    """Operator which generates I/Q/U pointing weights.

    Given the individual detector pointing, this computes the pointing weights
    assuming that the detector is a linear polarizer followed by a total
    power measurement.  By definition, the detector coordinate frame has the X-axis
    aligned with the polarization sensitive direction.  An optional dictionary of
    pointing weight calibration factors may be specified for each observation.

    If the hwp_angle field is specified, then an ideal HWP Mueller matrix is inserted
    in the optics chain before the linear polarizer.  In this case, the fp_gamma key
    name must be specified and each detector must have a value in the focalplane
    table.

    The timestream model without a HWP in COSMO convention is:

    .. math::
        d = cal \\left[I + \\frac{1 - \\epsilon}{1 + \\epsilon} \\left[Q \\cos\\left(2\\alpha\\right) + U \\sin\\left(2\\alpha\\right) \\right] \\right]

    When a HWP is present, we have:

    .. math::
        d = cal \\left[I + \\frac{1 - \\epsilon}{1 + \\epsilon} \\left[Q \\cos\\left(2(\\alpha - 2\\omega) \\right) - U \\sin\\left(2(\\alpha - 2\\omega) \\right) \\right] \\right]

    The detector orientation angle "alpha" in COSMO convention is measured in a
    right-handed sense from the local meridian and the HWP angle "omega" is also
    measured from the local meridian.  The omega value can be described in terms of
    alpha, a fixed per-detector offset gamma, and the time varying HWP angle measured
    from the focalplane coordinate frame X-axis:

    .. math::
        \\omega = \\alpha + {\\gamma}_{HWP}(t) - {\\gamma}_{DET}

    See documentation for a full treatment of this math.

    By default, this operator uses the "COSMO" convention for Q/U.  If the "IAU" trait
    is set to True, then resulting weights will differ by the sign of the U Stokes
    weight.

    If the view trait is not specified, then this operator will use the same data
    view as the detector pointing operator when computing the pointing matrix pixels
    and weights.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    detector_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="Operator that translates boresight pointing into detector frame",
    )

    mode = Unicode("I", help="The Stokes weights to generate (I or IQU)")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    hwp_angle = Unicode(
        None, allow_none=True, help="Observation shared key for HWP angle"
    )

    fp_gamma = Unicode(
        "gamma", allow_none=True, help="Focalplane key for detector gamma offset angle"
    )

    weights = Unicode(
        defaults.weights, help="Observation detdata key for output weights"
    )

    single_precision = Bool(False, help="If True, use 32bit float in output")

    cal = Unicode(
        None,
        allow_none=True,
        help="The observation key with a dictionary of pointing weight "
        "calibration for each det",
    )

    IAU = Bool(False, help="If True, use the IAU convention rather than COSMO")

    @traitlets.validate("detector_pointing")
    def _check_detector_pointing(self, proposal):
        detpointing = proposal["value"]
        if detpointing is not None:
            if not isinstance(detpointing, Operator):
                raise traitlets.TraitError(
                    "detector_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in [
                "view",
                "boresight",
                "shared_flags",
                "shared_flag_mask",
                "det_mask",
                "quats",
                "coord_in",
                "coord_out",
            ]:
                if not detpointing.has_trait(trt):
                    msg = f"detector_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return detpointing

    @traitlets.validate("mode")
    def _check_mode(self, proposal):
        check = proposal["value"]
        if check not in ["I", "IQU"]:
            raise traitlets.TraitError("Invalid mode (must be 'I' or 'IQU')")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        env = Environment.get()
        log = Logger.get()

        self._nnz = len(self.mode)

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.detector_pointing is None:
            raise RuntimeError("The detector_pointing trait must be set")

        if ("QU" in self.mode) and self.hwp_angle is not None:
            if self.fp_gamma is None:
                raise RuntimeError("If using HWP, you must specify the fp_gamma key")

        # Expand detector pointing
        quats_name = self.detector_pointing.quats

        view = self.view
        if view is None:
            # Use the same data view as detector pointing
            view = self.detector_pointing.view

        # Expand detector pointing
        self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel)

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                detectors, flagmask=self.detector_pointing.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # Check that our view is fully covered by detector pointing.  If the
            # detector_pointing view is None, then it has all samples.  If our own
            # view was None, then it would have been set to the detector_pointing
            # view above.
            if (view is not None) and (self.detector_pointing.view is not None):
                if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                    # We need to check intersection
                    intervals = ob.intervals[self.view]
                    detector_intervals = ob.intervals[self.detector_pointing.view]
                    intersection = detector_intervals & intervals
                    if intersection != intervals:
                        msg = (
                            f"view {self.view} is not fully covered by valid "
                            "detector pointing"
                        )
                        raise RuntimeError(msg)

            # Create (or re-use) output data for the weights
            if self.single_precision:
                exists = ob.detdata.ensure(
                    self.weights,
                    sample_shape=(self._nnz,),
                    dtype=np.float32,
                    detectors=dets,
                    accel=use_accel,
                )
            else:
                exists = ob.detdata.ensure(
                    self.weights,
                    sample_shape=(self._nnz,),
                    dtype=np.float64,
                    detectors=dets,
                    accel=use_accel,
                )

            quat_indx = ob.detdata[quats_name].indices(dets)
            weight_indx = ob.detdata[self.weights].indices(dets)

            # Do we already have pointing for all requested detectors?
            if exists:
                # Yes
                if data.comm.group_rank == 0:
                    msg = (
                        f"Group {data.comm.group}, ob {ob.name}, Stokes weights "
                        f"already computed for {dets}"
                    )
                    log.verbose(msg)
                continue

            # FIXME:  temporary hack until instrument classes are also pre-staged
            # to GPU
            focalplane = ob.telescope.focalplane
            det_epsilon = np.zeros(len(dets), dtype=np.float64)

            # Get the cross polar response from the focalplane
            if "pol_leakage" in focalplane.detector_data.colnames:
                for idet, d in enumerate(dets):
                    det_epsilon[idet] = focalplane[d]["pol_leakage"]

            # Get the per-detector calibration
            if self.cal is None:
                cal = np.array([1.0 for x in dets], np.float64)
            else:
                cal = np.array([ob[self.cal][x] for x in dets], np.float64)

            if self.mode == "IQU":
                det_gamma = np.zeros(len(dets), dtype=np.float64)
                if self.hwp_angle is None or self.hwp_angle not in ob.shared:
                    hwp_data = np.zeros(1, dtype=np.float64)
                else:
                    hwp_data = ob.shared[self.hwp_angle].data
                    for idet, d in enumerate(dets):
                        det_gamma[idet] = focalplane[d]["gamma"].to_value(u.rad)
                stokes_weights_IQU(
                    quat_indx,
                    ob.detdata[quats_name].data,
                    weight_indx,
                    ob.detdata[self.weights].data,
                    hwp_data,
                    ob.intervals[self.view].data,
                    det_epsilon,
                    det_gamma,
                    cal,
                    bool(self.IAU),
                    impl=implementation,
                    use_accel=use_accel,
                )
            else:
                stokes_weights_I(
                    weight_indx,
                    ob.detdata[self.weights].data,
                    ob.intervals[self.view].data,
                    cal,
                    impl=implementation,
                    use_accel=use_accel,
                )
        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.detector_pointing.requires()
        if "detdata" not in req:
            req["detdata"] = list()
        req["detdata"].append(self.weights)
        if self.cal is not None:
            req["meta"].append(self.cal)
        if self.hwp_angle is not None:
            req["shared"].append(self.hwp_angle)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = self.detector_pointing.provides()
        prov["detdata"].append(self.weights)
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        if (self.detector_pointing is not None) and (
            self.detector_pointing.supports_accel()
        ):
            return True
        else:
            return False

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

IAU = Bool(False, help='If True, use the IAU convention rather than COSMO') class-attribute instance-attribute

cal = Unicode(None, allow_none=True, help='The observation key with a dictionary of pointing weight calibration for each det') class-attribute instance-attribute

detector_pointing = Instance(klass=Operator, allow_none=True, help='Operator that translates boresight pointing into detector frame') class-attribute instance-attribute

fp_gamma = Unicode('gamma', allow_none=True, help='Focalplane key for detector gamma offset angle') class-attribute instance-attribute

hwp_angle = Unicode(None, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

mode = Unicode('I', help='The Stokes weights to generate (I or IQU)') class-attribute instance-attribute

single_precision = Bool(False, help='If True, use 32bit float in output') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

weights = Unicode(defaults.weights, help='Observation detdata key for output weights') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/stokes_weights/stokes_weights.py
134
135
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_detector_pointing(proposal)

Source code in toast/ops/stokes_weights/stokes_weights.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
    detpointing = proposal["value"]
    if detpointing is not None:
        if not isinstance(detpointing, Operator):
            raise traitlets.TraitError(
                "detector_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in [
            "view",
            "boresight",
            "shared_flags",
            "shared_flag_mask",
            "det_mask",
            "quats",
            "coord_in",
            "coord_out",
        ]:
            if not detpointing.has_trait(trt):
                msg = f"detector_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return detpointing

_check_mode(proposal)

Source code in toast/ops/stokes_weights/stokes_weights.py
127
128
129
130
131
132
@traitlets.validate("mode")
def _check_mode(self, proposal):
    check = proposal["value"]
    if check not in ["I", "IQU"]:
        raise traitlets.TraitError("Invalid mode (must be 'I' or 'IQU')")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/stokes_weights/stokes_weights.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    env = Environment.get()
    log = Logger.get()

    self._nnz = len(self.mode)

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.detector_pointing is None:
        raise RuntimeError("The detector_pointing trait must be set")

    if ("QU" in self.mode) and self.hwp_angle is not None:
        if self.fp_gamma is None:
            raise RuntimeError("If using HWP, you must specify the fp_gamma key")

    # Expand detector pointing
    quats_name = self.detector_pointing.quats

    view = self.view
    if view is None:
        # Use the same data view as detector pointing
        view = self.detector_pointing.view

    # Expand detector pointing
    self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel)

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            detectors, flagmask=self.detector_pointing.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # Check that our view is fully covered by detector pointing.  If the
        # detector_pointing view is None, then it has all samples.  If our own
        # view was None, then it would have been set to the detector_pointing
        # view above.
        if (view is not None) and (self.detector_pointing.view is not None):
            if ob.intervals[view] != ob.intervals[self.detector_pointing.view]:
                # We need to check intersection
                intervals = ob.intervals[self.view]
                detector_intervals = ob.intervals[self.detector_pointing.view]
                intersection = detector_intervals & intervals
                if intersection != intervals:
                    msg = (
                        f"view {self.view} is not fully covered by valid "
                        "detector pointing"
                    )
                    raise RuntimeError(msg)

        # Create (or re-use) output data for the weights
        if self.single_precision:
            exists = ob.detdata.ensure(
                self.weights,
                sample_shape=(self._nnz,),
                dtype=np.float32,
                detectors=dets,
                accel=use_accel,
            )
        else:
            exists = ob.detdata.ensure(
                self.weights,
                sample_shape=(self._nnz,),
                dtype=np.float64,
                detectors=dets,
                accel=use_accel,
            )

        quat_indx = ob.detdata[quats_name].indices(dets)
        weight_indx = ob.detdata[self.weights].indices(dets)

        # Do we already have pointing for all requested detectors?
        if exists:
            # Yes
            if data.comm.group_rank == 0:
                msg = (
                    f"Group {data.comm.group}, ob {ob.name}, Stokes weights "
                    f"already computed for {dets}"
                )
                log.verbose(msg)
            continue

        # FIXME:  temporary hack until instrument classes are also pre-staged
        # to GPU
        focalplane = ob.telescope.focalplane
        det_epsilon = np.zeros(len(dets), dtype=np.float64)

        # Get the cross polar response from the focalplane
        if "pol_leakage" in focalplane.detector_data.colnames:
            for idet, d in enumerate(dets):
                det_epsilon[idet] = focalplane[d]["pol_leakage"]

        # Get the per-detector calibration
        if self.cal is None:
            cal = np.array([1.0 for x in dets], np.float64)
        else:
            cal = np.array([ob[self.cal][x] for x in dets], np.float64)

        if self.mode == "IQU":
            det_gamma = np.zeros(len(dets), dtype=np.float64)
            if self.hwp_angle is None or self.hwp_angle not in ob.shared:
                hwp_data = np.zeros(1, dtype=np.float64)
            else:
                hwp_data = ob.shared[self.hwp_angle].data
                for idet, d in enumerate(dets):
                    det_gamma[idet] = focalplane[d]["gamma"].to_value(u.rad)
            stokes_weights_IQU(
                quat_indx,
                ob.detdata[quats_name].data,
                weight_indx,
                ob.detdata[self.weights].data,
                hwp_data,
                ob.intervals[self.view].data,
                det_epsilon,
                det_gamma,
                cal,
                bool(self.IAU),
                impl=implementation,
                use_accel=use_accel,
            )
        else:
            stokes_weights_I(
                weight_indx,
                ob.detdata[self.weights].data,
                ob.intervals[self.view].data,
                cal,
                impl=implementation,
                use_accel=use_accel,
            )
    return

_finalize(data, **kwargs)

Source code in toast/ops/stokes_weights/stokes_weights.py
272
273
def _finalize(self, data, **kwargs):
    return

_implementations()

Source code in toast/ops/stokes_weights/stokes_weights.py
293
294
295
296
297
298
299
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/stokes_weights/stokes_weights.py
288
289
290
291
def _provides(self):
    prov = self.detector_pointing.provides()
    prov["detdata"].append(self.weights)
    return prov

_requires()

Source code in toast/ops/stokes_weights/stokes_weights.py
275
276
277
278
279
280
281
282
283
284
285
286
def _requires(self):
    req = self.detector_pointing.requires()
    if "detdata" not in req:
        req["detdata"] = list()
    req["detdata"].append(self.weights)
    if self.cal is not None:
        req["meta"].append(self.cal)
    if self.hwp_angle is not None:
        req["shared"].append(self.hwp_angle)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/stokes_weights/stokes_weights.py
301
302
303
304
305
306
307
def _supports_accel(self):
    if (self.detector_pointing is not None) and (
        self.detector_pointing.supports_accel()
    ):
        return True
    else:
        return False

Scan Strategy Characterization

toast.ops.CadenceMap

Bases: Operator

Tabulate which days each pixel on the map is visited.

Source code in toast/ops/cadence_map.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
@trait_docs
class CadenceMap(Operator):
    """Tabulate which days each pixel on the map is visited."""

    # Class traits

    pixel_dist = Unicode(
        None,
        allow_none=True,
        help="The Data key containing the submap distribution",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pixel pointing operator.",
    )

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional telescope flagging",
    )

    output_dir = Unicode(
        ".",
        help="Write output data products to this directory",
    )

    save_pointing = Bool(False, help="If True, do not clear pixel numbers after use")

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pntg = proposal["value"]
        if pntg is not None:
            if not isinstance(pntg, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pntg.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pntg

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for trait in "pixel_pointing", "pixel_dist":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        comm = data.comm.comm_world
        rank = data.comm.world_rank

        # We need the pixel distribution to know total number of pixels

        if self.pixel_dist not in data:
            pix_dist = BuildPixelDistribution(
                pixel_dist=self.pixel_dist,
                pixel_pointing=self.pixel_pointing,
                shared_flags=self.shared_flags,
                shared_flag_mask=self.shared_flag_mask,
                save_pointing=self.save_pointing,
            )
            log.info_rank("Caching pixel distribution", comm=data.comm.comm_world)
            pix_dist.apply(data)

        npix = data[self.pixel_dist].n_pix

        if rank == 0:
            os.makedirs(self.output_dir, exist_ok=True)

        # determine the number of modified Julian days

        tmin = 1e30
        tmax = -1e30
        for obs in data.obs:
            times = obs.shared[self.times].data
            tmin = min(tmin, times[0])
            tmax = max(tmax, times[-1])

        if comm is not None:
            tmin = comm.allreduce(tmin, MPI.MIN)
            tmax = comm.allreduce(tmax, MPI.MAX)

        MJD_start = int(to_MJD(tmin))
        MJD_stop = int(to_MJD(tmax)) + 1
        nday = MJD_stop - MJD_start

        # Flag all pixels that are observed on each MJD

        if rank == 0:
            all_hit = np.zeros([nday, npix], dtype=bool)

        buflen = 10  # Number of days to process at once
        # FIXME : We should use `buflen` also for the HDF5 dataset size
        buf = np.zeros([buflen, npix], dtype=bool)
        day_start = MJD_start
        while day_start < MJD_stop:
            day_stop = min(MJD_stop, day_start + buflen)
            if rank == 0:
                log.debug(
                    f"Processing {MJD_start} <= {day_start} - {day_stop} <= {MJD_stop}"
                )
            buf[:, :] = False
            for obs in data.obs:
                obs_data = data.select(obs_uid=obs.uid)
                dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
                times = obs.shared[self.times].data
                days = to_MJD(times).astype(int)
                if days[0] >= day_stop or days[-1] < day_start:
                    continue
                if self.shared_flags:
                    cflag = (
                        obs.shared[self.shared_flags].data & self.shared_flag_mask
                    ) != 0
                for day in range(day_start, day_stop):
                    # Find samples that were collected on target day ...
                    good = days == day
                    if not np.any(good):
                        continue
                    if self.shared_flags:
                        # ... and are not flagged ...
                        good[cflag] = False
                    for det in dets:
                        if self.det_flags:
                            # ... even by detector flags
                            flag = obs.detdata[self.det_flags][det] & self.det_flag_mask
                            mask = np.logical_and(good, flag == 0)
                        else:
                            mask = good
                        # Compute pixel numbers.  Will do nothing if they already exist.
                        self.pixel_pointing.apply(obs_data, detectors=[det])
                        # Flag the hit pixels
                        pixels = obs.detdata[self.pixel_pointing.pixels][det]
                        mask[pixels < 0] = False
                        buf[day - day_start][pixels[mask]] = True
            if comm is not None:
                comm.Allreduce(MPI.IN_PLACE, buf, op=MPI.LOR)
            if rank == 0:
                for i in range(day_start, day_stop):
                    all_hit[i - MJD_start] = buf[i - day_start]
            day_start = day_stop

        if rank == 0:
            fname = os.path.join(self.output_dir, f"{self.name}.h5")
            with h5py.File(fname, "w") as f:
                dset = f.create_dataset("cadence", data=all_hit)
                dset.attrs["MJDSTART"] = MJD_start
                dset.attrs["MJDSTOP"] = MJD_stop
                dset.attrs["NESTED"] = self.pixel_pointing.nest
            log.info(f"Wrote cadence map to {fname}.")

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.pixel_pointing.requires()
        req["shared"].append(self.times)
        req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        return {}

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

output_dir = Unicode('.', help='Write output data products to this directory') class-attribute instance-attribute

pixel_dist = Unicode(None, allow_none=True, help='The Data key containing the submap distribution') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pixel pointing operator.') class-attribute instance-attribute

save_pointing = Bool(False, help='If True, do not clear pixel numbers after use') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional telescope flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/cadence_map.py
116
117
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/cadence_map.py
87
88
89
90
91
92
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/cadence_map.py
80
81
82
83
84
85
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_pixel_pointing(proposal)

Source code in toast/ops/cadence_map.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pntg = proposal["value"]
    if pntg is not None:
        if not isinstance(pntg, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pntg.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pntg

_check_shared_mask(proposal)

Source code in toast/ops/cadence_map.py
94
95
96
97
98
99
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/cadence_map.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for trait in "pixel_pointing", "pixel_dist":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    comm = data.comm.comm_world
    rank = data.comm.world_rank

    # We need the pixel distribution to know total number of pixels

    if self.pixel_dist not in data:
        pix_dist = BuildPixelDistribution(
            pixel_dist=self.pixel_dist,
            pixel_pointing=self.pixel_pointing,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            save_pointing=self.save_pointing,
        )
        log.info_rank("Caching pixel distribution", comm=data.comm.comm_world)
        pix_dist.apply(data)

    npix = data[self.pixel_dist].n_pix

    if rank == 0:
        os.makedirs(self.output_dir, exist_ok=True)

    # determine the number of modified Julian days

    tmin = 1e30
    tmax = -1e30
    for obs in data.obs:
        times = obs.shared[self.times].data
        tmin = min(tmin, times[0])
        tmax = max(tmax, times[-1])

    if comm is not None:
        tmin = comm.allreduce(tmin, MPI.MIN)
        tmax = comm.allreduce(tmax, MPI.MAX)

    MJD_start = int(to_MJD(tmin))
    MJD_stop = int(to_MJD(tmax)) + 1
    nday = MJD_stop - MJD_start

    # Flag all pixels that are observed on each MJD

    if rank == 0:
        all_hit = np.zeros([nday, npix], dtype=bool)

    buflen = 10  # Number of days to process at once
    # FIXME : We should use `buflen` also for the HDF5 dataset size
    buf = np.zeros([buflen, npix], dtype=bool)
    day_start = MJD_start
    while day_start < MJD_stop:
        day_stop = min(MJD_stop, day_start + buflen)
        if rank == 0:
            log.debug(
                f"Processing {MJD_start} <= {day_start} - {day_stop} <= {MJD_stop}"
            )
        buf[:, :] = False
        for obs in data.obs:
            obs_data = data.select(obs_uid=obs.uid)
            dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            times = obs.shared[self.times].data
            days = to_MJD(times).astype(int)
            if days[0] >= day_stop or days[-1] < day_start:
                continue
            if self.shared_flags:
                cflag = (
                    obs.shared[self.shared_flags].data & self.shared_flag_mask
                ) != 0
            for day in range(day_start, day_stop):
                # Find samples that were collected on target day ...
                good = days == day
                if not np.any(good):
                    continue
                if self.shared_flags:
                    # ... and are not flagged ...
                    good[cflag] = False
                for det in dets:
                    if self.det_flags:
                        # ... even by detector flags
                        flag = obs.detdata[self.det_flags][det] & self.det_flag_mask
                        mask = np.logical_and(good, flag == 0)
                    else:
                        mask = good
                    # Compute pixel numbers.  Will do nothing if they already exist.
                    self.pixel_pointing.apply(obs_data, detectors=[det])
                    # Flag the hit pixels
                    pixels = obs.detdata[self.pixel_pointing.pixels][det]
                    mask[pixels < 0] = False
                    buf[day - day_start][pixels[mask]] = True
        if comm is not None:
            comm.Allreduce(MPI.IN_PLACE, buf, op=MPI.LOR)
        if rank == 0:
            for i in range(day_start, day_stop):
                all_hit[i - MJD_start] = buf[i - day_start]
        day_start = day_stop

    if rank == 0:
        fname = os.path.join(self.output_dir, f"{self.name}.h5")
        with h5py.File(fname, "w") as f:
            dset = f.create_dataset("cadence", data=all_hit)
            dset.attrs["MJDSTART"] = MJD_start
            dset.attrs["MJDSTOP"] = MJD_stop
            dset.attrs["NESTED"] = self.pixel_pointing.nest
        log.info(f"Wrote cadence map to {fname}.")

_finalize(data, **kwargs)

Source code in toast/ops/cadence_map.py
230
231
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/cadence_map.py
241
242
def _provides(self):
    return {}

_requires()

Source code in toast/ops/cadence_map.py
233
234
235
236
237
238
239
def _requires(self):
    req = self.pixel_pointing.requires()
    req["shared"].append(self.times)
    req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

toast.ops.CrossLinking

Bases: Operator

Evaluate an ACT-style crosslinking map

Source code in toast/ops/crosslinking.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
@trait_docs
class CrossLinking(Operator):
    """Evaluate an ACT-style crosslinking map"""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pixel pointing operator.",
    )

    pixel_dist = Unicode(
        "pixel_dist",
        help="The Data key where the PixelDist object should be stored",
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_data_units = Unit(
        defaults.det_data_units, help="Output units if creating detector data"
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional telescope flagging",
    )

    output_dir = Unicode(
        ".",
        help="Write output data products to this directory",
    )

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    save_pointing = Bool(False, help="If True, do not clear pixel numbers after use")

    # FIXME: these should be made into traits and also placed in _provides().

    signal = "dummy_signal"
    weights = "crosslinking_weights"
    crosslinking_map = "crosslinking_map"
    noise_model = "uniform_noise_weights"

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pntg = proposal["value"]
        if pntg is not None:
            if not isinstance(pntg, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pntg.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pntg

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _get_weights(self, obs_data, det):
        """Evaluate the special pointing matrix"""

        obs = obs_data.obs[0]
        exists_signal = obs.detdata.ensure(
            self.signal, detectors=[det], create_units=self.det_data_units
        )
        exists_weights = obs.detdata.ensure(
            self.weights, sample_shape=(3,), detectors=[det]
        )

        signal = obs.detdata[self.signal][det]
        signal[:] = 1
        weights = obs.detdata[self.weights][det]
        # Compute the detector quaternions
        self.pixel_pointing.detector_pointing.apply(obs_data, detectors=[det])
        quat = obs.detdata[self.pixel_pointing.detector_pointing.quats][det]
        # measure the scan direction wrt the local meridian for each sample
        theta, phi, _ = qa.to_iso_angles(quat)
        theta = np.pi / 2 - theta
        # scan direction across the reference sample
        dphi = np.roll(phi, -1) - np.roll(phi, 1)
        dtheta = np.roll(theta, -1) - np.roll(theta, 1)
        # except first and last sample
        for dx, x in (dphi, phi), (dtheta, theta):
            dx[0] = x[1] - x[0]
            dx[-1] = x[-1] - x[-2]
        # scale dphi to on-sky
        dphi *= np.cos(theta)
        # Avoid overflows
        tiny = np.abs(dphi) < 1e-30
        if np.any(tiny):
            ang = np.zeros(signal.size)
            ang[tiny] = np.sign(dtheta) * np.sign(dphi) * np.pi / 2
            not_tiny = np.logical_not(tiny)
            ang[not_tiny] = np.arctan(dtheta[not_tiny] / dphi[not_tiny])
        else:
            ang = np.arctan(dtheta / dphi)

        weights[:] = np.vstack(
            [np.ones(signal.size), np.cos(2 * ang), np.sin(2 * ang)]
        ).T

        return

    def _purge_weights(self, obs):
        """Discard special pointing matrix and dummy signal"""
        del obs.detdata[self.signal]
        del obs.detdata[self.weights]
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for trait in "pixel_pointing", "pixel_dist":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        if data.comm.world_rank == 0:
            os.makedirs(self.output_dir, exist_ok=True)

        # Establish uniform noise weights
        noise_model = UniformNoise()
        for obs in data.obs:
            obs[self.noise_model] = noise_model

        # To accumulate, we need the pixel distribution.

        if self.pixel_dist not in data:
            pix_dist = BuildPixelDistribution(
                pixel_dist=self.pixel_dist,
                pixel_pointing=self.pixel_pointing,
                shared_flags=self.shared_flags,
                shared_flag_mask=self.shared_flag_mask,
                save_pointing=self.save_pointing,
            )
            log.info_rank("Caching pixel distribution", comm=data.comm.comm_world)
            pix_dist.apply(data)

        # Accumulation operator

        build_zmap = BuildNoiseWeighted(
            pixel_dist=self.pixel_dist,
            zmap=self.crosslinking_map,
            view=self.pixel_pointing.view,
            pixels=self.pixel_pointing.pixels,
            weights=self.weights,
            noise_model=self.noise_model,
            det_data=self.signal,
            det_data_units=self.det_data_units,
            det_flags=self.det_flags,
            det_flag_mask=self.det_flag_mask,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            sync_type=self.sync_type,
        )

        for obs in data.obs:
            obs_data = data.select(obs_uid=obs.uid)
            dets = obs.select_local_detectors(detectors, flagmask=self.det_flag_mask)
            for det in dets:
                # Pointing weights
                self._get_weights(obs_data, det)
                # Pixel numbers
                self.pixel_pointing.apply(obs_data, detectors=[det])
                # Accumulate
                build_zmap.exec(obs_data, detectors=[det])

        build_zmap.finalize(data)

        # Write out the results

        fname = os.path.join(self.output_dir, f"{self.name}.fits")
        write_healpix_fits(
            data[self.crosslinking_map], fname, nest=self.pixel_pointing.nest
        )
        log.info_rank(f"Wrote crosslinking to {fname}", comm=data.comm.comm_world)
        data[self.crosslinking_map].clear()
        del data[self.crosslinking_map]

        for obs in data.obs:
            self._purge_weights(obs)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.pixel_pointing.detector_pointing.requires()
        return req

    def _provides(self):
        return {}

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

crosslinking_map = 'crosslinking_map' class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Output units if creating detector data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

noise_model = 'uniform_noise_weights' class-attribute instance-attribute

output_dir = Unicode('.', help='Write output data products to this directory') class-attribute instance-attribute

pixel_dist = Unicode('pixel_dist', help='The Data key where the PixelDist object should be stored') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pixel pointing operator.') class-attribute instance-attribute

save_pointing = Bool(False, help='If True, do not clear pixel numbers after use') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional telescope flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

signal = 'dummy_signal' class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

weights = 'crosslinking_weights' class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/crosslinking.py
136
137
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_flag_mask(proposal)

Source code in toast/ops/crosslinking.py
107
108
109
110
111
112
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/crosslinking.py
100
101
102
103
104
105
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_pixel_pointing(proposal)

Source code in toast/ops/crosslinking.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pntg = proposal["value"]
    if pntg is not None:
        if not isinstance(pntg, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pntg.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pntg

_check_shared_mask(proposal)

Source code in toast/ops/crosslinking.py
114
115
116
117
118
119
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/crosslinking.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for trait in "pixel_pointing", "pixel_dist":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    if data.comm.world_rank == 0:
        os.makedirs(self.output_dir, exist_ok=True)

    # Establish uniform noise weights
    noise_model = UniformNoise()
    for obs in data.obs:
        obs[self.noise_model] = noise_model

    # To accumulate, we need the pixel distribution.

    if self.pixel_dist not in data:
        pix_dist = BuildPixelDistribution(
            pixel_dist=self.pixel_dist,
            pixel_pointing=self.pixel_pointing,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            save_pointing=self.save_pointing,
        )
        log.info_rank("Caching pixel distribution", comm=data.comm.comm_world)
        pix_dist.apply(data)

    # Accumulation operator

    build_zmap = BuildNoiseWeighted(
        pixel_dist=self.pixel_dist,
        zmap=self.crosslinking_map,
        view=self.pixel_pointing.view,
        pixels=self.pixel_pointing.pixels,
        weights=self.weights,
        noise_model=self.noise_model,
        det_data=self.signal,
        det_data_units=self.det_data_units,
        det_flags=self.det_flags,
        det_flag_mask=self.det_flag_mask,
        shared_flags=self.shared_flags,
        shared_flag_mask=self.shared_flag_mask,
        sync_type=self.sync_type,
    )

    for obs in data.obs:
        obs_data = data.select(obs_uid=obs.uid)
        dets = obs.select_local_detectors(detectors, flagmask=self.det_flag_mask)
        for det in dets:
            # Pointing weights
            self._get_weights(obs_data, det)
            # Pixel numbers
            self.pixel_pointing.apply(obs_data, detectors=[det])
            # Accumulate
            build_zmap.exec(obs_data, detectors=[det])

    build_zmap.finalize(data)

    # Write out the results

    fname = os.path.join(self.output_dir, f"{self.name}.fits")
    write_healpix_fits(
        data[self.crosslinking_map], fname, nest=self.pixel_pointing.nest
    )
    log.info_rank(f"Wrote crosslinking to {fname}", comm=data.comm.comm_world)
    data[self.crosslinking_map].clear()
    del data[self.crosslinking_map]

    for obs in data.obs:
        self._purge_weights(obs)

    return

_finalize(data, **kwargs)

Source code in toast/ops/crosslinking.py
267
268
def _finalize(self, data, **kwargs):
    return

_get_weights(obs_data, det)

Evaluate the special pointing matrix

Source code in toast/ops/crosslinking.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@function_timer
def _get_weights(self, obs_data, det):
    """Evaluate the special pointing matrix"""

    obs = obs_data.obs[0]
    exists_signal = obs.detdata.ensure(
        self.signal, detectors=[det], create_units=self.det_data_units
    )
    exists_weights = obs.detdata.ensure(
        self.weights, sample_shape=(3,), detectors=[det]
    )

    signal = obs.detdata[self.signal][det]
    signal[:] = 1
    weights = obs.detdata[self.weights][det]
    # Compute the detector quaternions
    self.pixel_pointing.detector_pointing.apply(obs_data, detectors=[det])
    quat = obs.detdata[self.pixel_pointing.detector_pointing.quats][det]
    # measure the scan direction wrt the local meridian for each sample
    theta, phi, _ = qa.to_iso_angles(quat)
    theta = np.pi / 2 - theta
    # scan direction across the reference sample
    dphi = np.roll(phi, -1) - np.roll(phi, 1)
    dtheta = np.roll(theta, -1) - np.roll(theta, 1)
    # except first and last sample
    for dx, x in (dphi, phi), (dtheta, theta):
        dx[0] = x[1] - x[0]
        dx[-1] = x[-1] - x[-2]
    # scale dphi to on-sky
    dphi *= np.cos(theta)
    # Avoid overflows
    tiny = np.abs(dphi) < 1e-30
    if np.any(tiny):
        ang = np.zeros(signal.size)
        ang[tiny] = np.sign(dtheta) * np.sign(dphi) * np.pi / 2
        not_tiny = np.logical_not(tiny)
        ang[not_tiny] = np.arctan(dtheta[not_tiny] / dphi[not_tiny])
    else:
        ang = np.arctan(dtheta / dphi)

    weights[:] = np.vstack(
        [np.ones(signal.size), np.cos(2 * ang), np.sin(2 * ang)]
    ).T

    return

_provides()

Source code in toast/ops/crosslinking.py
274
275
def _provides(self):
    return {}

_purge_weights(obs)

Discard special pointing matrix and dummy signal

Source code in toast/ops/crosslinking.py
185
186
187
188
189
def _purge_weights(self, obs):
    """Discard special pointing matrix and dummy signal"""
    del obs.detdata[self.signal]
    del obs.detdata[self.weights]
    return

_requires()

Source code in toast/ops/crosslinking.py
270
271
272
def _requires(self):
    req = self.pixel_pointing.detector_pointing.requires()
    return req

Noise Estimation

toast.ops.NoiseEstim

Bases: Operator

Noise estimation operator

Source code in toast/ops/noise_estimation.py
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
@trait_docs
class NoiseEstim(Operator):
    """Noise estimation operator"""

    API = Int(0, help="Internal interface version for this operator")

    times = Unicode(
        defaults.times,
        help="Observation shared key for timestamps",
    )

    detector_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="Operator that translates boresight pointing into detector frame.  "
        "Only relevant if `maskfile` and/or `mapfile` are set",
    )

    pixel_dist = Unicode(
        "pixel_dist",
        help="The Data key where the PixelDistribution object is located.  "
        "Only relevant if `maskfile` and/or `mapfile` are set",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="An instance of a pixel pointing operator.  "
        "Only relevant if `maskfile` and/or `mapfile` are set",
    )

    stokes_weights = Instance(
        klass=Operator,
        allow_none=True,
        help="An instance of a Stokes weights operator.  "
        "Only relevant if `mapfile` is set",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key apply filtering to",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    mask_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for processing mask flags",
    )

    mask_flag_mask = Int(
        defaults.det_mask_processing, help="Bit mask for raising processing mask flags"
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    out_model = Unicode(
        None, allow_none=True, help="Create a new noise model with this name"
    )

    output_dir = Unicode(
        None,
        allow_none=True,
        help="If specified, write output data products to this directory",
    )

    maskfile = Unicode(
        None,
        allow_none=True,
        help="Optional HEALPix processing mask",
    )

    mapfile = Unicode(
        None,
        allow_none=True,
        help="Optional HEALPix map to sample and subtract from the signal",
    )

    pol = Bool(True, help="Sample also the polarized part of the map")

    save_cov = Bool(False, help="Save also the sample covariance")

    symmetric = Bool(
        False,
        help="If True, treat positive and negative lags as equivalent "
        "in the cross correlator",
    )

    nbin_psd = Int(1000, allow_none=True, help="Bin the resulting PSD")

    lagmax = Int(
        10000,
        help="Maximum lag to consider for the covariance function. "
        "Will be truncated the length of the longest view.",
    )

    stationary_period = Quantity(
        86400 * u.s,
        help="Break the observation into several estimation periods of this length",
    )

    nosingle = Bool(
        False, help="Do not evaluate individual PSDs.  Overridden by `pairs`"
    )

    nocross = Bool(True, help="Do not evaluate cross-PSDs.  Overridden by `pairs`")

    nsum = Int(1, help="Downsampling factor for decimated data")

    naverage = Int(100, help="Smoothing kernel width for downsampled data")

    view = Unicode(
        None, allow_none=True, help="Only measure the covariance within each view"
    )

    pairs = List(
        [],
        help="Detector pairs to estimate noise for.  Overrides `nosingle` and `nocross`",
    )

    focalplane_key = Unicode(
        None, allow_none=True, help="When set, PSDs are measured over averaged TODs"
    )

    remove_common_mode = Bool(False, help="Remove common mode signal before estimation")

    @traitlets.validate("detector_pointing")
    def _check_detector_pointing(self, proposal):
        detpointing = proposal["value"]
        if detpointing is not None:
            if not isinstance(detpointing, Operator):
                raise traitlets.TraitError(
                    "detector_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in [
                "view",
                "boresight",
                "det_mask",
                "shared_flags",
                "shared_flag_mask",
                "quats",
                "coord_in",
                "coord_out",
            ]:
                if not detpointing.has_trait(trt):
                    msg = f"detector_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return detpointing

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("nbin_psd")
    def _check_nbin_psd(self, proposal):
        check = proposal["value"]
        if check is not None and check <= 1:
            raise traitlets.TraitError("Number of PSD bins should be greater than one")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        return

    @function_timer
    def _redistribute(self, obs):
        log = Logger.get()
        timer = Timer()
        timer.start()
        if (len(self.pairs) > 0 or (not self.nocross)) and (
            obs.comm_col is not None and obs.comm_col.size > 1
        ):
            self.redistribute = True
            # Redistribute the data so each process has all detectors
            # for some sample range
            # Duplicate just the fields of the observation we will use
            dup_shared = [self.times]
            if self.shared_flags is not None:
                dup_shared.append(self.shared_flags)
            dup_detdata = [self.det_data]
            if self.det_flags is not None:
                dup_detdata.append(self.det_flags)
            dup_intervals = list()
            if self.view is not None:
                dup_intervals.append(self.view)
            temp_obs = obs.duplicate(
                times=self.times,
                meta=list(),
                shared=dup_shared,
                detdata=dup_detdata,
                intervals=dup_intervals,
            )
            log.debug_rank(
                f"{obs.comm.group:4} : Duplicated observation in",
                comm=temp_obs.comm.comm_group,
                timer=timer,
            )
            # Redistribute this temporary observation to be distributed by samples
            global_intervals = temp_obs.redistribute(
                1,
                times=self.times,
                override_sample_sets=None,
                return_global_intervals=True,
            )
            if self.view is not None:
                global_intervals = global_intervals[self.view]
            log.debug_rank(
                f"{obs.comm.group:4} : Redistributed observation in",
                comm=temp_obs.comm.comm_group,
                timer=timer,
            )
        else:
            self.redistribute = False
            temp_obs = obs
            global_intervals = []
            if self.view is not None:
                for ival in obs.intervals[self.view]:
                    global_intervals.append((ival.start, ival.stop))
        if self.view is None:
            global_intervals = [(None, None)]

        return temp_obs, global_intervals

    @function_timer
    def _re_redistribute(self, obs, temp_obs):
        log = Logger.get()
        timer = Timer()
        timer.start()
        if self.redistribute:
            # Redistribute data back
            temp_obs.redistribute(
                obs.dist.process_rows,
                times=self.times,
                override_sample_sets=obs.dist.sample_sets,
            )
            log.debug_rank(
                f"{temp_obs.comm.group:4} : Re-redistributed observation in",
                comm=temp_obs.comm.comm_group,
                timer=timer,
            )
            # Copy data to original observation
            obs.detdata[self.det_data][:] = temp_obs.detdata[self.det_data][:]
            log.debug_rank(
                f"{temp_obs.comm.group:4} : Copied observation data in",
                comm=temp_obs.comm.comm_group,
                timer=timer,
            )
            if self.out_model is not None:
                obs[self.out_model] = temp_obs[self.out_model]
            self.redistribute = False
        return

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        if detectors is not None:
            msg = "NoiseEstim cannot be run with subsets of detectors"
            raise RuntimeError(msg)

        log = Logger.get()

        if self.focalplane_key is not None:
            if len(self.pairs) > 0:
                msg = "focalplane_key is not compatible with pairs"
                raise RuntimeError(msg)
            if self.remove_common_mode:
                # Measure and subtract the common mode signal across the focalplane.
                Copy(detdata=[(self.det_data, "temp_signal")]).apply(data)
                CommonModeFilter(
                    det_data="temp_signal",
                    det_mask=self.det_mask,
                    det_flags=self.det_flags,
                    det_flag_mask=self.det_flag_mask,
                    focalplane_key=self.focalplane_key,
                ).apply(data)
                Combine(
                    op="subtract",
                    first=self.det_data,
                    second="temp_signal",
                    output=self.det_data,
                ).apply(data)
                Delete(detdata="temp_signal")

        if self.mapfile is not None:
            if self.pol:
                weights = self.stokes_weights
            else:
                weights = None
            scan_map = ScanHealpixMap(
                file=self.mapfile,
                det_data=self.det_data,
                det_mask=self.det_mask,
                subtract=True,
                pixel_dist=self.pixel_dist,
                pixel_pointing=self.pixel_pointing,
                stokes_weights=weights,
            )
            scan_map.apply(data, detectors=detectors)

        if self.maskfile is not None:
            scan_mask = ScanHealpixMask(
                file=self.maskfile,
                det_mask=self.det_mask,
                det_flags=self.mask_flags,
                def_flags_value=self.mask_flag_mask,
                pixel_dist=self.pixel_dist,
                pixel_pointing=self.pixel_pointing,
            )
            scan_mask.apply(data, detectors=detectors)

        for orig_obs in data.obs:
            # Optionally redistribute data, but only if we are computing
            # cross spectra.
            obs, global_intervals = self._redistribute(orig_obs)

            # Get the set of all local detectors we are considering for this obs.
            local_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            good_dets = set(local_dets)

            if self.focalplane_key is not None:
                # Pick just one detector to represent each key value
                fp = obs.telescope.focalplane
                det_names = []
                key2det = {}
                det2key = {}
                for det in local_dets:
                    key = fp[det][self.focalplane_key]
                    if key not in key2det:
                        det_names.append(det)
                        key2det[key] = det
                        det2key[det] = key
                pairs = []
                for det1 in key2det.values():
                    for det2 in key2det.values():
                        if det1 == det2 and self.nosingle:
                            continue
                        if det1 != det2 and self.nocross:
                            continue
                        pairs.append([det1, det2])
            else:
                det2key = None
                det_names = obs.all_detectors
                ndet = len(det_names)
                if len(self.pairs) > 0:
                    pairs = self.pairs
                else:
                    # Construct a list of detector pairs
                    pairs = []
                    for idet1 in range(ndet):
                        det1 = det_names[idet1]
                        for idet2 in range(idet1, ndet):
                            det2 = det_names[idet2]
                            if det1 == det2 and self.nosingle:
                                continue
                            if det1 != det2 and self.nocross:
                                continue
                            pairs.append([det1, det2])

            if self.symmetric:
                # Remove duplicate entries in pair list
                unordered_pairs = set()
                for pair in pairs:
                    unordered_pairs.add(tuple(sorted(pair)))
                pairs = list(unordered_pairs)

            times = np.array(obs.shared[self.times])
            nsample = times.size

            shared_flags = np.zeros(times.size, dtype=bool)
            if self.shared_flags is not None:
                shared_flags[:] = (
                    obs.shared[self.shared_flags].data & self.shared_flag_mask
                ) != 0

            fsample = obs.telescope.focalplane.sample_rate.to_value(u.Hz)

            fileroot = f"{self.name}_{obs.name}"

            # Re-use this flag array
            flags = np.zeros(times.size, dtype=bool)

            noise_dets = list()
            noise_freqs = dict()
            noise_psds = dict()
            noise_indices = dict()

            for det1, det2 in pairs:
                if det1 not in det_names or det2 not in det_names:
                    # User-specified pair is invalid
                    continue
                if det1 not in good_dets or (
                    det2 is not None and det2 not in good_dets
                ):
                    # One of our detectors is cut.  Store a zero PSD.
                    nse_freqs = np.array(
                        [
                            0.0,
                            1.0e-5,
                            fsample / 4,
                            fsample / 2,
                        ],
                        dtype=np.float64,
                    )
                    nse_psd = np.zeros_like(nse_freqs)
                else:
                    signal1 = obs.detdata[self.det_data][det1]

                    flags[:] = shared_flags
                    if self.det_flags is not None:
                        flags[:] |= (
                            obs.detdata[self.det_flags][det1] & self.det_flag_mask
                        ) != 0

                    signal2 = None
                    if det1 != det2:
                        signal2 = obs.detdata[self.det_data][det2]
                        if self.det_flags is not None:
                            flags[:] |= (
                                obs.detdata[self.det_flags][det2] & self.det_flag_mask
                            ) != 0

                    if det2key is None:
                        det1_name = det1
                        det2_name = det2
                    else:
                        det1_name = det2key[det1]
                        det2_name = det2key[det2]

                    nse_freqs, nse_psd = self.process_noise_estimate(
                        obs,
                        global_intervals,
                        signal1,
                        signal2,
                        flags,
                        times,
                        fsample,
                        fileroot,
                        det1_name,
                        det2_name,
                        self.lagmax,
                    )

                det_units = obs.detdata[self.det_data].units
                if det_units == u.dimensionless_unscaled:
                    msg = f"Observation {obs.name}, detector data '{self.det_data}'"
                    msg += f" has no units.  Assuming Kelvin."
                    log.warning(msg)
                    det_units = u.K
                psd_unit = det_units**2 * u.second
                noise_dets.append(det1)
                noise_freqs[det1] = nse_freqs[1:] * u.Hz
                noise_psds[det1] = nse_psd[1:] * psd_unit
                noise_indices[det1] = obs.telescope.focalplane[det1]["uid"]

            if self.out_model is not None:
                # Create a noise model for our local detectors.
                obs[self.out_model] = Noise(
                    detectors=noise_dets,
                    freqs=noise_freqs,
                    psds=noise_psds,
                    indices=noise_indices,
                )

            # Redistribute the observation, replacing the input TOD with the filtered
            # one and redistributing the noise model.
            self._re_redistribute(orig_obs, obs)

            # Delete temporary obs
            del obs

    @function_timer
    def decimate(self, signal, flags):
        """Downsample previously highpass-filtered signal"""
        return signal[:: self.nsum].copy(), flags[:: self.nsum].copy()

    @function_timer
    def log_bin(self, freq, nbin=100, fmin=None, fmax=None):
        if np.any(freq == 0):
            msg = "Logarithmic binning should not include zero frequency"
            raise Exception(msg)

        if fmin is None:
            fmin = np.amin(freq)
        if fmax is None:
            fmax = np.amax(freq)

        bins = np.logspace(
            np.log(fmin), np.log(fmax), num=nbin + 1, endpoint=True, base=np.e
        )
        bins[-1] *= 1.01  # Widen the last bin not to have a bin with one entry

        locs = np.digitize(freq, bins).astype(np.int32)
        hits = np.zeros(nbin + 2, dtype=np.int32)
        for loc in locs:
            hits[loc] += 1
        return locs, hits

    @function_timer
    def bin_psds(self, my_psds, fmin=None, fmax=None):
        my_binned_psds = []
        my_times = []
        binfreq0 = None

        for i in range(len(my_psds)):
            t0, _, freq, psd = my_psds[i]

            good = freq != 0

            if self.nbin_psd is not None:
                locs, hits = self.log_bin(
                    freq[good], nbin=self.nbin_psd, fmin=fmin, fmax=fmax
                )
                binfreq = np.zeros(hits.size)
                for loc, f in zip(locs, freq[good]):
                    binfreq[loc] += f
                binfreq = binfreq[hits != 0] / hits[hits != 0]
            else:
                binfreq = freq
                hits = np.ones(len(binfreq))

            if binfreq0 is None:
                binfreq0 = binfreq
            else:
                if np.any(binfreq != binfreq0):
                    msg = "Binned PSD frequencies change"
                    raise RuntimeError(msg)

            if self.nbin_psd is not None:
                binpsd = np.zeros(hits.size)
                for loc, p in zip(locs, psd[good]):
                    binpsd[loc] += p
                binpsd = binpsd[hits != 0] / hits[hits != 0]
            else:
                binpsd = psd

            my_times.append(t0)
            my_binned_psds.append(binpsd)
        return my_binned_psds, my_times, binfreq0

    @function_timer
    def discard_outliers(self, binfreq, all_psds, all_times, all_cov):
        log = Logger.get()

        all_psds = copy.deepcopy(all_psds)
        all_times = copy.deepcopy(all_times)
        if self.save_cov:
            all_cov = copy.deepcopy(all_cov)

        nrow, ncol = np.shape(all_psds)

        # Discard empty PSDs

        i = 1
        nempty = 0
        while i < nrow:
            p = all_psds[i]
            if np.all(p == 0) or np.any(np.isnan(p)):
                del all_psds[i]
                del all_times[i]
                if self.save_cov:
                    del all_cov[i]
                nrow -= 1
                nempty += 1
            else:
                i += 1

        if nempty > 0:
            log.debug(f"Discarded {nempty} empty or NaN psds")

        # Throw away outlier PSDs by comparing the PSDs in specific bins

        if nrow < 10:
            nbad = 0
        else:
            all_good = np.isfinite(np.sum(all_psds, 1))
            for col in range(ncol - 1):
                if binfreq[col] < 0.001:
                    continue

                # Local outliers

                psdvalues = np.array([x[col] for x in all_psds])
                smooth_values = scipy.signal.medfilt(psdvalues, 11)
                good = np.ones(psdvalues.size, dtype=bool)
                good[psdvalues == 0] = False

                for i in range(10):
                    # Local test
                    diff = np.zeros(psdvalues.size)
                    diff[good] = np.log(psdvalues[good]) - np.log(smooth_values[good])
                    sdev = np.std(diff[good])
                    good[np.abs(diff) > 5 * sdev] = False
                    # Global test
                    diff = np.zeros(psdvalues.size)
                    diff[good] = np.log(psdvalues[good]) - np.mean(
                        np.log(psdvalues[good])
                    )
                    sdev = np.std(diff[good])
                    good[np.abs(diff) > 5 * sdev] = False

                all_good[np.logical_not(good)] = False

            bad = np.logical_not(all_good)
            nbad = np.sum(bad)
            if nbad > 0:
                for ii in np.argwhere(bad).ravel()[::-1]:
                    del all_psds[ii]
                    del all_times[ii]
                    if self.save_cov:
                        del all_cov[ii]

            if nbad > 0:
                log.debug(f"Masked extra {nbad} psds due to outliers.")
        return all_psds, all_times, nempty + nbad, all_cov

    @function_timer
    def save_psds(
        self, binfreq, all_psds, all_times, det1, det2, fsample, rootname, all_cov
    ):
        log = Logger.get()
        timer = Timer()
        timer.start()
        os.makedirs(self.output_dir, exist_ok=True)
        if det1 == det2:
            fn_out = os.path.join(self.output_dir, f"{rootname}_{det1}.fits")
        else:
            fn_out = os.path.join(self.output_dir, f"{rootname}_{det1}_{det2}.fits")
        all_psds = np.vstack([binfreq, all_psds])

        hdulist = [pf.PrimaryHDU()]

        cols = []
        cols.append(pf.Column(name="OBT", format="D", array=all_times))
        coldefs = pf.ColDefs(cols)
        hdu1 = pf.BinTableHDU.from_columns(coldefs)
        hdu1.header["RATE"] = fsample, "Sampling rate"
        hdulist.append(hdu1)

        cols = []
        cols.append(pf.Column(name="PSD", format=f"{binfreq.size}E", array=all_psds))
        coldefs = pf.ColDefs(cols)
        hdu2 = pf.BinTableHDU.from_columns(coldefs)
        hdu2.header["EXTNAME"] = str(det1), "Detector"
        hdu2.header["DET1"] = str(det1), "Detector1"
        hdu2.header["DET2"] = str(det2), "Detector2"
        hdulist.append(hdu2)

        if self.save_cov:
            all_cov = np.array(all_cov)
            cols = []
            nrow, ncol, nsamp = np.shape(all_cov)
            cols.append(
                pf.Column(
                    name="HITS",
                    format=f"{nsamp}J",
                    array=np.ascontiguousarray(all_cov[:, 0, :]),
                )
            )
            cols.append(
                pf.Column(
                    name="COV",
                    format=f"{nsamp}E",
                    array=np.ascontiguousarray(all_cov[:, 1, :]),
                )
            )
            coldefs = pf.ColDefs(cols)
            hdu3 = pf.BinTableHDU.from_columns(coldefs)
            hdu3.header["EXTNAME"] = str(det1), "Detector"
            hdu3.header["DET1"] = str(det1), "Detector1"
            hdu3.header["DET2"] = str(det2), "Detector2"
            hdulist.append(hdu3)

        hdulist = pf.HDUList(hdulist)

        with open(fn_out, "wb") as fits_out:
            hdulist.writeto(fits_out, overwrite=True)

        log.debug(f"Detector {det1} vs. {det2} PSDs stored in {fn_out}")

        return

    @function_timer
    def process_downsampled_noise_estimate(
        self,
        obs,
        global_intervals,
        timestamps,
        fsample,
        signal1,
        signal2,
        flags,
        my_psds1,
        my_cov1,
        lagmax,
    ):
        # Get the process grid row communicator, used to communicate overlaps
        comm = obs.comm_row

        # Get another PSD for a down-sampled TOD to measure the
        # low frequency power

        timestamps_decim = timestamps[:: self.nsum]
        # decimate() will smooth and downsample the signal in
        # each valid interval separately
        signal1_decim, flags_decim = self.decimate(signal1, flags)
        if signal2 is None:
            signal2_decim = None
        else:
            signal2_decim, flags_decim = self.decimate(signal2, flags)

        stationary_period = self.stationary_period.to_value(u.s)
        lagmax = min(lagmax, timestamps_decim.size)

        # We apply a prewhitening filter to the signal.  To accommodate the
        # quality flags, the filter is a moving average that only accounts
        # for the unflagged samples
        naverage = lagmax

        # Extend the local arrays to remove boundary effects from filtering
        if comm is None or comm.size == 1:
            extended_times = timestamps_decim
            extended_flags = flags_decim
            extended_signal1 = signal1_decim
            extended_signal2 = signal2_decim
        else:
            (
                extended_times,
                extended_flags,
                extended_signal1,
                extended_signal2,
            ) = communicate_overlap(
                timestamps_decim,
                signal1_decim,
                signal2_decim,
                flags_decim,
                lagmax,
                naverage,
                comm,
                obs.comm.group,
            )
        # High pass filter the signal to avoid aliasing
        extended_signal1 = highpass_flagged_signal(
            extended_signal1,
            extended_flags == 0,
            naverage,
        )
        if signal2 is not None:
            extended_signal2 = highpass_flagged_signal(
                extended_signal2,
                extended_flags == 0,
                naverage,
            )
        # Crop the filtering margin but keep up to lagmax samples
        half_average = naverage // 2 + 1
        if comm is not None and comm.rank > 0:
            extended_times = extended_times[half_average:]
            extended_flags = extended_flags[half_average:]
            extended_signal1 = extended_signal1[half_average:]
            if extended_signal2 is not None:
                extended_signal2 = extended_signal2[half_average:]
        if comm is not None and comm.rank < comm.size - 1:
            extended_times = extended_times[:-half_average]
            extended_flags = extended_flags[:-half_average]
            extended_signal1 = extended_signal1[:-half_average]
            if extended_signal2 is not None:
                extended_signal2 = extended_signal2[:-half_average]

        if signal2 is None:
            result = autocov_psd(
                timestamps_decim,
                extended_times,
                global_intervals,
                extended_signal1,
                extended_flags,
                lagmax,
                naverage,
                stationary_period,
                fsample / self.nsum,
                comm=comm,
                return_cov=self.save_cov,
            )
        else:
            result = crosscov_psd(
                timestamps_decim,
                extended_times,
                global_intervals,
                extended_signal1,
                extended_signal2,
                extended_flags,
                lagmax,
                naverage,
                stationary_period,
                fsample / self.nsum,
                comm=comm,
                return_cov=self.save_cov,
                symmetric=self.symmetric,
            )
        if self.save_cov:
            my_psds2, my_cov2 = result
        else:
            my_psds2, my_cov2 = result, None

        # Ensure the two sets of PSDs are of equal length

        my_new_psds1 = []
        my_new_psds2 = []
        if self.save_cov:
            my_new_cov1 = []
            my_new_cov2 = []
        i = 0
        while i < min(len(my_psds1), len(my_psds2)):
            t1 = my_psds1[i][0]
            t2 = my_psds2[i][0]
            if np.isclose(t1, t2):
                my_new_psds1.append(my_psds1[i])
                my_new_psds2.append(my_psds2[i])
                if self.save_cov:
                    my_new_cov1.append(my_cov1[i])
                    my_new_cov2.append(my_cov2[i])
                i += 1
            else:
                if t1 < t2:
                    del my_psds1[i]
                    if self._cov:
                        del my_cov1[i]
                else:
                    del my_psds2[i]
                    if self._cov:
                        del my_cov1[i]
        my_psds1 = my_new_psds1
        my_psds2 = my_new_psds2
        if self.save_cov:
            my_cov1 = my_new_cov1
            my_cov2 = my_new_cov2

        if len(my_psds1) != len(my_psds2):
            while my_psds1[-1][0] > my_psds2[-1][0]:
                del my_psds1[-1]
                if self.save_cov:
                    del my_cov1[-1]
            while my_psds1[-1][0] < my_psds2[-1][0]:
                del my_psds2[-1]
                if self.save_cov:
                    del my_cov2[-1]
        return my_psds1, my_cov1, my_psds2, my_cov2

    @function_timer
    def process_noise_estimate(
        self,
        obs,
        global_intervals,
        signal1,
        signal2,
        flags,
        timestamps,
        fsample,
        fileroot,
        det1,
        det2,
        lagmax,
    ):
        """Measure the sample (cross) covariance in the signal-subtracted
        TOD and Fourier-transform it for noise PSD.
        """

        log = Logger.get()

        # Get the process grid row communicator, used to communicate overlaps
        comm = obs.comm_row

        # We apply a prewhitening filter to the signal.  To accommodate the
        # quality flags, the filter is a moving average that only accounts
        # for the unflagged samples
        naverage = lagmax

        # Extend the local arrays to remove boundary effects from filtering
        if comm is None or comm.size == 1:
            extended_times = timestamps
            extended_flags = flags
            extended_signal1 = signal1
            extended_signal2 = signal2
        else:
            (
                extended_times,
                extended_flags,
                extended_signal1,
                extended_signal2,
            ) = communicate_overlap(
                timestamps,
                signal1,
                signal2,
                flags,
                lagmax,
                naverage,
                comm,
                obs.comm.group,
            )
        # High pass filter the signal to avoid aliasing
        extended_signal1 = highpass_flagged_signal(
            extended_signal1,
            extended_flags == 0,
            naverage,
        )
        if signal2 is not None:
            extended_signal2 = highpass_flagged_signal(
                extended_signal2,
                extended_flags == 0,
                naverage,
            )
        # Crop the filtering margin but keep up to lagmax samples
        half_average = naverage // 2 + 1
        if comm is not None and comm.rank > 0:
            extended_times = extended_times[half_average:]
            extended_flags = extended_flags[half_average:]
            extended_signal1 = extended_signal1[half_average:]
            if extended_signal2 is not None:
                extended_signal2 = extended_signal2[half_average:]
        if comm is not None and comm.rank < comm.size - 1:
            extended_times = extended_times[:-half_average]
            extended_flags = extended_flags[:-half_average]
            extended_signal1 = extended_signal1[:-half_average]
            if extended_signal2 is not None:
                extended_signal2 = extended_signal2[:-half_average]

        # Compute the autocovariance function and the matching
        # PSD for each stationary interval

        timer = Timer()
        timer.start()
        stationary_period = self.stationary_period.to_value(u.s)
        if signal2 is None:
            result = autocov_psd(
                timestamps,
                extended_times,
                global_intervals,
                extended_signal1,
                extended_flags,
                lagmax,
                naverage,
                stationary_period,
                fsample,
                comm=comm,
                return_cov=self.save_cov,
            )
        else:
            result = crosscov_psd(
                timestamps,
                extended_times,
                global_intervals,
                extended_signal1,
                extended_signal2,
                extended_flags,
                lagmax,
                naverage,
                stationary_period,
                fsample,
                comm=comm,
                return_cov=self.save_cov,
                symmetric=self.symmetric,
            )
        if self.save_cov:
            my_psds1, my_cov1 = result
        else:
            my_psds1, my_cov1 = result, None

        if self.nsum > 1:
            (
                my_psds1,
                my_cov1,
                my_psds2,
                my_cov2,
            ) = self.process_downsampled_noise_estimate(
                obs,
                global_intervals,
                extended_times,
                fsample,
                extended_signal1,
                extended_signal2,
                extended_flags,
                my_psds1,
                my_cov1,
                lagmax,
            )

        log.debug_rank(
            f"Compute Correlators and PSDs for {det1} / {det2}",
            comm=comm,
            rank=0,
            timer=timer,
        )

        # Now bin the PSDs

        fmin = 1 / stationary_period
        fmax = fsample / 2

        my_binned_psds1, my_times1, binfreq10 = self.bin_psds(my_psds1, fmin, fmax)
        if self.nsum > 1:
            my_binned_psds2, _, binfreq20 = self.bin_psds(my_psds2, fmin, fmax)

        log.debug_rank(
            f"Bin PSDs for {det1} / {det2}",
            comm=comm,
            rank=0,
            timer=timer,
        )

        # concatenate

        if self.save_cov:
            my_cov = my_cov1  # Only store the fully sampled covariance

        if binfreq10 is None:
            my_times = []
            my_binned_psds = []
            binfreq0 = None
        else:
            my_times = my_times1
            if self.nsum > 1:
                # frequencies that are usable in the down-sampled PSD
                fcut = fsample / 2 / self.naverage / 100
                ind1 = binfreq10 > fcut
                ind2 = binfreq20 <= fcut
                binfreq0 = np.hstack([binfreq20[ind2], binfreq10[ind1]])
                my_binned_psds = []
                for psd1, psd2 in zip(my_binned_psds1, my_binned_psds2):
                    my_binned_psds.append(np.hstack([psd2[ind2], psd1[ind1]]))
            else:
                binfreq0 = binfreq10
                my_binned_psds = my_binned_psds1

        # Collect and write the PSDs.  Start by determining the first
        # process to have a valid PSD to determine binning

        have_bins = binfreq0 is not None
        have_bins_all = None
        if comm is None:
            have_bins_all = [have_bins]
        else:
            have_bins_all = comm.allgather(have_bins)
        root = 0
        if np.any(have_bins_all):
            while not have_bins_all[root]:
                root += 1
        else:
            msg = "None of the processes have valid PSDs"
            raise RuntimeError(msg)
        binfreq = None
        if comm is None:
            binfreq = binfreq0
        else:
            binfreq = comm.bcast(binfreq0, root=root)
        if binfreq0 is not None and np.any(binfreq != binfreq0):
            msg = (
                f"{obs.comm.world_rank:4} : Binned PSD frequencies change. "
                f"len(binfreq0) = {binfreq0.size}, "
                f"len(binfreq) = {binfreq.size}, binfreq0={binfreq0}, "
                f"binfreq = {binfreq}. len(my_psds) = {len(my_psds1)}"
            )
            raise RuntimeError(msg)

        if len(my_times) != len(my_binned_psds):
            msg = (
                f"ERROR: Process {obs.comm.world_rank} has len(my_times) = "
                f"{len(my_times)}, "
                f"len(my_binned_psds) = {len(my_binned_psds)}"
            )
            raise RuntimeError(msg)

        all_times = None
        all_psds = None
        if comm is None:
            all_times = [my_times]
            all_psds = [my_binned_psds]
        else:
            all_times = comm.gather(my_times, root=0)
            all_psds = comm.gather(my_binned_psds, root=0)
        all_cov = None
        if self.save_cov:
            if comm is None:
                all_cov = [my_cov]
            else:
                all_cov = comm.gather(my_cov, root=0)

        log.debug_rank(
            f"Collect PSDs for {det1} / {det2}",
            comm=comm,
            rank=0,
            timer=timer,
        )

        final_freqs = None
        final_psd = None
        if obs.comm_row_rank == 0:
            if len(all_times) != len(all_psds):
                msg = (
                    f"ERROR: Process {obs.comm.world_rank} has len(all_times) = "
                    f"{len(all_times)},"
                    f" len(all_psds) = {len(all_psds)} before deglitch"
                )
                raise RuntimeError(msg)

            # De-glitch the binned PSDs and write them to file
            i = 0
            while i < len(all_times):
                if len(all_times[i]) == 0:
                    del all_times[i]
                    del all_psds[i]
                    if self.save_cov:
                        del all_cov[i]
                else:
                    i += 1

            if len(all_times) != len(all_psds):
                msg = (
                    f"ERROR: Process {obs.comm.world_rank} has len(all_times) = "
                    f"{len(all_times)}, "
                    f"len(all_psds) = {len(all_psds)} AFTER deglitch"
                )
                raise RuntimeError(msg)

            all_times = list(np.hstack(all_times))
            all_psds = list(np.hstack(all_psds))
            if self.save_cov:
                all_cov = list(np.hstack(all_cov))

            good_psds, good_times, nbad, good_cov = self.discard_outliers(
                binfreq, all_psds, all_times, all_cov
            )
            log.debug_rank("Discard outliers", timer=timer)

            if self.output_dir is not None:
                self.save_psds(
                    binfreq, all_psds, all_times, det1, det2, fsample, fileroot, all_cov
                )
                if nbad > 0:
                    self.save_psds(
                        binfreq,
                        good_psds,
                        good_times,
                        det1,
                        det2,
                        fsample,
                        fileroot + "_good",
                        good_cov,
                    )

            final_freqs = binfreq
            final_psd = np.mean(np.array(good_psds), axis=0)
            log.debug_rank(f"Write PSDs for {det1} / {det2}", timer=timer)

        return final_freqs, final_psd

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        if self.detector_pointing is not None:
            req.update(self.detector_pointing.requires())
        return req

    def _provides(self):
        prov = {
            "meta": list(),
            "shared": list(),
            "detdata": list(),
        }
        if self.out_model is not None:
            prov["meta"].append(self.out_model)
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key apply filtering to') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

detector_pointing = Instance(klass=Operator, allow_none=True, help='Operator that translates boresight pointing into detector frame. Only relevant if `maskfile` and/or `mapfile` are set') class-attribute instance-attribute

focalplane_key = Unicode(None, allow_none=True, help='When set, PSDs are measured over averaged TODs') class-attribute instance-attribute

lagmax = Int(10000, help='Maximum lag to consider for the covariance function. Will be truncated the length of the longest view.') class-attribute instance-attribute

mapfile = Unicode(None, allow_none=True, help='Optional HEALPix map to sample and subtract from the signal') class-attribute instance-attribute

mask_flag_mask = Int(defaults.det_mask_processing, help='Bit mask for raising processing mask flags') class-attribute instance-attribute

mask_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for processing mask flags') class-attribute instance-attribute

maskfile = Unicode(None, allow_none=True, help='Optional HEALPix processing mask') class-attribute instance-attribute

naverage = Int(100, help='Smoothing kernel width for downsampled data') class-attribute instance-attribute

nbin_psd = Int(1000, allow_none=True, help='Bin the resulting PSD') class-attribute instance-attribute

nocross = Bool(True, help='Do not evaluate cross-PSDs. Overridden by `pairs`') class-attribute instance-attribute

nosingle = Bool(False, help='Do not evaluate individual PSDs. Overridden by `pairs`') class-attribute instance-attribute

nsum = Int(1, help='Downsampling factor for decimated data') class-attribute instance-attribute

out_model = Unicode(None, allow_none=True, help='Create a new noise model with this name') class-attribute instance-attribute

output_dir = Unicode(None, allow_none=True, help='If specified, write output data products to this directory') class-attribute instance-attribute

pairs = List([], help='Detector pairs to estimate noise for. Overrides `nosingle` and `nocross`') class-attribute instance-attribute

pixel_dist = Unicode('pixel_dist', help='The Data key where the PixelDistribution object is located. Only relevant if `maskfile` and/or `mapfile` are set') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='An instance of a pixel pointing operator. Only relevant if `maskfile` and/or `mapfile` are set') class-attribute instance-attribute

pol = Bool(True, help='Sample also the polarized part of the map') class-attribute instance-attribute

remove_common_mode = Bool(False, help='Remove common mode signal before estimation') class-attribute instance-attribute

save_cov = Bool(False, help='Save also the sample covariance') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

stationary_period = Quantity(86400 * u.s, help='Break the observation into several estimation periods of this length') class-attribute instance-attribute

stokes_weights = Instance(klass=Operator, allow_none=True, help='An instance of a Stokes weights operator. Only relevant if `mapfile` is set') class-attribute instance-attribute

symmetric = Bool(False, help='If True, treat positive and negative lags as equivalent in the cross correlator') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Only measure the covariance within each view') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/noise_estimation.py
236
237
238
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

_check_det_flag_mask(proposal)

Source code in toast/ops/noise_estimation.py
215
216
217
218
219
220
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/noise_estimation.py
208
209
210
211
212
213
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_detector_pointing(proposal)

Source code in toast/ops/noise_estimation.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
    detpointing = proposal["value"]
    if detpointing is not None:
        if not isinstance(detpointing, Operator):
            raise traitlets.TraitError(
                "detector_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in [
            "view",
            "boresight",
            "det_mask",
            "shared_flags",
            "shared_flag_mask",
            "quats",
            "coord_in",
            "coord_out",
        ]:
            if not detpointing.has_trait(trt):
                msg = f"detector_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return detpointing

_check_nbin_psd(proposal)

Source code in toast/ops/noise_estimation.py
229
230
231
232
233
234
@traitlets.validate("nbin_psd")
def _check_nbin_psd(self, proposal):
    check = proposal["value"]
    if check is not None and check <= 1:
        raise traitlets.TraitError("Number of PSD bins should be greater than one")
    return check

_check_shared_mask(proposal)

Source code in toast/ops/noise_estimation.py
222
223
224
225
226
227
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/noise_estimation.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    if detectors is not None:
        msg = "NoiseEstim cannot be run with subsets of detectors"
        raise RuntimeError(msg)

    log = Logger.get()

    if self.focalplane_key is not None:
        if len(self.pairs) > 0:
            msg = "focalplane_key is not compatible with pairs"
            raise RuntimeError(msg)
        if self.remove_common_mode:
            # Measure and subtract the common mode signal across the focalplane.
            Copy(detdata=[(self.det_data, "temp_signal")]).apply(data)
            CommonModeFilter(
                det_data="temp_signal",
                det_mask=self.det_mask,
                det_flags=self.det_flags,
                det_flag_mask=self.det_flag_mask,
                focalplane_key=self.focalplane_key,
            ).apply(data)
            Combine(
                op="subtract",
                first=self.det_data,
                second="temp_signal",
                output=self.det_data,
            ).apply(data)
            Delete(detdata="temp_signal")

    if self.mapfile is not None:
        if self.pol:
            weights = self.stokes_weights
        else:
            weights = None
        scan_map = ScanHealpixMap(
            file=self.mapfile,
            det_data=self.det_data,
            det_mask=self.det_mask,
            subtract=True,
            pixel_dist=self.pixel_dist,
            pixel_pointing=self.pixel_pointing,
            stokes_weights=weights,
        )
        scan_map.apply(data, detectors=detectors)

    if self.maskfile is not None:
        scan_mask = ScanHealpixMask(
            file=self.maskfile,
            det_mask=self.det_mask,
            det_flags=self.mask_flags,
            def_flags_value=self.mask_flag_mask,
            pixel_dist=self.pixel_dist,
            pixel_pointing=self.pixel_pointing,
        )
        scan_mask.apply(data, detectors=detectors)

    for orig_obs in data.obs:
        # Optionally redistribute data, but only if we are computing
        # cross spectra.
        obs, global_intervals = self._redistribute(orig_obs)

        # Get the set of all local detectors we are considering for this obs.
        local_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
        good_dets = set(local_dets)

        if self.focalplane_key is not None:
            # Pick just one detector to represent each key value
            fp = obs.telescope.focalplane
            det_names = []
            key2det = {}
            det2key = {}
            for det in local_dets:
                key = fp[det][self.focalplane_key]
                if key not in key2det:
                    det_names.append(det)
                    key2det[key] = det
                    det2key[det] = key
            pairs = []
            for det1 in key2det.values():
                for det2 in key2det.values():
                    if det1 == det2 and self.nosingle:
                        continue
                    if det1 != det2 and self.nocross:
                        continue
                    pairs.append([det1, det2])
        else:
            det2key = None
            det_names = obs.all_detectors
            ndet = len(det_names)
            if len(self.pairs) > 0:
                pairs = self.pairs
            else:
                # Construct a list of detector pairs
                pairs = []
                for idet1 in range(ndet):
                    det1 = det_names[idet1]
                    for idet2 in range(idet1, ndet):
                        det2 = det_names[idet2]
                        if det1 == det2 and self.nosingle:
                            continue
                        if det1 != det2 and self.nocross:
                            continue
                        pairs.append([det1, det2])

        if self.symmetric:
            # Remove duplicate entries in pair list
            unordered_pairs = set()
            for pair in pairs:
                unordered_pairs.add(tuple(sorted(pair)))
            pairs = list(unordered_pairs)

        times = np.array(obs.shared[self.times])
        nsample = times.size

        shared_flags = np.zeros(times.size, dtype=bool)
        if self.shared_flags is not None:
            shared_flags[:] = (
                obs.shared[self.shared_flags].data & self.shared_flag_mask
            ) != 0

        fsample = obs.telescope.focalplane.sample_rate.to_value(u.Hz)

        fileroot = f"{self.name}_{obs.name}"

        # Re-use this flag array
        flags = np.zeros(times.size, dtype=bool)

        noise_dets = list()
        noise_freqs = dict()
        noise_psds = dict()
        noise_indices = dict()

        for det1, det2 in pairs:
            if det1 not in det_names or det2 not in det_names:
                # User-specified pair is invalid
                continue
            if det1 not in good_dets or (
                det2 is not None and det2 not in good_dets
            ):
                # One of our detectors is cut.  Store a zero PSD.
                nse_freqs = np.array(
                    [
                        0.0,
                        1.0e-5,
                        fsample / 4,
                        fsample / 2,
                    ],
                    dtype=np.float64,
                )
                nse_psd = np.zeros_like(nse_freqs)
            else:
                signal1 = obs.detdata[self.det_data][det1]

                flags[:] = shared_flags
                if self.det_flags is not None:
                    flags[:] |= (
                        obs.detdata[self.det_flags][det1] & self.det_flag_mask
                    ) != 0

                signal2 = None
                if det1 != det2:
                    signal2 = obs.detdata[self.det_data][det2]
                    if self.det_flags is not None:
                        flags[:] |= (
                            obs.detdata[self.det_flags][det2] & self.det_flag_mask
                        ) != 0

                if det2key is None:
                    det1_name = det1
                    det2_name = det2
                else:
                    det1_name = det2key[det1]
                    det2_name = det2key[det2]

                nse_freqs, nse_psd = self.process_noise_estimate(
                    obs,
                    global_intervals,
                    signal1,
                    signal2,
                    flags,
                    times,
                    fsample,
                    fileroot,
                    det1_name,
                    det2_name,
                    self.lagmax,
                )

            det_units = obs.detdata[self.det_data].units
            if det_units == u.dimensionless_unscaled:
                msg = f"Observation {obs.name}, detector data '{self.det_data}'"
                msg += f" has no units.  Assuming Kelvin."
                log.warning(msg)
                det_units = u.K
            psd_unit = det_units**2 * u.second
            noise_dets.append(det1)
            noise_freqs[det1] = nse_freqs[1:] * u.Hz
            noise_psds[det1] = nse_psd[1:] * psd_unit
            noise_indices[det1] = obs.telescope.focalplane[det1]["uid"]

        if self.out_model is not None:
            # Create a noise model for our local detectors.
            obs[self.out_model] = Noise(
                detectors=noise_dets,
                freqs=noise_freqs,
                psds=noise_psds,
                indices=noise_indices,
            )

        # Redistribute the observation, replacing the input TOD with the filtered
        # one and redistributing the noise model.
        self._re_redistribute(orig_obs, obs)

        # Delete temporary obs
        del obs

_finalize(data, **kwargs)

Source code in toast/ops/noise_estimation.py
1228
1229
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/noise_estimation.py
1248
1249
1250
1251
1252
1253
1254
1255
1256
def _provides(self):
    prov = {
        "meta": list(),
        "shared": list(),
        "detdata": list(),
    }
    if self.out_model is not None:
        prov["meta"].append(self.out_model)
    return prov

_re_redistribute(obs, temp_obs)

Source code in toast/ops/noise_estimation.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@function_timer
def _re_redistribute(self, obs, temp_obs):
    log = Logger.get()
    timer = Timer()
    timer.start()
    if self.redistribute:
        # Redistribute data back
        temp_obs.redistribute(
            obs.dist.process_rows,
            times=self.times,
            override_sample_sets=obs.dist.sample_sets,
        )
        log.debug_rank(
            f"{temp_obs.comm.group:4} : Re-redistributed observation in",
            comm=temp_obs.comm.comm_group,
            timer=timer,
        )
        # Copy data to original observation
        obs.detdata[self.det_data][:] = temp_obs.detdata[self.det_data][:]
        log.debug_rank(
            f"{temp_obs.comm.group:4} : Copied observation data in",
            comm=temp_obs.comm.comm_group,
            timer=timer,
        )
        if self.out_model is not None:
            obs[self.out_model] = temp_obs[self.out_model]
        self.redistribute = False
    return

_redistribute(obs)

Source code in toast/ops/noise_estimation.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
@function_timer
def _redistribute(self, obs):
    log = Logger.get()
    timer = Timer()
    timer.start()
    if (len(self.pairs) > 0 or (not self.nocross)) and (
        obs.comm_col is not None and obs.comm_col.size > 1
    ):
        self.redistribute = True
        # Redistribute the data so each process has all detectors
        # for some sample range
        # Duplicate just the fields of the observation we will use
        dup_shared = [self.times]
        if self.shared_flags is not None:
            dup_shared.append(self.shared_flags)
        dup_detdata = [self.det_data]
        if self.det_flags is not None:
            dup_detdata.append(self.det_flags)
        dup_intervals = list()
        if self.view is not None:
            dup_intervals.append(self.view)
        temp_obs = obs.duplicate(
            times=self.times,
            meta=list(),
            shared=dup_shared,
            detdata=dup_detdata,
            intervals=dup_intervals,
        )
        log.debug_rank(
            f"{obs.comm.group:4} : Duplicated observation in",
            comm=temp_obs.comm.comm_group,
            timer=timer,
        )
        # Redistribute this temporary observation to be distributed by samples
        global_intervals = temp_obs.redistribute(
            1,
            times=self.times,
            override_sample_sets=None,
            return_global_intervals=True,
        )
        if self.view is not None:
            global_intervals = global_intervals[self.view]
        log.debug_rank(
            f"{obs.comm.group:4} : Redistributed observation in",
            comm=temp_obs.comm.comm_group,
            timer=timer,
        )
    else:
        self.redistribute = False
        temp_obs = obs
        global_intervals = []
        if self.view is not None:
            for ival in obs.intervals[self.view]:
                global_intervals.append((ival.start, ival.stop))
    if self.view is None:
        global_intervals = [(None, None)]

    return temp_obs, global_intervals

_requires()

Source code in toast/ops/noise_estimation.py
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    if self.detector_pointing is not None:
        req.update(self.detector_pointing.requires())
    return req

bin_psds(my_psds, fmin=None, fmax=None)

Source code in toast/ops/noise_estimation.py
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
@function_timer
def bin_psds(self, my_psds, fmin=None, fmax=None):
    my_binned_psds = []
    my_times = []
    binfreq0 = None

    for i in range(len(my_psds)):
        t0, _, freq, psd = my_psds[i]

        good = freq != 0

        if self.nbin_psd is not None:
            locs, hits = self.log_bin(
                freq[good], nbin=self.nbin_psd, fmin=fmin, fmax=fmax
            )
            binfreq = np.zeros(hits.size)
            for loc, f in zip(locs, freq[good]):
                binfreq[loc] += f
            binfreq = binfreq[hits != 0] / hits[hits != 0]
        else:
            binfreq = freq
            hits = np.ones(len(binfreq))

        if binfreq0 is None:
            binfreq0 = binfreq
        else:
            if np.any(binfreq != binfreq0):
                msg = "Binned PSD frequencies change"
                raise RuntimeError(msg)

        if self.nbin_psd is not None:
            binpsd = np.zeros(hits.size)
            for loc, p in zip(locs, psd[good]):
                binpsd[loc] += p
            binpsd = binpsd[hits != 0] / hits[hits != 0]
        else:
            binpsd = psd

        my_times.append(t0)
        my_binned_psds.append(binpsd)
    return my_binned_psds, my_times, binfreq0

decimate(signal, flags)

Downsample previously highpass-filtered signal

Source code in toast/ops/noise_estimation.py
545
546
547
548
@function_timer
def decimate(self, signal, flags):
    """Downsample previously highpass-filtered signal"""
    return signal[:: self.nsum].copy(), flags[:: self.nsum].copy()

discard_outliers(binfreq, all_psds, all_times, all_cov)

Source code in toast/ops/noise_estimation.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
@function_timer
def discard_outliers(self, binfreq, all_psds, all_times, all_cov):
    log = Logger.get()

    all_psds = copy.deepcopy(all_psds)
    all_times = copy.deepcopy(all_times)
    if self.save_cov:
        all_cov = copy.deepcopy(all_cov)

    nrow, ncol = np.shape(all_psds)

    # Discard empty PSDs

    i = 1
    nempty = 0
    while i < nrow:
        p = all_psds[i]
        if np.all(p == 0) or np.any(np.isnan(p)):
            del all_psds[i]
            del all_times[i]
            if self.save_cov:
                del all_cov[i]
            nrow -= 1
            nempty += 1
        else:
            i += 1

    if nempty > 0:
        log.debug(f"Discarded {nempty} empty or NaN psds")

    # Throw away outlier PSDs by comparing the PSDs in specific bins

    if nrow < 10:
        nbad = 0
    else:
        all_good = np.isfinite(np.sum(all_psds, 1))
        for col in range(ncol - 1):
            if binfreq[col] < 0.001:
                continue

            # Local outliers

            psdvalues = np.array([x[col] for x in all_psds])
            smooth_values = scipy.signal.medfilt(psdvalues, 11)
            good = np.ones(psdvalues.size, dtype=bool)
            good[psdvalues == 0] = False

            for i in range(10):
                # Local test
                diff = np.zeros(psdvalues.size)
                diff[good] = np.log(psdvalues[good]) - np.log(smooth_values[good])
                sdev = np.std(diff[good])
                good[np.abs(diff) > 5 * sdev] = False
                # Global test
                diff = np.zeros(psdvalues.size)
                diff[good] = np.log(psdvalues[good]) - np.mean(
                    np.log(psdvalues[good])
                )
                sdev = np.std(diff[good])
                good[np.abs(diff) > 5 * sdev] = False

            all_good[np.logical_not(good)] = False

        bad = np.logical_not(all_good)
        nbad = np.sum(bad)
        if nbad > 0:
            for ii in np.argwhere(bad).ravel()[::-1]:
                del all_psds[ii]
                del all_times[ii]
                if self.save_cov:
                    del all_cov[ii]

        if nbad > 0:
            log.debug(f"Masked extra {nbad} psds due to outliers.")
    return all_psds, all_times, nempty + nbad, all_cov

log_bin(freq, nbin=100, fmin=None, fmax=None)

Source code in toast/ops/noise_estimation.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
@function_timer
def log_bin(self, freq, nbin=100, fmin=None, fmax=None):
    if np.any(freq == 0):
        msg = "Logarithmic binning should not include zero frequency"
        raise Exception(msg)

    if fmin is None:
        fmin = np.amin(freq)
    if fmax is None:
        fmax = np.amax(freq)

    bins = np.logspace(
        np.log(fmin), np.log(fmax), num=nbin + 1, endpoint=True, base=np.e
    )
    bins[-1] *= 1.01  # Widen the last bin not to have a bin with one entry

    locs = np.digitize(freq, bins).astype(np.int32)
    hits = np.zeros(nbin + 2, dtype=np.int32)
    for loc in locs:
        hits[loc] += 1
    return locs, hits

process_downsampled_noise_estimate(obs, global_intervals, timestamps, fsample, signal1, signal2, flags, my_psds1, my_cov1, lagmax)

Source code in toast/ops/noise_estimation.py
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
@function_timer
def process_downsampled_noise_estimate(
    self,
    obs,
    global_intervals,
    timestamps,
    fsample,
    signal1,
    signal2,
    flags,
    my_psds1,
    my_cov1,
    lagmax,
):
    # Get the process grid row communicator, used to communicate overlaps
    comm = obs.comm_row

    # Get another PSD for a down-sampled TOD to measure the
    # low frequency power

    timestamps_decim = timestamps[:: self.nsum]
    # decimate() will smooth and downsample the signal in
    # each valid interval separately
    signal1_decim, flags_decim = self.decimate(signal1, flags)
    if signal2 is None:
        signal2_decim = None
    else:
        signal2_decim, flags_decim = self.decimate(signal2, flags)

    stationary_period = self.stationary_period.to_value(u.s)
    lagmax = min(lagmax, timestamps_decim.size)

    # We apply a prewhitening filter to the signal.  To accommodate the
    # quality flags, the filter is a moving average that only accounts
    # for the unflagged samples
    naverage = lagmax

    # Extend the local arrays to remove boundary effects from filtering
    if comm is None or comm.size == 1:
        extended_times = timestamps_decim
        extended_flags = flags_decim
        extended_signal1 = signal1_decim
        extended_signal2 = signal2_decim
    else:
        (
            extended_times,
            extended_flags,
            extended_signal1,
            extended_signal2,
        ) = communicate_overlap(
            timestamps_decim,
            signal1_decim,
            signal2_decim,
            flags_decim,
            lagmax,
            naverage,
            comm,
            obs.comm.group,
        )
    # High pass filter the signal to avoid aliasing
    extended_signal1 = highpass_flagged_signal(
        extended_signal1,
        extended_flags == 0,
        naverage,
    )
    if signal2 is not None:
        extended_signal2 = highpass_flagged_signal(
            extended_signal2,
            extended_flags == 0,
            naverage,
        )
    # Crop the filtering margin but keep up to lagmax samples
    half_average = naverage // 2 + 1
    if comm is not None and comm.rank > 0:
        extended_times = extended_times[half_average:]
        extended_flags = extended_flags[half_average:]
        extended_signal1 = extended_signal1[half_average:]
        if extended_signal2 is not None:
            extended_signal2 = extended_signal2[half_average:]
    if comm is not None and comm.rank < comm.size - 1:
        extended_times = extended_times[:-half_average]
        extended_flags = extended_flags[:-half_average]
        extended_signal1 = extended_signal1[:-half_average]
        if extended_signal2 is not None:
            extended_signal2 = extended_signal2[:-half_average]

    if signal2 is None:
        result = autocov_psd(
            timestamps_decim,
            extended_times,
            global_intervals,
            extended_signal1,
            extended_flags,
            lagmax,
            naverage,
            stationary_period,
            fsample / self.nsum,
            comm=comm,
            return_cov=self.save_cov,
        )
    else:
        result = crosscov_psd(
            timestamps_decim,
            extended_times,
            global_intervals,
            extended_signal1,
            extended_signal2,
            extended_flags,
            lagmax,
            naverage,
            stationary_period,
            fsample / self.nsum,
            comm=comm,
            return_cov=self.save_cov,
            symmetric=self.symmetric,
        )
    if self.save_cov:
        my_psds2, my_cov2 = result
    else:
        my_psds2, my_cov2 = result, None

    # Ensure the two sets of PSDs are of equal length

    my_new_psds1 = []
    my_new_psds2 = []
    if self.save_cov:
        my_new_cov1 = []
        my_new_cov2 = []
    i = 0
    while i < min(len(my_psds1), len(my_psds2)):
        t1 = my_psds1[i][0]
        t2 = my_psds2[i][0]
        if np.isclose(t1, t2):
            my_new_psds1.append(my_psds1[i])
            my_new_psds2.append(my_psds2[i])
            if self.save_cov:
                my_new_cov1.append(my_cov1[i])
                my_new_cov2.append(my_cov2[i])
            i += 1
        else:
            if t1 < t2:
                del my_psds1[i]
                if self._cov:
                    del my_cov1[i]
            else:
                del my_psds2[i]
                if self._cov:
                    del my_cov1[i]
    my_psds1 = my_new_psds1
    my_psds2 = my_new_psds2
    if self.save_cov:
        my_cov1 = my_new_cov1
        my_cov2 = my_new_cov2

    if len(my_psds1) != len(my_psds2):
        while my_psds1[-1][0] > my_psds2[-1][0]:
            del my_psds1[-1]
            if self.save_cov:
                del my_cov1[-1]
        while my_psds1[-1][0] < my_psds2[-1][0]:
            del my_psds2[-1]
            if self.save_cov:
                del my_cov2[-1]
    return my_psds1, my_cov1, my_psds2, my_cov2

process_noise_estimate(obs, global_intervals, signal1, signal2, flags, timestamps, fsample, fileroot, det1, det2, lagmax)

Measure the sample (cross) covariance in the signal-subtracted TOD and Fourier-transform it for noise PSD.

Source code in toast/ops/noise_estimation.py
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
@function_timer
def process_noise_estimate(
    self,
    obs,
    global_intervals,
    signal1,
    signal2,
    flags,
    timestamps,
    fsample,
    fileroot,
    det1,
    det2,
    lagmax,
):
    """Measure the sample (cross) covariance in the signal-subtracted
    TOD and Fourier-transform it for noise PSD.
    """

    log = Logger.get()

    # Get the process grid row communicator, used to communicate overlaps
    comm = obs.comm_row

    # We apply a prewhitening filter to the signal.  To accommodate the
    # quality flags, the filter is a moving average that only accounts
    # for the unflagged samples
    naverage = lagmax

    # Extend the local arrays to remove boundary effects from filtering
    if comm is None or comm.size == 1:
        extended_times = timestamps
        extended_flags = flags
        extended_signal1 = signal1
        extended_signal2 = signal2
    else:
        (
            extended_times,
            extended_flags,
            extended_signal1,
            extended_signal2,
        ) = communicate_overlap(
            timestamps,
            signal1,
            signal2,
            flags,
            lagmax,
            naverage,
            comm,
            obs.comm.group,
        )
    # High pass filter the signal to avoid aliasing
    extended_signal1 = highpass_flagged_signal(
        extended_signal1,
        extended_flags == 0,
        naverage,
    )
    if signal2 is not None:
        extended_signal2 = highpass_flagged_signal(
            extended_signal2,
            extended_flags == 0,
            naverage,
        )
    # Crop the filtering margin but keep up to lagmax samples
    half_average = naverage // 2 + 1
    if comm is not None and comm.rank > 0:
        extended_times = extended_times[half_average:]
        extended_flags = extended_flags[half_average:]
        extended_signal1 = extended_signal1[half_average:]
        if extended_signal2 is not None:
            extended_signal2 = extended_signal2[half_average:]
    if comm is not None and comm.rank < comm.size - 1:
        extended_times = extended_times[:-half_average]
        extended_flags = extended_flags[:-half_average]
        extended_signal1 = extended_signal1[:-half_average]
        if extended_signal2 is not None:
            extended_signal2 = extended_signal2[:-half_average]

    # Compute the autocovariance function and the matching
    # PSD for each stationary interval

    timer = Timer()
    timer.start()
    stationary_period = self.stationary_period.to_value(u.s)
    if signal2 is None:
        result = autocov_psd(
            timestamps,
            extended_times,
            global_intervals,
            extended_signal1,
            extended_flags,
            lagmax,
            naverage,
            stationary_period,
            fsample,
            comm=comm,
            return_cov=self.save_cov,
        )
    else:
        result = crosscov_psd(
            timestamps,
            extended_times,
            global_intervals,
            extended_signal1,
            extended_signal2,
            extended_flags,
            lagmax,
            naverage,
            stationary_period,
            fsample,
            comm=comm,
            return_cov=self.save_cov,
            symmetric=self.symmetric,
        )
    if self.save_cov:
        my_psds1, my_cov1 = result
    else:
        my_psds1, my_cov1 = result, None

    if self.nsum > 1:
        (
            my_psds1,
            my_cov1,
            my_psds2,
            my_cov2,
        ) = self.process_downsampled_noise_estimate(
            obs,
            global_intervals,
            extended_times,
            fsample,
            extended_signal1,
            extended_signal2,
            extended_flags,
            my_psds1,
            my_cov1,
            lagmax,
        )

    log.debug_rank(
        f"Compute Correlators and PSDs for {det1} / {det2}",
        comm=comm,
        rank=0,
        timer=timer,
    )

    # Now bin the PSDs

    fmin = 1 / stationary_period
    fmax = fsample / 2

    my_binned_psds1, my_times1, binfreq10 = self.bin_psds(my_psds1, fmin, fmax)
    if self.nsum > 1:
        my_binned_psds2, _, binfreq20 = self.bin_psds(my_psds2, fmin, fmax)

    log.debug_rank(
        f"Bin PSDs for {det1} / {det2}",
        comm=comm,
        rank=0,
        timer=timer,
    )

    # concatenate

    if self.save_cov:
        my_cov = my_cov1  # Only store the fully sampled covariance

    if binfreq10 is None:
        my_times = []
        my_binned_psds = []
        binfreq0 = None
    else:
        my_times = my_times1
        if self.nsum > 1:
            # frequencies that are usable in the down-sampled PSD
            fcut = fsample / 2 / self.naverage / 100
            ind1 = binfreq10 > fcut
            ind2 = binfreq20 <= fcut
            binfreq0 = np.hstack([binfreq20[ind2], binfreq10[ind1]])
            my_binned_psds = []
            for psd1, psd2 in zip(my_binned_psds1, my_binned_psds2):
                my_binned_psds.append(np.hstack([psd2[ind2], psd1[ind1]]))
        else:
            binfreq0 = binfreq10
            my_binned_psds = my_binned_psds1

    # Collect and write the PSDs.  Start by determining the first
    # process to have a valid PSD to determine binning

    have_bins = binfreq0 is not None
    have_bins_all = None
    if comm is None:
        have_bins_all = [have_bins]
    else:
        have_bins_all = comm.allgather(have_bins)
    root = 0
    if np.any(have_bins_all):
        while not have_bins_all[root]:
            root += 1
    else:
        msg = "None of the processes have valid PSDs"
        raise RuntimeError(msg)
    binfreq = None
    if comm is None:
        binfreq = binfreq0
    else:
        binfreq = comm.bcast(binfreq0, root=root)
    if binfreq0 is not None and np.any(binfreq != binfreq0):
        msg = (
            f"{obs.comm.world_rank:4} : Binned PSD frequencies change. "
            f"len(binfreq0) = {binfreq0.size}, "
            f"len(binfreq) = {binfreq.size}, binfreq0={binfreq0}, "
            f"binfreq = {binfreq}. len(my_psds) = {len(my_psds1)}"
        )
        raise RuntimeError(msg)

    if len(my_times) != len(my_binned_psds):
        msg = (
            f"ERROR: Process {obs.comm.world_rank} has len(my_times) = "
            f"{len(my_times)}, "
            f"len(my_binned_psds) = {len(my_binned_psds)}"
        )
        raise RuntimeError(msg)

    all_times = None
    all_psds = None
    if comm is None:
        all_times = [my_times]
        all_psds = [my_binned_psds]
    else:
        all_times = comm.gather(my_times, root=0)
        all_psds = comm.gather(my_binned_psds, root=0)
    all_cov = None
    if self.save_cov:
        if comm is None:
            all_cov = [my_cov]
        else:
            all_cov = comm.gather(my_cov, root=0)

    log.debug_rank(
        f"Collect PSDs for {det1} / {det2}",
        comm=comm,
        rank=0,
        timer=timer,
    )

    final_freqs = None
    final_psd = None
    if obs.comm_row_rank == 0:
        if len(all_times) != len(all_psds):
            msg = (
                f"ERROR: Process {obs.comm.world_rank} has len(all_times) = "
                f"{len(all_times)},"
                f" len(all_psds) = {len(all_psds)} before deglitch"
            )
            raise RuntimeError(msg)

        # De-glitch the binned PSDs and write them to file
        i = 0
        while i < len(all_times):
            if len(all_times[i]) == 0:
                del all_times[i]
                del all_psds[i]
                if self.save_cov:
                    del all_cov[i]
            else:
                i += 1

        if len(all_times) != len(all_psds):
            msg = (
                f"ERROR: Process {obs.comm.world_rank} has len(all_times) = "
                f"{len(all_times)}, "
                f"len(all_psds) = {len(all_psds)} AFTER deglitch"
            )
            raise RuntimeError(msg)

        all_times = list(np.hstack(all_times))
        all_psds = list(np.hstack(all_psds))
        if self.save_cov:
            all_cov = list(np.hstack(all_cov))

        good_psds, good_times, nbad, good_cov = self.discard_outliers(
            binfreq, all_psds, all_times, all_cov
        )
        log.debug_rank("Discard outliers", timer=timer)

        if self.output_dir is not None:
            self.save_psds(
                binfreq, all_psds, all_times, det1, det2, fsample, fileroot, all_cov
            )
            if nbad > 0:
                self.save_psds(
                    binfreq,
                    good_psds,
                    good_times,
                    det1,
                    det2,
                    fsample,
                    fileroot + "_good",
                    good_cov,
                )

        final_freqs = binfreq
        final_psd = np.mean(np.array(good_psds), axis=0)
        log.debug_rank(f"Write PSDs for {det1} / {det2}", timer=timer)

    return final_freqs, final_psd

save_psds(binfreq, all_psds, all_times, det1, det2, fsample, rootname, all_cov)

Source code in toast/ops/noise_estimation.py
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
@function_timer
def save_psds(
    self, binfreq, all_psds, all_times, det1, det2, fsample, rootname, all_cov
):
    log = Logger.get()
    timer = Timer()
    timer.start()
    os.makedirs(self.output_dir, exist_ok=True)
    if det1 == det2:
        fn_out = os.path.join(self.output_dir, f"{rootname}_{det1}.fits")
    else:
        fn_out = os.path.join(self.output_dir, f"{rootname}_{det1}_{det2}.fits")
    all_psds = np.vstack([binfreq, all_psds])

    hdulist = [pf.PrimaryHDU()]

    cols = []
    cols.append(pf.Column(name="OBT", format="D", array=all_times))
    coldefs = pf.ColDefs(cols)
    hdu1 = pf.BinTableHDU.from_columns(coldefs)
    hdu1.header["RATE"] = fsample, "Sampling rate"
    hdulist.append(hdu1)

    cols = []
    cols.append(pf.Column(name="PSD", format=f"{binfreq.size}E", array=all_psds))
    coldefs = pf.ColDefs(cols)
    hdu2 = pf.BinTableHDU.from_columns(coldefs)
    hdu2.header["EXTNAME"] = str(det1), "Detector"
    hdu2.header["DET1"] = str(det1), "Detector1"
    hdu2.header["DET2"] = str(det2), "Detector2"
    hdulist.append(hdu2)

    if self.save_cov:
        all_cov = np.array(all_cov)
        cols = []
        nrow, ncol, nsamp = np.shape(all_cov)
        cols.append(
            pf.Column(
                name="HITS",
                format=f"{nsamp}J",
                array=np.ascontiguousarray(all_cov[:, 0, :]),
            )
        )
        cols.append(
            pf.Column(
                name="COV",
                format=f"{nsamp}E",
                array=np.ascontiguousarray(all_cov[:, 1, :]),
            )
        )
        coldefs = pf.ColDefs(cols)
        hdu3 = pf.BinTableHDU.from_columns(coldefs)
        hdu3.header["EXTNAME"] = str(det1), "Detector"
        hdu3.header["DET1"] = str(det1), "Detector1"
        hdu3.header["DET2"] = str(det2), "Detector2"
        hdulist.append(hdu3)

    hdulist = pf.HDUList(hdulist)

    with open(fn_out, "wb") as fits_out:
        hdulist.writeto(fits_out, overwrite=True)

    log.debug(f"Detector {det1} vs. {det2} PSDs stored in {fn_out}")

    return

toast.ops.FitNoiseModel

Bases: Operator

Perform a least squares fit to an existing noise model.

This takes an existing estimated noise model and attempts to fit each spectrum to 1/f parameters.

If the output model is not specified, then the input is modified in place.

If the data has been filtered with a low-pass, then the high frequency spectral points are not representative of the actual white noise plateau. In this case, The min / max frequencies to consider can be specified.

Source code in toast/ops/noise_model.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
@trait_docs
class FitNoiseModel(Operator):
    """Perform a least squares fit to an existing noise model.

    This takes an existing estimated noise model and attempts to fit each
    spectrum to 1/f parameters.

    If the output model is not specified, then the input is modified in place.

    If the data has been filtered with a low-pass, then the high frequency spectral
    points are not representative of the actual white noise plateau.  In this case,
    The min / max frequencies to consider can be specified.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    noise_model = Unicode(
        "noise_model", help="The observation key containing the input noise model"
    )

    out_model = Unicode(
        None, allow_none=True, help="Create a new noise model with this name"
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    bad_fit_mask = Int(
        defaults.det_mask_processing, help="Bit mask to raise for bad fits"
    )

    f_min = Quantity(1.0e-5 * u.Hz, help="Low-frequency rolloff of model in the fit")

    white_noise_min = Quantity(
        None,
        allow_none=True,
        help="The minimum frequency to consider for the white noise plateau",
    )

    white_noise_max = Quantity(
        None,
        allow_none=True,
        help="The maximum frequency to consider for the white noise plateau",
    )

    least_squares_xtol = Float(
        None,
        allow_none=True,
        help="The xtol value passed to the least_squares solver",
    )

    least_squares_ftol = Float(
        1.0e-10,
        allow_none=True,
        help="The ftol value passed to the least_squares solver",
    )

    least_squares_gtol = Float(
        None,
        allow_none=True,
        help="The gtol value passed to the least_squares solver",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if detectors is not None:
            msg = "FitNoiseModel will fit all detectors- ignoring input detector list"
            log.warning(msg)

        if self.white_noise_max is not None:
            # Ensure that the min is also set
            if self.white_noise_min is None:
                msg = "You must set both of the min / max values or none of them"
                raise RuntimeError(msg)

        for ob in data.obs:
            in_model = ob[self.noise_model]
            # We will use the best fit parameters from each detector as
            # the starting guess for the next detector.
            params = None
            nse_rate = dict()
            nse_fmin = dict()
            nse_fknee = dict()
            nse_alpha = dict()
            nse_NET = dict()
            nse_indx = dict()

            # We are building a noise model with entries for all local detectors,
            # even ones that are flagged.
            for det in ob.local_detectors:
                freqs = in_model.freq(det)
                in_psd = in_model.psd(det)
                cur_flag = ob.local_detector_flags[det]
                nse_indx[det] = in_model.index(det)
                nse_rate[det] = 2.0 * freqs[-1]
                nse_NET[det] = 0.0 * np.sqrt(1.0 * in_psd.unit)
                nse_fmin[det] = 0.0 * u.Hz
                nse_fknee[det] = 0.0 * u.Hz
                nse_alpha[det] = 0.0
                if cur_flag & self.det_mask != 0:
                    continue
                props = self._fit_log_psd(freqs, in_psd, guess=params)
                if props["fit_result"].success:
                    # This was a good fit
                    params = props["fit_result"].x
                else:
                    params = None
                    msg = f"FitNoiseModel observation {ob.name}, det {det} failed, "
                    msg += f"using white noise with NET = {props['NET']}"
                    log.warning(msg)
                    msg = f"  Best Result = {props['fit_result']}"
                    log.verbose(msg)
                    new_flag = cur_flag | self.bad_fit_mask
                    ob.update_local_detector_flags({det: new_flag})

                nse_fmin[det] = props["fmin"]
                nse_fknee[det] = props["fknee"]
                nse_alpha[det] = props["alpha"]
                nse_NET[det] = props["NET"]

            new_model = AnalyticNoise(
                detectors=ob.local_detectors,
                rate=nse_rate,
                fmin=nse_fmin,
                fknee=nse_fknee,
                alpha=nse_alpha,
                NET=nse_NET,
                indices=nse_indx,
            )

            if self.out_model is None or self.noise_model == self.out_model:
                # We are replacing the input
                del ob[self.noise_model]
                ob[self.noise_model] = new_model
            else:
                # We are storing this in a new key
                ob[self.out_model] = new_model
        return

    def _estimate_net(self, freqs, data):
        """Estimate the NET from the high frequency PSD.

        This assumes that at high frequency the PSD has a white noise "plateau".  A simple
        parabola is fit to the last bit of the spectrum and this is used to compute the
        NET.

        Args:
            freqs (array):  The frequency values in Hz
            data (array):  The PSD in arbitrary units

        Returns:
            (float):  The estimated NET.

        """
        log = Logger.get()

        def quad_func(x, a, b, c):
            # Parabola
            return a * (x - b) ** 2 + c

        def lin_func(x, a, b, c):
            # Line
            return a * (x - b) + c

        n_psd = len(data)
        offset = int(0.8 * n_psd)
        try_quad = True
        if n_psd - offset < 10:
            # Too few points
            try_quad = False
            if n_psd < 10:
                # Crazy...
                offset = 0
            else:
                offset = n_psd - 10

        ffreq = np.log(freqs[offset:])
        fdata = np.log(data[offset:])
        if try_quad:
            try:
                params, params_cov = curve_fit(
                    quad_func, ffreq, fdata, p0=[1.0, ffreq[-1], fdata[-1]]
                )
                # It worked!
                fdata = quad_func(ffreq, params[0], params[1], params[2])
                fdata = np.exp(fdata)
                return np.sqrt(fdata[-1])
            except RuntimeError:
                pass

        params, params_cov = curve_fit(
            lin_func, ffreq, fdata, p0=[0.0, ffreq[-1], fdata[-1]]
        )
        fdata = lin_func(ffreq, params[0], params[1], params[2])
        fdata = np.exp(fdata)
        net = np.sqrt(fdata[-1])
        return net

    def _evaluate_model(self, freqs, fmin, net, fknee, alpha):
        """Evaluate the noise model

        Given the input frequencies, NET, slope alpha, f_min and f_knee,
        evaluate the PSD as:

        PSD = NET^2 * [ (f^alpha + f_knee^alpha) / (f^alpha + f_min^alpha) ]

        Args:
            freqs (array):  The input frequencies in Hz
            fmin (float):  The extreme low-frequency rolloff
            fknee (float):  The knee frequency
            alpha (float):  The slope parameter

        Returns:
            (array):  The model PSD

        """
        ktemp = np.power(fknee, alpha)
        mtemp = np.power(fmin, alpha)
        temp = np.power(freqs, alpha)
        psd = (temp + ktemp) / (temp + mtemp)
        psd *= net**2
        return psd

    def _evaluate_log_model(self, freqs, fmin, net, fknee, alpha):
        """Evaluate the natural log of the noise model

        Given the input frequencies, NET, slope alpha, f_min and f_knee,
        evaluate the ln(PSD) as:

        ln(PSD) = 2 * ln(NET) + ln(f^alpha + f_knee^alpha) - ln(f^alpha + f_min^alpha)

        Args:
            freqs (array):  The input frequencies in Hz
            fmin (float):  The extreme low-frequency rolloff
            fknee (float):  The knee frequency
            alpha (float):  The slope parameter

        Returns:
            (array):  The log of the model PSD

        """
        f_alpha = np.power(freqs, alpha)
        fknee_alpha = np.power(fknee, alpha)
        fmin_alpha = np.power(fmin, alpha)
        psd = (
            2.0 * np.log(net)
            + np.log(f_alpha + fknee_alpha)
            - np.log(f_alpha + fmin_alpha)
        )
        return psd

    def _fit_log_fun(self, x, *args, **kwargs):
        """Evaluate the weighted residual in log space.

        For the given set of parameters, this evaluates the model log PSD and computes the
        residual from the real data.  This residual is further weighted so that the better
        constrained high-frequency values have more significance.  We arbitrarily choose a
        weighting of:

            W = f_nyquist - (f_nyquist / (1 + f^2))

        Args:
            x (array):  The current model parameters
            kwargs:  The fixed information is passed in through the least squares solver.

        Returns:
            (array):  The array of residuals

        """
        freqs = kwargs["freqs"]
        logdata = kwargs["logdata"]
        fmin = kwargs["fmin"]
        net = kwargs["net"]
        fknee = x[0]
        alpha = x[1]
        current = self._evaluate_log_model(freqs, fmin, net, fknee, alpha)
        resid = current - logdata
        return resid

    def _fit_log_jac(self, x, *args, **kwargs):
        """Evaluate the partial derivatives of model.

        This returns the Jacobian containing the partial derivatives of the log-space
        model with respect to the fit parameters.

        Args:
            x (array):  The current model parameters
            kwargs:  The fixed information is passed in through the least squares solver.

        Returns:
            (array):  The Jacobian

        """
        freqs = kwargs["freqs"]
        fmin = kwargs["fmin"]
        fknee = x[0]
        alpha = x[1]
        n_freq = len(freqs)

        log_freqs = np.log(freqs)
        f_alpha = np.power(freqs, alpha)
        fknee_alpha = np.power(fknee, alpha)
        fmin_alpha = np.power(fmin, alpha)

        fkalpha = f_alpha + fknee_alpha
        fmalpha = f_alpha + fmin_alpha

        J = np.empty((n_freq, x.size), dtype=np.float64)

        # Partial derivative wrt f_knee
        J[:, 0] = alpha * np.power(fknee, alpha - 1.0) / fkalpha

        # Partial derivative wrt alpha
        J[:, 1] = (f_alpha * log_freqs + fknee_alpha * np.log(fknee)) / fkalpha - (
            f_alpha * log_freqs + fmin_alpha * np.log(fmin)
        ) / fmalpha
        return J

    def _get_err_ret(self, psd_unit):
        eret = dict()
        eret["fit_result"] = types.SimpleNamespace()
        eret["fit_result"].success = False
        eret["NET"] = 0.0 * np.sqrt(1.0 * psd_unit)
        eret["fmin"] = 0.0 * u.Hz
        eret["fknee"] = 0.0 * u.Hz
        eret["alpha"] = 0.0
        return eret

    def _get_err_ret(self, psd_unit):
        # Internal function to build a fake return result
        # when the fitting fails for some reason.
        eret = dict()
        eret["fit_result"] = types.SimpleNamespace()
        eret["fit_result"].success = False
        eret["NET"] = 0.0 * np.sqrt(1.0 * psd_unit)
        eret["fmin"] = 0.0 * u.Hz
        eret["fknee"] = 0.0 * u.Hz
        eret["alpha"] = 0.0
        return eret

    def _fit_log_psd(self, freqs, data, guess=None):
        """Perform a log-space fit to model PSD parameters.

        Args:
            freqs (Quantity):  The frequency values
            data (Quantity):  The estimated input PSD
            guess (array):  Optional starting point guess

        Returns:
            (dict):  Dictionary of fit parameters

        """
        log = Logger.get()
        psd_unit = data.unit
        ret = dict()

        # We cut the lowest frequency bin value, and any leading negative values,
        # since these are usually due to poor estimation.  If the user has specified
        # a maximum frequency for the white noise plateau, then we also stop our
        # fit at that point.
        raw_freqs = freqs.to_value(u.Hz)
        raw_data = data.value
        n_raw = len(raw_data)
        n_skip = 1
        while n_skip < n_raw and raw_data[n_skip] <= 0:
            n_skip += 1
        if n_skip == n_raw:
            msg = f"All {n_raw} PSD values were negative.  Giving up."
            log.warning(msg)
            ret = self._get_err_ret(psd_unit)
            return ret

        n_trim = 0
        if self.white_noise_max is not None:
            max_hz = self.white_noise_max.to_value(u.Hz)
            for f in raw_freqs:
                if f > max_hz:
                    n_trim += 1

        if n_skip + n_trim >= n_raw:
            msg = f"All {n_raw} PSD values either negative or above plateau."
            log.warning(msg)
            ret = self._get_err_ret(psd_unit)
            return ret

        input_freqs = raw_freqs[n_skip : n_raw - n_trim]
        input_data = raw_data[n_skip : n_raw - n_trim]
        # Force all points to be positive
        good = input_data > 0
        if np.count_nonzero(good) == 0:
            # All PSD values zero, must be flagged
            msg = f"All PSD values zero, skipping fit."
            log.warning(msg)
            ret = self._get_err_ret(psd_unit)
            return ret
        bad = np.logical_not(good)
        n_bad = np.count_nonzero(bad)
        if n_bad > 0:
            msg = "Some PSDs have negative values.  Consider changing "
            msg += "noise estimation parameters."
            log.warning(msg)
        good_min = np.min(input_data[good])
        input_data[bad] = 1.0e-6 * good_min
        input_log_data = np.log(input_data)

        raw_fmin = self.f_min.to_value(u.Hz)

        if self.white_noise_max is None:
            net = self._estimate_net(input_freqs, input_data)
        else:
            plateau_samples = np.logical_and(
                (input_freqs > self.white_noise_min.to_value(u.Hz)),
                (input_freqs < self.white_noise_max.to_value(u.Hz)),
            )
            net = np.sqrt(np.mean(input_data[plateau_samples]))

        midfreq = 0.5 * input_freqs[-1]

        bounds = (
            np.array([input_freqs[0], 0.1]),
            np.array([input_freqs[-1], 10.0]),
        )
        x_0 = guess
        if x_0 is None:
            x_0 = np.array([midfreq, 1.0])

        try:
            result = least_squares(
                self._fit_log_fun,
                x_0,
                jac=self._fit_log_jac,
                bounds=bounds,
                xtol=self.least_squares_xtol,
                gtol=self.least_squares_gtol,
                ftol=self.least_squares_ftol,
                max_nfev=500,
                verbose=0,
                kwargs={
                    "freqs": input_freqs,
                    "logdata": input_log_data,
                    "fmin": raw_fmin,
                    "net": net,
                },
            )
        except Exception:
            log.verbose(f"PSD fit raised exception, skipping")
            ret = self._get_err_ret(psd_unit)
            return ret

        ret["fit_result"] = result
        ret["NET"] = net * np.sqrt(1.0 * psd_unit)
        ret["fmin"] = self.f_min
        if result.success:
            ret["fknee"] = result.x[0] * u.Hz
            ret["alpha"] = result.x[1]
        else:
            ret["fknee"] = 0.0 * u.Hz
            ret["alpha"] = 1.0

        return ret

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        return dict()

    def _provides(self):
        prov = {"meta": [self.noise_model]}
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

bad_fit_mask = Int(defaults.det_mask_processing, help='Bit mask to raise for bad fits') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

f_min = Quantity(1e-05 * u.Hz, help='Low-frequency rolloff of model in the fit') class-attribute instance-attribute

least_squares_ftol = Float(1e-10, allow_none=True, help='The ftol value passed to the least_squares solver') class-attribute instance-attribute

least_squares_gtol = Float(None, allow_none=True, help='The gtol value passed to the least_squares solver') class-attribute instance-attribute

least_squares_xtol = Float(None, allow_none=True, help='The xtol value passed to the least_squares solver') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='The observation key containing the input noise model') class-attribute instance-attribute

out_model = Unicode(None, allow_none=True, help='Create a new noise model with this name') class-attribute instance-attribute

white_noise_max = Quantity(None, allow_none=True, help='The maximum frequency to consider for the white noise plateau') class-attribute instance-attribute

white_noise_min = Quantity(None, allow_none=True, help='The minimum frequency to consider for the white noise plateau') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/noise_model.py
182
183
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/noise_model.py
175
176
177
178
179
180
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_estimate_net(freqs, data)

Estimate the NET from the high frequency PSD.

This assumes that at high frequency the PSD has a white noise "plateau". A simple parabola is fit to the last bit of the spectrum and this is used to compute the NET.

Parameters:

Name Type Description Default
freqs array

The frequency values in Hz

required
data array

The PSD in arbitrary units

required

Returns:

Type Description
float

The estimated NET.

Source code in toast/ops/noise_model.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def _estimate_net(self, freqs, data):
    """Estimate the NET from the high frequency PSD.

    This assumes that at high frequency the PSD has a white noise "plateau".  A simple
    parabola is fit to the last bit of the spectrum and this is used to compute the
    NET.

    Args:
        freqs (array):  The frequency values in Hz
        data (array):  The PSD in arbitrary units

    Returns:
        (float):  The estimated NET.

    """
    log = Logger.get()

    def quad_func(x, a, b, c):
        # Parabola
        return a * (x - b) ** 2 + c

    def lin_func(x, a, b, c):
        # Line
        return a * (x - b) + c

    n_psd = len(data)
    offset = int(0.8 * n_psd)
    try_quad = True
    if n_psd - offset < 10:
        # Too few points
        try_quad = False
        if n_psd < 10:
            # Crazy...
            offset = 0
        else:
            offset = n_psd - 10

    ffreq = np.log(freqs[offset:])
    fdata = np.log(data[offset:])
    if try_quad:
        try:
            params, params_cov = curve_fit(
                quad_func, ffreq, fdata, p0=[1.0, ffreq[-1], fdata[-1]]
            )
            # It worked!
            fdata = quad_func(ffreq, params[0], params[1], params[2])
            fdata = np.exp(fdata)
            return np.sqrt(fdata[-1])
        except RuntimeError:
            pass

    params, params_cov = curve_fit(
        lin_func, ffreq, fdata, p0=[0.0, ffreq[-1], fdata[-1]]
    )
    fdata = lin_func(ffreq, params[0], params[1], params[2])
    fdata = np.exp(fdata)
    net = np.sqrt(fdata[-1])
    return net

_evaluate_log_model(freqs, fmin, net, fknee, alpha)

Evaluate the natural log of the noise model

Given the input frequencies, NET, slope alpha, f_min and f_knee, evaluate the ln(PSD) as:

ln(PSD) = 2 * ln(NET) + ln(f^alpha + f_knee^alpha) - ln(f^alpha + f_min^alpha)

Parameters:

Name Type Description Default
freqs array

The input frequencies in Hz

required
fmin float

The extreme low-frequency rolloff

required
fknee float

The knee frequency

required
alpha float

The slope parameter

required

Returns:

Type Description
array

The log of the model PSD

Source code in toast/ops/noise_model.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def _evaluate_log_model(self, freqs, fmin, net, fknee, alpha):
    """Evaluate the natural log of the noise model

    Given the input frequencies, NET, slope alpha, f_min and f_knee,
    evaluate the ln(PSD) as:

    ln(PSD) = 2 * ln(NET) + ln(f^alpha + f_knee^alpha) - ln(f^alpha + f_min^alpha)

    Args:
        freqs (array):  The input frequencies in Hz
        fmin (float):  The extreme low-frequency rolloff
        fknee (float):  The knee frequency
        alpha (float):  The slope parameter

    Returns:
        (array):  The log of the model PSD

    """
    f_alpha = np.power(freqs, alpha)
    fknee_alpha = np.power(fknee, alpha)
    fmin_alpha = np.power(fmin, alpha)
    psd = (
        2.0 * np.log(net)
        + np.log(f_alpha + fknee_alpha)
        - np.log(f_alpha + fmin_alpha)
    )
    return psd

_evaluate_model(freqs, fmin, net, fknee, alpha)

Evaluate the noise model

Given the input frequencies, NET, slope alpha, f_min and f_knee, evaluate the PSD as:

PSD = NET^2 * [ (f^alpha + f_knee^alpha) / (f^alpha + f_min^alpha) ]

Parameters:

Name Type Description Default
freqs array

The input frequencies in Hz

required
fmin float

The extreme low-frequency rolloff

required
fknee float

The knee frequency

required
alpha float

The slope parameter

required

Returns:

Type Description
array

The model PSD

Source code in toast/ops/noise_model.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def _evaluate_model(self, freqs, fmin, net, fknee, alpha):
    """Evaluate the noise model

    Given the input frequencies, NET, slope alpha, f_min and f_knee,
    evaluate the PSD as:

    PSD = NET^2 * [ (f^alpha + f_knee^alpha) / (f^alpha + f_min^alpha) ]

    Args:
        freqs (array):  The input frequencies in Hz
        fmin (float):  The extreme low-frequency rolloff
        fknee (float):  The knee frequency
        alpha (float):  The slope parameter

    Returns:
        (array):  The model PSD

    """
    ktemp = np.power(fknee, alpha)
    mtemp = np.power(fmin, alpha)
    temp = np.power(freqs, alpha)
    psd = (temp + ktemp) / (temp + mtemp)
    psd *= net**2
    return psd

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/noise_model.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if detectors is not None:
        msg = "FitNoiseModel will fit all detectors- ignoring input detector list"
        log.warning(msg)

    if self.white_noise_max is not None:
        # Ensure that the min is also set
        if self.white_noise_min is None:
            msg = "You must set both of the min / max values or none of them"
            raise RuntimeError(msg)

    for ob in data.obs:
        in_model = ob[self.noise_model]
        # We will use the best fit parameters from each detector as
        # the starting guess for the next detector.
        params = None
        nse_rate = dict()
        nse_fmin = dict()
        nse_fknee = dict()
        nse_alpha = dict()
        nse_NET = dict()
        nse_indx = dict()

        # We are building a noise model with entries for all local detectors,
        # even ones that are flagged.
        for det in ob.local_detectors:
            freqs = in_model.freq(det)
            in_psd = in_model.psd(det)
            cur_flag = ob.local_detector_flags[det]
            nse_indx[det] = in_model.index(det)
            nse_rate[det] = 2.0 * freqs[-1]
            nse_NET[det] = 0.0 * np.sqrt(1.0 * in_psd.unit)
            nse_fmin[det] = 0.0 * u.Hz
            nse_fknee[det] = 0.0 * u.Hz
            nse_alpha[det] = 0.0
            if cur_flag & self.det_mask != 0:
                continue
            props = self._fit_log_psd(freqs, in_psd, guess=params)
            if props["fit_result"].success:
                # This was a good fit
                params = props["fit_result"].x
            else:
                params = None
                msg = f"FitNoiseModel observation {ob.name}, det {det} failed, "
                msg += f"using white noise with NET = {props['NET']}"
                log.warning(msg)
                msg = f"  Best Result = {props['fit_result']}"
                log.verbose(msg)
                new_flag = cur_flag | self.bad_fit_mask
                ob.update_local_detector_flags({det: new_flag})

            nse_fmin[det] = props["fmin"]
            nse_fknee[det] = props["fknee"]
            nse_alpha[det] = props["alpha"]
            nse_NET[det] = props["NET"]

        new_model = AnalyticNoise(
            detectors=ob.local_detectors,
            rate=nse_rate,
            fmin=nse_fmin,
            fknee=nse_fknee,
            alpha=nse_alpha,
            NET=nse_NET,
            indices=nse_indx,
        )

        if self.out_model is None or self.noise_model == self.out_model:
            # We are replacing the input
            del ob[self.noise_model]
            ob[self.noise_model] = new_model
        else:
            # We are storing this in a new key
            ob[self.out_model] = new_model
    return

_finalize(data, **kwargs)

Source code in toast/ops/noise_model.py
585
586
def _finalize(self, data, **kwargs):
    return

_fit_log_fun(x, *args, **kwargs)

Evaluate the weighted residual in log space.

For the given set of parameters, this evaluates the model log PSD and computes the residual from the real data. This residual is further weighted so that the better constrained high-frequency values have more significance. We arbitrarily choose a weighting of:

W = f_nyquist - (f_nyquist / (1 + f^2))

Parameters:

Name Type Description Default
x array

The current model parameters

required
kwargs

The fixed information is passed in through the least squares solver.

{}

Returns:

Type Description
array

The array of residuals

Source code in toast/ops/noise_model.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def _fit_log_fun(self, x, *args, **kwargs):
    """Evaluate the weighted residual in log space.

    For the given set of parameters, this evaluates the model log PSD and computes the
    residual from the real data.  This residual is further weighted so that the better
    constrained high-frequency values have more significance.  We arbitrarily choose a
    weighting of:

        W = f_nyquist - (f_nyquist / (1 + f^2))

    Args:
        x (array):  The current model parameters
        kwargs:  The fixed information is passed in through the least squares solver.

    Returns:
        (array):  The array of residuals

    """
    freqs = kwargs["freqs"]
    logdata = kwargs["logdata"]
    fmin = kwargs["fmin"]
    net = kwargs["net"]
    fknee = x[0]
    alpha = x[1]
    current = self._evaluate_log_model(freqs, fmin, net, fknee, alpha)
    resid = current - logdata
    return resid

_fit_log_jac(x, *args, **kwargs)

Evaluate the partial derivatives of model.

This returns the Jacobian containing the partial derivatives of the log-space model with respect to the fit parameters.

Parameters:

Name Type Description Default
x array

The current model parameters

required
kwargs

The fixed information is passed in through the least squares solver.

{}

Returns:

Type Description
array

The Jacobian

Source code in toast/ops/noise_model.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def _fit_log_jac(self, x, *args, **kwargs):
    """Evaluate the partial derivatives of model.

    This returns the Jacobian containing the partial derivatives of the log-space
    model with respect to the fit parameters.

    Args:
        x (array):  The current model parameters
        kwargs:  The fixed information is passed in through the least squares solver.

    Returns:
        (array):  The Jacobian

    """
    freqs = kwargs["freqs"]
    fmin = kwargs["fmin"]
    fknee = x[0]
    alpha = x[1]
    n_freq = len(freqs)

    log_freqs = np.log(freqs)
    f_alpha = np.power(freqs, alpha)
    fknee_alpha = np.power(fknee, alpha)
    fmin_alpha = np.power(fmin, alpha)

    fkalpha = f_alpha + fknee_alpha
    fmalpha = f_alpha + fmin_alpha

    J = np.empty((n_freq, x.size), dtype=np.float64)

    # Partial derivative wrt f_knee
    J[:, 0] = alpha * np.power(fknee, alpha - 1.0) / fkalpha

    # Partial derivative wrt alpha
    J[:, 1] = (f_alpha * log_freqs + fknee_alpha * np.log(fknee)) / fkalpha - (
        f_alpha * log_freqs + fmin_alpha * np.log(fmin)
    ) / fmalpha
    return J

_fit_log_psd(freqs, data, guess=None)

Perform a log-space fit to model PSD parameters.

Parameters:

Name Type Description Default
freqs Quantity

The frequency values

required
data Quantity

The estimated input PSD

required
guess array

Optional starting point guess

None

Returns:

Type Description
dict

Dictionary of fit parameters

Source code in toast/ops/noise_model.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def _fit_log_psd(self, freqs, data, guess=None):
    """Perform a log-space fit to model PSD parameters.

    Args:
        freqs (Quantity):  The frequency values
        data (Quantity):  The estimated input PSD
        guess (array):  Optional starting point guess

    Returns:
        (dict):  Dictionary of fit parameters

    """
    log = Logger.get()
    psd_unit = data.unit
    ret = dict()

    # We cut the lowest frequency bin value, and any leading negative values,
    # since these are usually due to poor estimation.  If the user has specified
    # a maximum frequency for the white noise plateau, then we also stop our
    # fit at that point.
    raw_freqs = freqs.to_value(u.Hz)
    raw_data = data.value
    n_raw = len(raw_data)
    n_skip = 1
    while n_skip < n_raw and raw_data[n_skip] <= 0:
        n_skip += 1
    if n_skip == n_raw:
        msg = f"All {n_raw} PSD values were negative.  Giving up."
        log.warning(msg)
        ret = self._get_err_ret(psd_unit)
        return ret

    n_trim = 0
    if self.white_noise_max is not None:
        max_hz = self.white_noise_max.to_value(u.Hz)
        for f in raw_freqs:
            if f > max_hz:
                n_trim += 1

    if n_skip + n_trim >= n_raw:
        msg = f"All {n_raw} PSD values either negative or above plateau."
        log.warning(msg)
        ret = self._get_err_ret(psd_unit)
        return ret

    input_freqs = raw_freqs[n_skip : n_raw - n_trim]
    input_data = raw_data[n_skip : n_raw - n_trim]
    # Force all points to be positive
    good = input_data > 0
    if np.count_nonzero(good) == 0:
        # All PSD values zero, must be flagged
        msg = f"All PSD values zero, skipping fit."
        log.warning(msg)
        ret = self._get_err_ret(psd_unit)
        return ret
    bad = np.logical_not(good)
    n_bad = np.count_nonzero(bad)
    if n_bad > 0:
        msg = "Some PSDs have negative values.  Consider changing "
        msg += "noise estimation parameters."
        log.warning(msg)
    good_min = np.min(input_data[good])
    input_data[bad] = 1.0e-6 * good_min
    input_log_data = np.log(input_data)

    raw_fmin = self.f_min.to_value(u.Hz)

    if self.white_noise_max is None:
        net = self._estimate_net(input_freqs, input_data)
    else:
        plateau_samples = np.logical_and(
            (input_freqs > self.white_noise_min.to_value(u.Hz)),
            (input_freqs < self.white_noise_max.to_value(u.Hz)),
        )
        net = np.sqrt(np.mean(input_data[plateau_samples]))

    midfreq = 0.5 * input_freqs[-1]

    bounds = (
        np.array([input_freqs[0], 0.1]),
        np.array([input_freqs[-1], 10.0]),
    )
    x_0 = guess
    if x_0 is None:
        x_0 = np.array([midfreq, 1.0])

    try:
        result = least_squares(
            self._fit_log_fun,
            x_0,
            jac=self._fit_log_jac,
            bounds=bounds,
            xtol=self.least_squares_xtol,
            gtol=self.least_squares_gtol,
            ftol=self.least_squares_ftol,
            max_nfev=500,
            verbose=0,
            kwargs={
                "freqs": input_freqs,
                "logdata": input_log_data,
                "fmin": raw_fmin,
                "net": net,
            },
        )
    except Exception:
        log.verbose(f"PSD fit raised exception, skipping")
        ret = self._get_err_ret(psd_unit)
        return ret

    ret["fit_result"] = result
    ret["NET"] = net * np.sqrt(1.0 * psd_unit)
    ret["fmin"] = self.f_min
    if result.success:
        ret["fknee"] = result.x[0] * u.Hz
        ret["alpha"] = result.x[1]
    else:
        ret["fknee"] = 0.0 * u.Hz
        ret["alpha"] = 1.0

    return ret

_get_err_ret(psd_unit)

Source code in toast/ops/noise_model.py
452
453
454
455
456
457
458
459
460
461
462
def _get_err_ret(self, psd_unit):
    # Internal function to build a fake return result
    # when the fitting fails for some reason.
    eret = dict()
    eret["fit_result"] = types.SimpleNamespace()
    eret["fit_result"].success = False
    eret["NET"] = 0.0 * np.sqrt(1.0 * psd_unit)
    eret["fmin"] = 0.0 * u.Hz
    eret["fknee"] = 0.0 * u.Hz
    eret["alpha"] = 0.0
    return eret

_provides()

Source code in toast/ops/noise_model.py
591
592
593
def _provides(self):
    prov = {"meta": [self.noise_model]}
    return prov

_requires()

Source code in toast/ops/noise_model.py
588
589
def _requires(self):
    return dict()

toast.ops.FlagNoiseFit

Bases: Operator

Operator which flags detectors that have outlier noise properties.

Source code in toast/ops/noise_model.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
@trait_docs
class FlagNoiseFit(Operator):
    """Operator which flags detectors that have outlier noise properties."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    noise_model = Unicode(
        "noise_model", help="The observation key containing the noise model"
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_data = Unicode(
        defaults.det_data,
        help="Observation detdata key for timestreams (only if RMS cut enabled)",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging (only if RMS cut used)",
    )

    outlier_flag_mask = Int(
        defaults.det_mask_processing, help="Bit mask to raise flags with"
    )

    sigma_rms = Float(
        None,
        allow_none=True,
        help="In addition to flagging based on estimated model, also apply overall TOD cut",
    )

    sigma_NET = Float(5.0, help="Flag detectors with NET values outside this range")

    sigma_fknee = Float(
        None,
        allow_none=True,
        help="Flag detectors with knee frequency values outside this range",
    )

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if self.det_flags is None:
            raise RuntimeError("You must set det_flags before calling exec()")

        for obs in data.obs:
            dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            if self.noise_model not in obs:
                msg = f"Observation {obs.name} does not contain noise model {self.noise_model}"
                raise RuntimeError(msg)

            exists = obs.detdata.ensure(self.det_flags, dtype=np.uint8, detectors=dets)

            model = obs[self.noise_model]

            local_net = list()
            local_fknee = list()
            local_rms = list()

            # If we have an analytic noise model from a simulation or fit, then we can
            # access the properties directly.  If not, we will use the detector weight
            # as a proxy for the NET and make a crude estimate of the knee frequency.

            for det in dets:
                try:
                    NET = model.NET(det)
                except AttributeError:
                    wt = model.detector_weight(det)
                    NET = np.sqrt(1.0 / (wt * model.rate(det)))
                local_net.append(NET.to_value(u.K * np.sqrt(1.0 * u.second)))
                if self.sigma_fknee is not None:
                    try:
                        fknee = model.fknee(det)
                        local_fknee.append(fknee.to_value(u.Hz))
                    except AttributeError:
                        msg = f"Observation {obs.name}, noise model {self.noise_model} "
                        msg += "has no f_knee estimate.  Use FitNoiseModel before flagging."
                if self.sigma_rms is not None:
                    good = (obs.detdata[self.det_flags][det] & self.det_flag_mask) == 0
                    ddata = np.copy(obs.detdata[self.det_data][det, good])
                    avg = np.mean(ddata)
                    ddata -= avg
                    local_rms.append(np.std(ddata))
                    del ddata

            local_net = np.array(local_net, dtype=np.float64)
            local_fknee = np.array(local_fknee, dtype=np.float64)
            local_rms = np.array(local_rms, dtype=np.float64)
            local_names = dets

            # Send all values to one process for the trivial calculation
            all_net = None
            all_fknee = None
            all_rms = None
            all_names = None
            if obs.comm_row_rank == 0:
                # First process column.  Gather results to rank zero.
                if obs.comm_col is None:
                    all_net = local_net
                    all_fknee = local_fknee
                    all_names = local_names
                    all_rms = local_rms
                else:
                    proc_vals = obs.comm_col.gather(local_net, root=0)
                    if obs.comm_col_rank == 0:
                        all_net = np.array(list(flatten(proc_vals)))
                    proc_vals = obs.comm_col.gather(local_fknee, root=0)
                    if obs.comm_col_rank == 0:
                        all_fknee = np.array(list(flatten(proc_vals)))
                    proc_vals = obs.comm_col.gather(local_rms, root=0)
                    if obs.comm_col_rank == 0:
                        all_rms = np.array(list(flatten(proc_vals)))
                    proc_vals = obs.comm_col.gather(local_names, root=0)
                    if obs.comm_col_rank == 0:
                        all_names = list(flatten(proc_vals))

            # Iteratively cut
            all_flags = None
            if obs.comm.group_rank == 0:
                all_good = all_net > 0.0
                n_good_fit = np.count_nonzero(all_good)
                msg = f"obs {obs.name}: {n_good_fit} / {len(all_good)} "
                msg += "detectors have valid noise model"
                log.debug(msg)
                n_cut = 1
                flag_pass = 0
                while n_cut > 0:
                    n_cut = 0
                    net_mean = np.mean(all_net[all_good])
                    net_std = np.std(all_net[all_good])
                    for idet, (name, net) in enumerate(zip(all_names, all_net)):
                        if not all_good[idet]:
                            # Already cut
                            continue
                        if np.absolute(net - net_mean) > net_std * self.sigma_NET:
                            msg = f"obs {obs.name}, det {name} has NET "
                            msg += f"{net} that is > {self.sigma_NET} "
                            msg += f"x {net_std} from {net_mean}"
                            log.debug(msg)
                            all_good[idet] = False
                            n_cut += 1
                    if self.sigma_fknee is not None:
                        fknee_mean = np.mean(all_fknee[all_good])
                        fknee_std = np.std(all_fknee[all_good])
                        for idet, (name, fknee) in enumerate(zip(all_names, all_fknee)):
                            if not all_good[idet]:
                                # Already cut
                                continue
                            if (
                                np.absolute(fknee - fknee_mean)
                                > fknee_std * self.sigma_fknee
                            ):
                                msg = f"obs {obs.name}, det {name} has f_knee "
                                msg += f"{fknee} that is > {self.sigma_fknee} "
                                msg += f"x {fknee_std} from {fknee_mean}"
                                log.debug(msg)
                                all_good[idet] = False
                                n_cut += 1
                    if self.sigma_rms is not None:
                        rms_mean = np.mean(all_rms[all_good])
                        rms_std = np.std(all_rms[all_good])
                        for idet, (name, rms) in enumerate(zip(all_names, all_rms)):
                            if not all_good[idet]:
                                # Already cut
                                continue
                            if np.absolute(rms - rms_mean) > rms_std * self.sigma_rms:
                                msg = f"obs {obs.name}, det {name} has TOD RMS "
                                msg += f"{rms} that is > {self.sigma_rms} "
                                msg += f"x {rms_std} from {rms_mean}"
                                log.debug(msg)
                                all_good[idet] = False
                                n_cut += 1
                    msg = f"pass {flag_pass}, {n_cut} detectors flagged"
                    log.debug(msg)
                    flag_pass += 1
                all_flags = {
                    x: self.outlier_flag_mask
                    for i, x in enumerate(all_names)
                    if not all_good[i]
                }
                msg = f"obs {obs.name}: flagged {len(all_flags)} / {len(all_names)}"
                msg += " outlier detectors"
                log.info(msg)
            if obs.comm.comm_group is not None:
                all_flags = obs.comm.comm_group.bcast(all_flags, root=0)

            # Every process flags its local detectors
            det_check = set(dets)
            local_flags = dict(obs.local_detector_flags)
            for det, val in all_flags.items():
                if det in det_check:
                    local_flags[det] |= val
                    obs.detdata[self.det_flags][det, :] |= val
            obs.update_local_detector_flags(local_flags)

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": [
                self.noise_model,
            ],
            "shared": [],
            "detdata": [],
            "intervals": [],
        }
        return req

    def _provides(self):
        prov = {
            "meta": [],
            "shared": [],
            "detdata": [self.det_flags],
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for timestreams (only if RMS cut enabled)') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging (only if RMS cut used)') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='The observation key containing the noise model') class-attribute instance-attribute

outlier_flag_mask = Int(defaults.det_mask_processing, help='Bit mask to raise flags with') class-attribute instance-attribute

sigma_NET = Float(5.0, help='Flag detectors with NET values outside this range') class-attribute instance-attribute

sigma_fknee = Float(None, allow_none=True, help='Flag detectors with knee frequency values outside this range') class-attribute instance-attribute

sigma_rms = Float(None, allow_none=True, help='In addition to flagging based on estimated model, also apply overall TOD cut') class-attribute instance-attribute

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/noise_model.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if self.det_flags is None:
        raise RuntimeError("You must set det_flags before calling exec()")

    for obs in data.obs:
        dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        if self.noise_model not in obs:
            msg = f"Observation {obs.name} does not contain noise model {self.noise_model}"
            raise RuntimeError(msg)

        exists = obs.detdata.ensure(self.det_flags, dtype=np.uint8, detectors=dets)

        model = obs[self.noise_model]

        local_net = list()
        local_fknee = list()
        local_rms = list()

        # If we have an analytic noise model from a simulation or fit, then we can
        # access the properties directly.  If not, we will use the detector weight
        # as a proxy for the NET and make a crude estimate of the knee frequency.

        for det in dets:
            try:
                NET = model.NET(det)
            except AttributeError:
                wt = model.detector_weight(det)
                NET = np.sqrt(1.0 / (wt * model.rate(det)))
            local_net.append(NET.to_value(u.K * np.sqrt(1.0 * u.second)))
            if self.sigma_fknee is not None:
                try:
                    fknee = model.fknee(det)
                    local_fknee.append(fknee.to_value(u.Hz))
                except AttributeError:
                    msg = f"Observation {obs.name}, noise model {self.noise_model} "
                    msg += "has no f_knee estimate.  Use FitNoiseModel before flagging."
            if self.sigma_rms is not None:
                good = (obs.detdata[self.det_flags][det] & self.det_flag_mask) == 0
                ddata = np.copy(obs.detdata[self.det_data][det, good])
                avg = np.mean(ddata)
                ddata -= avg
                local_rms.append(np.std(ddata))
                del ddata

        local_net = np.array(local_net, dtype=np.float64)
        local_fknee = np.array(local_fknee, dtype=np.float64)
        local_rms = np.array(local_rms, dtype=np.float64)
        local_names = dets

        # Send all values to one process for the trivial calculation
        all_net = None
        all_fknee = None
        all_rms = None
        all_names = None
        if obs.comm_row_rank == 0:
            # First process column.  Gather results to rank zero.
            if obs.comm_col is None:
                all_net = local_net
                all_fknee = local_fknee
                all_names = local_names
                all_rms = local_rms
            else:
                proc_vals = obs.comm_col.gather(local_net, root=0)
                if obs.comm_col_rank == 0:
                    all_net = np.array(list(flatten(proc_vals)))
                proc_vals = obs.comm_col.gather(local_fknee, root=0)
                if obs.comm_col_rank == 0:
                    all_fknee = np.array(list(flatten(proc_vals)))
                proc_vals = obs.comm_col.gather(local_rms, root=0)
                if obs.comm_col_rank == 0:
                    all_rms = np.array(list(flatten(proc_vals)))
                proc_vals = obs.comm_col.gather(local_names, root=0)
                if obs.comm_col_rank == 0:
                    all_names = list(flatten(proc_vals))

        # Iteratively cut
        all_flags = None
        if obs.comm.group_rank == 0:
            all_good = all_net > 0.0
            n_good_fit = np.count_nonzero(all_good)
            msg = f"obs {obs.name}: {n_good_fit} / {len(all_good)} "
            msg += "detectors have valid noise model"
            log.debug(msg)
            n_cut = 1
            flag_pass = 0
            while n_cut > 0:
                n_cut = 0
                net_mean = np.mean(all_net[all_good])
                net_std = np.std(all_net[all_good])
                for idet, (name, net) in enumerate(zip(all_names, all_net)):
                    if not all_good[idet]:
                        # Already cut
                        continue
                    if np.absolute(net - net_mean) > net_std * self.sigma_NET:
                        msg = f"obs {obs.name}, det {name} has NET "
                        msg += f"{net} that is > {self.sigma_NET} "
                        msg += f"x {net_std} from {net_mean}"
                        log.debug(msg)
                        all_good[idet] = False
                        n_cut += 1
                if self.sigma_fknee is not None:
                    fknee_mean = np.mean(all_fknee[all_good])
                    fknee_std = np.std(all_fknee[all_good])
                    for idet, (name, fknee) in enumerate(zip(all_names, all_fknee)):
                        if not all_good[idet]:
                            # Already cut
                            continue
                        if (
                            np.absolute(fknee - fknee_mean)
                            > fknee_std * self.sigma_fknee
                        ):
                            msg = f"obs {obs.name}, det {name} has f_knee "
                            msg += f"{fknee} that is > {self.sigma_fknee} "
                            msg += f"x {fknee_std} from {fknee_mean}"
                            log.debug(msg)
                            all_good[idet] = False
                            n_cut += 1
                if self.sigma_rms is not None:
                    rms_mean = np.mean(all_rms[all_good])
                    rms_std = np.std(all_rms[all_good])
                    for idet, (name, rms) in enumerate(zip(all_names, all_rms)):
                        if not all_good[idet]:
                            # Already cut
                            continue
                        if np.absolute(rms - rms_mean) > rms_std * self.sigma_rms:
                            msg = f"obs {obs.name}, det {name} has TOD RMS "
                            msg += f"{rms} that is > {self.sigma_rms} "
                            msg += f"x {rms_std} from {rms_mean}"
                            log.debug(msg)
                            all_good[idet] = False
                            n_cut += 1
                msg = f"pass {flag_pass}, {n_cut} detectors flagged"
                log.debug(msg)
                flag_pass += 1
            all_flags = {
                x: self.outlier_flag_mask
                for i, x in enumerate(all_names)
                if not all_good[i]
            }
            msg = f"obs {obs.name}: flagged {len(all_flags)} / {len(all_names)}"
            msg += " outlier detectors"
            log.info(msg)
        if obs.comm.comm_group is not None:
            all_flags = obs.comm.comm_group.bcast(all_flags, root=0)

        # Every process flags its local detectors
        det_check = set(dets)
        local_flags = dict(obs.local_detector_flags)
        for det, val in all_flags.items():
            if det in det_check:
                local_flags[det] |= val
                obs.detdata[self.det_flags][det, :] |= val
        obs.update_local_detector_flags(local_flags)

_finalize(data, **kwargs)

Source code in toast/ops/noise_model.py
808
809
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/noise_model.py
822
823
824
825
826
827
828
def _provides(self):
    prov = {
        "meta": [],
        "shared": [],
        "detdata": [self.det_flags],
    }
    return prov

_requires()

Source code in toast/ops/noise_model.py
811
812
813
814
815
816
817
818
819
820
def _requires(self):
    req = {
        "meta": [
            self.noise_model,
        ],
        "shared": [],
        "detdata": [],
        "intervals": [],
    }
    return req

toast.ops.SignalDiffNoiseModel

Bases: Operator

Evaluate a simple white noise model based on consecutive sample differences.

Source code in toast/ops/signal_diff_noise_model.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@trait_docs
class SignalDiffNoiseModel(Operator):
    """Evaluate a simple white noise model based on consecutive sample
    differences.
    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(defaults.det_data, help="Observation detdata key to analyze")

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    noise_model = Unicode(
        "noise_model", help="The observation key containing the output noise model"
    )

    view = Unicode(
        None,
        allow_none=True,
        help="Evaluate the sample differences in this view",
    )

    fmin = Quantity(1e-6 * u.Hz, help="Minimum frequency to use for noise model.")

    fknee = Quantity(1e-6 * u.Hz, help="Knee frequency to use for noise model.")

    alpha = Float(1.0, help="Slope of the 1/f noise model")

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.net_factors = []
        self.total_factors = []
        self.weights_in = []
        self.weights_out = []
        self.rates = []

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        if detectors is not None:
            msg = "You must run this operator on all detectors at once"
            log.error(msg)
            raise RuntimeError(msg)

        for ob in data.obs:
            if not ob.is_distributed_by_detector:
                msg = "Observation data must be distributed by detector, not samples"
                log.error(msg)
                raise RuntimeError(msg)
            focalplane = ob.telescope.focalplane
            fsample = focalplane.sample_rate

            shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
            signal_units = ob.detdata[self.det_data].units

            # Create the noise model for all detectors, even flagged ones.
            dets = []
            fmin = {}
            fknee = {}
            alpha = {}
            NET = {}
            rates = {}
            indices = {}
            for name in ob.local_detectors:
                dets.append(name)
                rates[name] = fsample
                fmin[name] = self.fmin
                fknee[name] = self.fknee
                alpha[name] = self.alpha
                NET[name] = 0.0 * signal_units / np.sqrt(fsample)
                indices[name] = focalplane[name]["uid"]

            # Set the NET for the good detectors
            for name in ob.select_local_detectors(flagmask=defaults.det_mask_invalid):
                # Estimate white noise from consecutive sample differences.
                # Neither of the samples can have flags raised.
                sig = ob.detdata[self.det_data][name]
                det_flags = ob.detdata[self.det_flags][name] & self.det_flag_mask
                good = np.logical_and(shared_flags == 0, det_flags == 0)
                sig_diff = sig[1:] - sig[:-1]
                good_diff = np.logical_and(good[1:], good[:-1])
                sigma = np.std(sig_diff[good_diff]) / np.sqrt(2) * signal_units
                net = sigma / np.sqrt(fsample)
                # Store the estimate in a noise model
                NET[name] = net

            ob[self.noise_model] = AnalyticNoise(
                rate=rates,
                fmin=fmin,
                detectors=dets,
                fknee=fknee,
                alpha=alpha,
                NET=NET,
                indices=indices,
            )

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": list(),
            "shared": list(),
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "meta": [self.noise_model],
            "shared": list(),
            "detdata": list(),
            "intervals": list(),
        }
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

alpha = Float(1.0, help='Slope of the 1/f noise model') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key to analyze') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

fknee = Quantity(1e-06 * u.Hz, help='Knee frequency to use for noise model.') class-attribute instance-attribute

fmin = Quantity(1e-06 * u.Hz, help='Minimum frequency to use for noise model.') class-attribute instance-attribute

net_factors = [] instance-attribute

noise_model = Unicode('noise_model', help='The observation key containing the output noise model') class-attribute instance-attribute

rates = [] instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

total_factors = [] instance-attribute

view = Unicode(None, allow_none=True, help='Evaluate the sample differences in this view') class-attribute instance-attribute

weights_in = [] instance-attribute

weights_out = [] instance-attribute

__init__(**kwargs)

Source code in toast/ops/signal_diff_noise_model.py
 99
100
101
102
103
104
105
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.net_factors = []
    self.total_factors = []
    self.weights_in = []
    self.weights_out = []
    self.rates = []

_check_det_flag_mask(proposal)

Source code in toast/ops/signal_diff_noise_model.py
92
93
94
95
96
97
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/signal_diff_noise_model.py
78
79
80
81
82
83
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/signal_diff_noise_model.py
85
86
87
88
89
90
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/signal_diff_noise_model.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    if detectors is not None:
        msg = "You must run this operator on all detectors at once"
        log.error(msg)
        raise RuntimeError(msg)

    for ob in data.obs:
        if not ob.is_distributed_by_detector:
            msg = "Observation data must be distributed by detector, not samples"
            log.error(msg)
            raise RuntimeError(msg)
        focalplane = ob.telescope.focalplane
        fsample = focalplane.sample_rate

        shared_flags = ob.shared[self.shared_flags].data & self.shared_flag_mask
        signal_units = ob.detdata[self.det_data].units

        # Create the noise model for all detectors, even flagged ones.
        dets = []
        fmin = {}
        fknee = {}
        alpha = {}
        NET = {}
        rates = {}
        indices = {}
        for name in ob.local_detectors:
            dets.append(name)
            rates[name] = fsample
            fmin[name] = self.fmin
            fknee[name] = self.fknee
            alpha[name] = self.alpha
            NET[name] = 0.0 * signal_units / np.sqrt(fsample)
            indices[name] = focalplane[name]["uid"]

        # Set the NET for the good detectors
        for name in ob.select_local_detectors(flagmask=defaults.det_mask_invalid):
            # Estimate white noise from consecutive sample differences.
            # Neither of the samples can have flags raised.
            sig = ob.detdata[self.det_data][name]
            det_flags = ob.detdata[self.det_flags][name] & self.det_flag_mask
            good = np.logical_and(shared_flags == 0, det_flags == 0)
            sig_diff = sig[1:] - sig[:-1]
            good_diff = np.logical_and(good[1:], good[:-1])
            sigma = np.std(sig_diff[good_diff]) / np.sqrt(2) * signal_units
            net = sigma / np.sqrt(fsample)
            # Store the estimate in a noise model
            NET[name] = net

        ob[self.noise_model] = AnalyticNoise(
            rate=rates,
            fmin=fmin,
            detectors=dets,
            fknee=fknee,
            alpha=alpha,
            NET=NET,
            indices=indices,
        )

    return

_finalize(data, **kwargs)

Source code in toast/ops/signal_diff_noise_model.py
170
171
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/signal_diff_noise_model.py
188
189
190
191
192
193
194
195
def _provides(self):
    prov = {
        "meta": [self.noise_model],
        "shared": list(),
        "detdata": list(),
        "intervals": list(),
    }
    return prov

_requires()

Source code in toast/ops/signal_diff_noise_model.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def _requires(self):
    req = {
        "meta": list(),
        "shared": list(),
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

Map Making

Utilities

toast.ops.BuildPixelDistribution

Bases: Operator

Operator which builds the pixel distribution information.

This operator runs the pointing operator and builds the PixelDist instance describing how submaps are distributed among processes. This requires expanding the full detector pointing once in order to compute the distribution. This is done one detector at a time unless the save_pointing trait is set to True.

NOTE: The pointing operator must have the "pixels" and "create_dist" traits, which will be set by this operator during execution.

Output PixelDistribution objects are stored in the Data dictionary.

Source code in toast/ops/pointing.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@trait_docs
class BuildPixelDistribution(Operator):
    """Operator which builds the pixel distribution information.

    This operator runs the pointing operator and builds the PixelDist instance
    describing how submaps are distributed among processes.  This requires expanding
    the full detector pointing once in order to compute the distribution.  This is
    done one detector at a time unless the save_pointing trait is set to True.

    NOTE:  The pointing operator must have the "pixels" and "create_dist"
    traits, which will be set by this operator during execution.

    Output PixelDistribution objects are stored in the Data dictionary.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        "pixel_dist",
        help="The Data key where the PixelDist object should be stored",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pointing operator",
    )

    save_pointing = Bool(
        False, help="If True, do not clear detector pointing matrices after use"
    )

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pntg = proposal["value"]
        if pntg is not None:
            if not isinstance(pntg, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pntg.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pntg

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for trait in ("pixel_pointing",):
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        if self.pixel_dist in data:
            msg = f"pixel distribution `{self.pixel_dist}` already exists"
            raise RuntimeError(msg)

        if detectors is not None:
            msg = "A subset of detectors is specified, but the pixel distribution\n"
            msg += "does not yet exist- and creating this requires all detectors."
            raise RuntimeError(msg)

        msg = "Creating pixel distribution '{}' in Data".format(self.pixel_dist)
        if data.comm.world_rank == 0:
            log.debug(msg)

        # Turn on creation of the pixel distribution
        self.pixel_pointing.create_dist = self.pixel_dist

        # Compute the pointing matrix

        pixel_dist_pipe = None
        if self.save_pointing:
            # We are keeping the pointing, which means we need to run all detectors
            # at once so they all end up in the detdata for all observations.
            pixel_dist_pipe = Pipeline(detector_sets=["ALL"])
        else:
            # Run one detector a at time and discard.
            pixel_dist_pipe = Pipeline(detector_sets=["SINGLE"])
        pixel_dist_pipe.operators = [
            self.pixel_pointing,
        ]
        # FIXME: Disable accelerator use for now, since it is a small amount of
        # calculation for a huge data volume.
        pipe_out = pixel_dist_pipe.apply(data, detectors=detectors, use_accel=False)

        # Turn pixel distribution creation off again
        self.pixel_pointing.create_dist = None

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.pixel_pointing.requires()
        return req

    def _provides(self):
        prov = {
            "global": [self.pixel_dist],
            "shared": list(),
            "detdata": list(),
        }
        if self.save_pointing:
            prov["detdata"].extend([self.pixels])
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

pixel_dist = Unicode('pixel_dist', help='The Data key where the PixelDist object should be stored') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pointing operator') class-attribute instance-attribute

save_pointing = Bool(False, help='If True, do not clear detector pointing matrices after use') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/pointing.py
67
68
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_pixel_pointing(proposal)

Source code in toast/ops/pointing.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pntg = proposal["value"]
    if pntg is not None:
        if not isinstance(pntg, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pntg.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pntg

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/pointing.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for trait in ("pixel_pointing",):
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    if self.pixel_dist in data:
        msg = f"pixel distribution `{self.pixel_dist}` already exists"
        raise RuntimeError(msg)

    if detectors is not None:
        msg = "A subset of detectors is specified, but the pixel distribution\n"
        msg += "does not yet exist- and creating this requires all detectors."
        raise RuntimeError(msg)

    msg = "Creating pixel distribution '{}' in Data".format(self.pixel_dist)
    if data.comm.world_rank == 0:
        log.debug(msg)

    # Turn on creation of the pixel distribution
    self.pixel_pointing.create_dist = self.pixel_dist

    # Compute the pointing matrix

    pixel_dist_pipe = None
    if self.save_pointing:
        # We are keeping the pointing, which means we need to run all detectors
        # at once so they all end up in the detdata for all observations.
        pixel_dist_pipe = Pipeline(detector_sets=["ALL"])
    else:
        # Run one detector a at time and discard.
        pixel_dist_pipe = Pipeline(detector_sets=["SINGLE"])
    pixel_dist_pipe.operators = [
        self.pixel_pointing,
    ]
    # FIXME: Disable accelerator use for now, since it is a small amount of
    # calculation for a huge data volume.
    pipe_out = pixel_dist_pipe.apply(data, detectors=detectors, use_accel=False)

    # Turn pixel distribution creation off again
    self.pixel_pointing.create_dist = None

    return

_finalize(data, **kwargs)

Source code in toast/ops/pointing.py
117
118
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/pointing.py
124
125
126
127
128
129
130
131
132
def _provides(self):
    prov = {
        "global": [self.pixel_dist],
        "shared": list(),
        "detdata": list(),
    }
    if self.save_pointing:
        prov["detdata"].extend([self.pixels])
    return prov

_requires()

Source code in toast/ops/pointing.py
120
121
122
def _requires(self):
    req = self.pixel_pointing.requires()
    return req

toast.ops.BuildHitMap

Bases: Operator

Operator which builds a hitmap.

Given the pointing matrix for each detector, accumulate the hit map. The PixelData object containing the hit map is returned by the finalize() method.

If any samples have compromised telescope pointing, those pixel indices should have already been set to a negative value by the operator that generated the pointing matrix.

Although individual detector flags do not impact the pointing per se, they can be used with this operator in order to produce a hit map that is consistent with other pixel space products. The detector mask defaults to cutting "non-science" samples.

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@trait_docs
class BuildHitMap(Operator):
    """Operator which builds a hitmap.

    Given the pointing matrix for each detector, accumulate the hit map.  The PixelData
    object containing the hit map is returned by the finalize() method.

    If any samples have compromised telescope pointing, those pixel indices should
    have already been set to a negative value by the operator that generated the
    pointing matrix.

    Although individual detector flags do not impact the pointing per se, they can be
    used with this operator in order to produce a hit map that is consistent with other
    pixel space products.  The detector mask defaults to cutting "non-science" samples.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        None,
        allow_none=True,
        help="The Data key containing the submap distribution",
    )

    hits = Unicode("hits", help="The Data key for the output hit map")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional flagging",
    )

    pixels = Unicode(defaults.pixels, help="Observation detdata key for pixel indices")

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("sync_type")
    def _check_sync_type(self, proposal):
        check = proposal["value"]
        if check != "allreduce" and check != "alltoallv":
            raise traitlets.TraitError("Invalid communication algorithm")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.pixel_dist is None:
            raise RuntimeError(
                "You must set the 'pixel_dist' trait before calling exec()"
            )

        if self.pixel_dist not in data:
            msg = "Data does not contain submap distribution '{}'".format(
                self.pixel_dist
            )
            raise RuntimeError(msg)

        dist = data[self.pixel_dist]
        log.verbose_rank(
            f"Building hit map with pixel_distribution {self.pixel_dist}",
            comm=data.comm.comm_world,
        )

        hits = None
        if self.hits in data:
            # We have an existing map from a previous call.  Verify
            # the distribution and nnz.
            if data[self.hits].distribution != dist:
                msg = "Existing hits '{}' has different data distribution".format(
                    self.hits
                )
                log.error(msg)
                raise RuntimeError(msg)
            if data[self.hits].n_value != 1:
                msg = "Existing hits '{}' has {} nnz, not 1".format(
                    self.hits, data[self.hits].n_value
                )
                log.error(msg)
                raise RuntimeError(msg)
            hits = data[self.hits]
        else:
            data[self.hits] = PixelData(dist, np.int64, n_value=1)
            hits = data[self.hits]

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors, flagmask=self.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # The pixels and weights view for this observation
            pix = ob.view[self.view].detdata[self.pixels]
            if self.det_flags is not None:
                flgs = ob.view[self.view].detdata[self.det_flags]
            else:
                flgs = [None for x in pix]
            if self.shared_flags is not None:
                shared_flgs = ob.view[self.view].shared[self.shared_flags]
            else:
                shared_flgs = [None for x in pix]

            for det in dets:
                # Process every data view
                for pview, fview, shared_fview in zip(pix, flgs, shared_flgs):
                    # Get local submap and pixels
                    local_sm, local_pix = dist.global_pixel_to_submap(pview[det])

                    # Samples with telescope pointing problems are already flagged in
                    # the pointing operators by setting the pixel numbers to a negative
                    # value.  Here we optionally apply detector flags to the local
                    # pixel numbers to flag more samples.

                    # Apply the flags if needed
                    if self.det_flags is not None:
                        local_pix[(fview[det] & self.det_flag_mask) != 0] = -1
                    if self.shared_flags is not None:
                        local_pix[(shared_fview & self.shared_flag_mask) != 0] = -1

                    cov_accum_diag_hits(
                        dist.n_local_submap,
                        dist.n_pix_submap,
                        1,
                        local_sm.astype(np.int64),
                        local_pix.astype(np.int64),
                        hits.raw,
                        impl=implementation,
                        use_accel=use_accel,
                    )
        return

    @function_timer
    def _finalize(self, data, **kwargs):
        if self.hits in data:
            if self.sync_type == "alltoallv":
                data[self.hits].sync_alltoallv()
            else:
                data[self.hits].sync_allreduce()
        return

    def _requires(self):
        req = {
            "global": [self.pixel_dist],
            "shared": list(),
            "detdata": [self.pixels],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {"global": [self.hits]}
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        # NOTE: the kernels called do not follow the proper pattern yet
        return False

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

hits = Unicode('hits', help='The Data key for the output hit map') class-attribute instance-attribute

pixel_dist = Unicode(None, allow_none=True, help='The Data key containing the submap distribution') class-attribute instance-attribute

pixels = Unicode(defaults.pixels, help='Observation detdata key for pixel indices') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
110
111
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
89
90
91
92
93
94
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
 96
 97
 98
 99
100
101
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_sync_type(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
103
104
105
106
107
108
@traitlets.validate("sync_type")
def _check_sync_type(self, proposal):
    check = proposal["value"]
    if check != "allreduce" and check != "alltoallv":
        raise traitlets.TraitError("Invalid communication algorithm")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.pixel_dist is None:
        raise RuntimeError(
            "You must set the 'pixel_dist' trait before calling exec()"
        )

    if self.pixel_dist not in data:
        msg = "Data does not contain submap distribution '{}'".format(
            self.pixel_dist
        )
        raise RuntimeError(msg)

    dist = data[self.pixel_dist]
    log.verbose_rank(
        f"Building hit map with pixel_distribution {self.pixel_dist}",
        comm=data.comm.comm_world,
    )

    hits = None
    if self.hits in data:
        # We have an existing map from a previous call.  Verify
        # the distribution and nnz.
        if data[self.hits].distribution != dist:
            msg = "Existing hits '{}' has different data distribution".format(
                self.hits
            )
            log.error(msg)
            raise RuntimeError(msg)
        if data[self.hits].n_value != 1:
            msg = "Existing hits '{}' has {} nnz, not 1".format(
                self.hits, data[self.hits].n_value
            )
            log.error(msg)
            raise RuntimeError(msg)
        hits = data[self.hits]
    else:
        data[self.hits] = PixelData(dist, np.int64, n_value=1)
        hits = data[self.hits]

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            selection=detectors, flagmask=self.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # The pixels and weights view for this observation
        pix = ob.view[self.view].detdata[self.pixels]
        if self.det_flags is not None:
            flgs = ob.view[self.view].detdata[self.det_flags]
        else:
            flgs = [None for x in pix]
        if self.shared_flags is not None:
            shared_flgs = ob.view[self.view].shared[self.shared_flags]
        else:
            shared_flgs = [None for x in pix]

        for det in dets:
            # Process every data view
            for pview, fview, shared_fview in zip(pix, flgs, shared_flgs):
                # Get local submap and pixels
                local_sm, local_pix = dist.global_pixel_to_submap(pview[det])

                # Samples with telescope pointing problems are already flagged in
                # the pointing operators by setting the pixel numbers to a negative
                # value.  Here we optionally apply detector flags to the local
                # pixel numbers to flag more samples.

                # Apply the flags if needed
                if self.det_flags is not None:
                    local_pix[(fview[det] & self.det_flag_mask) != 0] = -1
                if self.shared_flags is not None:
                    local_pix[(shared_fview & self.shared_flag_mask) != 0] = -1

                cov_accum_diag_hits(
                    dist.n_local_submap,
                    dist.n_pix_submap,
                    1,
                    local_sm.astype(np.int64),
                    local_pix.astype(np.int64),
                    hits.raw,
                    impl=implementation,
                    use_accel=use_accel,
                )
    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
207
208
209
210
211
212
213
214
@function_timer
def _finalize(self, data, **kwargs):
    if self.hits in data:
        if self.sync_type == "alltoallv":
            data[self.hits].sync_alltoallv()
        else:
            data[self.hits].sync_allreduce()
    return

_implementations()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
235
236
237
238
239
240
241
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
231
232
233
def _provides(self):
    prov = {"global": [self.hits]}
    return prov

_requires()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def _requires(self):
    req = {
        "global": [self.pixel_dist],
        "shared": list(),
        "detdata": [self.pixels],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
243
244
245
def _supports_accel(self):
    # NOTE: the kernels called do not follow the proper pattern yet
    return False

toast.ops.BuildInverseCovariance

Bases: Operator

Operator which builds a pixel-space diagonal inverse noise covariance.

Given the pointing matrix and noise model for each detector, accumulate the inverse noise covariance:

.. math:: N_pp'^{-1} = \left( P^T N_tt'^{-1} P \right)

The PixelData object containing this is returned by the finalize() method.

If any samples have compromised telescope pointing, those pixel indices should have already been set to a negative value by the operator that generated the pointing matrix. Individual detector flags can optionally be applied to timesamples when accumulating data. The detector mask defaults to cutting "non-science" samples.

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
@trait_docs
class BuildInverseCovariance(Operator):
    """Operator which builds a pixel-space diagonal inverse noise covariance.

    Given the pointing matrix and noise model for each detector, accumulate the inverse
    noise covariance:

    .. math::
        N_pp'^{-1} = \\left( P^T N_tt'^{-1} P \\right)

    The PixelData object containing this is returned by the finalize() method.

    If any samples have compromised telescope pointing, those pixel indices should
    have already been set to a negative value by the operator that generated the
    pointing matrix.  Individual detector flags can optionally be applied to
    timesamples when accumulating data.  The detector mask defaults to cutting
    "non-science" samples.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        None,
        allow_none=True,
        help="The Data key containing the submap distribution",
    )

    inverse_covariance = Unicode(
        "inv_covariance", help="The Data key for the output inverse covariance"
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    det_data_units = Unit(defaults.det_data_units, help="Desired timestream units")

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional flagging",
    )

    pixels = Unicode("pixels", help="Observation detdata key for pixel indices")

    weights = Unicode("weights", help="Observation detdata key for Stokes weights")

    noise_model = Unicode(
        "noise_model", help="Observation key containing the noise model"
    )

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("sync_type")
    def _check_sync_type(self, proposal):
        check = proposal["value"]
        if check != "allreduce" and check != "alltoallv":
            raise traitlets.TraitError("Invalid communication algorithm")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.pixel_dist is None:
            raise RuntimeError(
                "You must set the 'pixel_dist' trait before calling exec()"
            )

        if self.pixel_dist not in data:
            msg = "Data does not contain submap distribution '{}'".format(
                self.pixel_dist
            )
            raise RuntimeError(msg)

        dist = data[self.pixel_dist]
        log.verbose_rank(
            f"Building inverse covariance with pixel_distribution {self.pixel_dist}",
            comm=data.comm.comm_world,
        )

        invcov_units = 1.0 / (self.det_data_units**2)

        invcov = None
        weight_nnz = None
        cov_nnz = None

        # We will store the lower triangle of the covariance.  This operator requires
        # that all detectors in all observations have the same number of non-zeros
        # in the pointing matrix.

        if self.inverse_covariance in data:
            # We have an existing map from a previous call.  Verify
            # the distribution and units.
            if data[self.inverse_covariance].distribution != dist:
                msg = "Existing inv cov '{}' has different data distribution".format(
                    self.inverse_covariance
                )
                log.error(msg)
                raise RuntimeError(msg)
            if data[self.inverse_covariance].units != invcov_units:
                msg = "Existing inv cov '{}' has different units".format(
                    self.inverse_covariance
                )
                log.error(msg)
                raise RuntimeError(msg)
            invcov = data[self.inverse_covariance]
            cov_nnz = invcov.n_value
            weight_nnz = int((np.sqrt(1 + 8 * cov_nnz) - 1) // 2)
        else:
            # We are creating a new data object
            weight_nnz = 0
            for ob in data.obs:
                # Get the detectors we are using for this observation
                dets = ob.select_local_detectors(
                    selection=detectors, flagmask=self.det_mask
                )
                if len(dets) == 0:
                    # Nothing to do for this observation
                    continue
                if self.weights in ob.detdata:
                    if len(ob.detdata[self.weights].detector_shape) == 1:
                        cur_nnz = 1
                    else:
                        cur_nnz = ob.detdata[self.weights].detector_shape[1]
                    weight_nnz = max(weight_nnz, cur_nnz)
                else:
                    raise RuntimeError(
                        f"Stokes weights '{self.weights}' not in obs {ob.name}"
                    )
            if data.comm.comm_world is not None:
                weight_nnz = data.comm.comm_world.allreduce(weight_nnz, op=MPI.MAX)
            if weight_nnz == 0:
                msg = f"No valid detectors. Could not infer the pointing matrix "
                msg += f"dimensions from the data."
                raise RuntimeError(msg)
            cov_nnz = int(weight_nnz * (weight_nnz + 1) // 2)
            data[self.inverse_covariance] = PixelData(
                dist, np.float64, n_value=cov_nnz, units=invcov_units
            )
            invcov = data[self.inverse_covariance]

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors, flagmask=self.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # Check that the noise model exists
            if self.noise_model not in ob:
                msg = "Noise model {} does not exist in observation {}".format(
                    self.noise_model, ob.name
                )
                raise RuntimeError(msg)

            noise = ob[self.noise_model]

            # The pixels and weights view for this observation
            pix = ob.view[self.view].detdata[self.pixels]
            wts = ob.view[self.view].detdata[self.weights]
            if self.det_flags is not None:
                flgs = ob.view[self.view].detdata[self.det_flags]
            else:
                flgs = [None for x in wts]
            if self.shared_flags is not None:
                shared_flgs = ob.view[self.view].shared[self.shared_flags]
            else:
                shared_flgs = [None for x in wts]

            for det in dets:
                # Process every data view
                for pview, wview, fview, shared_fview in zip(pix, wts, flgs, shared_flgs):
                    # We require that the pointing matrix has the same number of
                    # non-zero elements for every detector and every observation.
                    # We check that here.

                    check_nnz = None
                    if len(wview.detector_shape) == 1:
                        check_nnz = 1
                    else:
                        check_nnz = wview.detector_shape[1]
                    if check_nnz != weight_nnz:
                        msg = "observation '{}', detector '{}', pointing weights '{}' has {} nnz, not {}".format(
                            ob.name, det, self.weights, check_nnz, weight_nnz
                        )
                        raise RuntimeError(msg)

                    # Get local submap and pixels
                    local_sm, local_pix = dist.global_pixel_to_submap(pview[det])

                    # Get the detector weight from the noise model.
                    detweight = noise.detector_weight(det)

                    # Samples with telescope pointing problems are already flagged in
                    # the pointing operators by setting the pixel numbers to a negative
                    # value.  Here we optionally apply detector flags to the local
                    # pixel numbers to flag more samples.

                    # Apply the flags if needed
                    if self.det_flags is not None:
                        local_pix[(fview[det] & self.det_flag_mask) != 0] = -1
                    if self.shared_flags is not None:
                        local_pix[(shared_fview & self.shared_flag_mask) != 0] = -1

                    # Accumulate
                    cov_accum_diag_invnpp(
                        dist.n_local_submap,
                        dist.n_pix_submap,
                        weight_nnz,
                        local_sm.astype(np.int64),
                        local_pix.astype(np.int64),
                        wview[det].reshape(-1),
                        detweight.to_value(invcov_units),
                        invcov.raw,
                        impl=implementation,
                        use_accel=use_accel,
                    )
        return

    @function_timer
    def _finalize(self, data, **kwargs):
        if self.inverse_covariance in data:
            if self.sync_type == "alltoallv":
                data[self.inverse_covariance].sync_alltoallv()
            else:
                data[self.inverse_covariance].sync_allreduce()
        return

    def _requires(self):
        req = {
            "global": [self.pixel_dist],
            "meta": [self.noise_model],
            "shared": list(),
            "detdata": [self.pixels, self.weights],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {"global": [self.inverse_covariance]}
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        # NOTE: the kernels called do not follow the proper pattern yet
        return False

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Desired timestream units') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

inverse_covariance = Unicode('inv_covariance', help='The Data key for the output inverse covariance') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='Observation key containing the noise model') class-attribute instance-attribute

pixel_dist = Unicode(None, allow_none=True, help='The Data key containing the submap distribution') class-attribute instance-attribute

pixels = Unicode('pixels', help='Observation detdata key for pixel indices') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

weights = Unicode('weights', help='Observation detdata key for Stokes weights') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
348
349
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
327
328
329
330
331
332
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
334
335
336
337
338
339
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_sync_type(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
341
342
343
344
345
346
@traitlets.validate("sync_type")
def _check_sync_type(self, proposal):
    check = proposal["value"]
    if check != "allreduce" and check != "alltoallv":
        raise traitlets.TraitError("Invalid communication algorithm")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.pixel_dist is None:
        raise RuntimeError(
            "You must set the 'pixel_dist' trait before calling exec()"
        )

    if self.pixel_dist not in data:
        msg = "Data does not contain submap distribution '{}'".format(
            self.pixel_dist
        )
        raise RuntimeError(msg)

    dist = data[self.pixel_dist]
    log.verbose_rank(
        f"Building inverse covariance with pixel_distribution {self.pixel_dist}",
        comm=data.comm.comm_world,
    )

    invcov_units = 1.0 / (self.det_data_units**2)

    invcov = None
    weight_nnz = None
    cov_nnz = None

    # We will store the lower triangle of the covariance.  This operator requires
    # that all detectors in all observations have the same number of non-zeros
    # in the pointing matrix.

    if self.inverse_covariance in data:
        # We have an existing map from a previous call.  Verify
        # the distribution and units.
        if data[self.inverse_covariance].distribution != dist:
            msg = "Existing inv cov '{}' has different data distribution".format(
                self.inverse_covariance
            )
            log.error(msg)
            raise RuntimeError(msg)
        if data[self.inverse_covariance].units != invcov_units:
            msg = "Existing inv cov '{}' has different units".format(
                self.inverse_covariance
            )
            log.error(msg)
            raise RuntimeError(msg)
        invcov = data[self.inverse_covariance]
        cov_nnz = invcov.n_value
        weight_nnz = int((np.sqrt(1 + 8 * cov_nnz) - 1) // 2)
    else:
        # We are creating a new data object
        weight_nnz = 0
        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors, flagmask=self.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            if self.weights in ob.detdata:
                if len(ob.detdata[self.weights].detector_shape) == 1:
                    cur_nnz = 1
                else:
                    cur_nnz = ob.detdata[self.weights].detector_shape[1]
                weight_nnz = max(weight_nnz, cur_nnz)
            else:
                raise RuntimeError(
                    f"Stokes weights '{self.weights}' not in obs {ob.name}"
                )
        if data.comm.comm_world is not None:
            weight_nnz = data.comm.comm_world.allreduce(weight_nnz, op=MPI.MAX)
        if weight_nnz == 0:
            msg = f"No valid detectors. Could not infer the pointing matrix "
            msg += f"dimensions from the data."
            raise RuntimeError(msg)
        cov_nnz = int(weight_nnz * (weight_nnz + 1) // 2)
        data[self.inverse_covariance] = PixelData(
            dist, np.float64, n_value=cov_nnz, units=invcov_units
        )
        invcov = data[self.inverse_covariance]

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            selection=detectors, flagmask=self.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # Check that the noise model exists
        if self.noise_model not in ob:
            msg = "Noise model {} does not exist in observation {}".format(
                self.noise_model, ob.name
            )
            raise RuntimeError(msg)

        noise = ob[self.noise_model]

        # The pixels and weights view for this observation
        pix = ob.view[self.view].detdata[self.pixels]
        wts = ob.view[self.view].detdata[self.weights]
        if self.det_flags is not None:
            flgs = ob.view[self.view].detdata[self.det_flags]
        else:
            flgs = [None for x in wts]
        if self.shared_flags is not None:
            shared_flgs = ob.view[self.view].shared[self.shared_flags]
        else:
            shared_flgs = [None for x in wts]

        for det in dets:
            # Process every data view
            for pview, wview, fview, shared_fview in zip(pix, wts, flgs, shared_flgs):
                # We require that the pointing matrix has the same number of
                # non-zero elements for every detector and every observation.
                # We check that here.

                check_nnz = None
                if len(wview.detector_shape) == 1:
                    check_nnz = 1
                else:
                    check_nnz = wview.detector_shape[1]
                if check_nnz != weight_nnz:
                    msg = "observation '{}', detector '{}', pointing weights '{}' has {} nnz, not {}".format(
                        ob.name, det, self.weights, check_nnz, weight_nnz
                    )
                    raise RuntimeError(msg)

                # Get local submap and pixels
                local_sm, local_pix = dist.global_pixel_to_submap(pview[det])

                # Get the detector weight from the noise model.
                detweight = noise.detector_weight(det)

                # Samples with telescope pointing problems are already flagged in
                # the pointing operators by setting the pixel numbers to a negative
                # value.  Here we optionally apply detector flags to the local
                # pixel numbers to flag more samples.

                # Apply the flags if needed
                if self.det_flags is not None:
                    local_pix[(fview[det] & self.det_flag_mask) != 0] = -1
                if self.shared_flags is not None:
                    local_pix[(shared_fview & self.shared_flag_mask) != 0] = -1

                # Accumulate
                cov_accum_diag_invnpp(
                    dist.n_local_submap,
                    dist.n_pix_submap,
                    weight_nnz,
                    local_sm.astype(np.int64),
                    local_pix.astype(np.int64),
                    wview[det].reshape(-1),
                    detweight.to_value(invcov_units),
                    invcov.raw,
                    impl=implementation,
                    use_accel=use_accel,
                )
    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
516
517
518
519
520
521
522
523
@function_timer
def _finalize(self, data, **kwargs):
    if self.inverse_covariance in data:
        if self.sync_type == "alltoallv":
            data[self.inverse_covariance].sync_alltoallv()
        else:
            data[self.inverse_covariance].sync_allreduce()
    return

_implementations()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
545
546
547
548
549
550
551
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
541
542
543
def _provides(self):
    prov = {"global": [self.inverse_covariance]}
    return prov

_requires()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def _requires(self):
    req = {
        "global": [self.pixel_dist],
        "meta": [self.noise_model],
        "shared": list(),
        "detdata": [self.pixels, self.weights],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
553
554
555
def _supports_accel(self):
    # NOTE: the kernels called do not follow the proper pattern yet
    return False

toast.ops.BuildNoiseWeighted

Bases: Operator

Operator which builds a noise-weighted map.

Given the pointing matrix and noise model for each detector, accumulate the noise weighted map:

.. math:: Z_p = P^T N_tt'^{-1} d

Which is the timestream data waited by the diagonal time domain noise covariance and projected into pixel space. The PixelData object containing this is returned by the finalize() method.

If any samples have compromised telescope pointing, those pixel indices should have already been set to a negative value by the operator that generated the pointing matrix. Individual detector flags can optionally be applied to timesamples when accumulating data. The detector mask defaults to cutting "non-science" samples.

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
@trait_docs
class BuildNoiseWeighted(Operator):
    """Operator which builds a noise-weighted map.

    Given the pointing matrix and noise model for each detector, accumulate the noise
    weighted map:

    .. math::
        Z_p = P^T N_tt'^{-1} d

    Which is the timestream data waited by the diagonal time domain noise covariance
    and projected into pixel space.  The PixelData object containing this is returned
    by the finalize() method.

    If any samples have compromised telescope pointing, those pixel indices should
    have already been set to a negative value by the operator that generated the
    pointing matrix.  Individual detector flags can optionally be applied to
    timesamples when accumulating data.  The detector mask defaults to cutting
    "non-science" samples.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        None,
        allow_none=True,
        help="The Data key containing the submap distribution",
    )

    zmap = Unicode("zmap", help="The Data key for the output noise weighted map")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_data = Unicode(
        defaults.det_data,
        allow_none=True,
        help="Observation detdata key for the timestream data",
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    det_data_units = Unit(defaults.det_data_units, help="Desired timestream units")

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional flagging",
    )

    pixels = Unicode("pixels", help="Observation detdata key for pixel indices")

    weights = Unicode("weights", help="Observation detdata key for Stokes weights")

    noise_model = Unicode(
        "noise_model", help="Observation key containing the noise model"
    )

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("sync_type")
    def _check_sync_type(self, proposal):
        check = proposal["value"]
        if check != "allreduce" and check != "alltoallv":
            raise traitlets.TraitError("Invalid communication algorithm")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        if self.pixel_dist is None:
            raise RuntimeError(
                "You must set the 'pixel_dist' trait before calling exec()"
            )

        if self.pixel_dist not in data:
            msg = "Data does not contain submap distribution '{}'".format(
                self.pixel_dist
            )
            raise RuntimeError(msg)

        # Check that the detector data is set
        if self.det_data is None:
            raise RuntimeError("You must set the det_data trait before calling exec()")

        dist = data[self.pixel_dist]
        if data.comm.world_rank == 0:
            log.verbose(
                "Building noise weighted map with pixel_distribution {}".format(
                    self.pixel_dist
                )
            )

        detwt_units = 1.0 / (self.det_data_units**2)
        zmap_units = 1.0 / self.det_data_units

        zmap = None
        weight_nnz = None

        # This operator requires that all detectors in all observations have the same
        # number of non-zeros in the pointing matrix.

        if self.zmap in data:
            # We have an existing map from a previous call.  Verify
            # the distribution and units
            if data[self.zmap].distribution != dist:
                msg = "Existing zmap '{}' has different data distribution".format(
                    self.zmap
                )
                log.error(msg)
                raise RuntimeError(msg)
            if data[self.zmap].units != zmap_units:
                msg = f"Existing zmap '{self.zmap}' has different units"
                msg += f" ({data[self.zmap].units}) != {zmap_units}"
                log.error(msg)
                raise RuntimeError(msg)
            zmap = data[self.zmap]
            weight_nnz = zmap.n_value
        else:
            weight_nnz = 0
            for ob in data.obs:
                # Get the detectors we are using for this observation
                dets = ob.select_local_detectors(
                    selection=detectors, flagmask=self.det_mask
                )
                if len(dets) == 0:
                    # Nothing to do for this observation
                    continue
                if self.weights in ob.detdata:
                    if len(ob.detdata[self.weights].detector_shape) == 1:
                        weight_nnz = 1
                    else:
                        weight_nnz = ob.detdata[self.weights].detector_shape[1]
                else:
                    raise RuntimeError(
                        f"Stokes weights '{self.weights}' not in obs {ob.name}"
                    )
            if data.comm.comm_world is not None:
                weight_nnz = data.comm.comm_world.allreduce(weight_nnz, op=MPI.MAX)
            data[self.zmap] = PixelData(
                dist, np.float64, n_value=weight_nnz, units=zmap_units
            )
            zmap = data[self.zmap]

        if use_accel:
            if not zmap.accel_exists():
                # Does not yet exist, create it
                log.verbose_rank(
                    f"Operator {self.name} zmap not yet on device, creating",
                    comm=data.comm.comm_group,
                )
                zmap.accel_create(f"{self.name}", zero_out=True)
                zmap.accel_used(True)
            elif not zmap.accel_in_use():
                # Device copy not currently in use
                log.verbose_rank(
                    f"Operator {self.name} zmap:  copy host to device",
                    comm=data.comm.comm_group,
                )
                zmap.accel_update_device()
            else:
                log.verbose_rank(
                    f"Operator {self.name} zmap:  already in use on device",
                    comm=data.comm.comm_group,
                )
        else:
            if zmap.accel_in_use():
                # Device copy in use, but we are running on host.  Update host
                log.verbose_rank(
                    f"Operator {self.name} zmap:  update host from device",
                    comm=data.comm.comm_group,
                )
                zmap.accel_update_host()

        # # DEBUGGING
        # restore_dev = False
        # prefix="HOST"
        # if zmap.accel_in_use():
        #     zmap.accel_update_host()
        #     restore_dev = True
        #     prefix="DEVICE"
        # zmap_min = np.amin(zmap.data)
        # zmap_max = np.amax(zmap.data)
        # print(f"{prefix} {self.name} dets {detectors} starting zmap output:  min={zmap_min}, max={zmap_max}", flush=True)
        # for ism, sm in enumerate(zmap.data):
        #     for ismpix, smpix in enumerate(sm):
        #         if np.count_nonzero(smpix) > 0:
        #             print(f"{prefix} {self.name} ({ism}, {ismpix}) = {smpix}", flush=True)
        # if restore_dev:
        #     zmap.accel_update_device()

        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors, flagmask=self.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue

            # Check that the noise model exists
            if self.noise_model not in ob:
                msg = "Noise model {} does not exist in observation {}".format(
                    self.noise_model, ob.name
                )
                raise RuntimeError(msg)

            noise = ob[self.noise_model]

            # Scale factor to get timestream data into desired units.
            data_scale = unit_conversion(
                ob.detdata[self.det_data].units, self.det_data_units
            )

            # Detector inverse variance weights
            detweights = np.array(
                [noise.detector_weight(x).to_value(detwt_units) for x in dets],
                dtype=np.float64,
            )

            # Pre-multiply the detector inverse variance weights by the
            # data scaling factor, so that this combination is applied
            # in the compiled kernel below.
            detweights *= data_scale

            pix_indx = ob.detdata[self.pixels].indices(dets)
            weight_indx = ob.detdata[self.weights].indices(dets)
            data_indx = ob.detdata[self.det_data].indices(dets)

            n_weight_dets = ob.detdata[self.weights].data.shape[0]

            if self.det_flags is not None:
                flag_indx = ob.detdata[self.det_flags].indices(dets)
                flag_data = ob.detdata[self.det_flags].data
            else:
                flag_indx = np.array([-1], dtype=np.int32)
                flag_data = np.zeros((1, 1), dtype=np.uint8)

            if self.shared_flags is not None:
                shared_flag_data = ob.shared[self.shared_flags].data
            else:
                shared_flag_data = np.zeros(1, dtype=np.uint8)

            build_noise_weighted(
                zmap.distribution.global_submap_to_local,
                zmap.data,
                pix_indx,
                ob.detdata[self.pixels].data,
                weight_indx,
                ob.detdata[self.weights].data,
                data_indx,
                ob.detdata[self.det_data].data,
                flag_indx,
                flag_data,
                detweights,
                self.det_flag_mask,
                ob.intervals[self.view].data,
                shared_flag_data,
                self.shared_flag_mask,
                impl=implementation,
                use_accel=use_accel,
            )

        # # DEBUGGING
        # restore_dev = False
        # prefix="HOST"
        # if zmap.accel_in_use():
        #     zmap.accel_update_host()
        #     restore_dev = True
        #     prefix="DEVICE"
        # zmap_min = np.amin(zmap.data)
        # zmap_max = np.amax(zmap.data)
        # print(f"{prefix} {self.name} dets {detectors} ending zmap output:  min={zmap_min}, max={zmap_max}", flush=True)
        # for ism, sm in enumerate(zmap.data):
        #     for ismpix, smpix in enumerate(sm):
        #         if np.count_nonzero(smpix) > 0:
        #             print(f"{prefix} {self.name} ({ism}, {ismpix}) = {smpix}", flush=True)
        # if restore_dev:
        #     zmap.accel_update_device()

        return

    @function_timer
    def _finalize(self, data, use_accel=None, **kwargs):
        if self.zmap in data:
            log = Logger.get()
            # We have called exec() at least once
            restore_device = False
            if data[self.zmap].accel_in_use():
                log.verbose_rank(
                    f"Operator {self.name} finalize calling zmap update self",
                    comm=data.comm.comm_group,
                )
                restore_device = True
                data[self.zmap].accel_update_host()
            if self.sync_type == "alltoallv":
                data[self.zmap].sync_alltoallv()
            else:
                data[self.zmap].sync_allreduce()

            zmap_good = data[self.zmap].data[:, :, 0] != 0.0
            zmap_min = np.zeros((data[self.zmap].n_value), dtype=np.float64)
            zmap_max = np.zeros((data[self.zmap].n_value), dtype=np.float64)
            if np.count_nonzero(zmap_good) > 0:
                zmap_min[:] = np.amin(data[self.zmap].data[zmap_good, :], axis=0)
                zmap_max[:] = np.amax(data[self.zmap].data[zmap_good, :], axis=0)
            all_zmap_min = np.zeros_like(zmap_min)
            all_zmap_max = np.zeros_like(zmap_max)
            if data.comm.comm_world is not None:
                data.comm.comm_world.Reduce(zmap_min, all_zmap_min, op=MPI.MIN, root=0)
                data.comm.comm_world.Reduce(zmap_max, all_zmap_max, op=MPI.MAX, root=0)
            if data.comm.world_rank == 0:
                msg = f"  Noise-weighted map pixel value range:\n"
                for m in range(data[self.zmap].n_value):
                    msg += f"    map {m} {zmap_min[m]:1.3e} ... {zmap_max[m]:1.3e}"
                log.debug(msg)

            if restore_device:
                log.verbose_rank(
                    f"Operator {self.name} finalize calling zmap update device",
                    comm=data.comm.comm_group,
                )
                data[self.zmap].accel_update_device()
        return

    def _requires(self):
        req = {
            "global": [self.pixel_dist],
            "meta": [self.noise_model],
            "shared": list(),
            "detdata": [self.pixels, self.weights, self.det_data],
            "intervals": list(),
        }
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        prov = {
            "global": [self.zmap],
        }
        return prov

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        return True

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, allow_none=True, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Desired timestream units') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='Observation key containing the noise model') class-attribute instance-attribute

pixel_dist = Unicode(None, allow_none=True, help='The Data key containing the submap distribution') class-attribute instance-attribute

pixels = Unicode('pixels', help='Observation detdata key for pixel indices') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

weights = Unicode('weights', help='Observation detdata key for Stokes weights') class-attribute instance-attribute

zmap = Unicode('zmap', help='The Data key for the output noise weighted map') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
664
665
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
643
644
645
646
647
648
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
650
651
652
653
654
655
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_sync_type(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
657
658
659
660
661
662
@traitlets.validate("sync_type")
def _check_sync_type(self, proposal):
    check = proposal["value"]
    if check != "allreduce" and check != "alltoallv":
        raise traitlets.TraitError("Invalid communication algorithm")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    if self.pixel_dist is None:
        raise RuntimeError(
            "You must set the 'pixel_dist' trait before calling exec()"
        )

    if self.pixel_dist not in data:
        msg = "Data does not contain submap distribution '{}'".format(
            self.pixel_dist
        )
        raise RuntimeError(msg)

    # Check that the detector data is set
    if self.det_data is None:
        raise RuntimeError("You must set the det_data trait before calling exec()")

    dist = data[self.pixel_dist]
    if data.comm.world_rank == 0:
        log.verbose(
            "Building noise weighted map with pixel_distribution {}".format(
                self.pixel_dist
            )
        )

    detwt_units = 1.0 / (self.det_data_units**2)
    zmap_units = 1.0 / self.det_data_units

    zmap = None
    weight_nnz = None

    # This operator requires that all detectors in all observations have the same
    # number of non-zeros in the pointing matrix.

    if self.zmap in data:
        # We have an existing map from a previous call.  Verify
        # the distribution and units
        if data[self.zmap].distribution != dist:
            msg = "Existing zmap '{}' has different data distribution".format(
                self.zmap
            )
            log.error(msg)
            raise RuntimeError(msg)
        if data[self.zmap].units != zmap_units:
            msg = f"Existing zmap '{self.zmap}' has different units"
            msg += f" ({data[self.zmap].units}) != {zmap_units}"
            log.error(msg)
            raise RuntimeError(msg)
        zmap = data[self.zmap]
        weight_nnz = zmap.n_value
    else:
        weight_nnz = 0
        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors, flagmask=self.det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            if self.weights in ob.detdata:
                if len(ob.detdata[self.weights].detector_shape) == 1:
                    weight_nnz = 1
                else:
                    weight_nnz = ob.detdata[self.weights].detector_shape[1]
            else:
                raise RuntimeError(
                    f"Stokes weights '{self.weights}' not in obs {ob.name}"
                )
        if data.comm.comm_world is not None:
            weight_nnz = data.comm.comm_world.allreduce(weight_nnz, op=MPI.MAX)
        data[self.zmap] = PixelData(
            dist, np.float64, n_value=weight_nnz, units=zmap_units
        )
        zmap = data[self.zmap]

    if use_accel:
        if not zmap.accel_exists():
            # Does not yet exist, create it
            log.verbose_rank(
                f"Operator {self.name} zmap not yet on device, creating",
                comm=data.comm.comm_group,
            )
            zmap.accel_create(f"{self.name}", zero_out=True)
            zmap.accel_used(True)
        elif not zmap.accel_in_use():
            # Device copy not currently in use
            log.verbose_rank(
                f"Operator {self.name} zmap:  copy host to device",
                comm=data.comm.comm_group,
            )
            zmap.accel_update_device()
        else:
            log.verbose_rank(
                f"Operator {self.name} zmap:  already in use on device",
                comm=data.comm.comm_group,
            )
    else:
        if zmap.accel_in_use():
            # Device copy in use, but we are running on host.  Update host
            log.verbose_rank(
                f"Operator {self.name} zmap:  update host from device",
                comm=data.comm.comm_group,
            )
            zmap.accel_update_host()

    # # DEBUGGING
    # restore_dev = False
    # prefix="HOST"
    # if zmap.accel_in_use():
    #     zmap.accel_update_host()
    #     restore_dev = True
    #     prefix="DEVICE"
    # zmap_min = np.amin(zmap.data)
    # zmap_max = np.amax(zmap.data)
    # print(f"{prefix} {self.name} dets {detectors} starting zmap output:  min={zmap_min}, max={zmap_max}", flush=True)
    # for ism, sm in enumerate(zmap.data):
    #     for ismpix, smpix in enumerate(sm):
    #         if np.count_nonzero(smpix) > 0:
    #             print(f"{prefix} {self.name} ({ism}, {ismpix}) = {smpix}", flush=True)
    # if restore_dev:
    #     zmap.accel_update_device()

    for ob in data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            selection=detectors, flagmask=self.det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue

        # Check that the noise model exists
        if self.noise_model not in ob:
            msg = "Noise model {} does not exist in observation {}".format(
                self.noise_model, ob.name
            )
            raise RuntimeError(msg)

        noise = ob[self.noise_model]

        # Scale factor to get timestream data into desired units.
        data_scale = unit_conversion(
            ob.detdata[self.det_data].units, self.det_data_units
        )

        # Detector inverse variance weights
        detweights = np.array(
            [noise.detector_weight(x).to_value(detwt_units) for x in dets],
            dtype=np.float64,
        )

        # Pre-multiply the detector inverse variance weights by the
        # data scaling factor, so that this combination is applied
        # in the compiled kernel below.
        detweights *= data_scale

        pix_indx = ob.detdata[self.pixels].indices(dets)
        weight_indx = ob.detdata[self.weights].indices(dets)
        data_indx = ob.detdata[self.det_data].indices(dets)

        n_weight_dets = ob.detdata[self.weights].data.shape[0]

        if self.det_flags is not None:
            flag_indx = ob.detdata[self.det_flags].indices(dets)
            flag_data = ob.detdata[self.det_flags].data
        else:
            flag_indx = np.array([-1], dtype=np.int32)
            flag_data = np.zeros((1, 1), dtype=np.uint8)

        if self.shared_flags is not None:
            shared_flag_data = ob.shared[self.shared_flags].data
        else:
            shared_flag_data = np.zeros(1, dtype=np.uint8)

        build_noise_weighted(
            zmap.distribution.global_submap_to_local,
            zmap.data,
            pix_indx,
            ob.detdata[self.pixels].data,
            weight_indx,
            ob.detdata[self.weights].data,
            data_indx,
            ob.detdata[self.det_data].data,
            flag_indx,
            flag_data,
            detweights,
            self.det_flag_mask,
            ob.intervals[self.view].data,
            shared_flag_data,
            self.shared_flag_mask,
            impl=implementation,
            use_accel=use_accel,
        )

    # # DEBUGGING
    # restore_dev = False
    # prefix="HOST"
    # if zmap.accel_in_use():
    #     zmap.accel_update_host()
    #     restore_dev = True
    #     prefix="DEVICE"
    # zmap_min = np.amin(zmap.data)
    # zmap_max = np.amax(zmap.data)
    # print(f"{prefix} {self.name} dets {detectors} ending zmap output:  min={zmap_min}, max={zmap_max}", flush=True)
    # for ism, sm in enumerate(zmap.data):
    #     for ismpix, smpix in enumerate(sm):
    #         if np.count_nonzero(smpix) > 0:
    #             print(f"{prefix} {self.name} ({ism}, {ismpix}) = {smpix}", flush=True)
    # if restore_dev:
    #     zmap.accel_update_device()

    return

_finalize(data, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
@function_timer
def _finalize(self, data, use_accel=None, **kwargs):
    if self.zmap in data:
        log = Logger.get()
        # We have called exec() at least once
        restore_device = False
        if data[self.zmap].accel_in_use():
            log.verbose_rank(
                f"Operator {self.name} finalize calling zmap update self",
                comm=data.comm.comm_group,
            )
            restore_device = True
            data[self.zmap].accel_update_host()
        if self.sync_type == "alltoallv":
            data[self.zmap].sync_alltoallv()
        else:
            data[self.zmap].sync_allreduce()

        zmap_good = data[self.zmap].data[:, :, 0] != 0.0
        zmap_min = np.zeros((data[self.zmap].n_value), dtype=np.float64)
        zmap_max = np.zeros((data[self.zmap].n_value), dtype=np.float64)
        if np.count_nonzero(zmap_good) > 0:
            zmap_min[:] = np.amin(data[self.zmap].data[zmap_good, :], axis=0)
            zmap_max[:] = np.amax(data[self.zmap].data[zmap_good, :], axis=0)
        all_zmap_min = np.zeros_like(zmap_min)
        all_zmap_max = np.zeros_like(zmap_max)
        if data.comm.comm_world is not None:
            data.comm.comm_world.Reduce(zmap_min, all_zmap_min, op=MPI.MIN, root=0)
            data.comm.comm_world.Reduce(zmap_max, all_zmap_max, op=MPI.MAX, root=0)
        if data.comm.world_rank == 0:
            msg = f"  Noise-weighted map pixel value range:\n"
            for m in range(data[self.zmap].n_value):
                msg += f"    map {m} {zmap_min[m]:1.3e} ... {zmap_max[m]:1.3e}"
            log.debug(msg)

        if restore_device:
            log.verbose_rank(
                f"Operator {self.name} finalize calling zmap update device",
                comm=data.comm.comm_group,
            )
            data[self.zmap].accel_update_device()
    return

_implementations()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
951
952
953
954
955
956
957
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
945
946
947
948
949
def _provides(self):
    prov = {
        "global": [self.zmap],
    }
    return prov

_requires()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
def _requires(self):
    req = {
        "global": [self.pixel_dist],
        "meta": [self.noise_model],
        "shared": list(),
        "detdata": [self.pixels, self.weights, self.det_data],
        "intervals": list(),
    }
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
959
960
def _supports_accel(self):
    return True

toast.ops.CovarianceAndHits

Bases: Operator

Operator which builds the pixel-space diagonal noise covariance and hit map.

Frequently the first step in map making is to determine what pixels on the sky have been covered and build the diagonal noise covariance. During the construction of the covariance we can cut pixels that are poorly conditioned.

This operator runs the pointing operator and builds the PixelDist instance describing how submaps are distributed among processes. It builds the hit map and the inverse covariance and then inverts this with a threshold on the condition number in each pixel. The detector flag mask defaults to cutting "non-science" samples.

NOTE: The pixel pointing operator must have the "pixels", "create_dist" traits, which will be set by this operator during execution.

Output PixelData objects are stored in the Data dictionary.

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
@trait_docs
class CovarianceAndHits(Operator):
    """Operator which builds the pixel-space diagonal noise covariance and hit map.

    Frequently the first step in map making is to determine what pixels on the sky
    have been covered and build the diagonal noise covariance.  During the construction
    of the covariance we can cut pixels that are poorly conditioned.

    This operator runs the pointing operator and builds the PixelDist instance
    describing how submaps are distributed among processes.  It builds the hit map
    and the inverse covariance and then inverts this with a threshold on the condition
    number in each pixel.  The detector flag mask defaults to cutting "non-science"
    samples.

    NOTE:  The pixel pointing operator must have the "pixels", "create_dist"
    traits, which will be set by this operator during execution.

    Output PixelData objects are stored in the Data dictionary.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        "pixel_dist",
        help="The Data key where the PixelDist object should be stored",
    )

    covariance = Unicode(
        "covariance",
        help="The Data key where the covariance should be stored",
    )

    inverse_covariance = Unicode(
        None,
        allow_none=True,
        help="The Data key where the inverse covariance should be stored",
    )

    hits = Unicode(
        "hits",
        help="The Data key where the hit map should be stored",
    )

    rcond = Unicode(
        "rcond",
        help="The Data key where the reciprocal condition number should be stored",
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    det_data_units = Unit(defaults.det_data_units, help="Desired timestream units")

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional flagging",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pointing operator",
    )

    stokes_weights = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a Stokes weights operator",
    )

    noise_model = Unicode(
        "noise_model", help="Observation key containing the noise model"
    )

    rcond_threshold = Float(
        1.0e-8, help="Minimum value for inverse condition number cut."
    )

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    save_pointing = Bool(
        False, help="If True, do not clear detector pointing matrices after use"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("sync_type")
    def _check_sync_type(self, proposal):
        check = proposal["value"]
        if check != "allreduce" and check != "alltoallv":
            raise traitlets.TraitError("Invalid communication algorithm")
        return check

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pixels = proposal["value"]
        if pixels is not None:
            if not isinstance(pixels, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pixels.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pixels

    @traitlets.validate("stokes_weights")
    def _check_stokes_weights(self, proposal):
        weights = proposal["value"]
        if weights is not None:
            if not isinstance(weights, Operator):
                raise traitlets.TraitError(
                    "stokes_weights should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["weights", "view"]:
                if not weights.has_trait(trt):
                    msg = f"stokes_weights operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return weights

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        for trait in "pixel_pointing", "stokes_weights":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        # Set pointing flags
        self.pixel_pointing.detector_pointing.det_mask = self.det_mask
        self.pixel_pointing.detector_pointing.det_flag_mask = self.det_flag_mask
        if hasattr(self.stokes_weights, "detector_pointing"):
            self.stokes_weights.detector_pointing.det_mask = self.det_mask
            self.stokes_weights.detector_pointing.det_flag_mask = self.det_flag_mask

        # Construct the pointing distribution if it does not already exist

        if self.pixel_dist not in data:
            pix_dist = BuildPixelDistribution(
                pixel_dist=self.pixel_dist,
                pixel_pointing=self.pixel_pointing,
                save_pointing=self.save_pointing,
            )
            pix_dist.apply(data)

        # Check if map domain products exist and are consistent.  The hits
        # and inverse covariance accumulation operators support multiple
        # calls to exec() to accumulate data.  But in this convenience
        # function we are explicitly accumulating in one-shot.  This means
        # that any existing data products must be set to zero.

        if self.hits in data:
            if data[self.hits].distribution == data[self.pixel_dist]:
                # Distributions are equal, just set to zero
                data[self.hits].reset()
            else:
                # Inconsistent- delete it so that it will be re-created.
                del data[self.hits]
        if self.covariance in data:
            if data[self.covariance].distribution == data[self.pixel_dist]:
                # Distribution matches, set to zero and update units
                data[self.covariance].reset()
                invcov_units = 1.0 / (self.det_data_units**2)
                data[self.covariance].update_units(invcov_units)
            else:
                del data[self.covariance]

        # Hit map operator

        build_hits = BuildHitMap(
            pixel_dist=self.pixel_dist,
            hits=self.hits,
            view=self.pixel_pointing.view,
            pixels=self.pixel_pointing.pixels,
            det_mask=self.det_mask,
            det_flags=self.det_flags,
            det_flag_mask=self.det_flag_mask,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            sync_type=self.sync_type,
        )

        # Inverse covariance.  Note that we save the output to our specified
        # "covariance" key, because we are going to invert it in-place.

        build_invcov = BuildInverseCovariance(
            pixel_dist=self.pixel_dist,
            inverse_covariance=self.covariance,
            view=self.pixel_pointing.view,
            pixels=self.pixel_pointing.pixels,
            weights=self.stokes_weights.weights,
            noise_model=self.noise_model,
            det_data_units=self.det_data_units,
            det_mask=self.det_mask,
            det_flags=self.det_flags,
            det_flag_mask=self.det_flag_mask,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            sync_type=self.sync_type,
        )

        # Build a pipeline to expand pointing and accumulate

        accum = None
        if self.save_pointing:
            # Process all detectors at once
            accum = Pipeline(detector_sets=["ALL"])
        else:
            # Process one detector at a time.
            accum = Pipeline(detector_sets=["SINGLE"])
        accum.operators = [
            self.pixel_pointing,
            self.stokes_weights,
            build_hits,
            build_invcov,
        ]

        pipe_out = accum.apply(data, detectors=detectors)

        # Optionally, store the inverse covariance
        if self.inverse_covariance is not None:
            if self.inverse_covariance in data:
                del data[self.inverse_covariance]
            data[self.inverse_covariance] = data[self.covariance].duplicate()

        # Extract the results
        hits = data[self.hits]
        cov = data[self.covariance]

        # Invert the covariance in place
        rcond = PixelData(cov.distribution, np.float64, n_value=1)
        covariance_invert(
            cov,
            self.rcond_threshold,
            rcond=rcond,
            use_alltoallv=(self.sync_type == "alltoallv"),
        )

        rcond_good = rcond.data[:, :, 0] > 0.0
        rcond_min = 0.0
        rcond_max = 0.0
        if np.count_nonzero(rcond_good) > 0:
            rcond_min = np.amin(rcond.data[rcond_good, 0])
            rcond_max = np.amax(rcond.data[rcond_good, 0])
        if data.comm.comm_world is not None:
            rcond_min = data.comm.comm_world.reduce(rcond_min, root=0, op=MPI.MIN)
            rcond_max = data.comm.comm_world.reduce(rcond_max, root=0, op=MPI.MAX)
        if data.comm.world_rank == 0:
            msg = f"  Pixel covariance condition number range = "
            msg += f"{rcond_min:1.3e} ... {rcond_max:1.3e}"
            log.debug(msg)

        # Store rcond
        data[self.rcond] = rcond

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.pixel_pointing.requires()
        req.update(self.stokes_weights.requires())
        req["meta"].append(self.noise_model)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        return req

    def _provides(self):
        prov = {
            "global": [self.pixel_dist, self.hits, self.covariance, self.rcond],
            "shared": list(),
            "detdata": list(),
        }
        if self.save_pointing:
            prov["detdata"].extend([self.pixels, self.weights])
        if self.inverse_covariance is not None:
            prov["global"].append(self.inverse_covariance)
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

covariance = Unicode('covariance', help='The Data key where the covariance should be stored') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Desired timestream units') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

hits = Unicode('hits', help='The Data key where the hit map should be stored') class-attribute instance-attribute

inverse_covariance = Unicode(None, allow_none=True, help='The Data key where the inverse covariance should be stored') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='Observation key containing the noise model') class-attribute instance-attribute

pixel_dist = Unicode('pixel_dist', help='The Data key where the PixelDist object should be stored') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pointing operator') class-attribute instance-attribute

rcond = Unicode('rcond', help='The Data key where the reciprocal condition number should be stored') class-attribute instance-attribute

rcond_threshold = Float(1e-08, help='Minimum value for inverse condition number cut.') class-attribute instance-attribute

save_pointing = Bool(False, help='If True, do not clear detector pointing matrices after use') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

stokes_weights = Instance(klass=Operator, allow_none=True, help='This must be an instance of a Stokes weights operator') class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1129
1130
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1071
1072
1073
1074
1075
1076
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1078
1079
1080
1081
1082
1083
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_pixel_pointing(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pixels = proposal["value"]
    if pixels is not None:
        if not isinstance(pixels, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pixels.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pixels

_check_shared_mask(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1085
1086
1087
1088
1089
1090
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_check_stokes_weights(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
@traitlets.validate("stokes_weights")
def _check_stokes_weights(self, proposal):
    weights = proposal["value"]
    if weights is not None:
        if not isinstance(weights, Operator):
            raise traitlets.TraitError(
                "stokes_weights should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["weights", "view"]:
            if not weights.has_trait(trt):
                msg = f"stokes_weights operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return weights

_check_sync_type(proposal)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1092
1093
1094
1095
1096
1097
@traitlets.validate("sync_type")
def _check_sync_type(self, proposal):
    check = proposal["value"]
    if check != "allreduce" and check != "alltoallv":
        raise traitlets.TraitError("Invalid communication algorithm")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    for trait in "pixel_pointing", "stokes_weights":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    # Set pointing flags
    self.pixel_pointing.detector_pointing.det_mask = self.det_mask
    self.pixel_pointing.detector_pointing.det_flag_mask = self.det_flag_mask
    if hasattr(self.stokes_weights, "detector_pointing"):
        self.stokes_weights.detector_pointing.det_mask = self.det_mask
        self.stokes_weights.detector_pointing.det_flag_mask = self.det_flag_mask

    # Construct the pointing distribution if it does not already exist

    if self.pixel_dist not in data:
        pix_dist = BuildPixelDistribution(
            pixel_dist=self.pixel_dist,
            pixel_pointing=self.pixel_pointing,
            save_pointing=self.save_pointing,
        )
        pix_dist.apply(data)

    # Check if map domain products exist and are consistent.  The hits
    # and inverse covariance accumulation operators support multiple
    # calls to exec() to accumulate data.  But in this convenience
    # function we are explicitly accumulating in one-shot.  This means
    # that any existing data products must be set to zero.

    if self.hits in data:
        if data[self.hits].distribution == data[self.pixel_dist]:
            # Distributions are equal, just set to zero
            data[self.hits].reset()
        else:
            # Inconsistent- delete it so that it will be re-created.
            del data[self.hits]
    if self.covariance in data:
        if data[self.covariance].distribution == data[self.pixel_dist]:
            # Distribution matches, set to zero and update units
            data[self.covariance].reset()
            invcov_units = 1.0 / (self.det_data_units**2)
            data[self.covariance].update_units(invcov_units)
        else:
            del data[self.covariance]

    # Hit map operator

    build_hits = BuildHitMap(
        pixel_dist=self.pixel_dist,
        hits=self.hits,
        view=self.pixel_pointing.view,
        pixels=self.pixel_pointing.pixels,
        det_mask=self.det_mask,
        det_flags=self.det_flags,
        det_flag_mask=self.det_flag_mask,
        shared_flags=self.shared_flags,
        shared_flag_mask=self.shared_flag_mask,
        sync_type=self.sync_type,
    )

    # Inverse covariance.  Note that we save the output to our specified
    # "covariance" key, because we are going to invert it in-place.

    build_invcov = BuildInverseCovariance(
        pixel_dist=self.pixel_dist,
        inverse_covariance=self.covariance,
        view=self.pixel_pointing.view,
        pixels=self.pixel_pointing.pixels,
        weights=self.stokes_weights.weights,
        noise_model=self.noise_model,
        det_data_units=self.det_data_units,
        det_mask=self.det_mask,
        det_flags=self.det_flags,
        det_flag_mask=self.det_flag_mask,
        shared_flags=self.shared_flags,
        shared_flag_mask=self.shared_flag_mask,
        sync_type=self.sync_type,
    )

    # Build a pipeline to expand pointing and accumulate

    accum = None
    if self.save_pointing:
        # Process all detectors at once
        accum = Pipeline(detector_sets=["ALL"])
    else:
        # Process one detector at a time.
        accum = Pipeline(detector_sets=["SINGLE"])
    accum.operators = [
        self.pixel_pointing,
        self.stokes_weights,
        build_hits,
        build_invcov,
    ]

    pipe_out = accum.apply(data, detectors=detectors)

    # Optionally, store the inverse covariance
    if self.inverse_covariance is not None:
        if self.inverse_covariance in data:
            del data[self.inverse_covariance]
        data[self.inverse_covariance] = data[self.covariance].duplicate()

    # Extract the results
    hits = data[self.hits]
    cov = data[self.covariance]

    # Invert the covariance in place
    rcond = PixelData(cov.distribution, np.float64, n_value=1)
    covariance_invert(
        cov,
        self.rcond_threshold,
        rcond=rcond,
        use_alltoallv=(self.sync_type == "alltoallv"),
    )

    rcond_good = rcond.data[:, :, 0] > 0.0
    rcond_min = 0.0
    rcond_max = 0.0
    if np.count_nonzero(rcond_good) > 0:
        rcond_min = np.amin(rcond.data[rcond_good, 0])
        rcond_max = np.amax(rcond.data[rcond_good, 0])
    if data.comm.comm_world is not None:
        rcond_min = data.comm.comm_world.reduce(rcond_min, root=0, op=MPI.MIN)
        rcond_max = data.comm.comm_world.reduce(rcond_max, root=0, op=MPI.MAX)
    if data.comm.world_rank == 0:
        msg = f"  Pixel covariance condition number range = "
        msg += f"{rcond_min:1.3e} ... {rcond_max:1.3e}"
        log.debug(msg)

    # Store rcond
    data[self.rcond] = rcond

    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1270
1271
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
def _provides(self):
    prov = {
        "global": [self.pixel_dist, self.hits, self.covariance, self.rcond],
        "shared": list(),
        "detdata": list(),
    }
    if self.save_pointing:
        prov["detdata"].extend([self.pixels, self.weights])
    if self.inverse_covariance is not None:
        prov["global"].append(self.inverse_covariance)
    return prov

_requires()

Source code in toast/ops/mapmaker_utils/mapmaker_utils.py
1273
1274
1275
1276
1277
1278
1279
1280
1281
def _requires(self):
    req = self.pixel_pointing.requires()
    req.update(self.stokes_weights.requires())
    req["meta"].append(self.noise_model)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    return req

toast.ops.NoiseWeight

Bases: Operator

Apply diagonal noise weighting to detector data.

This simple operator takes the detector weight from the specified noise model and applies it to the timestream values. We ignore all detector flags in this operator, since there is no harm in multiplying the noise weight by values in invalid samples.

Source code in toast/ops/noise_weight/noise_weight.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@trait_docs
class NoiseWeight(Operator):
    """Apply diagonal noise weighting to detector data.

    This simple operator takes the detector weight from the specified noise model and
    applies it to the timestream values.  We ignore all detector flags in this operator,
    since there is no harm in multiplying the noise weight by values in invalid samples.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    noise_model = Unicode(
        "noise_model", help="The observation key containing the noise model"
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_data = Unicode(
        None, allow_none=True, help="Observation detdata key for the timestream data"
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_flag_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for detector sample flagging",
    )

    det_data_units = Unit(
        defaults.det_data_units, help="Output units if creating detector data"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        for ob in data.obs:
            if self.det_data not in ob.detdata:
                continue
            data_input_units = self.det_data_units
            data_invcov_units = 1.0 / data_input_units**2
            data_output_units = 1.0 / data_input_units

            dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
            if len(dets) == 0:
                # Nothing to do for this observation, but
                # update the units of the output
                ob.detdata[self.det_data].update_units(data_output_units)
                continue

            # Check that the noise model exists
            if self.noise_model not in ob:
                msg = "Noise model {} does not exist in observation {}".format(
                    self.noise_model, ob.name
                )
                raise RuntimeError(msg)

            # Compute the noise for each detector (using the correct units)
            noise = ob[self.noise_model]
            detector_weights = np.array(
                [
                    noise.detector_weight(detector).to(data_invcov_units).value
                    for detector in dets
                ],
                dtype=np.float64,
            )

            if ob.detdata[self.det_data].units != data_input_units:
                msg = f"obs {ob.name} detdata {self.det_data}"
                msg += f" does not have units of {data_input_units}"
                msg += f" before noise weighting"
                log.error(msg)
                raise RuntimeError(msg)

            # Multiply detectors by their respective noise weight
            intervals = ob.intervals[self.view].data
            det_data = ob.detdata[self.det_data].data
            det_data_indx = ob.detdata[self.det_data].indices(dets)
            noise_weight(
                det_data,
                det_data_indx,
                intervals,
                detector_weights,
                impl=implementation,
                use_accel=use_accel,
            )

            # Update the units of the output
            ob.detdata[self.det_data].update_units(data_output_units)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": [self.noise_model],
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.view is not None:
            req["intervals"].append(self.view)
        return req

    def _provides(self):
        return {"detdata": [self.det_data]}

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        return True

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

det_data = Unicode(None, allow_none=True, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Output units if creating detector data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_invalid, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='The observation key containing the noise model') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/noise_weight/noise_weight.py
72
73
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/noise_weight/noise_weight.py
58
59
60
61
62
63
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/noise_weight/noise_weight.py
65
66
67
68
69
70
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/noise_weight/noise_weight.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    for ob in data.obs:
        if self.det_data not in ob.detdata:
            continue
        data_input_units = self.det_data_units
        data_invcov_units = 1.0 / data_input_units**2
        data_output_units = 1.0 / data_input_units

        dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
        if len(dets) == 0:
            # Nothing to do for this observation, but
            # update the units of the output
            ob.detdata[self.det_data].update_units(data_output_units)
            continue

        # Check that the noise model exists
        if self.noise_model not in ob:
            msg = "Noise model {} does not exist in observation {}".format(
                self.noise_model, ob.name
            )
            raise RuntimeError(msg)

        # Compute the noise for each detector (using the correct units)
        noise = ob[self.noise_model]
        detector_weights = np.array(
            [
                noise.detector_weight(detector).to(data_invcov_units).value
                for detector in dets
            ],
            dtype=np.float64,
        )

        if ob.detdata[self.det_data].units != data_input_units:
            msg = f"obs {ob.name} detdata {self.det_data}"
            msg += f" does not have units of {data_input_units}"
            msg += f" before noise weighting"
            log.error(msg)
            raise RuntimeError(msg)

        # Multiply detectors by their respective noise weight
        intervals = ob.intervals[self.view].data
        det_data = ob.detdata[self.det_data].data
        det_data_indx = ob.detdata[self.det_data].indices(dets)
        noise_weight(
            det_data,
            det_data_indx,
            intervals,
            detector_weights,
            impl=implementation,
            use_accel=use_accel,
        )

        # Update the units of the output
        ob.detdata[self.det_data].update_units(data_output_units)

    return

_finalize(data, **kwargs)

Source code in toast/ops/noise_weight/noise_weight.py
138
139
def _finalize(self, data, **kwargs):
    return

_implementations()

Source code in toast/ops/noise_weight/noise_weight.py
154
155
156
157
158
159
160
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_provides()

Source code in toast/ops/noise_weight/noise_weight.py
151
152
def _provides(self):
    return {"detdata": [self.det_data]}

_requires()

Source code in toast/ops/noise_weight/noise_weight.py
141
142
143
144
145
146
147
148
149
def _requires(self):
    req = {
        "meta": [self.noise_model],
        "detdata": [self.det_data],
        "intervals": list(),
    }
    if self.view is not None:
        req["intervals"].append(self.view)
    return req

_supports_accel()

Source code in toast/ops/noise_weight/noise_weight.py
162
163
def _supports_accel(self):
    return True

toast.ops.BinMap

Bases: Operator

Operator which bins a map.

Given a noise model and a pointing operator, build the noise weighted map and apply the noise covariance to get resulting binned map.

Source code in toast/ops/mapmaker_binning.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
@trait_docs
class BinMap(Operator):
    """Operator which bins a map.

    Given a noise model and a pointing operator, build the noise weighted map and
    apply the noise covariance to get resulting binned map.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    pixel_dist = Unicode(
        "pixel_dist",
        help="The Data key where the PixelDist object should be stored",
    )

    covariance = Unicode(
        "covariance",
        help="The Data key containing the noise covariance PixelData instance",
    )

    binned = Unicode(
        "binned",
        help="The Data key where the binned map should be stored",
    )

    noiseweighted = Unicode(
        None,
        allow_none=True,
        help="The Data key where the noiseweighted map should be stored",
    )

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    det_data_units = Unit(defaults.det_data_units, help="Desired timestream units")

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional telescope flagging",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pixel pointing operator",
    )

    stokes_weights = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a Stokes weights operator",
    )

    pre_process = Instance(
        klass=Operator,
        allow_none=True,
        help="Optional extra operator to run prior to binning",
    )

    noise_model = Unicode(
        defaults.noise_model, help="Observation key containing the noise model"
    )

    sync_type = Unicode(
        "alltoallv", help="Communication algorithm: 'allreduce' or 'alltoallv'"
    )

    full_pointing = Bool(
        False, help="If True, expand pointing for all detectors and save"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("sync_type")
    def _check_sync_type(self, proposal):
        check = proposal["value"]
        if check != "allreduce" and check != "alltoallv":
            raise traitlets.TraitError("Invalid communication algorithm")
        return check

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pixels = proposal["value"]
        if pixels is not None:
            if not isinstance(pixels, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pixels.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pixels

    @traitlets.validate("stokes_weights")
    def _check_stokes_weights(self, proposal):
        weights = proposal["value"]
        if weights is not None:
            if not isinstance(weights, Operator):
                raise traitlets.TraitError(
                    "stokes_weights should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["weights", "view"]:
                if not weights.has_trait(trt):
                    msg = f"stokes_weights operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return weights

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()
        timer = Timer()
        timer.start()

        for trait in "pixel_pointing", "stokes_weights", "det_data":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        log.verbose_rank("  BinMap building pipeline", comm=data.comm.comm_world)

        if self.covariance not in data:
            msg = f"Data does not contain noise covariance '{self.covariance}'"
            log.error(msg)
            raise RuntimeError(msg)

        cov = data[self.covariance]

        # Check that covariance has consistent units
        if cov.units != (self.det_data_units**2).decompose():
            msg = f"Covariance '{self.covariance}' units {cov.units} do not"
            msg += f" equal det_data units ({self.det_data_units}) squared."
            log.error(msg)
            raise RuntimeError(msg)

        # Sanity check that the covariance pixel distribution agrees
        if cov.distribution != data[self.pixel_dist]:
            msg = (
                f"Pixel distribution '{self.pixel_dist}' does not match the one "
                f"used by covariance '{self.covariance}'"
            )
            log.error(msg)
            raise RuntimeError(msg)

        # Set outputs of the pointing operator

        self.pixel_pointing.create_dist = None

        # If the binned map already exists in the data, verify the distribution and
        # reset to zero.

        if self.binned in data:
            if data[self.binned].distribution != data[self.pixel_dist]:
                msg = (
                    f"Pixel distribution '{self.pixel_dist}' does not match "
                    f"existing binned map '{self.binned}'"
                )
                log.error(msg)
                raise RuntimeError(msg)
            data[self.binned].reset()
            data[self.binned].update_units(1.0 / self.det_data_units)

        # Use the same detector mask in the pointing
        self.pixel_pointing.detector_pointing.det_mask = self.det_mask
        self.pixel_pointing.detector_pointing.det_flag_mask = self.det_flag_mask
        if hasattr(self.stokes_weights, "detector_pointing"):
            self.stokes_weights.detector_pointing.det_mask = self.det_mask
            self.stokes_weights.detector_pointing.det_flag_mask = self.det_flag_mask

        # Noise weighted map.  We output this to the final binned map location,
        # since we will multiply by the covariance in-place.

        build_zmap = BuildNoiseWeighted(
            pixel_dist=self.pixel_dist,
            zmap=self.binned,
            view=self.pixel_pointing.view,
            pixels=self.pixel_pointing.pixels,
            weights=self.stokes_weights.weights,
            noise_model=self.noise_model,
            det_data=self.det_data,
            det_data_units=self.det_data_units,
            det_mask=self.det_mask,
            det_flags=self.det_flags,
            det_flag_mask=self.det_flag_mask,
            shared_flags=self.shared_flags,
            shared_flag_mask=self.shared_flag_mask,
            sync_type=self.sync_type,
        )

        # Build a pipeline to expand pointing and accumulate

        accum = None
        accum_ops = list()
        if self.pre_process is not None:
            accum_ops.append(self.pre_process)
        if self.full_pointing:
            # Process all detectors at once
            accum = Pipeline(detector_sets=["ALL"])
        else:
            # Process one detector at a time.
            accum = Pipeline(detector_sets=["SINGLE"])
        accum_ops.extend([self.pixel_pointing, self.stokes_weights, build_zmap])

        accum.operators = accum_ops

        if data.comm.world_rank == 0:
            log.verbose("  BinMap running pipeline")
        pipe_out = accum.apply(data, detectors=detectors)

        # print("Binned zmap = ", data[self.binned].data)

        # Optionally, store the noise-weighted map
        if self.noiseweighted is not None:
            data[self.noiseweighted] = data[self.binned].duplicate()

        # Extract the results
        binned_map = data[self.binned]

        # Apply the covariance in place
        if data.comm.world_rank == 0:
            log.verbose("  BinMap applying covariance")
        covariance_apply(cov, binned_map, use_alltoallv=(self.sync_type == "alltoallv"))
        # print("Binned final = ", data[self.binned].data)
        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = self.pixel_pointing.requires()
        req.update(self.stokes_weights.requires())
        if self.pre_process is not None:
            req.update(self.pre_process.requires())
        req["global"].extend([self.pixel_dist, self.covariance])
        req["meta"].extend([self.noise_model])
        req["detdata"].extend([self.det_data])
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        if self.pre_process is not None:
            req.update(self.pre_process.requires())
        return req

    def _provides(self):
        prov = {"global": [self.binned]}
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

binned = Unicode('binned', help='The Data key where the binned map should be stored') class-attribute instance-attribute

covariance = Unicode('covariance', help='The Data key containing the noise covariance PixelData instance') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Desired timestream units') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

full_pointing = Bool(False, help='If True, expand pointing for all detectors and save') class-attribute instance-attribute

noise_model = Unicode(defaults.noise_model, help='Observation key containing the noise model') class-attribute instance-attribute

noiseweighted = Unicode(None, allow_none=True, help='The Data key where the noiseweighted map should be stored') class-attribute instance-attribute

pixel_dist = Unicode('pixel_dist', help='The Data key where the PixelDist object should be stored') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pixel pointing operator') class-attribute instance-attribute

pre_process = Instance(klass=Operator, allow_none=True, help='Optional extra operator to run prior to binning') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional telescope flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

stokes_weights = Instance(klass=Operator, allow_none=True, help='This must be an instance of a Stokes weights operator') class-attribute instance-attribute

sync_type = Unicode('alltoallv', help="Communication algorithm: 'allreduce' or 'alltoallv'") class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_binning.py
175
176
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_binning.py
117
118
119
120
121
122
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_binning.py
124
125
126
127
128
129
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_pixel_pointing(proposal)

Source code in toast/ops/mapmaker_binning.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pixels = proposal["value"]
    if pixels is not None:
        if not isinstance(pixels, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pixels.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pixels

_check_shared_mask(proposal)

Source code in toast/ops/mapmaker_binning.py
131
132
133
134
135
136
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_check_stokes_weights(proposal)

Source code in toast/ops/mapmaker_binning.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@traitlets.validate("stokes_weights")
def _check_stokes_weights(self, proposal):
    weights = proposal["value"]
    if weights is not None:
        if not isinstance(weights, Operator):
            raise traitlets.TraitError(
                "stokes_weights should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["weights", "view"]:
            if not weights.has_trait(trt):
                msg = f"stokes_weights operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return weights

_check_sync_type(proposal)

Source code in toast/ops/mapmaker_binning.py
138
139
140
141
142
143
@traitlets.validate("sync_type")
def _check_sync_type(self, proposal):
    check = proposal["value"]
    if check != "allreduce" and check != "alltoallv":
        raise traitlets.TraitError("Invalid communication algorithm")
    return check

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/mapmaker_binning.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()
    timer = Timer()
    timer.start()

    for trait in "pixel_pointing", "stokes_weights", "det_data":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    log.verbose_rank("  BinMap building pipeline", comm=data.comm.comm_world)

    if self.covariance not in data:
        msg = f"Data does not contain noise covariance '{self.covariance}'"
        log.error(msg)
        raise RuntimeError(msg)

    cov = data[self.covariance]

    # Check that covariance has consistent units
    if cov.units != (self.det_data_units**2).decompose():
        msg = f"Covariance '{self.covariance}' units {cov.units} do not"
        msg += f" equal det_data units ({self.det_data_units}) squared."
        log.error(msg)
        raise RuntimeError(msg)

    # Sanity check that the covariance pixel distribution agrees
    if cov.distribution != data[self.pixel_dist]:
        msg = (
            f"Pixel distribution '{self.pixel_dist}' does not match the one "
            f"used by covariance '{self.covariance}'"
        )
        log.error(msg)
        raise RuntimeError(msg)

    # Set outputs of the pointing operator

    self.pixel_pointing.create_dist = None

    # If the binned map already exists in the data, verify the distribution and
    # reset to zero.

    if self.binned in data:
        if data[self.binned].distribution != data[self.pixel_dist]:
            msg = (
                f"Pixel distribution '{self.pixel_dist}' does not match "
                f"existing binned map '{self.binned}'"
            )
            log.error(msg)
            raise RuntimeError(msg)
        data[self.binned].reset()
        data[self.binned].update_units(1.0 / self.det_data_units)

    # Use the same detector mask in the pointing
    self.pixel_pointing.detector_pointing.det_mask = self.det_mask
    self.pixel_pointing.detector_pointing.det_flag_mask = self.det_flag_mask
    if hasattr(self.stokes_weights, "detector_pointing"):
        self.stokes_weights.detector_pointing.det_mask = self.det_mask
        self.stokes_weights.detector_pointing.det_flag_mask = self.det_flag_mask

    # Noise weighted map.  We output this to the final binned map location,
    # since we will multiply by the covariance in-place.

    build_zmap = BuildNoiseWeighted(
        pixel_dist=self.pixel_dist,
        zmap=self.binned,
        view=self.pixel_pointing.view,
        pixels=self.pixel_pointing.pixels,
        weights=self.stokes_weights.weights,
        noise_model=self.noise_model,
        det_data=self.det_data,
        det_data_units=self.det_data_units,
        det_mask=self.det_mask,
        det_flags=self.det_flags,
        det_flag_mask=self.det_flag_mask,
        shared_flags=self.shared_flags,
        shared_flag_mask=self.shared_flag_mask,
        sync_type=self.sync_type,
    )

    # Build a pipeline to expand pointing and accumulate

    accum = None
    accum_ops = list()
    if self.pre_process is not None:
        accum_ops.append(self.pre_process)
    if self.full_pointing:
        # Process all detectors at once
        accum = Pipeline(detector_sets=["ALL"])
    else:
        # Process one detector at a time.
        accum = Pipeline(detector_sets=["SINGLE"])
    accum_ops.extend([self.pixel_pointing, self.stokes_weights, build_zmap])

    accum.operators = accum_ops

    if data.comm.world_rank == 0:
        log.verbose("  BinMap running pipeline")
    pipe_out = accum.apply(data, detectors=detectors)

    # print("Binned zmap = ", data[self.binned].data)

    # Optionally, store the noise-weighted map
    if self.noiseweighted is not None:
        data[self.noiseweighted] = data[self.binned].duplicate()

    # Extract the results
    binned_map = data[self.binned]

    # Apply the covariance in place
    if data.comm.world_rank == 0:
        log.verbose("  BinMap applying covariance")
    covariance_apply(cov, binned_map, use_alltoallv=(self.sync_type == "alltoallv"))
    # print("Binned final = ", data[self.binned].data)
    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker_binning.py
295
296
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/mapmaker_binning.py
314
315
316
def _provides(self):
    prov = {"global": [self.binned]}
    return prov

_requires()

Source code in toast/ops/mapmaker_binning.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def _requires(self):
    req = self.pixel_pointing.requires()
    req.update(self.stokes_weights.requires())
    if self.pre_process is not None:
        req.update(self.pre_process.requires())
    req["global"].extend([self.pixel_dist, self.covariance])
    req["meta"].extend([self.noise_model])
    req["detdata"].extend([self.det_data])
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    if self.pre_process is not None:
        req.update(self.pre_process.requires())
    return req

Observation Matrices

toast.ops.ObsMat

Bases: object

Observation Matrix class

Source code in toast/ops/obsmat.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class ObsMat(object):
    """Observation Matrix class"""

    def __init__(self, filename=None):
        self.filename = filename
        self.matrix = None
        self.load()
        return

    @function_timer
    def load(self, filename=None):
        if filename is not None:
            self.filename = filename
        if self.filename is None:
            self.matrix = None
            self.nnz = 0
            self.nrow, self.ncol = 0, 0
            return
        self.matrix = scipy.sparse.load_npz(self.filename)
        self.nnz = self.matrix.nnz
        if self.nnz < 0:
            msg = f"Overflow in {self.filename}: nnz = {self.nnz}"
            raise RuntimeError(msg)
        self.nrow, self.ncol = self.matrix.shape
        return

    @function_timer
    def apply(self, map_in):
        nmap, npix = np.atleast_2d(map_in).shape
        npixtot = np.prod(map_in.shape)
        if npixtot != self.ncol:
            msg = f"Map is incompatible with the observation matrix. "
            msg += f"shape(matrix) = {self.matrix.shape}, shape(map) = {map_in.shape}"
            raise RuntimeError(msg)
        map_out = self.matrix.dot(map_in.ravel())
        if nmap != 1:
            map_out = map_out.reshape([nmap, -1])
        return map_out

    def sort_indices(self):
        self.matrix.sort_indices()

    @property
    def data(self):
        return self.matrix.data

    def __iadd__(self, other):
        if hasattr(other, "matrix"):
            self.matrix += other.matrix
        else:
            self.matrix += other
        return self

    def __imul__(self, other):
        if hasattr(other, "matrix"):
            self.matrix *= other.matrix
        else:
            self.matrix *= other
        return self

data property

filename = filename instance-attribute

matrix = None instance-attribute

__iadd__(other)

Source code in toast/ops/obsmat.py
66
67
68
69
70
71
def __iadd__(self, other):
    if hasattr(other, "matrix"):
        self.matrix += other.matrix
    else:
        self.matrix += other
    return self

__imul__(other)

Source code in toast/ops/obsmat.py
73
74
75
76
77
78
def __imul__(self, other):
    if hasattr(other, "matrix"):
        self.matrix *= other.matrix
    else:
        self.matrix *= other
    return self

__init__(filename=None)

Source code in toast/ops/obsmat.py
23
24
25
26
27
def __init__(self, filename=None):
    self.filename = filename
    self.matrix = None
    self.load()
    return

apply(map_in)

Source code in toast/ops/obsmat.py
46
47
48
49
50
51
52
53
54
55
56
57
@function_timer
def apply(self, map_in):
    nmap, npix = np.atleast_2d(map_in).shape
    npixtot = np.prod(map_in.shape)
    if npixtot != self.ncol:
        msg = f"Map is incompatible with the observation matrix. "
        msg += f"shape(matrix) = {self.matrix.shape}, shape(map) = {map_in.shape}"
        raise RuntimeError(msg)
    map_out = self.matrix.dot(map_in.ravel())
    if nmap != 1:
        map_out = map_out.reshape([nmap, -1])
    return map_out

load(filename=None)

Source code in toast/ops/obsmat.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@function_timer
def load(self, filename=None):
    if filename is not None:
        self.filename = filename
    if self.filename is None:
        self.matrix = None
        self.nnz = 0
        self.nrow, self.ncol = 0, 0
        return
    self.matrix = scipy.sparse.load_npz(self.filename)
    self.nnz = self.matrix.nnz
    if self.nnz < 0:
        msg = f"Overflow in {self.filename}: nnz = {self.nnz}"
        raise RuntimeError(msg)
    self.nrow, self.ncol = self.matrix.shape
    return

sort_indices()

Source code in toast/ops/obsmat.py
59
60
def sort_indices(self):
    self.matrix.sort_indices()

toast.ops.FilterBin

Bases: Operator

FilterBin buids a template matrix and projects out compromised modes. It then bins the signal and optionally writes out the sparse observation matrix that matches the filtering operations. FilterBin supports deprojection templates.

THIS OPERATOR ASSUMES OBSERVATIONS ARE DISTRIBUTED BY DETECTOR WITHIN A GROUP

Source code in toast/ops/filterbin.py
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
@trait_docs
class FilterBin(Operator):
    """FilterBin buids a template matrix and projects out
    compromised modes.  It then bins the signal and optionally
    writes out the sparse observation matrix that matches the
    filtering operations.
    FilterBin supports deprojection templates.

    THIS OPERATOR ASSUMES OBSERVATIONS ARE DISTRIBUTED BY DETECTOR
    WITHIN A GROUP

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    filter_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask value for flagging samples that fail filtering",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional telescope flagging",
    )

    hwp_angle = Unicode(
        defaults.hwp_angle, allow_none=True, help="Observation shared key for HWP angle"
    )

    deproject_map = Unicode(
        None,
        allow_none=True,
        help="Healpix map containing the deprojection templates: "
        "intensity map and its derivatives",
    )

    deproject_nnz = Int(
        1,
        help="Number of deprojection templates to regress.  Must be less than "
        "or equal to number of columns in `deproject_map`.",
    )

    deproject_pattern = Unicode(
        ".*",
        help="Regular expression to test detector names with.  Only matching "
        "detectors will be deprojected.  Used to identify differenced TOD.",
    )

    binning = Instance(
        klass=Operator,
        allow_none=True,
        help="Binning operator for map making.",
    )

    azimuth = Unicode(
        defaults.azimuth, allow_none=True, help="Observation shared key for Azimuth"
    )

    hwp_filter_order = Int(
        None,
        allow_none=True,
        help="Order of HWP-synchronous signal filter.",
    )

    ground_filter_order = Int(
        5,
        allow_none=True,
        help="Order of a Legendre polynomial to fit as a function of azimuth.",
    )

    split_ground_template = Bool(
        False, help="Apply a different template for left and right scans"
    )

    leftright_interval = Unicode(
        defaults.throw_leftright_interval,
        help="Intervals for left-to-right scans",
    )

    rightleft_interval = Unicode(
        defaults.throw_rightleft_interval,
        help="Intervals for right-to-left scans",
    )

    poly_filter_order = Int(1, allow_none=True, help="Polynomial order")

    poly_filter_view = Unicode(
        "throw", allow_none=True, help="Intervals for polynomial filtering"
    )

    write_obs_matrix = Bool(False, help="Write the observation matrix")

    noiseweight_obs_matrix = Bool(
        False, help="If True, observation matrix should match noise-weighted maps"
    )

    output_dir = Unicode(
        ".",
        help="Write output data products to this directory",
    )

    write_binmap = Bool(False, help="If True, write the unfiltered map")

    write_map = Bool(True, help="If True, write the filtered map")

    write_noiseweighted_binmap = Bool(
        False,
        help="If True, write the noise-weighted unfiltered map",
    )

    write_noiseweighted_map = Bool(
        False,
        help="If True, write the noise-weighted filtered map",
    )

    write_hits = Bool(True, help="If True, write the hits map")

    write_cov = Bool(True, help="If True, write the white noise covariance matrices.")

    write_invcov = Bool(
        False,
        help="If True, write the inverse white noise covariance matrices.",
    )

    write_rcond = Bool(True, help="If True, write the reciprocal condition numbers.")

    keep_final_products = Bool(
        False, help="If True, keep the map domain products in data after write"
    )

    mc_mode = Bool(False, help="If True, re-use solver flags, sparse covariances, etc")

    mc_index = Int(None, allow_none=True, help="The Monte-Carlo index")

    maskfile = Unicode(
        None,
        allow_none=True,
        help="Optional processing mask",
    )

    cache_dir = Unicode(
        None,
        allow_none=True,
        help="Cache directory for additive observation matrix products",
    )

    rcond_threshold = Float(
        1.0e-3,
        help="Minimum value for inverse pixel condition number cut.",
    )

    deproject_map_name = "deprojection_map"

    write_hdf5 = Bool(
        False, help="If True, output maps are in HDF5 rather than FITS format."
    )

    write_hdf5_serial = Bool(
        False, help="If True, force serial HDF5 write of output maps."
    )

    reset_pix_dist = Bool(
        False,
        help="Clear any existing pixel distribution.  Useful when applying"
        "repeatedly to different data objects.",
    )

    report_memory = Bool(False, help="Report memory throughout the execution")

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("binning")
    def _check_binning(self, proposal):
        bin = proposal["value"]
        if bin is not None:
            if not isinstance(bin, Operator):
                raise traitlets.TraitError("binning should be an Operator instance")
            # Check that this operator has the traits we require
            for trt in [
                "det_data",
                "pixel_dist",
                "pixel_pointing",
                "stokes_weights",
                "binned",
                "covariance",
                "det_flags",
                "det_flag_mask",
                "shared_flags",
                "shared_flag_mask",
                "noise_model",
                "full_pointing",
                "sync_type",
            ]:
                if not bin.has_trait(trt):
                    msg = "binning operator should have a '{}' trait".format(trt)
                    raise traitlets.TraitError(msg)
        return bin

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()

        timer = Timer()
        timer.start()

        memreport = MemoryCounter()
        if not self.report_memory:
            memreport.enabled = False

        memreport.prefix = "Start of mapmaking"
        memreport.apply(data)

        for trait in ("binning",):
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        # Optionally destroy existing pixel distributions (useful if calling
        # repeatedly with different data objects)

        binning = self.binning
        if self.reset_pix_dist:
            if binning.pixel_dist in data:
                del data[binning.pixel_dist]
            if binning.covariance in data:
                # Cannot trust earlier covariance
                del data[binning.covariance]

        if binning.pixel_dist not in data:
            pix_dist = BuildPixelDistribution(
                pixel_dist=binning.pixel_dist,
                pixel_pointing=binning.pixel_pointing,
                shared_flags=binning.shared_flags,
                shared_flag_mask=binning.shared_flag_mask,
            )
            pix_dist.apply(data)
            log.debug_rank(
                "Cached pixel distribution in", comm=data.comm.comm_world, timer=timer
            )

        self.npix = data[binning.pixel_dist].n_pix
        self.nnz = len(self.binning.stokes_weights.mode)

        self.npixtot = self.npix * self.nnz
        self.ncov = self.nnz * (self.nnz + 1) // 2

        if self.maskfile is not None:
            raise RuntimeError("Filtering mask not yet implemented")

        log.debug_rank(
            f"FilterBin:  Running with self.cache_dir = {self.cache_dir}",
            comm=data.comm.comm_world,
        )

        # Get the units used across the distributed data for our desired
        # input detector data
        self._det_data_units = data.detector_units(self.det_data)

        self._initialize_comm(data)

        # Filter data

        self._initialize_obs_matrix()
        log.debug_rank(
            "FilterBin: Initialized observation_matrix in",
            comm=self.comm,
            timer=timer,
        )

        self._load_deprojection_map(data)
        log.debug_rank(
            "FilterBin: Loaded deprojection map in", comm=self.comm, timer=timer
        )

        self._bin_map(data, detectors, filtered=False)
        log.debug_rank(
            "FilterBin: Binned unfiltered map in", comm=self.comm, timer=timer
        )

        log.debug_rank("FilterBin: Filtering signal", comm=self.comm)

        timer1 = Timer()
        timer1.start()
        timer2 = Timer()
        timer2.start()

        memreport.prefix = "Before filtering"
        memreport.apply(data)

        t1 = time()
        for iobs, obs in enumerate(data.obs):
            dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin: Processing observation "
                    f"{iobs} / {len(data.obs)}",
                )

            common_templates = self._build_common_templates(obs)
            if self.shared_flags is not None:
                common_flags = obs.shared[self.shared_flags].data
            else:
                common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:   Built common templates in "
                    f"{time() - t1:.2f} s",
                )
                t1 = time()

            memreport.prefix = "After common templates"
            memreport.apply(data)

            last_good_fit = None
            template_covariance = None

            for idet, det in enumerate(dets):
                t1 = time()
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:   Processing detector "
                        f"# {idet + 1} / {len(dets)}",
                    )

                signal = obs.detdata[self.det_data][det]
                flags = obs.detdata[self.det_flags][det]
                # `good` is essentially the diagonal noise matrix used in
                # template regression.  All good detector samples have the
                # same noise weight and rest have zero weight.
                good_fit = np.logical_and(
                    (common_flags & self.shared_flag_mask) == 0,
                    (flags & self.det_flag_mask) == 0,
                )
                good_bin = np.logical_and(
                    (common_flags & self.binning.shared_flag_mask) == 0,
                    (flags & self.binning.det_flag_mask) == 0,
                )

                if np.sum(good_fit) == 0:
                    continue

                deproject = (
                    self.deproject_map is not None
                    and self._deproject_pattern.match(det) is not None
                )

                if deproject or self.write_obs_matrix:
                    # We'll need pixel numbers
                    obs_data = data.select(obs_uid=obs.uid)
                    self.binning.pixel_pointing.apply(obs_data, detectors=[det])
                    pixels = obs.detdata[self.binning.pixel_pointing.pixels][det]
                    # and weights
                    self.binning.stokes_weights.apply(obs_data, detectors=[det])
                    weights = obs.detdata[self.binning.stokes_weights.weights][det]
                else:
                    pixels = None
                    weights = None

                det_templates = common_templates.mask(good_fit)

                if (
                    self.deproject_map is not None
                    and self._deproject_pattern.match(det) is not None
                ):
                    self._add_deprojection_templates(data, obs, pixels, det_templates)
                    # Must re-evaluate the template covariance
                    template_covariance = None

                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:   Built deprojection "
                        f"templates in {time() - t1:.2f} s. "
                        f"ntemplate = {det_templates.ntemplate}",
                    )
                    t1 = time()

                if det_templates.ntemplate == 0:
                    # No templates to fit
                    continue

                # memreport.prefix = "After detector templates"
                # memreport.apply(data)

                if template_covariance is None or np.any(last_good_fit != good_fit):
                    template_covariance = self._build_template_covariance(
                        det_templates, good_fit
                    )
                    last_good_fit = good_fit.copy()

                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:   Built template covariance "
                        f"{time() - t1:.2f} s",
                    )
                    t1 = time()

                self._regress_templates(
                    det_templates, template_covariance, signal, good_fit
                )
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:   Regressed templates in "
                        f"{time() - t1:.2f} s",
                    )
                    t1 = time()

                self._accumulate_observation_matrix(
                    obs,
                    det,
                    pixels,
                    weights,
                    good_fit,
                    good_bin,
                    det_templates,
                    template_covariance,
                )

        log.debug_rank(
            f"{self.group:4} : FilterBin:   Filtered group data in",
            comm=self.gcomm,
            timer=timer1,
        )

        if self.comm is not None:
            self.comm.Barrier()

        log.info_rank(
            f"FilterBin:   Filtered data in",
            comm=self.comm,
            timer=timer2,
        )

        memreport.prefix = "After filtering"
        memreport.apply(data)

        # Bin filtered signal

        self._bin_map(data, detectors, filtered=True)
        log.debug_rank("FilterBin: Binned filtered map in", comm=self.comm, timer=timer)

        log.info_rank(
            f"FilterBin:   Binned data in",
            comm=self.comm,
            timer=timer2,
        )

        memreport.prefix = "After binning"
        memreport.apply(data)

        if self.write_obs_matrix:
            if not self.noiseweight_obs_matrix:
                log.debug_rank(
                    "FilterBin: De-weighting observation matrix", comm=self.comm
                )
                self._deweight_obs_matrix(data)
                log.debug_rank(
                    "FilterBin: De-weighted observation_matrix in",
                    comm=self.comm,
                    timer=timer2,
                )

            log.info_rank("FilterBin: Collecting observation matrix", comm=self.comm)
            self._collect_obs_matrix()
            log.info_rank(
                "FilterBin: Collected observation_matrix in",
                comm=self.comm,
                timer=timer2,
            )

            memreport.prefix = "After observation matrix"
            memreport.apply(data)

        return

    @function_timer
    def _add_hwp_templates(self, obs, templates):
        if self.hwp_filter_order is None:
            return

        if self.hwp_angle not in obs.shared:
            msg = (
                f"Cannot apply HWP filtering at order = {self.hwp_filter_order}: "
                f"no HWP angle found under key = '{self.hwp_angle}'"
            )
            raise RuntimeError(msg)
        hwp_angle = obs.shared[self.hwp_angle].data
        shared_flags = np.array(obs.shared[self.shared_flags])

        nfilter = 2 * self.hwp_filter_order
        if nfilter < 1:
            return

        fourier_templates = np.zeros([nfilter, hwp_angle.size])
        fourier(hwp_angle, fourier_templates, 1, self.hwp_filter_order + 1)

        templates.append(fourier_templates)

        return

    @function_timer
    def _add_ground_templates(self, obs, templates):
        if self.ground_filter_order is None:
            return

        # To avoid template degeneracies, ground filter only includes
        # polynomial orders not present in the polynomial filter

        phase = self._get_phase(obs)
        shared_flags = np.array(obs.shared[self.shared_flags])

        min_order = 0
        if self.poly_filter_order is not None:
            min_order = self.poly_filter_order + 1
        max_order = self.ground_filter_order
        nfilter = max_order - min_order + 1
        if nfilter < 1:
            return

        legendre_templates = np.zeros([nfilter, phase.size])
        legendre(phase, legendre_templates, min_order, max_order + 1)
        if not self.split_ground_template:
            legendre_filter = legendre_templates
        else:
            # Separate ground filter by scan direction.
            legendre_filter = []
            masks = []
            for name in self.leftright_interval, self.rightleft_interval:
                mask = np.zeros(phase.size, dtype=bool)
                for ival in obs.intervals[name]:
                    mask[ival.first : ival.last] = True
                masks.append(mask)
            for template in legendre_templates:
                for mask in masks:
                    temp = template.copy()
                    temp[mask] = 0
                    legendre_filter.append(temp)
            legendre_filter = np.vstack(legendre_filter)

        templates.append(legendre_filter)

        return

    @function_timer
    def _add_poly_templates(self, obs, templates):
        if self.poly_filter_order is None:
            return
        nfilter = self.poly_filter_order + 1
        intervals = obs.intervals[self.poly_filter_view]
        if self.shared_flags is None:
            shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)
        else:
            shared_flags = np.array(obs.shared[self.shared_flags])
        bad = (shared_flags & self.shared_flag_mask) != 0

        for ival in intervals:
            istart = ival.first
            istop = ival.last
            # Trim flagged samples from both ends
            while istart < istop and bad[istart]:
                istart += 1
            while istop - 1 > istart and bad[istop - 1]:
                istop -= 1
            if istop - istart < nfilter:
                # Not enough samples to filter, flag this interval
                shared_flags[ival.first : ival.last] |= self.filter_flag_mask
                continue
            wbin = 2 / (istop - istart)
            phase = (np.arange(istop - istart) + 0.5) * wbin - 1
            legendre_templates = np.zeros([nfilter, phase.size])
            legendre(phase, legendre_templates, 0, nfilter)
            templates.append(legendre_templates, start=istart, stop=istop)

        if self.shared_flags is not None:
            obs.shared[self.shared_flags].set(shared_flags, offset=(0,), fromrank=0)

        return

    @function_timer
    def _build_common_templates(self, obs):
        templates = SparseTemplates()

        self._add_hwp_templates(obs, templates)
        self._add_ground_templates(obs, templates)
        self._add_poly_templates(obs, templates)

        return templates

    @function_timer
    def _add_deprojection_templates(self, data, obs, pixels, templates):
        deproject_map = data[self.deproject_map_name]
        map_dist = deproject_map.distribution
        local_sm, local_pix = map_dist.global_pixel_to_submap(pixels)

        if deproject_map.dtype.char == "d":
            scan_map = scan_map_float64
        elif deproject_map.dtype.char == "f":
            scan_map = scan_map_float32
        else:
            raise RuntimeError("Deprojection supports only float32 and float64 maps")

        nsample = pixels.size
        nnz = self._deproject_nnz
        weights = np.zeros([nsample, nnz], dtype=np.float64)
        dptemplate_raw = AlignedF64.zeros(nsample)
        dptemplate = dptemplate_raw.array()
        norm = np.dot(common_templates[0], common_templates[0])
        for inz in range(self._deproject_nnz):
            weights[:] = 0
            weights[:, inz] = 1
            scan_map(
                deproject_map.distribution.n_pix_submap,
                deproject_map.n_value,
                local_sm.astype(np.int64),
                local_pix.astype(np.int64),
                deproject_map.raw,
                weights.reshape(-1),
                template,
            )
            dptemplate *= np.sqrt(norm / np.dot(dptemplate, dptemplate))
            templates.append(dptemplate)
        return

    @function_timer
    def _build_template_covariance(self, templates, good):
        """Calculate (F^T N^-1_F F)^-1

        Observe that the sample noise weights in N^-1_F need not be the
        same as in binning the filtered signal.  For instance, samples
        falling on point sources may be masked here but included in the
        final map.
        """
        log = Logger.get()
        ntemplate = templates.ntemplate
        invcov = np.zeros([ntemplate, ntemplate])
        build_template_covariance(
            templates.starts,
            templates.stops,
            templates.templates,
            good.astype(np.float64),
            invcov,
        )
        try:
            rcond = 1 / np.linalg.cond(invcov)
        except np.linalg.LinAlgError:
            print(
                f"Failed condition number calculation for {ntemplate}x{ntemplate} matrix:"
            )
            print(f"{invcov}", flush=True)
            print(f"Diagonal:")
            for row in range(ntemplate):
                print(f"{row:03d} {invcov[row, row]}")
            raise
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: Template covariance matrix "
                f"rcond = {rcond}",
            )
        if rcond > 1e-6:
            cov = np.linalg.inv(invcov)
        else:
            log.warning(
                f"{self.group:4} : FilterBin: WARNING: template covariance matrix "
                f"is poorly conditioned: "
                f"rcond = {rcond}.  Using matrix pseudoinverse.",
            )
            cov = np.linalg.pinv(invcov, rcond=1e-10, hermitian=True)

        return cov

    @function_timer
    def _regress_templates(self, templates, template_covariance, signal, good):
        """Calculate Zd = (I - F(F^T N^-1_F F)^-1 F^T N^-1_F)d

        All samples that are not flagged (zero weight in N^-1_F) have
        equal weight.
        """
        proj = templates.dot(signal * good)
        amplitudes = np.dot(template_covariance, proj)
        templates.subtract(signal, amplitudes)
        return

    @function_timer
    def _compress_pixels(self, pixels):
        if any(pixels < 0):
            msg = f"Unflagged samples have {np.sum(pixels < 0)} negative pixel numbers"
            raise RuntimeError(msg)
        if any(pixels >= self.npix):
            msg = f"Unflagged samples have {np.sum(pixels >= self.npix)} pixels >= {self.npix}"
            raise RuntimeError(msg)
        local_to_global = np.sort(list(set(pixels)))
        compressed_pixels = np.searchsorted(local_to_global, pixels)
        return compressed_pixels, local_to_global.size, local_to_global

    @function_timer
    def _add_matrix(self, local_obs_matrix, detweight):
        """Add the local (per detector) observation matrix to the full
        matrix
        """
        log = Logger.get()
        t1 = time()
        if False:
            # Use scipy sparse implementation
            self.obs_matrix += local_obs_matrix * detweight
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin: Add and construct matrix "
                    f"in {time() - t1:.2f} s",
                )
        else:
            # Use our own compiled kernel
            n = self.obs_matrix.nnz + local_obs_matrix.nnz
            data = np.zeros(n, dtype=np.float64)
            indices = np.zeros(n, dtype=np.int64)
            indptr = np.zeros(self.npixtot + 1, dtype=np.int64)
            add_matrix(
                self.obs_matrix.data,
                self.obs_matrix.indices,
                self.obs_matrix.indptr,
                local_obs_matrix.data * detweight,
                local_obs_matrix.indices,
                local_obs_matrix.indptr,
                data,
                indices,
                indptr,
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin: Add matrix "
                    f"in {time() - t1:.2f} s",
                )
            t1 = time()
            n = indptr[-1]
            self.obs_matrix = scipy.sparse.csr_matrix(
                (data[:n], indices[:n], indptr),
                shape=(self.npixtot, self.npixtot),
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin: construct CSR matrix "
                    f"in {time() - t1:.2f} s",
                )
        return

    def _expand_matrix(self, compressed_matrix, local_to_global):
        """Expands a dense, compressed matrix into a sparse matrix with
        global indexing
        """
        n = compressed_matrix.size
        values = np.zeros(n, dtype=np.float64)
        indices = np.zeros(n, dtype=np.int64)
        indptr = np.zeros(self.npixtot + 1, dtype=np.int64)
        expand_matrix(
            compressed_matrix,
            local_to_global,
            self.npix,
            self.nnz,
            values,
            indices,
            indptr,
        )
        nnz = indptr[-1]

        sparse_matrix = scipy.sparse.csr_matrix(
            (values[:nnz], indices[:nnz], indptr),
            shape=(self.npixtot, self.npixtot),
        )
        return sparse_matrix

    @function_timer
    def _accumulate_observation_matrix(
        self,
        obs,
        det,
        pixels,
        weights,
        good_fit,
        good_bin,
        det_templates,
        template_covariance,
    ):
        """Calculate P^T N^-1 Z P
        This part of the covariance calculation is cumulative: each observation
        and detector is computed independently and can be cached.

        Observe that `N` in this equation need not be the same used in
        template covariance in `Z`.
        """
        if not self.write_obs_matrix:
            return
        log = Logger.get()
        templates = det_templates.to_dense(good_fit.size)
        fname_cache = None
        local_obs_matrix = None
        t1 = time()
        if self.cache_dir is not None:
            cache_dir = os.path.join(self.cache_dir, obs.name)
            os.makedirs(cache_dir, exist_ok=True)
            fname_cache = os.path.join(cache_dir, det)
            try:
                mm_data = np.load(fname_cache + ".data.npy")
                mm_indices = np.load(fname_cache + ".indices.npy")
                mm_indptr = np.load(fname_cache + ".indptr.npy")
                local_obs_matrix = scipy.sparse.csr_matrix(
                    (mm_data, mm_indices, mm_indptr),
                    self.obs_matrix.shape,
                )
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:     loaded cached matrix from "
                        f"{fname_cache}* in {time() - t1:.2f} s",
                    )
                    t1 = time()
            except:
                local_obs_matrix = None
        else:
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     cache_dir = {self.cache_dir}"
                )

        if local_obs_matrix is None:
            templates = templates.T.copy()
            good_any = np.logical_or(good_fit, good_bin)

            # Temporarily compress pixels
            if self.grank == 0:
                log.debug(f"{self.group:4} : FilterBin:     Compressing pixels")
            c_pixels, c_npix, local_to_global = self._compress_pixels(
                pixels[good_any].copy()
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin: Compressed in {time() - t1:.2f} s",
                )
                t1 = time()
            c_npixtot = c_npix * self.nnz
            c_obs_matrix = np.zeros([c_npixtot, c_npixtot])
            if self.grank == 0:
                log.debug(f"{self.group:4} : FilterBin:     Accumulating")
            accumulate_observation_matrix(
                c_obs_matrix,
                c_pixels,
                weights[good_any].copy(),
                templates[good_any].copy(),
                template_covariance,
                good_fit[good_any].astype(np.uint8),
                good_bin[good_any].astype(np.uint8),
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     Accumulated in {time() - t1:.2f} s"
                )
                log.debug(
                    f"{self.group:4} : FilterBin:     Expanding local to global",
                )
                t1 = time()
            # expand to global pixel numbers
            local_obs_matrix = self._expand_matrix(c_obs_matrix, local_to_global)
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     Expanded in {time() - t1:.2f} s"
                )
                t1 = time()

            if fname_cache is not None:
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:     Caching to {fname_cache}*",
                    )
                np.save(fname_cache + ".data", local_obs_matrix.data)
                np.save(fname_cache + ".indices", local_obs_matrix.indices)
                np.save(fname_cache + ".indptr", local_obs_matrix.indptr)
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:     cached in {time() - t1:.2f} s",
                    )
                    t1 = time()
            else:
                if self.grank == 0:
                    log.debug(
                        f"{self.group:4} : FilterBin:     NOT caching detector matrix",
                    )

        if self.grank == 0:
            log.debug(f"{self.group:4} : FilterBin:     Adding to global")
        detweight = obs[self.binning.noise_model].detector_weight(det)
        self._add_matrix(local_obs_matrix, detweight)
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin:     Added in {time() - t1:.2f} s",
            )
        return

    @function_timer
    def _get_phase(self, obs):
        if self.ground_filter_order is None:
            return None
        try:
            azmin = obs["scan_min_az"].to_value(u.radian)
            azmax = obs["scan_max_az"].to_value(u.radian)
            if self.azimuth is not None:
                az = obs.shared[self.azimuth]
            else:
                quats = obs.shared[self.boresight_azel]
                theta, phi, _ = qa.to_iso_angles(quats)
                az = 2 * np.pi - phi
        except Exception as e:
            msg = (
                f"Failed to get boresight azimuth from TOD.  "
                f"Perhaps it is not ground TOD? '{e}'"
            )
            raise RuntimeError(msg)
        phase = (np.unwrap(az) - azmin) / (azmax - azmin) * 2 - 1

        return phase

    @function_timer
    def _initialize_comm(self, data):
        """Create convenience aliases to the communicators and properties."""
        self.comm = data.comm.comm_world
        self.rank = data.comm.world_rank
        self.ntask = data.comm.world_size
        self.gcomm = data.comm.comm_group
        self.group = data.comm.group
        self.grank = data.comm.group_rank
        return

    @function_timer
    def _initialize_obs_matrix(self):
        if self.write_obs_matrix:
            self.obs_matrix = scipy.sparse.csr_matrix(
                (self.npixtot, self.npixtot), dtype=np.float64
            )
            if self.rank == 0 and self.cache_dir is not None:
                os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.obs_matrix = None
        return

    @function_timer
    def _deweight_obs_matrix(self, data):
        """Apply (P^T N^-1 P)^-1 to the cumulative part of the
        observation matrix, P^T N^-1 Z P.
        """
        if not self.write_obs_matrix:
            return
        # Apply the white noise covariance to the observation matrix
        white_noise_cov = data[self.binning.covariance]
        cc = scipy.sparse.dok_matrix((self.npixtot, self.npixtot), dtype=np.float64)
        nsubmap = white_noise_cov.distribution.n_submap
        npix_submap = white_noise_cov.distribution.n_pix_submap
        for isubmap_local, isubmap_global in enumerate(
            white_noise_cov.distribution.local_submaps
        ):
            submap = white_noise_cov.data[isubmap_local]
            offset = isubmap_global * npix_submap
            for pix_local in range(npix_submap):
                if np.all(submap[pix_local] == 0):
                    continue
                pix = pix_local + offset
                icov = 0
                for inz in range(self.nnz):
                    for jnz in range(inz, self.nnz):
                        cc[pix + inz * self.npix, pix + jnz * self.npix] = submap[
                            pix_local, icov
                        ]
                        if inz != jnz:
                            cc[pix + jnz * self.npix, pix + inz * self.npix] = submap[
                                pix_local, icov
                            ]
                        icov += 1
        cc = cc.tocsr()
        self.obs_matrix = cc.dot(self.obs_matrix)
        return

    @function_timer
    def _collect_obs_matrix(self):
        if not self.write_obs_matrix:
            return
        # Combine the observation matrix across processes
        # Reduce the observation matrices.  We use the buffer protocol
        # for better performance, even though it requires more MPI calls
        # than sending the sparse matrix objects directly
        log = Logger.get()
        timer = Timer()
        timer.start()
        nrow_tot = self.npixtot
        nslice = 128
        nrow_write = nrow_tot // nslice
        for islice, row_start in enumerate(range(0, nrow_tot, nrow_write)):
            row_stop = row_start + nrow_write
            obs_matrix_slice = self.obs_matrix[row_start:row_stop]
            nnz = obs_matrix_slice.nnz
            if self.comm is not None:
                nnz = self.comm.allreduce(nnz)
            if nnz == 0:
                log.debug_rank(
                    f"Slice {islice+1:5} / {nslice}: {row_start:12} - {row_stop:12} "
                    f"is empty.  Skipping.",
                    comm=self.comm,
                )
                continue
            log.debug_rank(
                f"Collecting slice {islice+1:5} / {nslice} : {row_start:12} - "
                f"{row_stop:12}",
                comm=self.comm,
            )

            factor = 1
            while factor < self.ntask:
                log.debug_rank(
                    f"FilterBin: Collecting {2 * factor} / {self.ntask}",
                    comm=self.comm,
                )
                if self.rank % (factor * 2) == 0:
                    # this task receives
                    receive_from = self.rank + factor
                    if receive_from < self.ntask:
                        size_recv = self.comm.recv(source=receive_from, tag=factor)
                        data_recv = np.zeros(size_recv, dtype=np.float64)
                        self.comm.Recv(
                            data_recv, source=receive_from, tag=factor + self.ntask
                        )
                        indices_recv = np.zeros(size_recv, dtype=np.int64)
                        self.comm.Recv(
                            indices_recv,
                            source=receive_from,
                            tag=factor + 2 * self.ntask,
                        )
                        indptr_recv = np.zeros(
                            obs_matrix_slice.indptr.size, dtype=np.int64
                        )
                        self.comm.Recv(
                            indptr_recv,
                            source=receive_from,
                            tag=factor + 3 * self.ntask,
                        )
                        obs_matrix_slice += scipy.sparse.csr_matrix(
                            (data_recv, indices_recv, indptr_recv),
                            obs_matrix_slice.shape,
                        )
                        del data_recv, indices_recv, indptr_recv
                elif self.rank % (factor * 2) == factor:
                    # this task sends
                    send_to = self.rank - factor
                    self.comm.send(obs_matrix_slice.data.size, dest=send_to, tag=factor)
                    self.comm.Send(
                        obs_matrix_slice.data, dest=send_to, tag=factor + self.ntask
                    )
                    self.comm.Send(
                        obs_matrix_slice.indices.astype(np.int64),
                        dest=send_to,
                        tag=factor + 2 * self.ntask,
                    )
                    self.comm.Send(
                        obs_matrix_slice.indptr.astype(np.int64),
                        dest=send_to,
                        tag=factor + 3 * self.ntask,
                    )

                if self.comm is not None:
                    self.comm.Barrier()
                log.debug_rank("FilterBin: Collected in", comm=self.comm, timer=timer)
                factor *= 2

            # Write out the observation matrix
            if self.noiseweight_obs_matrix:
                fname = os.path.join(
                    self.output_dir, f"{self.name}_noiseweighted_obs_matrix"
                )
            else:
                fname = os.path.join(self.output_dir, f"{self.name}_obs_matrix")
            fname += f".{row_start:012}.{row_stop:012}.{nrow_tot:012}"
            log.debug_rank(
                f"FilterBin: Writing observation matrix to {fname}.npz",
                comm=self.comm,
            )
            if self.rank == 0:
                if True:
                    # Write out the members of the CSR matrix separately because
                    # scipy.sparse.save_npz is so inefficient
                    np.save(f"{fname}.data", obs_matrix_slice.data)
                    np.save(f"{fname}.indices", obs_matrix_slice.indices)
                    np.save(f"{fname}.indptr", obs_matrix_slice.indptr)
                else:
                    scipy.sparse.save_npz(fname, obs_matrix_slice)
            log.info_rank(
                f"FilterBin: Wrote observation matrix to {fname} in",
                comm=self.comm,
                timer=timer,
            )
        # After writing we are done
        del self.obs_matrix
        self.obs_matrix = None
        return

    @function_timer
    def _bin_map(self, data, detectors, filtered):
        """Bin the signal onto a map.  Optionally write out hits and
        white noise covariance matrices.
        """

        log = Logger.get()
        timer = Timer()
        timer.start()

        hits_name = f"{self.name}_hits"
        invcov_name = f"{self.name}_invcov"
        cov_name = f"{self.name}_cov"
        rcond_name = f"{self.name}_rcond"
        if filtered:
            map_name = f"{self.name}_filtered_map"
            noiseweighted_map_name = f"{self.name}_noiseweighted_filtered_map"
        else:
            map_name = f"{self.name}_unfiltered_map"
            noiseweighted_map_name = f"{self.name}_noiseweighted_unfiltered_map"

        self.binning.noiseweighted = noiseweighted_map_name
        self.binning.binned = map_name
        self.binning.det_data = self.det_data
        self.binning.det_data_units = self._det_data_units
        self.binning.covariance = cov_name

        if self.binning.covariance not in data:
            cov = CovarianceAndHits(
                pixel_dist=self.binning.pixel_dist,
                covariance=self.binning.covariance,
                inverse_covariance=invcov_name,
                hits=hits_name,
                rcond=rcond_name,
                det_mask=self.binning.det_mask,
                det_flags=self.binning.det_flags,
                det_flag_mask=self.binning.det_flag_mask,
                det_data_units=self._det_data_units,
                shared_flags=self.binning.shared_flags,
                shared_flag_mask=self.binning.shared_flag_mask,
                pixel_pointing=self.binning.pixel_pointing,
                stokes_weights=self.binning.stokes_weights,
                noise_model=self.binning.noise_model,
                rcond_threshold=self.rcond_threshold,
                sync_type=self.binning.sync_type,
                save_pointing=self.binning.full_pointing,
            )
            cov.apply(data, detectors=detectors)
            log.info_rank(f"Binned covariance and hits in", comm=self.comm, timer=timer)

        self.binning.apply(data, detectors=detectors)
        log.info_rank(f"Binned signal in", comm=self.comm, timer=timer)

        mc_root = self.name
        if self.mc_mode:
            if self.mc_root is not None:
                mc_root += f"_{self.mc_root}"
            if self.mc_index is not None:
                mc_root += f"_{self.mc_index:05d}"

        binned = not filtered  # only write hits and covariance once
        if binned:
            write_map = self.write_binmap
            write_noiseweighted_map = self.write_noiseweighted_binmap
        else:
            write_map = self.write_map
            write_noiseweighted_map = self.write_noiseweighted_map
        keep_final = self.keep_final_products
        keep_cov = self.keep_final_products or self.write_obs_matrix
        for key, write, keep, force, rootname in [
            (hits_name, self.write_hits and binned, keep_final, False, self.name),
            (rcond_name, self.write_rcond and binned, keep_final, False, self.name),
            (
                noiseweighted_map_name,
                write_noiseweighted_map,
                keep_final,
                True,
                mc_root,
            ),
            (map_name, write_map, keep_final, True, mc_root),
            (invcov_name, self.write_invcov and binned, keep_final, False, self.name),
            (cov_name, self.write_cov and binned, keep_cov, False, self.name),
        ]:
            if write:
                product = key.replace(f"{self.name}_", "")
                try:
                    if hasattr(self.binning.pixel_pointing, "wcs"):
                        # WCS pixelization
                        fname = os.path.join(
                            self.output_dir, f"{rootname}_{product}.fits"
                        )
                        if self.mc_mode and not force:
                            if os.path.isfile(fname):
                                log.info_rank(
                                    f"Skipping existing file: {fname}", comm=self.comm
                                )
                                continue
                        write_wcs_fits(data[key], fname)
                    else:
                        if self.write_hdf5:
                            # Non-standard HEALPix HDF5 output
                            fname = os.path.join(
                                self.output_dir, f"{rootname}_{product}.h5"
                            )
                            if self.mc_mode and not force:
                                if os.path.isfile(fname):
                                    log.info_rank(
                                        f"Skipping existing file: {fname}",
                                        comm=self.comm,
                                    )
                                    continue
                            write_healpix_hdf5(
                                data[key],
                                fname,
                                nest=self.binning.pixel_pointing.nest,
                                force_serial=self.write_hdf5_serial,
                            )
                        else:
                            # Standard HEALPix FITS output
                            fname = os.path.join(
                                self.output_dir, f"{rootname}_{product}.fits"
                            )
                            if self.mc_mode and not force:
                                if os.path.isfile(fname):
                                    log.info_rank(
                                        f"Skipping existing file: {fname}",
                                        comm=self.comm,
                                    )
                                    continue
                            write_healpix_fits(
                                data[key], fname, nest=self.binning.pixel_pointing.nest
                            )
                except Exception as e:
                    msg = f"ERROR: failed to write {fname} : {e}"
                    raise RuntimeError(msg)
                log.info_rank(f"Wrote {fname} in", comm=self.comm, timer=timer)
            if not keep and not self.mc_mode:
                if key in data:
                    data[key].clear()
                    del data[key]

        return

    @function_timer
    def _load_deprojection_map(self, data):
        if self.deproject_map is None:
            return None
        data[self.deproject_map_name] = PixelData(
            data[self.binning.pixel_dist],
            dtype=np.float32,
            n_value=self.deproject_nnz,
            units=self._det_data_units,
        )
        if filename_is_hdf5(self.deproject_map):
            read_healpix_hdf5(
                data[self.deproject_map_name],
                self.deproject_map,
                nest=self.binning.pixel_pointing.nest,
            )
        elif filename_is_fits(self.deproject_map):
            read_healpix_fits(
                data[self.deproject_map_name],
                self.deproject_map,
                nest=self.binning.pixel_pointing.nest,
            )
        else:
            msg = f"Cannot determine deprojection map type: {self.deproject_map}"
            raise RuntimeError(msg)
        self._deproject_pattern = re.compile(self.deproject_pattern)
        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # This operator requires everything that its sub-operators needs.
        req = self.binning.requires()
        req["detdata"].append(self.det_data)
        return req

    def _provides(self):
        prov = dict()
        prov["global"] = [self.binning.binned]
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

azimuth = Unicode(defaults.azimuth, allow_none=True, help='Observation shared key for Azimuth') class-attribute instance-attribute

binning = Instance(klass=Operator, allow_none=True, help='Binning operator for map making.') class-attribute instance-attribute

cache_dir = Unicode(None, allow_none=True, help='Cache directory for additive observation matrix products') class-attribute instance-attribute

deproject_map = Unicode(None, allow_none=True, help='Healpix map containing the deprojection templates: intensity map and its derivatives') class-attribute instance-attribute

deproject_map_name = 'deprojection_map' class-attribute instance-attribute

deproject_nnz = Int(1, help='Number of deprojection templates to regress. Must be less than or equal to number of columns in `deproject_map`.') class-attribute instance-attribute

deproject_pattern = Unicode('.*', help='Regular expression to test detector names with. Only matching detectors will be deprojected. Used to identify differenced TOD.') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

filter_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask value for flagging samples that fail filtering') class-attribute instance-attribute

ground_filter_order = Int(5, allow_none=True, help='Order of a Legendre polynomial to fit as a function of azimuth.') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

hwp_filter_order = Int(None, allow_none=True, help='Order of HWP-synchronous signal filter.') class-attribute instance-attribute

keep_final_products = Bool(False, help='If True, keep the map domain products in data after write') class-attribute instance-attribute

leftright_interval = Unicode(defaults.throw_leftright_interval, help='Intervals for left-to-right scans') class-attribute instance-attribute

maskfile = Unicode(None, allow_none=True, help='Optional processing mask') class-attribute instance-attribute

mc_index = Int(None, allow_none=True, help='The Monte-Carlo index') class-attribute instance-attribute

mc_mode = Bool(False, help='If True, re-use solver flags, sparse covariances, etc') class-attribute instance-attribute

noiseweight_obs_matrix = Bool(False, help='If True, observation matrix should match noise-weighted maps') class-attribute instance-attribute

output_dir = Unicode('.', help='Write output data products to this directory') class-attribute instance-attribute

poly_filter_order = Int(1, allow_none=True, help='Polynomial order') class-attribute instance-attribute

poly_filter_view = Unicode('throw', allow_none=True, help='Intervals for polynomial filtering') class-attribute instance-attribute

rcond_threshold = Float(0.001, help='Minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

report_memory = Bool(False, help='Report memory throughout the execution') class-attribute instance-attribute

reset_pix_dist = Bool(False, help='Clear any existing pixel distribution. Useful when applyingrepeatedly to different data objects.') class-attribute instance-attribute

rightleft_interval = Unicode(defaults.throw_rightleft_interval, help='Intervals for right-to-left scans') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional telescope flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

split_ground_template = Bool(False, help='Apply a different template for left and right scans') class-attribute instance-attribute

write_binmap = Bool(False, help='If True, write the unfiltered map') class-attribute instance-attribute

write_cov = Bool(True, help='If True, write the white noise covariance matrices.') class-attribute instance-attribute

write_hdf5 = Bool(False, help='If True, output maps are in HDF5 rather than FITS format.') class-attribute instance-attribute

write_hdf5_serial = Bool(False, help='If True, force serial HDF5 write of output maps.') class-attribute instance-attribute

write_hits = Bool(True, help='If True, write the hits map') class-attribute instance-attribute

write_invcov = Bool(False, help='If True, write the inverse white noise covariance matrices.') class-attribute instance-attribute

write_map = Bool(True, help='If True, write the filtered map') class-attribute instance-attribute

write_noiseweighted_binmap = Bool(False, help='If True, write the noise-weighted unfiltered map') class-attribute instance-attribute

write_noiseweighted_map = Bool(False, help='If True, write the noise-weighted filtered map') class-attribute instance-attribute

write_obs_matrix = Bool(False, help='Write the observation matrix') class-attribute instance-attribute

write_rcond = Bool(True, help='If True, write the reciprocal condition numbers.') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/filterbin.py
460
461
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_accumulate_observation_matrix(obs, det, pixels, weights, good_fit, good_bin, det_templates, template_covariance)

Calculate P^T N^-1 Z P This part of the covariance calculation is cumulative: each observation and detector is computed independently and can be cached.

Observe that N in this equation need not be the same used in template covariance in Z.

Source code in toast/ops/filterbin.py
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
@function_timer
def _accumulate_observation_matrix(
    self,
    obs,
    det,
    pixels,
    weights,
    good_fit,
    good_bin,
    det_templates,
    template_covariance,
):
    """Calculate P^T N^-1 Z P
    This part of the covariance calculation is cumulative: each observation
    and detector is computed independently and can be cached.

    Observe that `N` in this equation need not be the same used in
    template covariance in `Z`.
    """
    if not self.write_obs_matrix:
        return
    log = Logger.get()
    templates = det_templates.to_dense(good_fit.size)
    fname_cache = None
    local_obs_matrix = None
    t1 = time()
    if self.cache_dir is not None:
        cache_dir = os.path.join(self.cache_dir, obs.name)
        os.makedirs(cache_dir, exist_ok=True)
        fname_cache = os.path.join(cache_dir, det)
        try:
            mm_data = np.load(fname_cache + ".data.npy")
            mm_indices = np.load(fname_cache + ".indices.npy")
            mm_indptr = np.load(fname_cache + ".indptr.npy")
            local_obs_matrix = scipy.sparse.csr_matrix(
                (mm_data, mm_indices, mm_indptr),
                self.obs_matrix.shape,
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     loaded cached matrix from "
                    f"{fname_cache}* in {time() - t1:.2f} s",
                )
                t1 = time()
        except:
            local_obs_matrix = None
    else:
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin:     cache_dir = {self.cache_dir}"
            )

    if local_obs_matrix is None:
        templates = templates.T.copy()
        good_any = np.logical_or(good_fit, good_bin)

        # Temporarily compress pixels
        if self.grank == 0:
            log.debug(f"{self.group:4} : FilterBin:     Compressing pixels")
        c_pixels, c_npix, local_to_global = self._compress_pixels(
            pixels[good_any].copy()
        )
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: Compressed in {time() - t1:.2f} s",
            )
            t1 = time()
        c_npixtot = c_npix * self.nnz
        c_obs_matrix = np.zeros([c_npixtot, c_npixtot])
        if self.grank == 0:
            log.debug(f"{self.group:4} : FilterBin:     Accumulating")
        accumulate_observation_matrix(
            c_obs_matrix,
            c_pixels,
            weights[good_any].copy(),
            templates[good_any].copy(),
            template_covariance,
            good_fit[good_any].astype(np.uint8),
            good_bin[good_any].astype(np.uint8),
        )
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin:     Accumulated in {time() - t1:.2f} s"
            )
            log.debug(
                f"{self.group:4} : FilterBin:     Expanding local to global",
            )
            t1 = time()
        # expand to global pixel numbers
        local_obs_matrix = self._expand_matrix(c_obs_matrix, local_to_global)
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin:     Expanded in {time() - t1:.2f} s"
            )
            t1 = time()

        if fname_cache is not None:
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     Caching to {fname_cache}*",
                )
            np.save(fname_cache + ".data", local_obs_matrix.data)
            np.save(fname_cache + ".indices", local_obs_matrix.indices)
            np.save(fname_cache + ".indptr", local_obs_matrix.indptr)
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     cached in {time() - t1:.2f} s",
                )
                t1 = time()
        else:
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:     NOT caching detector matrix",
                )

    if self.grank == 0:
        log.debug(f"{self.group:4} : FilterBin:     Adding to global")
    detweight = obs[self.binning.noise_model].detector_weight(det)
    self._add_matrix(local_obs_matrix, detweight)
    if self.grank == 0:
        log.debug(
            f"{self.group:4} : FilterBin:     Added in {time() - t1:.2f} s",
        )
    return

_add_deprojection_templates(data, obs, pixels, templates)

Source code in toast/ops/filterbin.py
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
@function_timer
def _add_deprojection_templates(self, data, obs, pixels, templates):
    deproject_map = data[self.deproject_map_name]
    map_dist = deproject_map.distribution
    local_sm, local_pix = map_dist.global_pixel_to_submap(pixels)

    if deproject_map.dtype.char == "d":
        scan_map = scan_map_float64
    elif deproject_map.dtype.char == "f":
        scan_map = scan_map_float32
    else:
        raise RuntimeError("Deprojection supports only float32 and float64 maps")

    nsample = pixels.size
    nnz = self._deproject_nnz
    weights = np.zeros([nsample, nnz], dtype=np.float64)
    dptemplate_raw = AlignedF64.zeros(nsample)
    dptemplate = dptemplate_raw.array()
    norm = np.dot(common_templates[0], common_templates[0])
    for inz in range(self._deproject_nnz):
        weights[:] = 0
        weights[:, inz] = 1
        scan_map(
            deproject_map.distribution.n_pix_submap,
            deproject_map.n_value,
            local_sm.astype(np.int64),
            local_pix.astype(np.int64),
            deproject_map.raw,
            weights.reshape(-1),
            template,
        )
        dptemplate *= np.sqrt(norm / np.dot(dptemplate, dptemplate))
        templates.append(dptemplate)
    return

_add_ground_templates(obs, templates)

Source code in toast/ops/filterbin.py
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
@function_timer
def _add_ground_templates(self, obs, templates):
    if self.ground_filter_order is None:
        return

    # To avoid template degeneracies, ground filter only includes
    # polynomial orders not present in the polynomial filter

    phase = self._get_phase(obs)
    shared_flags = np.array(obs.shared[self.shared_flags])

    min_order = 0
    if self.poly_filter_order is not None:
        min_order = self.poly_filter_order + 1
    max_order = self.ground_filter_order
    nfilter = max_order - min_order + 1
    if nfilter < 1:
        return

    legendre_templates = np.zeros([nfilter, phase.size])
    legendre(phase, legendre_templates, min_order, max_order + 1)
    if not self.split_ground_template:
        legendre_filter = legendre_templates
    else:
        # Separate ground filter by scan direction.
        legendre_filter = []
        masks = []
        for name in self.leftright_interval, self.rightleft_interval:
            mask = np.zeros(phase.size, dtype=bool)
            for ival in obs.intervals[name]:
                mask[ival.first : ival.last] = True
            masks.append(mask)
        for template in legendre_templates:
            for mask in masks:
                temp = template.copy()
                temp[mask] = 0
                legendre_filter.append(temp)
        legendre_filter = np.vstack(legendre_filter)

    templates.append(legendre_filter)

    return

_add_hwp_templates(obs, templates)

Source code in toast/ops/filterbin.py
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
@function_timer
def _add_hwp_templates(self, obs, templates):
    if self.hwp_filter_order is None:
        return

    if self.hwp_angle not in obs.shared:
        msg = (
            f"Cannot apply HWP filtering at order = {self.hwp_filter_order}: "
            f"no HWP angle found under key = '{self.hwp_angle}'"
        )
        raise RuntimeError(msg)
    hwp_angle = obs.shared[self.hwp_angle].data
    shared_flags = np.array(obs.shared[self.shared_flags])

    nfilter = 2 * self.hwp_filter_order
    if nfilter < 1:
        return

    fourier_templates = np.zeros([nfilter, hwp_angle.size])
    fourier(hwp_angle, fourier_templates, 1, self.hwp_filter_order + 1)

    templates.append(fourier_templates)

    return

_add_matrix(local_obs_matrix, detweight)

Add the local (per detector) observation matrix to the full matrix

Source code in toast/ops/filterbin.py
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
@function_timer
def _add_matrix(self, local_obs_matrix, detweight):
    """Add the local (per detector) observation matrix to the full
    matrix
    """
    log = Logger.get()
    t1 = time()
    if False:
        # Use scipy sparse implementation
        self.obs_matrix += local_obs_matrix * detweight
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: Add and construct matrix "
                f"in {time() - t1:.2f} s",
            )
    else:
        # Use our own compiled kernel
        n = self.obs_matrix.nnz + local_obs_matrix.nnz
        data = np.zeros(n, dtype=np.float64)
        indices = np.zeros(n, dtype=np.int64)
        indptr = np.zeros(self.npixtot + 1, dtype=np.int64)
        add_matrix(
            self.obs_matrix.data,
            self.obs_matrix.indices,
            self.obs_matrix.indptr,
            local_obs_matrix.data * detweight,
            local_obs_matrix.indices,
            local_obs_matrix.indptr,
            data,
            indices,
            indptr,
        )
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: Add matrix "
                f"in {time() - t1:.2f} s",
            )
        t1 = time()
        n = indptr[-1]
        self.obs_matrix = scipy.sparse.csr_matrix(
            (data[:n], indices[:n], indptr),
            shape=(self.npixtot, self.npixtot),
        )
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: construct CSR matrix "
                f"in {time() - t1:.2f} s",
            )
    return

_add_poly_templates(obs, templates)

Source code in toast/ops/filterbin.py
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
@function_timer
def _add_poly_templates(self, obs, templates):
    if self.poly_filter_order is None:
        return
    nfilter = self.poly_filter_order + 1
    intervals = obs.intervals[self.poly_filter_view]
    if self.shared_flags is None:
        shared_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)
    else:
        shared_flags = np.array(obs.shared[self.shared_flags])
    bad = (shared_flags & self.shared_flag_mask) != 0

    for ival in intervals:
        istart = ival.first
        istop = ival.last
        # Trim flagged samples from both ends
        while istart < istop and bad[istart]:
            istart += 1
        while istop - 1 > istart and bad[istop - 1]:
            istop -= 1
        if istop - istart < nfilter:
            # Not enough samples to filter, flag this interval
            shared_flags[ival.first : ival.last] |= self.filter_flag_mask
            continue
        wbin = 2 / (istop - istart)
        phase = (np.arange(istop - istart) + 0.5) * wbin - 1
        legendre_templates = np.zeros([nfilter, phase.size])
        legendre(phase, legendre_templates, 0, nfilter)
        templates.append(legendre_templates, start=istart, stop=istop)

    if self.shared_flags is not None:
        obs.shared[self.shared_flags].set(shared_flags, offset=(0,), fromrank=0)

    return

_bin_map(data, detectors, filtered)

Bin the signal onto a map. Optionally write out hits and white noise covariance matrices.

Source code in toast/ops/filterbin.py
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
@function_timer
def _bin_map(self, data, detectors, filtered):
    """Bin the signal onto a map.  Optionally write out hits and
    white noise covariance matrices.
    """

    log = Logger.get()
    timer = Timer()
    timer.start()

    hits_name = f"{self.name}_hits"
    invcov_name = f"{self.name}_invcov"
    cov_name = f"{self.name}_cov"
    rcond_name = f"{self.name}_rcond"
    if filtered:
        map_name = f"{self.name}_filtered_map"
        noiseweighted_map_name = f"{self.name}_noiseweighted_filtered_map"
    else:
        map_name = f"{self.name}_unfiltered_map"
        noiseweighted_map_name = f"{self.name}_noiseweighted_unfiltered_map"

    self.binning.noiseweighted = noiseweighted_map_name
    self.binning.binned = map_name
    self.binning.det_data = self.det_data
    self.binning.det_data_units = self._det_data_units
    self.binning.covariance = cov_name

    if self.binning.covariance not in data:
        cov = CovarianceAndHits(
            pixel_dist=self.binning.pixel_dist,
            covariance=self.binning.covariance,
            inverse_covariance=invcov_name,
            hits=hits_name,
            rcond=rcond_name,
            det_mask=self.binning.det_mask,
            det_flags=self.binning.det_flags,
            det_flag_mask=self.binning.det_flag_mask,
            det_data_units=self._det_data_units,
            shared_flags=self.binning.shared_flags,
            shared_flag_mask=self.binning.shared_flag_mask,
            pixel_pointing=self.binning.pixel_pointing,
            stokes_weights=self.binning.stokes_weights,
            noise_model=self.binning.noise_model,
            rcond_threshold=self.rcond_threshold,
            sync_type=self.binning.sync_type,
            save_pointing=self.binning.full_pointing,
        )
        cov.apply(data, detectors=detectors)
        log.info_rank(f"Binned covariance and hits in", comm=self.comm, timer=timer)

    self.binning.apply(data, detectors=detectors)
    log.info_rank(f"Binned signal in", comm=self.comm, timer=timer)

    mc_root = self.name
    if self.mc_mode:
        if self.mc_root is not None:
            mc_root += f"_{self.mc_root}"
        if self.mc_index is not None:
            mc_root += f"_{self.mc_index:05d}"

    binned = not filtered  # only write hits and covariance once
    if binned:
        write_map = self.write_binmap
        write_noiseweighted_map = self.write_noiseweighted_binmap
    else:
        write_map = self.write_map
        write_noiseweighted_map = self.write_noiseweighted_map
    keep_final = self.keep_final_products
    keep_cov = self.keep_final_products or self.write_obs_matrix
    for key, write, keep, force, rootname in [
        (hits_name, self.write_hits and binned, keep_final, False, self.name),
        (rcond_name, self.write_rcond and binned, keep_final, False, self.name),
        (
            noiseweighted_map_name,
            write_noiseweighted_map,
            keep_final,
            True,
            mc_root,
        ),
        (map_name, write_map, keep_final, True, mc_root),
        (invcov_name, self.write_invcov and binned, keep_final, False, self.name),
        (cov_name, self.write_cov and binned, keep_cov, False, self.name),
    ]:
        if write:
            product = key.replace(f"{self.name}_", "")
            try:
                if hasattr(self.binning.pixel_pointing, "wcs"):
                    # WCS pixelization
                    fname = os.path.join(
                        self.output_dir, f"{rootname}_{product}.fits"
                    )
                    if self.mc_mode and not force:
                        if os.path.isfile(fname):
                            log.info_rank(
                                f"Skipping existing file: {fname}", comm=self.comm
                            )
                            continue
                    write_wcs_fits(data[key], fname)
                else:
                    if self.write_hdf5:
                        # Non-standard HEALPix HDF5 output
                        fname = os.path.join(
                            self.output_dir, f"{rootname}_{product}.h5"
                        )
                        if self.mc_mode and not force:
                            if os.path.isfile(fname):
                                log.info_rank(
                                    f"Skipping existing file: {fname}",
                                    comm=self.comm,
                                )
                                continue
                        write_healpix_hdf5(
                            data[key],
                            fname,
                            nest=self.binning.pixel_pointing.nest,
                            force_serial=self.write_hdf5_serial,
                        )
                    else:
                        # Standard HEALPix FITS output
                        fname = os.path.join(
                            self.output_dir, f"{rootname}_{product}.fits"
                        )
                        if self.mc_mode and not force:
                            if os.path.isfile(fname):
                                log.info_rank(
                                    f"Skipping existing file: {fname}",
                                    comm=self.comm,
                                )
                                continue
                        write_healpix_fits(
                            data[key], fname, nest=self.binning.pixel_pointing.nest
                        )
            except Exception as e:
                msg = f"ERROR: failed to write {fname} : {e}"
                raise RuntimeError(msg)
            log.info_rank(f"Wrote {fname} in", comm=self.comm, timer=timer)
        if not keep and not self.mc_mode:
            if key in data:
                data[key].clear()
                del data[key]

    return

_build_common_templates(obs)

Source code in toast/ops/filterbin.py
843
844
845
846
847
848
849
850
851
@function_timer
def _build_common_templates(self, obs):
    templates = SparseTemplates()

    self._add_hwp_templates(obs, templates)
    self._add_ground_templates(obs, templates)
    self._add_poly_templates(obs, templates)

    return templates

_build_template_covariance(templates, good)

Calculate (F^T N^-1_F F)^-1

Observe that the sample noise weights in N^-1_F need not be the same as in binning the filtered signal. For instance, samples falling on point sources may be masked here but included in the final map.

Source code in toast/ops/filterbin.py
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
@function_timer
def _build_template_covariance(self, templates, good):
    """Calculate (F^T N^-1_F F)^-1

    Observe that the sample noise weights in N^-1_F need not be the
    same as in binning the filtered signal.  For instance, samples
    falling on point sources may be masked here but included in the
    final map.
    """
    log = Logger.get()
    ntemplate = templates.ntemplate
    invcov = np.zeros([ntemplate, ntemplate])
    build_template_covariance(
        templates.starts,
        templates.stops,
        templates.templates,
        good.astype(np.float64),
        invcov,
    )
    try:
        rcond = 1 / np.linalg.cond(invcov)
    except np.linalg.LinAlgError:
        print(
            f"Failed condition number calculation for {ntemplate}x{ntemplate} matrix:"
        )
        print(f"{invcov}", flush=True)
        print(f"Diagonal:")
        for row in range(ntemplate):
            print(f"{row:03d} {invcov[row, row]}")
        raise
    if self.grank == 0:
        log.debug(
            f"{self.group:4} : FilterBin: Template covariance matrix "
            f"rcond = {rcond}",
        )
    if rcond > 1e-6:
        cov = np.linalg.inv(invcov)
    else:
        log.warning(
            f"{self.group:4} : FilterBin: WARNING: template covariance matrix "
            f"is poorly conditioned: "
            f"rcond = {rcond}.  Using matrix pseudoinverse.",
        )
        cov = np.linalg.pinv(invcov, rcond=1e-10, hermitian=True)

    return cov

_check_binning(proposal)

Source code in toast/ops/filterbin.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
@traitlets.validate("binning")
def _check_binning(self, proposal):
    bin = proposal["value"]
    if bin is not None:
        if not isinstance(bin, Operator):
            raise traitlets.TraitError("binning should be an Operator instance")
        # Check that this operator has the traits we require
        for trt in [
            "det_data",
            "pixel_dist",
            "pixel_pointing",
            "stokes_weights",
            "binned",
            "covariance",
            "det_flags",
            "det_flag_mask",
            "shared_flags",
            "shared_flag_mask",
            "noise_model",
            "full_pointing",
            "sync_type",
        ]:
            if not bin.has_trait(trt):
                msg = "binning operator should have a '{}' trait".format(trt)
                raise traitlets.TraitError(msg)
    return bin

_check_det_flag_mask(proposal)

Source code in toast/ops/filterbin.py
419
420
421
422
423
424
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/filterbin.py
412
413
414
415
416
417
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_shared_mask(proposal)

Source code in toast/ops/filterbin.py
426
427
428
429
430
431
@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_collect_obs_matrix()

Source code in toast/ops/filterbin.py
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
@function_timer
def _collect_obs_matrix(self):
    if not self.write_obs_matrix:
        return
    # Combine the observation matrix across processes
    # Reduce the observation matrices.  We use the buffer protocol
    # for better performance, even though it requires more MPI calls
    # than sending the sparse matrix objects directly
    log = Logger.get()
    timer = Timer()
    timer.start()
    nrow_tot = self.npixtot
    nslice = 128
    nrow_write = nrow_tot // nslice
    for islice, row_start in enumerate(range(0, nrow_tot, nrow_write)):
        row_stop = row_start + nrow_write
        obs_matrix_slice = self.obs_matrix[row_start:row_stop]
        nnz = obs_matrix_slice.nnz
        if self.comm is not None:
            nnz = self.comm.allreduce(nnz)
        if nnz == 0:
            log.debug_rank(
                f"Slice {islice+1:5} / {nslice}: {row_start:12} - {row_stop:12} "
                f"is empty.  Skipping.",
                comm=self.comm,
            )
            continue
        log.debug_rank(
            f"Collecting slice {islice+1:5} / {nslice} : {row_start:12} - "
            f"{row_stop:12}",
            comm=self.comm,
        )

        factor = 1
        while factor < self.ntask:
            log.debug_rank(
                f"FilterBin: Collecting {2 * factor} / {self.ntask}",
                comm=self.comm,
            )
            if self.rank % (factor * 2) == 0:
                # this task receives
                receive_from = self.rank + factor
                if receive_from < self.ntask:
                    size_recv = self.comm.recv(source=receive_from, tag=factor)
                    data_recv = np.zeros(size_recv, dtype=np.float64)
                    self.comm.Recv(
                        data_recv, source=receive_from, tag=factor + self.ntask
                    )
                    indices_recv = np.zeros(size_recv, dtype=np.int64)
                    self.comm.Recv(
                        indices_recv,
                        source=receive_from,
                        tag=factor + 2 * self.ntask,
                    )
                    indptr_recv = np.zeros(
                        obs_matrix_slice.indptr.size, dtype=np.int64
                    )
                    self.comm.Recv(
                        indptr_recv,
                        source=receive_from,
                        tag=factor + 3 * self.ntask,
                    )
                    obs_matrix_slice += scipy.sparse.csr_matrix(
                        (data_recv, indices_recv, indptr_recv),
                        obs_matrix_slice.shape,
                    )
                    del data_recv, indices_recv, indptr_recv
            elif self.rank % (factor * 2) == factor:
                # this task sends
                send_to = self.rank - factor
                self.comm.send(obs_matrix_slice.data.size, dest=send_to, tag=factor)
                self.comm.Send(
                    obs_matrix_slice.data, dest=send_to, tag=factor + self.ntask
                )
                self.comm.Send(
                    obs_matrix_slice.indices.astype(np.int64),
                    dest=send_to,
                    tag=factor + 2 * self.ntask,
                )
                self.comm.Send(
                    obs_matrix_slice.indptr.astype(np.int64),
                    dest=send_to,
                    tag=factor + 3 * self.ntask,
                )

            if self.comm is not None:
                self.comm.Barrier()
            log.debug_rank("FilterBin: Collected in", comm=self.comm, timer=timer)
            factor *= 2

        # Write out the observation matrix
        if self.noiseweight_obs_matrix:
            fname = os.path.join(
                self.output_dir, f"{self.name}_noiseweighted_obs_matrix"
            )
        else:
            fname = os.path.join(self.output_dir, f"{self.name}_obs_matrix")
        fname += f".{row_start:012}.{row_stop:012}.{nrow_tot:012}"
        log.debug_rank(
            f"FilterBin: Writing observation matrix to {fname}.npz",
            comm=self.comm,
        )
        if self.rank == 0:
            if True:
                # Write out the members of the CSR matrix separately because
                # scipy.sparse.save_npz is so inefficient
                np.save(f"{fname}.data", obs_matrix_slice.data)
                np.save(f"{fname}.indices", obs_matrix_slice.indices)
                np.save(f"{fname}.indptr", obs_matrix_slice.indptr)
            else:
                scipy.sparse.save_npz(fname, obs_matrix_slice)
        log.info_rank(
            f"FilterBin: Wrote observation matrix to {fname} in",
            comm=self.comm,
            timer=timer,
        )
    # After writing we are done
    del self.obs_matrix
    self.obs_matrix = None
    return

_compress_pixels(pixels)

Source code in toast/ops/filterbin.py
947
948
949
950
951
952
953
954
955
956
957
@function_timer
def _compress_pixels(self, pixels):
    if any(pixels < 0):
        msg = f"Unflagged samples have {np.sum(pixels < 0)} negative pixel numbers"
        raise RuntimeError(msg)
    if any(pixels >= self.npix):
        msg = f"Unflagged samples have {np.sum(pixels >= self.npix)} pixels >= {self.npix}"
        raise RuntimeError(msg)
    local_to_global = np.sort(list(set(pixels)))
    compressed_pixels = np.searchsorted(local_to_global, pixels)
    return compressed_pixels, local_to_global.size, local_to_global

_deweight_obs_matrix(data)

Apply (P^T N^-1 P)^-1 to the cumulative part of the observation matrix, P^T N^-1 Z P.

Source code in toast/ops/filterbin.py
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
@function_timer
def _deweight_obs_matrix(self, data):
    """Apply (P^T N^-1 P)^-1 to the cumulative part of the
    observation matrix, P^T N^-1 Z P.
    """
    if not self.write_obs_matrix:
        return
    # Apply the white noise covariance to the observation matrix
    white_noise_cov = data[self.binning.covariance]
    cc = scipy.sparse.dok_matrix((self.npixtot, self.npixtot), dtype=np.float64)
    nsubmap = white_noise_cov.distribution.n_submap
    npix_submap = white_noise_cov.distribution.n_pix_submap
    for isubmap_local, isubmap_global in enumerate(
        white_noise_cov.distribution.local_submaps
    ):
        submap = white_noise_cov.data[isubmap_local]
        offset = isubmap_global * npix_submap
        for pix_local in range(npix_submap):
            if np.all(submap[pix_local] == 0):
                continue
            pix = pix_local + offset
            icov = 0
            for inz in range(self.nnz):
                for jnz in range(inz, self.nnz):
                    cc[pix + inz * self.npix, pix + jnz * self.npix] = submap[
                        pix_local, icov
                    ]
                    if inz != jnz:
                        cc[pix + jnz * self.npix, pix + inz * self.npix] = submap[
                            pix_local, icov
                        ]
                    icov += 1
    cc = cc.tocsr()
    self.obs_matrix = cc.dot(self.obs_matrix)
    return

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/filterbin.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()

    timer = Timer()
    timer.start()

    memreport = MemoryCounter()
    if not self.report_memory:
        memreport.enabled = False

    memreport.prefix = "Start of mapmaking"
    memreport.apply(data)

    for trait in ("binning",):
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    # Optionally destroy existing pixel distributions (useful if calling
    # repeatedly with different data objects)

    binning = self.binning
    if self.reset_pix_dist:
        if binning.pixel_dist in data:
            del data[binning.pixel_dist]
        if binning.covariance in data:
            # Cannot trust earlier covariance
            del data[binning.covariance]

    if binning.pixel_dist not in data:
        pix_dist = BuildPixelDistribution(
            pixel_dist=binning.pixel_dist,
            pixel_pointing=binning.pixel_pointing,
            shared_flags=binning.shared_flags,
            shared_flag_mask=binning.shared_flag_mask,
        )
        pix_dist.apply(data)
        log.debug_rank(
            "Cached pixel distribution in", comm=data.comm.comm_world, timer=timer
        )

    self.npix = data[binning.pixel_dist].n_pix
    self.nnz = len(self.binning.stokes_weights.mode)

    self.npixtot = self.npix * self.nnz
    self.ncov = self.nnz * (self.nnz + 1) // 2

    if self.maskfile is not None:
        raise RuntimeError("Filtering mask not yet implemented")

    log.debug_rank(
        f"FilterBin:  Running with self.cache_dir = {self.cache_dir}",
        comm=data.comm.comm_world,
    )

    # Get the units used across the distributed data for our desired
    # input detector data
    self._det_data_units = data.detector_units(self.det_data)

    self._initialize_comm(data)

    # Filter data

    self._initialize_obs_matrix()
    log.debug_rank(
        "FilterBin: Initialized observation_matrix in",
        comm=self.comm,
        timer=timer,
    )

    self._load_deprojection_map(data)
    log.debug_rank(
        "FilterBin: Loaded deprojection map in", comm=self.comm, timer=timer
    )

    self._bin_map(data, detectors, filtered=False)
    log.debug_rank(
        "FilterBin: Binned unfiltered map in", comm=self.comm, timer=timer
    )

    log.debug_rank("FilterBin: Filtering signal", comm=self.comm)

    timer1 = Timer()
    timer1.start()
    timer2 = Timer()
    timer2.start()

    memreport.prefix = "Before filtering"
    memreport.apply(data)

    t1 = time()
    for iobs, obs in enumerate(data.obs):
        dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin: Processing observation "
                f"{iobs} / {len(data.obs)}",
            )

        common_templates = self._build_common_templates(obs)
        if self.shared_flags is not None:
            common_flags = obs.shared[self.shared_flags].data
        else:
            common_flags = np.zeros(obs.n_local_samples, dtype=np.uint8)

        if self.grank == 0:
            log.debug(
                f"{self.group:4} : FilterBin:   Built common templates in "
                f"{time() - t1:.2f} s",
            )
            t1 = time()

        memreport.prefix = "After common templates"
        memreport.apply(data)

        last_good_fit = None
        template_covariance = None

        for idet, det in enumerate(dets):
            t1 = time()
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:   Processing detector "
                    f"# {idet + 1} / {len(dets)}",
                )

            signal = obs.detdata[self.det_data][det]
            flags = obs.detdata[self.det_flags][det]
            # `good` is essentially the diagonal noise matrix used in
            # template regression.  All good detector samples have the
            # same noise weight and rest have zero weight.
            good_fit = np.logical_and(
                (common_flags & self.shared_flag_mask) == 0,
                (flags & self.det_flag_mask) == 0,
            )
            good_bin = np.logical_and(
                (common_flags & self.binning.shared_flag_mask) == 0,
                (flags & self.binning.det_flag_mask) == 0,
            )

            if np.sum(good_fit) == 0:
                continue

            deproject = (
                self.deproject_map is not None
                and self._deproject_pattern.match(det) is not None
            )

            if deproject or self.write_obs_matrix:
                # We'll need pixel numbers
                obs_data = data.select(obs_uid=obs.uid)
                self.binning.pixel_pointing.apply(obs_data, detectors=[det])
                pixels = obs.detdata[self.binning.pixel_pointing.pixels][det]
                # and weights
                self.binning.stokes_weights.apply(obs_data, detectors=[det])
                weights = obs.detdata[self.binning.stokes_weights.weights][det]
            else:
                pixels = None
                weights = None

            det_templates = common_templates.mask(good_fit)

            if (
                self.deproject_map is not None
                and self._deproject_pattern.match(det) is not None
            ):
                self._add_deprojection_templates(data, obs, pixels, det_templates)
                # Must re-evaluate the template covariance
                template_covariance = None

            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:   Built deprojection "
                    f"templates in {time() - t1:.2f} s. "
                    f"ntemplate = {det_templates.ntemplate}",
                )
                t1 = time()

            if det_templates.ntemplate == 0:
                # No templates to fit
                continue

            # memreport.prefix = "After detector templates"
            # memreport.apply(data)

            if template_covariance is None or np.any(last_good_fit != good_fit):
                template_covariance = self._build_template_covariance(
                    det_templates, good_fit
                )
                last_good_fit = good_fit.copy()

            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:   Built template covariance "
                    f"{time() - t1:.2f} s",
                )
                t1 = time()

            self._regress_templates(
                det_templates, template_covariance, signal, good_fit
            )
            if self.grank == 0:
                log.debug(
                    f"{self.group:4} : FilterBin:   Regressed templates in "
                    f"{time() - t1:.2f} s",
                )
                t1 = time()

            self._accumulate_observation_matrix(
                obs,
                det,
                pixels,
                weights,
                good_fit,
                good_bin,
                det_templates,
                template_covariance,
            )

    log.debug_rank(
        f"{self.group:4} : FilterBin:   Filtered group data in",
        comm=self.gcomm,
        timer=timer1,
    )

    if self.comm is not None:
        self.comm.Barrier()

    log.info_rank(
        f"FilterBin:   Filtered data in",
        comm=self.comm,
        timer=timer2,
    )

    memreport.prefix = "After filtering"
    memreport.apply(data)

    # Bin filtered signal

    self._bin_map(data, detectors, filtered=True)
    log.debug_rank("FilterBin: Binned filtered map in", comm=self.comm, timer=timer)

    log.info_rank(
        f"FilterBin:   Binned data in",
        comm=self.comm,
        timer=timer2,
    )

    memreport.prefix = "After binning"
    memreport.apply(data)

    if self.write_obs_matrix:
        if not self.noiseweight_obs_matrix:
            log.debug_rank(
                "FilterBin: De-weighting observation matrix", comm=self.comm
            )
            self._deweight_obs_matrix(data)
            log.debug_rank(
                "FilterBin: De-weighted observation_matrix in",
                comm=self.comm,
                timer=timer2,
            )

        log.info_rank("FilterBin: Collecting observation matrix", comm=self.comm)
        self._collect_obs_matrix()
        log.info_rank(
            "FilterBin: Collected observation_matrix in",
            comm=self.comm,
            timer=timer2,
        )

        memreport.prefix = "After observation matrix"
        memreport.apply(data)

    return

_expand_matrix(compressed_matrix, local_to_global)

Expands a dense, compressed matrix into a sparse matrix with global indexing

Source code in toast/ops/filterbin.py
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
def _expand_matrix(self, compressed_matrix, local_to_global):
    """Expands a dense, compressed matrix into a sparse matrix with
    global indexing
    """
    n = compressed_matrix.size
    values = np.zeros(n, dtype=np.float64)
    indices = np.zeros(n, dtype=np.int64)
    indptr = np.zeros(self.npixtot + 1, dtype=np.int64)
    expand_matrix(
        compressed_matrix,
        local_to_global,
        self.npix,
        self.nnz,
        values,
        indices,
        indptr,
    )
    nnz = indptr[-1]

    sparse_matrix = scipy.sparse.csr_matrix(
        (values[:nnz], indices[:nnz], indptr),
        shape=(self.npixtot, self.npixtot),
    )
    return sparse_matrix

_finalize(data, **kwargs)

Source code in toast/ops/filterbin.py
1533
1534
def _finalize(self, data, **kwargs):
    return

_get_phase(obs)

Source code in toast/ops/filterbin.py
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
@function_timer
def _get_phase(self, obs):
    if self.ground_filter_order is None:
        return None
    try:
        azmin = obs["scan_min_az"].to_value(u.radian)
        azmax = obs["scan_max_az"].to_value(u.radian)
        if self.azimuth is not None:
            az = obs.shared[self.azimuth]
        else:
            quats = obs.shared[self.boresight_azel]
            theta, phi, _ = qa.to_iso_angles(quats)
            az = 2 * np.pi - phi
    except Exception as e:
        msg = (
            f"Failed to get boresight azimuth from TOD.  "
            f"Perhaps it is not ground TOD? '{e}'"
        )
        raise RuntimeError(msg)
    phase = (np.unwrap(az) - azmin) / (azmax - azmin) * 2 - 1

    return phase

_initialize_comm(data)

Create convenience aliases to the communicators and properties.

Source code in toast/ops/filterbin.py
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
@function_timer
def _initialize_comm(self, data):
    """Create convenience aliases to the communicators and properties."""
    self.comm = data.comm.comm_world
    self.rank = data.comm.world_rank
    self.ntask = data.comm.world_size
    self.gcomm = data.comm.comm_group
    self.group = data.comm.group
    self.grank = data.comm.group_rank
    return

_initialize_obs_matrix()

Source code in toast/ops/filterbin.py
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
@function_timer
def _initialize_obs_matrix(self):
    if self.write_obs_matrix:
        self.obs_matrix = scipy.sparse.csr_matrix(
            (self.npixtot, self.npixtot), dtype=np.float64
        )
        if self.rank == 0 and self.cache_dir is not None:
            os.makedirs(self.cache_dir, exist_ok=True)
    else:
        self.obs_matrix = None
    return

_load_deprojection_map(data)

Source code in toast/ops/filterbin.py
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
@function_timer
def _load_deprojection_map(self, data):
    if self.deproject_map is None:
        return None
    data[self.deproject_map_name] = PixelData(
        data[self.binning.pixel_dist],
        dtype=np.float32,
        n_value=self.deproject_nnz,
        units=self._det_data_units,
    )
    if filename_is_hdf5(self.deproject_map):
        read_healpix_hdf5(
            data[self.deproject_map_name],
            self.deproject_map,
            nest=self.binning.pixel_pointing.nest,
        )
    elif filename_is_fits(self.deproject_map):
        read_healpix_fits(
            data[self.deproject_map_name],
            self.deproject_map,
            nest=self.binning.pixel_pointing.nest,
        )
    else:
        msg = f"Cannot determine deprojection map type: {self.deproject_map}"
        raise RuntimeError(msg)
    self._deproject_pattern = re.compile(self.deproject_pattern)
    return

_provides()

Source code in toast/ops/filterbin.py
1542
1543
1544
1545
def _provides(self):
    prov = dict()
    prov["global"] = [self.binning.binned]
    return prov

_regress_templates(templates, template_covariance, signal, good)

Calculate Zd = (I - F(F^T N^-1_F F)^-1 F^T N^-1_F)d

All samples that are not flagged (zero weight in N^-1_F) have equal weight.

Source code in toast/ops/filterbin.py
935
936
937
938
939
940
941
942
943
944
945
@function_timer
def _regress_templates(self, templates, template_covariance, signal, good):
    """Calculate Zd = (I - F(F^T N^-1_F F)^-1 F^T N^-1_F)d

    All samples that are not flagged (zero weight in N^-1_F) have
    equal weight.
    """
    proj = templates.dot(signal * good)
    amplitudes = np.dot(template_covariance, proj)
    templates.subtract(signal, amplitudes)
    return

_requires()

Source code in toast/ops/filterbin.py
1536
1537
1538
1539
1540
def _requires(self):
    # This operator requires everything that its sub-operators needs.
    req = self.binning.requires()
    req["detdata"].append(self.det_data)
    return req

toast.ops.combine_observation_matrix(rootname)

Combine slices of the observation matrix into a single scipy sparse matrix file

Parameters:

Name Type Description Default
rootname str)

rootname of the matrix slices. Typically {filterbin.output_dir}/{filterbin_name}_obs_matrix.

required

Returns: filename_matrix (str) : Name of the composed matrix file, {rootname}.npz.

Source code in toast/ops/filterbin.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def combine_observation_matrix(rootname):
    """Combine slices of the observation matrix into a single
    scipy sparse matrix file

    Args:
        rootname (str) : rootname of the matrix slices.  Typically
            `{filterbin.output_dir}/{filterbin_name}_obs_matrix`.
    Returns:
        filename_matrix (str) : Name of the composed matrix file,
            `{rootname}.npz`.
    """

    log = Logger.get()
    timer0 = Timer()
    timer0.start()
    timer = Timer()
    timer.start()

    datafiles = sorted(glob(f"{rootname}.*.*.*.data.npy"))
    if len(datafiles) == 0:
        msg = f"No files match {rootname}.*.*.*.data.npy"
        raise RuntimeError(msg)

    all_data = []
    all_indices = []
    all_indptr = [0]

    current_row = 0
    current_offset = 0
    shape = None

    log.info(f"Combining observation matrix from {len(datafiles)} input files ...")

    for datafile in datafiles:
        parts = datafile.split(".")
        row_start = int(parts[-5])
        row_stop = int(parts[-4])
        nrow_tot = int(parts[-3])
        if shape is None:
            shape = (nrow_tot, nrow_tot)
        elif shape[0] != nrow_tot:
            raise RuntimeError("Mismatch in shape")
        if current_row != row_start:
            all_indptr.append(np.zeros(row_start - current_row) + current_offset)
            current_row = row_start
        log.info(f"Loading {datafile}")
        data = np.load(datafile)
        indices = np.load(datafile.replace(".data.", ".indices.")).astype(np.int64)
        indptr = np.load(datafile.replace(".data.", ".indptr.")).astype(np.int64)
        all_data.append(data)
        all_indices.append(indices)
        indptr += current_offset
        all_indptr.append(indptr[1:])
        current_row = row_stop
        current_offset = indptr[-1]

    log.info_rank(f"Inputs loaded in", timer=timer, comm=None)

    if current_row != nrow_tot:
        all_indptr.append(np.zeros(nrow_tot - current_row) + current_offset)

    log.info("Constructing CSR matrix ...")

    all_data = np.hstack(all_data)
    all_indices = np.hstack(all_indices)
    all_indptr = np.hstack(all_indptr)
    obs_matrix = scipy.sparse.csr_matrix((all_data, all_indices, all_indptr), shape)
    if obs_matrix.nnz < 0:
        msg = f"Overflow in csr_matrix: nnz = {obs_matrix.nnz}.\n"
        raise RuntimeError(msg)

    log.info_rank(f"Constructed in", timer=timer, comm=None)

    log.info(f"Writing {rootname}.npz ...")
    scipy.sparse.save_npz(rootname, obs_matrix)
    log.info_rank(f"Wrote in", timer=timer, comm=None)

    log.info_rank(f"All done in", timer=timer0, comm=None)

    return f"{rootname}.npz"

toast.ops.coadd_observation_matrix(inmatrix, outmatrix, file_invcov=None, file_cov=None, nside_submap=16, rcond_limit=0.001, double_precision=False, comm=None)

Co-add noise-weighted observation matrices

Parameters:

Name Type Description Default
inmatrix(iterable)

One or more noise-weighted observation matrix files. If a matrix is used to model several similar observations, append +N to the file name to indicate the multiplicity.

required
outmatrix(string)

Name of output file

required
file_invcov(string)

Name of output inverse covariance file

required
file_cov(string)

Name of output covariance file

required
nside_submap(int)

Submap size is 12 * nside_submap ** 2. Number of submaps is (nside / nside_submap) ** 2

required
rcond_limit(float)

"Reciprocal condition number limit

required
double_precision(bool)

Output in double precision

required
Source code in toast/ops/obsmat.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def coadd_observation_matrix(
    inmatrix,
    outmatrix,
    file_invcov=None,
    file_cov=None,
    nside_submap=16,
    rcond_limit=1e-3,
    double_precision=False,
    comm=None,
):
    """Co-add noise-weighted observation matrices

    Args:
        inmatrix(iterable) : One or more noise-weighted observation
            matrix files.  If a matrix is used to model several similar
            observations, append `+N` to the file name to indicate the
             multiplicity.
        outmatrix(string) : Name of output file
        file_invcov(string) : Name of output inverse covariance file
        file_cov(string) : Name of output covariance file
        nside_submap(int) : Submap size is 12 * nside_submap ** 2.
            Number of submaps is (nside / nside_submap) ** 2
        rcond_limit(float) : "Reciprocal condition number limit
        double_precision(bool) : Output in double precision
    """

    log = Logger.get()
    if comm is None:
        comm, ntask, rank = get_world()
    else:
        ntask = comm.size
        rank = comm.rank
    timer0 = Timer()
    timer1 = Timer()
    timer0.start()
    timer1.start()

    if double_precision:
        dtype = np.float64
    else:
        dtype = np.float32

    if len(inmatrix) == 1:
        # Only one file provided, try interpreting it as a text file with a list
        try:
            with open(inmatrix[0], "r") as listfile:
                infiles = listfile.readlines()
            log.info_rank(f"Loaded {inmatrix[0]} in", timer=timer1, comm=comm)
        except UnicodeDecodeError:
            # Didn't work. Assume that user supplied a single matrix file
            infiles = inmatrix
    else:
        infiles = inmatrix

    obs_matrix_sum = None
    invcov_sum = None
    nnz = None
    npix = None

    for ifine, infile_matrix in enumerate(infiles):
        infile_matrix = infile_matrix.strip()
        if "noiseweighted" not in infile_matrix:
            msg = (
                f"Observation matrix does not seem to be "
                f"noise-weighted: '{infile_matrix}'"
            )
            raise RuntimeError(msg)
        if "+" in infile_matrix:
            infile_matrix, N = infile_matrix.split("+")
            N = float(N)
        else:
            N = 1
        if not os.path.isfile(infile_matrix):
            msg = f"Matrix not found: {infile_matrix}"
            raise RuntimeError(msg)
        prefix = ""
        log.info(f"{prefix}Loading {infile_matrix}")
        obs_matrix = ObsMat(infile_matrix)
        if N != 1:
            obs_matrix *= N
        if obs_matrix_sum is None:
            obs_matrix_sum = obs_matrix
        else:
            obs_matrix_sum += obs_matrix
        log.info_rank(f"{prefix}Loaded {infile_matrix} in", timer=timer1, comm=None)

        # We'll need the white noise covariance as well
        infile_invcov = infile_matrix.replace("noiseweighted_obs_matrix.npz", "invcov")
        if os.path.isfile(infile_invcov + ".fits"):
            infile_invcov += ".fits"
        elif os.path.isfile(infile_invcov + ".h5"):
            infile_invcov += ".h5"
        else:
            msg = (
                f"Cannot find an inverse covariance matrix to go "
                "with '{infile_matrix}'"
            )
            raise RuntimeError(msg)
        log.info(f"{prefix}Loading {infile_invcov}")
        invcov = read_healpix(
            infile_invcov, None, nest=True, dtype=float, verbose=False
        )
        if N != 1:
            invcov *= N
        if invcov_sum is None:
            invcov_sum = invcov
            nnzcov, npix = invcov.shape
            nnz = 1
            while (nnz * (nnz + 1)) // 2 != nnzcov:
                nnz += 1
            npixtot = npix * nnz
        else:
            invcov_sum += invcov
        log.info_rank(f"{prefix}Loaded {infile_invcov} in", timer=timer1, comm=None)

    # Put the inverse white noise covariance in a TOAST pixel object

    npix_submap = 12 * nside_submap**2
    nsubmap = npix // npix_submap
    local_submaps = [submap for submap in range(nsubmap) if submap % ntask == rank]
    dist = PixelDistribution(
        n_pix=npix, n_submap=nsubmap, local_submaps=local_submaps, comm=comm
    )
    dist_cov = PixelData(dist, float, n_value=nnzcov)
    for local_submap, global_submap in enumerate(local_submaps):
        pix_start = global_submap * npix_submap
        pix_stop = pix_start + npix_submap
        dist_cov.data[local_submap] = invcov_sum[:, pix_start:pix_stop].T
    del invcov_sum

    # Optionally write out the inverse white noise covariance

    if file_invcov is not None:
        log.info_rank(f"Writing {file_invcov}", comm=comm)
        if filename_is_fits(file_invcov):
            write_healpix_fits(
                dist_cov,
                file_invcov,
                nest=True,
                single_precision=not double_precision,
            )
        else:
            write_healpix_hdf5(
                dist_cov,
                file_invcov,
                nest=True,
                single_precision=not double_precision,
                force_serial=True,
            )
        log.info_rank(f"Wrote {file_invcov}", timer=timer1, comm=comm)

    # Invert the white noise covariance

    log.info_rank("Inverting white noise matrices", comm=comm)
    dist_rcond = PixelData(dist, float, n_value=1)
    covariance_invert(dist_cov, rcond_limit, rcond=dist_rcond, use_alltoallv=True)
    log.info_rank(f"Inverted white noise matrices in", timer=timer1, comm=comm)

    # Optionally write out the white noise covariance

    if file_cov is not None:
        log.info_rank(f"Writing {file_cov}", comm=comm)
        if filename_is_fits(file_cov):
            write_healpix_fits(
                dist_cov,
                file_cov,
                nest=True,
                single_precision=not double_precision,
            )
        else:
            write_healpix_hdf5(
                dist_cov,
                file_cov,
                nest=True,
                single_precision=not double_precision,
                force_serial=True,
            )
        log.info_rank(f"Wrote {file_cov} in", timer=timer1, comm=comm)

    # De-weight the observation matrix

    log.info_rank(f"De-weighting obs matrix", comm=comm)
    cc = scipy.sparse.dok_matrix((npixtot, npixtot), dtype=np.float64)
    nsubmap = dist_cov.distribution.n_submap
    npix_submap = dist_cov.distribution.n_pix_submap
    for isubmap_local, isubmap_global in enumerate(dist_cov.distribution.local_submaps):
        submap = dist_cov.data[isubmap_local]
        offset = isubmap_global * npix_submap
        for pix_local in range(npix_submap):
            if np.all(submap[pix_local] == 0):
                continue
            pix = pix_local + offset
            icov = 0
            for inz in range(nnz):
                for jnz in range(inz, nnz):
                    cc[pix + inz * npix, pix + jnz * npix] = submap[pix_local, icov]
                    if inz != jnz:
                        cc[pix + jnz * npix, pix + inz * npix] = submap[pix_local, icov]
                    icov += 1
    cc = cc.tocsr()
    obs_matrix_sum = cc.dot(obs_matrix_sum.matrix)
    log.info_rank(f"De-weighted obs matrix in", timer=timer1, comm=comm)

    # Write out the co-added and de-weighted matrix

    if not outmatrix.endswith(".npz"):
        outmatrix += ".npz"
    log.info_rank(f"Writing {outmatrix}", comm=comm)
    scipy.sparse.save_npz(outmatrix, obs_matrix_sum.astype(dtype))
    log.info_rank(f"Wrote {outmatrix} in", timer=timer1, comm=comm)

    log.info_rank(f"Co-added and de-weighted obs matrix in", timer=timer0, comm=comm)

    return outmatrix

Template Regression

toast.templates.Amplitudes

Bases: AcceleratorObject

Class for distributed template amplitudes.

In the general case, template amplitudes exist as sparse, non-unique values across all processes. This object provides methods for describing the local distribution of amplitudes and for doing global reductions and dot products.

There are 4 supported cases:

1.  If n_global == n_local, then every process has a full copy of the amplitude
    values.

2.  If n_global != n_local and both local_indices and local_ranges are None,
    then every process has a disjoint set of amplitudes.  The sum of n_local
    across all processes must equal n_global.

3.  If n_global != n_local and local_ranges is not None, then local_ranges
    specifies the contiguous global slices that are concatenated to form the
    local data.  The sum of the lengths of the slices must equal n_local.

4.  If n_global != n_local and local_indices is not None, then local_indices
    is an array of the global indices of all the local data.  The length of
    local_indices must equal n_local.  WARNING:  this case is more costly in
    terms of storage and reduction.  Avoid it if possible.

Because different process groups have different sets of observations, there are some types of templates which may only have shared amplitudes within the group communicator. If use_group is True, the group communicator is used instead of the world communicator, and n_global is interpreted as the number of amplitudes in the group. This information is needed whenever working with the full set of amplitudes (for example when doing I/O).

Parameters:

Name Type Description Default
comm Comm

The toast communicator.

required
n_global int

The number of global values across all processes.

required
n_local int

The number of values on this process.

required
local_indices array

If not None, the explicit indices of the local amplitudes within the global array.

None
local_ranges list

If not None, a list of tuples with the (offset, n_amp) amplitude ranges stored locally.

None
dtype dtype

The amplitude dtype.

float64
use_group bool

If True, use the group rather than world communicator.

False
Source code in toast/templates/amplitudes.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
class Amplitudes(AcceleratorObject):
    """Class for distributed template amplitudes.

    In the general case, template amplitudes exist as sparse, non-unique values across
    all processes.  This object provides methods for describing the local distribution
    of amplitudes and for doing global reductions and dot products.

    There are 4 supported cases:

        1.  If n_global == n_local, then every process has a full copy of the amplitude
            values.

        2.  If n_global != n_local and both local_indices and local_ranges are None,
            then every process has a disjoint set of amplitudes.  The sum of n_local
            across all processes must equal n_global.

        3.  If n_global != n_local and local_ranges is not None, then local_ranges
            specifies the contiguous global slices that are concatenated to form the
            local data.  The sum of the lengths of the slices must equal n_local.

        4.  If n_global != n_local and local_indices is not None, then local_indices
            is an array of the global indices of all the local data.  The length of
            local_indices must equal n_local.  WARNING:  this case is more costly in
            terms of storage and reduction.  Avoid it if possible.

    Because different process groups have different sets of observations, there are
    some types of templates which may only have shared amplitudes within the group
    communicator.  If use_group is True, the group communicator is used instead of the
    world communicator, and n_global is interpreted as the number of amplitudes in the
    group.  This information is needed whenever working with the full set of amplitudes
    (for example when doing I/O).

    Args:
        comm (toast.Comm):  The toast communicator.
        n_global (int):  The number of global values across all processes.
        n_local (int):  The number of values on this process.
        local_indices (array):  If not None, the explicit indices of the local
            amplitudes within the global array.
        local_ranges (list):  If not None, a list of tuples with the (offset, n_amp)
            amplitude ranges stored locally.
        dtype (dtype):  The amplitude dtype.
        use_group (bool):  If True, use the group rather than world communicator.

    """

    def __init__(
        self,
        comm,
        n_global,
        n_local,
        local_indices=None,
        local_ranges=None,
        dtype=np.float64,
        use_group=False,
    ):
        super().__init__()
        # print(
        #     f"Amplitudes({comm.world_rank}, n_global={n_global}, n_local={n_local}, lc_ind={local_indices}, lc_rng={local_ranges}, dt={dtype}, use_group={use_group}"
        # )
        self._comm = comm
        self._n_global = n_global
        self._n_local = n_local
        self._local_indices = local_indices
        self._local_ranges = local_ranges
        self._use_group = use_group
        if use_group:
            self._mpicomm = self._comm.comm_group
        else:
            self._mpicomm = self._comm.comm_world
        self._dtype = np.dtype(dtype)
        self._storage_class, self._itemsize = dtype_to_aligned(dtype)
        self._full = False
        self._global_first = None
        self._global_last = None
        if self._n_global == self._n_local:
            self._full = True
            self._global_first = 0
            self._global_last = self._n_local - 1
        else:
            if (self._local_indices is None) and (self._local_ranges is None):
                rank = 0
                if self._mpicomm is not None:
                    all_n_local = self._mpicomm.gather(self._n_local, root=0)
                    rank = self._mpicomm.rank
                    if rank == 0:
                        all_n_local = np.array(all_n_local, dtype=np.int64)
                        if np.sum(all_n_local) != self._n_global:
                            msg = "Total amplitudes on all processes does "
                            msg += "not equal n_global"
                            raise RuntimeError(msg)
                    all_n_local = self._mpicomm.bcast(all_n_local, root=0)
                else:
                    all_n_local = np.array([self._n_local], dtype=np.int64)
                self._global_first = 0
                for i in range(rank):
                    self._global_first += all_n_local[i]
                self._global_last = self._global_first + self._n_local - 1
            elif self._local_ranges is not None:
                # local data is specified by ranges
                check = 0
                last = 0
                for off, n in self._local_ranges:
                    check += n
                    if off < last:
                        msg = "local_ranges must not overlap and must be sorted"
                        raise RuntimeError(msg)
                    last = off + n
                    if last > self._n_global:
                        msg = "local_ranges extends beyond the number of global amps"
                        raise RuntimeError(msg)
                if check != self._n_local:
                    raise RuntimeError("local_ranges must sum to n_local")
                self._global_first = self._local_ranges[0][0]
                self._global_last = (
                    self._local_ranges[-1][0] + self._local_ranges[-1][1] - 1
                )
            else:
                # local data has explicit global indices
                if len(self._local_indices) != self._n_local:
                    msg = "Length of local_indices must match n_local"
                    raise RuntimeError(msg)
                self._global_first = self._local_indices[0]
                self._global_last = self._local_indices[-1]
        if self._n_local == 0:
            self._raw = None
            self.local = None
        else:
            self._raw = self._storage_class.zeros(self._n_local)
            self.local = self._raw.array()

        # Support flagging of template amplitudes.  This can be used to flag some
        # amplitudes if too many timestream samples contributing to the amplitude value
        # are bad.  We will be passing these flags to compiled code, and there
        # is no way easy way to do this using numpy bool and C++ bool.  So we waste
        # a bit of memory and use a whole byte per amplitude.
        if self._n_local == 0:
            self._raw_flags = None
            self.local_flags = None
        else:
            self._raw_flags = AlignedU8.zeros(self._n_local)
            self.local_flags = self._raw_flags.array()

    def clear(self):
        """Delete the underlying memory.

        This will forcibly delete the C-allocated memory and invalidate all python
        references to this object.  DO NOT CALL THIS unless you are sure all references
        are no longer being used and you are about to delete the object.

        """
        if self.accel_exists():
            self.accel_delete()
        if hasattr(self, "local"):
            del self.local
            self.local = None
        if hasattr(self, "local_flags"):
            del self.local_flags
            self.local_flags = None
        if hasattr(self, "_raw"):
            if self._raw is not None:
                self._raw.clear()
            del self._raw
            self._raw = None
        if hasattr(self, "_raw_flags"):
            if self._raw_flags is not None:
                self._raw_flags.clear()
            del self._raw_flags
            self._raw_flags = None

    def __del__(self):
        self.clear()

    def __repr__(self):
        val = "<Amplitudes n_global={} n_local={} comm={}\n  {}\n  {}>".format(
            self.n_global, self.n_local, self.comm, self.local, self.local_flags
        )
        return val

    def __eq__(self, value):
        if isinstance(value, Amplitudes):
            return self.local == value.local
        else:
            return self.local == value

    # Arithmetic.  These assume that flagging is consistent between the pairs of
    # Amplitudes (always true when used in the mapmaking) or that the flagged values
    # have been zeroed out.

    def __iadd__(self, other):
        if self.local is None:
            return self
        if isinstance(other, Amplitudes):
            if other.local is not None:
                self.local[:] += other.local
        else:
            if other is not None:
                self.local[:] += other
        return self

    def __isub__(self, other):
        if self.local is None:
            return self
        if isinstance(other, Amplitudes):
            if other.local is not None:
                self.local[:] -= other.local
        else:
            if other is not None:
                self.local[:] -= other
        return self

    def __imul__(self, other):
        if self.local is None:
            return self
        if isinstance(other, Amplitudes):
            if other.local is not None:
                self.local[:] *= other.local
        else:
            if other is not None:
                self.local[:] *= other
        return self

    def __itruediv__(self, other):
        if self.local is None:
            return self
        if isinstance(other, Amplitudes):
            if other.local is not None:
                self.local[:] /= other.local
        else:
            if other is not None:
                self.local[:] /= other
        return self

    def __add__(self, other):
        result = self.duplicate()
        result += other
        return result

    def __sub__(self, other):
        result = self.duplicate()
        result -= other
        return result

    def __mul__(self, other):
        result = self.duplicate()
        result *= other
        return result

    def __truediv__(self, other):
        result = self.duplicate()
        result /= other
        return result

    def reset(self):
        """Set all amplitude values to zero."""
        if self.local is None:
            return
        self.local[:] = 0
        if self.accel_exists():
            self._accel_reset_local()

    def reset_flags(self):
        """Set all flag values to zero."""
        if self.local_flags is None:
            return
        self.local_flags[:] = 0
        if self.accel_exists():
            self._accel_reset_local_flags()

    def duplicate(self):
        """Return a copy of the data."""
        ret = Amplitudes(
            self._comm,
            self._n_global,
            self._n_local,
            local_indices=self._local_indices,
            local_ranges=self._local_ranges,
            dtype=self._dtype,
            use_group=self._use_group,
        )
        if self.accel_exists():
            ret.accel_create(self._accel_name)
        restore = False
        if self.accel_in_use():
            # We have no good way to copy between device buffers,
            # so do this on the host.  The duplicate() method is
            # not used inside the solver loop.
            self.accel_update_host()
            restore = True
        if self.local is not None:
            ret.local[:] = self.local
        if self.local_flags is not None:
            ret.local_flags[:] = self.local_flags
        if restore:
            self.accel_update_device()
            ret.accel_update_device()
        return ret

    @property
    def comm(self):
        """The toast communicator in use."""
        return self._comm

    @property
    def n_global(self):
        """The total number of amplitudes."""
        return self._n_global

    @property
    def n_local(self):
        """The number of locally stored amplitudes."""
        return self._n_local

    @property
    def n_local_flagged(self):
        """The number of local amplitudes that are flagged."""
        if self.local_flags is None:
            return 0
        else:
            return np.count_nonzero(self.local_flags)

    @property
    def local_indices(self):
        """The global indices of the local amplitudes, or None."""
        return self._local_indices

    @property
    def local_ranges(self):
        """The global slices covered by local amplitudes, or None."""
        return self._local_indices

    @property
    def use_group(self):
        """Whether to use the group communicator rather than the global one."""
        return self._use_group

    def sync(self, comm_bytes=10000000):
        """Perform an Allreduce across all processes.

        Args:
            comm_bytes (int):  The maximum number of bytes to communicate in each
                call to Allreduce.

        Returns:
            None

        """
        if self._mpicomm is None:
            # Nothing to do
            return

        if not self._full and (
            self._local_indices is None and self._local_ranges is None
        ):
            # Disjoint set of amplitudes, no communication needed.
            return

        log = Logger.get()

        n_comm = int(comm_bytes / self._itemsize)
        n_total = self._n_global
        if n_comm > n_total:
            n_comm = n_total

        # Create persistent buffers for the reduction

        send_raw = self._storage_class.zeros(n_comm)
        send_buffer = send_raw.array()
        recv_raw = self._storage_class.zeros(n_comm)
        recv_buffer = recv_raw.array()

        # Buffered Allreduce

        # For each buffer, the local indices of relevant data
        local_selected = None

        # For each buffer, the indices of relevant data in the buffer
        buffer_selected = None

        comm_offset = 0
        while comm_offset < n_total:
            if comm_offset + n_comm > n_total:
                n_comm = n_total - comm_offset

            if self._full:
                # Shortcut if we have all global amplitudes locally
                send_buffer[:n_comm] = self.local[comm_offset : comm_offset + n_comm]
                bad = self.local_flags[comm_offset : comm_offset + n_comm] != 0
                send_buffer[:n_comm][bad] = 0
            else:
                # Need to compute our overlap with the global amplitude range.
                send_buffer[:] = 0
                if (
                    (self._global_last >= comm_offset)
                    and self.local is not None
                    and (self._global_first < comm_offset + n_comm)
                ):
                    # We have some overlap
                    if self._local_ranges is not None:
                        sel_start = None
                        n_sel = 0

                        # current local offset of the range
                        range_off = 0

                        # build up the corresponding buffer indices
                        buffer_selected = list()

                        for off, n in self._local_ranges:
                            if off >= comm_offset + n_comm:
                                range_off += n
                                continue
                            if off + n <= comm_offset:
                                range_off += n
                                continue
                            # This range has some overlap...

                            # This is the starting local memory offset of this range:
                            local_off = range_off

                            # Copy offset into the buffer
                            buf_off = 0

                            # The global starting index of the copy
                            start_indx = None

                            if comm_offset > off:
                                local_off += comm_offset - off
                                start_indx = comm_offset
                            else:
                                buf_off = off - comm_offset
                                start_indx = off

                            if sel_start is None:
                                # this is the first range with some overlap
                                sel_start = local_off

                            n_copy = None
                            if comm_offset + n_comm > off + n:
                                n_copy = off + n - start_indx
                            else:
                                n_copy = comm_offset + n_comm - start_indx

                            n_sel += n_copy

                            buffer_selected.append(
                                np.arange(buf_off, buf_off + n_copy, 1, dtype=np.int64)
                            )
                            send_view = send_buffer[buf_off : buf_off + n_copy]
                            send_view[:] = self.local[local_off : local_off + n_copy]
                            send_view[
                                self.local_flags[local_off : local_off + n_copy] != 0
                            ] = 0
                            range_off += n

                        local_selected = slice(sel_start, sel_start + n_sel, 1)
                        buffer_selected = np.concatenate(buffer_selected)

                    elif self._local_indices is not None:
                        local_selected = np.logical_and(
                            np.logical_and(
                                self._local_indices >= comm_offset,
                                self._local_indices < comm_offset + n_comm,
                            ),
                            self.local_flags == 0,
                        )
                        buffer_selected = (
                            self._local_indices[local_selected] - comm_offset
                        )
                        send_buffer[buffer_selected] = self.local[local_selected]
                    else:
                        raise RuntimeError(
                            "should never get here- non-full, disjoint data requires no sync"
                        )

            self._mpicomm.Allreduce(send_buffer, recv_buffer, op=MPI.SUM)

            if self._full:
                # Shortcut if we have all global amplitudes locally
                self.local[comm_offset : comm_offset + n_comm] = recv_buffer[:n_comm]
            else:
                if (
                    (self._global_last >= comm_offset)
                    and self.local is not None
                    and (self._global_first < comm_offset + n_comm)
                ):
                    self.local[local_selected] = recv_buffer[buffer_selected]

            comm_offset += n_comm

        # Cleanup
        del send_buffer
        del recv_buffer
        send_raw.clear()
        recv_raw.clear()
        del send_raw
        del recv_raw

    def dot(self, other, comm_bytes=10000000):
        """Perform a dot product with another Amplitudes object.

        The other instance must have the same data distribution.  The two objects are
        assumed to have already been synchronized, so that any amplitudes that exist
        on multiple processes have the same values.  This further assumes that any
        flagged amplitudes have been set to zero.

        Args:
            other (Amplitudes):  The other instance.
            comm_bytes (int):  The maximum number of bytes to communicate in each
                call to Allreduce.  Only used in the case of explicitly indexed
                amplitudes on each process.

        Result:
            (float):  The dot product.

        """
        if other.n_global != self.n_global:
            raise RuntimeError("Amplitudes must have the same number of values")
        if other.n_local != self.n_local:
            raise RuntimeError("Amplitudes must have the same number of local values")

        if self._mpicomm is None or self._full:
            # Only one process, or every process has the full set of values.
            return np.dot(
                np.where(self.local_flags == 0, self.local, 0),
                np.where(other.local_flags == 0, other.local, 0),
            )

        if (self._local_ranges is None) and (self._local_indices is None):
            # Every process has a unique set of amplitudes.  Reduce the local
            # dot products.
            if self.local is None:
                local_result = 0
            else:
                local_result = np.dot(
                    np.where(self.local_flags == 0, self.local, 0),
                    np.where(other.local_flags == 0, other.local, 0),
                )
            result = self._mpicomm.allreduce(local_result, op=MPI.SUM)
            return result

        # Each amplitude must only contribute once to the dot product.  Every
        # amplitude will be processed by the lowest-rank process which has
        # that amplitude.  We do this in a buffered way so that we don't need
        # store this amplitude assignment information for the whole data at
        # once.
        n_comm = int(comm_bytes / self._itemsize)
        n_total = self._n_global
        if n_comm > n_total:
            n_comm = n_total

        local_raw = AlignedI32.zeros(n_comm)
        assigned_raw = AlignedI32.zeros(n_comm)
        local = local_raw.array()
        assigned = assigned_raw.array()

        local_result = 0

        # For each buffer, the local indices of relevant data
        local_selected = None

        # For each buffer, the indices of relevant data in the buffer
        buffer_selected = None

        comm_offset = 0
        while comm_offset < n_total:
            if comm_offset + n_comm > n_total:
                n_comm = n_total - comm_offset
            local[:] = self._mpicomm.size

            if (
                (self._global_last >= comm_offset)
                and self.local is not None
                and (self._global_first < comm_offset + n_comm)
            ):
                # We have some overlap
                if self._local_ranges is not None:
                    sel_start = None
                    n_sel = 0

                    # current local offset of the range
                    range_off = 0

                    # build up the corresponding buffer indices
                    buffer_selected = list()

                    for off, n in self._local_ranges:
                        if off >= comm_offset + n_comm:
                            range_off += n
                            continue
                        if off + n <= comm_offset:
                            range_off += n
                            continue
                        # This range has some overlap...

                        # This is the starting local memory offset of this range:
                        local_off = range_off

                        # Copy offset into the buffer
                        buf_off = 0

                        # The global starting index of the copy
                        start_indx = None

                        if comm_offset > off:
                            local_off += comm_offset - off
                            start_indx = comm_offset
                        else:
                            buf_off = off - comm_offset
                            start_indx = off

                        if sel_start is None:
                            # this is the first range with some overlap
                            sel_start = local_off

                        n_set = None
                        if comm_offset + n_comm > off + n:
                            n_set = off + n - start_indx
                        else:
                            n_set = comm_offset + n_comm - start_indx

                        n_sel += n_set

                        buffer_selected.append(
                            np.arange(buf_off, buf_off + n_set, 1, dtype=np.int64)
                        )
                        local_view = local[buf_off : buf_off + n_set]
                        local_view[:] = self._mpicomm.rank
                        local_view[
                            self.local_flags[local_off : local_off + n_set] != 0
                        ] = self._mpicomm.size
                        range_off += n

                    local_selected = slice(sel_start, sel_start + n_sel, 1)
                    buffer_selected = np.concatenate(buffer_selected)

                elif self._local_indices is not None:
                    local_selected = np.logical_and(
                        np.logical_and(
                            self._local_indices >= comm_offset,
                            self._local_indices < comm_offset + n_comm,
                        ),
                        self.local_flags == 0,
                    )
                    buffer_selected = self._local_indices[local_selected] - comm_offset
                    local[buffer_selected] = self._mpicomm.rank
                else:
                    raise RuntimeError(
                        "should never get here- non-full, disjoint data requires no sync"
                    )

            self._mpicomm.Allreduce(local, assigned, op=MPI.MIN)

            if (
                (self._global_last >= comm_offset)
                and self.local is not None
                and (self._global_first < comm_offset + n_comm)
            ):
                # Compute local dot product of just our assigned, unflagged elements
                local_result += np.dot(
                    np.where(
                        np.logical_and(
                            self.local_flags[local_selected] == 0,
                            assigned[buffer_selected] == self._mpicomm.rank,
                        ),
                        self.local[local_selected],
                        0,
                    ),
                    np.where(
                        np.logical_and(
                            other.local_flags[local_selected] == 0,
                            assigned[buffer_selected] == self._mpicomm.rank,
                        ),
                        other.local[local_selected],
                        0,
                    ),
                )

            comm_offset += n_comm

        result = self._mpicomm.allreduce(local_result, op=MPI.SUM)

        del local
        del assigned
        local_raw.clear()
        assigned_raw.clear()
        del local_raw
        del assigned_raw

        return result

    def _accel_exists(self):
        if self.local is None:
            return False
        if use_accel_omp:
            return accel_data_present(
                self._raw, name=self._accel_name
            ) and accel_data_present(self._raw_flags, name=self._accel_name)
        elif use_accel_jax:
            return accel_data_present(self.local) and accel_data_present(
                self.local_flags
            )
        else:
            return False

    def _accel_create(self, zero_out=False):
        if self.local is None:
            return
        if use_accel_omp:
            _ = accel_data_create(self._raw, name=self._accel_name, zero_out=zero_out)
            _ = accel_data_create(
                self._raw_flags, name=self._accel_name, zero_out=zero_out
            )
        elif use_accel_jax:
            self.local = accel_data_create(self.local, zero_out=zero_out)
            self.local_flags = accel_data_create(self.local_flags, zero_out=zero_out)

    def _accel_update_device(self):
        if self.local is None:
            return
        if use_accel_omp:
            _ = accel_data_update_device(self._raw, name=self._accel_name)
            _ = accel_data_update_device(self._raw_flags, name=self._accel_name)
        elif use_accel_jax:
            self.local = accel_data_update_device(self.local)
            self.local_flags = accel_data_update_device(self.local_flags)

    def _accel_update_host(self):
        if self.local is None:
            return
        if use_accel_omp:
            _ = accel_data_update_host(self._raw, name=self._accel_name)
            _ = accel_data_update_host(self._raw_flags, name=self._accel_name)
        elif use_accel_jax:
            self.local = accel_data_update_host(self.local)
            self.local_flags = accel_data_update_host(self.local_flags)

    def _accel_delete(self):
        if self.local is None:
            return
        if use_accel_omp:
            _ = accel_data_delete(self._raw, name=self._accel_name)
            _ = accel_data_delete(self._raw_flags, name=self._accel_name)
        elif use_accel_jax:
            self.local = accel_data_delete(self.local)
            self.local_flags = accel_data_delete(self.local_flags)

    def _accel_reset_local(self):
        if self.local is None:
            return
        # if not self.accel_in_use():
        #     return
        if use_accel_omp:
            accel_data_reset(self._raw, name=self._accel_name)
        elif use_accel_jax:
            accel_data_reset(self.local)

    def _accel_reset_local_flags(self):
        if self.local is None:
            return
        # if not self.accel_in_use():
        #     return
        if use_accel_omp:
            accel_data_reset(self._raw_flags, name=self._accel_name)
        elif use_accel_jax:
            accel_data_reset(self.local_flags)

    def _accel_reset(self):
        self._accel_reset_local()
        self._accel_reset_local_flags()

_comm = comm instance-attribute

_dtype = np.dtype(dtype) instance-attribute

_full = False instance-attribute

_global_first = None instance-attribute

_global_last = None instance-attribute

_local_indices = local_indices instance-attribute

_local_ranges = local_ranges instance-attribute

_mpicomm = self._comm.comm_group instance-attribute

_n_global = n_global instance-attribute

_n_local = n_local instance-attribute

_raw = None instance-attribute

_raw_flags = None instance-attribute

_use_group = use_group instance-attribute

comm property

The toast communicator in use.

local = None instance-attribute

local_flags = None instance-attribute

local_indices property

The global indices of the local amplitudes, or None.

local_ranges property

The global slices covered by local amplitudes, or None.

n_global property

The total number of amplitudes.

n_local property

The number of locally stored amplitudes.

n_local_flagged property

The number of local amplitudes that are flagged.

use_group property

Whether to use the group communicator rather than the global one.

__add__(other)

Source code in toast/templates/amplitudes.py
268
269
270
271
def __add__(self, other):
    result = self.duplicate()
    result += other
    return result

__del__()

Source code in toast/templates/amplitudes.py
205
206
def __del__(self):
    self.clear()

__eq__(value)

Source code in toast/templates/amplitudes.py
214
215
216
217
218
def __eq__(self, value):
    if isinstance(value, Amplitudes):
        return self.local == value.local
    else:
        return self.local == value

__iadd__(other)

Source code in toast/templates/amplitudes.py
224
225
226
227
228
229
230
231
232
233
def __iadd__(self, other):
    if self.local is None:
        return self
    if isinstance(other, Amplitudes):
        if other.local is not None:
            self.local[:] += other.local
    else:
        if other is not None:
            self.local[:] += other
    return self

__imul__(other)

Source code in toast/templates/amplitudes.py
246
247
248
249
250
251
252
253
254
255
def __imul__(self, other):
    if self.local is None:
        return self
    if isinstance(other, Amplitudes):
        if other.local is not None:
            self.local[:] *= other.local
    else:
        if other is not None:
            self.local[:] *= other
    return self

__init__(comm, n_global, n_local, local_indices=None, local_ranges=None, dtype=np.float64, use_group=False)

Source code in toast/templates/amplitudes.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def __init__(
    self,
    comm,
    n_global,
    n_local,
    local_indices=None,
    local_ranges=None,
    dtype=np.float64,
    use_group=False,
):
    super().__init__()
    # print(
    #     f"Amplitudes({comm.world_rank}, n_global={n_global}, n_local={n_local}, lc_ind={local_indices}, lc_rng={local_ranges}, dt={dtype}, use_group={use_group}"
    # )
    self._comm = comm
    self._n_global = n_global
    self._n_local = n_local
    self._local_indices = local_indices
    self._local_ranges = local_ranges
    self._use_group = use_group
    if use_group:
        self._mpicomm = self._comm.comm_group
    else:
        self._mpicomm = self._comm.comm_world
    self._dtype = np.dtype(dtype)
    self._storage_class, self._itemsize = dtype_to_aligned(dtype)
    self._full = False
    self._global_first = None
    self._global_last = None
    if self._n_global == self._n_local:
        self._full = True
        self._global_first = 0
        self._global_last = self._n_local - 1
    else:
        if (self._local_indices is None) and (self._local_ranges is None):
            rank = 0
            if self._mpicomm is not None:
                all_n_local = self._mpicomm.gather(self._n_local, root=0)
                rank = self._mpicomm.rank
                if rank == 0:
                    all_n_local = np.array(all_n_local, dtype=np.int64)
                    if np.sum(all_n_local) != self._n_global:
                        msg = "Total amplitudes on all processes does "
                        msg += "not equal n_global"
                        raise RuntimeError(msg)
                all_n_local = self._mpicomm.bcast(all_n_local, root=0)
            else:
                all_n_local = np.array([self._n_local], dtype=np.int64)
            self._global_first = 0
            for i in range(rank):
                self._global_first += all_n_local[i]
            self._global_last = self._global_first + self._n_local - 1
        elif self._local_ranges is not None:
            # local data is specified by ranges
            check = 0
            last = 0
            for off, n in self._local_ranges:
                check += n
                if off < last:
                    msg = "local_ranges must not overlap and must be sorted"
                    raise RuntimeError(msg)
                last = off + n
                if last > self._n_global:
                    msg = "local_ranges extends beyond the number of global amps"
                    raise RuntimeError(msg)
            if check != self._n_local:
                raise RuntimeError("local_ranges must sum to n_local")
            self._global_first = self._local_ranges[0][0]
            self._global_last = (
                self._local_ranges[-1][0] + self._local_ranges[-1][1] - 1
            )
        else:
            # local data has explicit global indices
            if len(self._local_indices) != self._n_local:
                msg = "Length of local_indices must match n_local"
                raise RuntimeError(msg)
            self._global_first = self._local_indices[0]
            self._global_last = self._local_indices[-1]
    if self._n_local == 0:
        self._raw = None
        self.local = None
    else:
        self._raw = self._storage_class.zeros(self._n_local)
        self.local = self._raw.array()

    # Support flagging of template amplitudes.  This can be used to flag some
    # amplitudes if too many timestream samples contributing to the amplitude value
    # are bad.  We will be passing these flags to compiled code, and there
    # is no way easy way to do this using numpy bool and C++ bool.  So we waste
    # a bit of memory and use a whole byte per amplitude.
    if self._n_local == 0:
        self._raw_flags = None
        self.local_flags = None
    else:
        self._raw_flags = AlignedU8.zeros(self._n_local)
        self.local_flags = self._raw_flags.array()

__isub__(other)

Source code in toast/templates/amplitudes.py
235
236
237
238
239
240
241
242
243
244
def __isub__(self, other):
    if self.local is None:
        return self
    if isinstance(other, Amplitudes):
        if other.local is not None:
            self.local[:] -= other.local
    else:
        if other is not None:
            self.local[:] -= other
    return self

__itruediv__(other)

Source code in toast/templates/amplitudes.py
257
258
259
260
261
262
263
264
265
266
def __itruediv__(self, other):
    if self.local is None:
        return self
    if isinstance(other, Amplitudes):
        if other.local is not None:
            self.local[:] /= other.local
    else:
        if other is not None:
            self.local[:] /= other
    return self

__mul__(other)

Source code in toast/templates/amplitudes.py
278
279
280
281
def __mul__(self, other):
    result = self.duplicate()
    result *= other
    return result

__repr__()

Source code in toast/templates/amplitudes.py
208
209
210
211
212
def __repr__(self):
    val = "<Amplitudes n_global={} n_local={} comm={}\n  {}\n  {}>".format(
        self.n_global, self.n_local, self.comm, self.local, self.local_flags
    )
    return val

__sub__(other)

Source code in toast/templates/amplitudes.py
273
274
275
276
def __sub__(self, other):
    result = self.duplicate()
    result -= other
    return result

__truediv__(other)

Source code in toast/templates/amplitudes.py
283
284
285
286
def __truediv__(self, other):
    result = self.duplicate()
    result /= other
    return result

_accel_create(zero_out=False)

Source code in toast/templates/amplitudes.py
740
741
742
743
744
745
746
747
748
749
750
def _accel_create(self, zero_out=False):
    if self.local is None:
        return
    if use_accel_omp:
        _ = accel_data_create(self._raw, name=self._accel_name, zero_out=zero_out)
        _ = accel_data_create(
            self._raw_flags, name=self._accel_name, zero_out=zero_out
        )
    elif use_accel_jax:
        self.local = accel_data_create(self.local, zero_out=zero_out)
        self.local_flags = accel_data_create(self.local_flags, zero_out=zero_out)

_accel_delete()

Source code in toast/templates/amplitudes.py
772
773
774
775
776
777
778
779
780
def _accel_delete(self):
    if self.local is None:
        return
    if use_accel_omp:
        _ = accel_data_delete(self._raw, name=self._accel_name)
        _ = accel_data_delete(self._raw_flags, name=self._accel_name)
    elif use_accel_jax:
        self.local = accel_data_delete(self.local)
        self.local_flags = accel_data_delete(self.local_flags)

_accel_exists()

Source code in toast/templates/amplitudes.py
726
727
728
729
730
731
732
733
734
735
736
737
738
def _accel_exists(self):
    if self.local is None:
        return False
    if use_accel_omp:
        return accel_data_present(
            self._raw, name=self._accel_name
        ) and accel_data_present(self._raw_flags, name=self._accel_name)
    elif use_accel_jax:
        return accel_data_present(self.local) and accel_data_present(
            self.local_flags
        )
    else:
        return False

_accel_reset()

Source code in toast/templates/amplitudes.py
802
803
804
def _accel_reset(self):
    self._accel_reset_local()
    self._accel_reset_local_flags()

_accel_reset_local()

Source code in toast/templates/amplitudes.py
782
783
784
785
786
787
788
789
790
def _accel_reset_local(self):
    if self.local is None:
        return
    # if not self.accel_in_use():
    #     return
    if use_accel_omp:
        accel_data_reset(self._raw, name=self._accel_name)
    elif use_accel_jax:
        accel_data_reset(self.local)

_accel_reset_local_flags()

Source code in toast/templates/amplitudes.py
792
793
794
795
796
797
798
799
800
def _accel_reset_local_flags(self):
    if self.local is None:
        return
    # if not self.accel_in_use():
    #     return
    if use_accel_omp:
        accel_data_reset(self._raw_flags, name=self._accel_name)
    elif use_accel_jax:
        accel_data_reset(self.local_flags)

_accel_update_device()

Source code in toast/templates/amplitudes.py
752
753
754
755
756
757
758
759
760
def _accel_update_device(self):
    if self.local is None:
        return
    if use_accel_omp:
        _ = accel_data_update_device(self._raw, name=self._accel_name)
        _ = accel_data_update_device(self._raw_flags, name=self._accel_name)
    elif use_accel_jax:
        self.local = accel_data_update_device(self.local)
        self.local_flags = accel_data_update_device(self.local_flags)

_accel_update_host()

Source code in toast/templates/amplitudes.py
762
763
764
765
766
767
768
769
770
def _accel_update_host(self):
    if self.local is None:
        return
    if use_accel_omp:
        _ = accel_data_update_host(self._raw, name=self._accel_name)
        _ = accel_data_update_host(self._raw_flags, name=self._accel_name)
    elif use_accel_jax:
        self.local = accel_data_update_host(self.local)
        self.local_flags = accel_data_update_host(self.local_flags)

clear()

Delete the underlying memory.

This will forcibly delete the C-allocated memory and invalidate all python references to this object. DO NOT CALL THIS unless you are sure all references are no longer being used and you are about to delete the object.

Source code in toast/templates/amplitudes.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def clear(self):
    """Delete the underlying memory.

    This will forcibly delete the C-allocated memory and invalidate all python
    references to this object.  DO NOT CALL THIS unless you are sure all references
    are no longer being used and you are about to delete the object.

    """
    if self.accel_exists():
        self.accel_delete()
    if hasattr(self, "local"):
        del self.local
        self.local = None
    if hasattr(self, "local_flags"):
        del self.local_flags
        self.local_flags = None
    if hasattr(self, "_raw"):
        if self._raw is not None:
            self._raw.clear()
        del self._raw
        self._raw = None
    if hasattr(self, "_raw_flags"):
        if self._raw_flags is not None:
            self._raw_flags.clear()
        del self._raw_flags
        self._raw_flags = None

dot(other, comm_bytes=10000000)

Perform a dot product with another Amplitudes object.

The other instance must have the same data distribution. The two objects are assumed to have already been synchronized, so that any amplitudes that exist on multiple processes have the same values. This further assumes that any flagged amplitudes have been set to zero.

Parameters:

Name Type Description Default
other Amplitudes

The other instance.

required
comm_bytes int

The maximum number of bytes to communicate in each call to Allreduce. Only used in the case of explicitly indexed amplitudes on each process.

10000000
Result

(float): The dot product.

Source code in toast/templates/amplitudes.py
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
def dot(self, other, comm_bytes=10000000):
    """Perform a dot product with another Amplitudes object.

    The other instance must have the same data distribution.  The two objects are
    assumed to have already been synchronized, so that any amplitudes that exist
    on multiple processes have the same values.  This further assumes that any
    flagged amplitudes have been set to zero.

    Args:
        other (Amplitudes):  The other instance.
        comm_bytes (int):  The maximum number of bytes to communicate in each
            call to Allreduce.  Only used in the case of explicitly indexed
            amplitudes on each process.

    Result:
        (float):  The dot product.

    """
    if other.n_global != self.n_global:
        raise RuntimeError("Amplitudes must have the same number of values")
    if other.n_local != self.n_local:
        raise RuntimeError("Amplitudes must have the same number of local values")

    if self._mpicomm is None or self._full:
        # Only one process, or every process has the full set of values.
        return np.dot(
            np.where(self.local_flags == 0, self.local, 0),
            np.where(other.local_flags == 0, other.local, 0),
        )

    if (self._local_ranges is None) and (self._local_indices is None):
        # Every process has a unique set of amplitudes.  Reduce the local
        # dot products.
        if self.local is None:
            local_result = 0
        else:
            local_result = np.dot(
                np.where(self.local_flags == 0, self.local, 0),
                np.where(other.local_flags == 0, other.local, 0),
            )
        result = self._mpicomm.allreduce(local_result, op=MPI.SUM)
        return result

    # Each amplitude must only contribute once to the dot product.  Every
    # amplitude will be processed by the lowest-rank process which has
    # that amplitude.  We do this in a buffered way so that we don't need
    # store this amplitude assignment information for the whole data at
    # once.
    n_comm = int(comm_bytes / self._itemsize)
    n_total = self._n_global
    if n_comm > n_total:
        n_comm = n_total

    local_raw = AlignedI32.zeros(n_comm)
    assigned_raw = AlignedI32.zeros(n_comm)
    local = local_raw.array()
    assigned = assigned_raw.array()

    local_result = 0

    # For each buffer, the local indices of relevant data
    local_selected = None

    # For each buffer, the indices of relevant data in the buffer
    buffer_selected = None

    comm_offset = 0
    while comm_offset < n_total:
        if comm_offset + n_comm > n_total:
            n_comm = n_total - comm_offset
        local[:] = self._mpicomm.size

        if (
            (self._global_last >= comm_offset)
            and self.local is not None
            and (self._global_first < comm_offset + n_comm)
        ):
            # We have some overlap
            if self._local_ranges is not None:
                sel_start = None
                n_sel = 0

                # current local offset of the range
                range_off = 0

                # build up the corresponding buffer indices
                buffer_selected = list()

                for off, n in self._local_ranges:
                    if off >= comm_offset + n_comm:
                        range_off += n
                        continue
                    if off + n <= comm_offset:
                        range_off += n
                        continue
                    # This range has some overlap...

                    # This is the starting local memory offset of this range:
                    local_off = range_off

                    # Copy offset into the buffer
                    buf_off = 0

                    # The global starting index of the copy
                    start_indx = None

                    if comm_offset > off:
                        local_off += comm_offset - off
                        start_indx = comm_offset
                    else:
                        buf_off = off - comm_offset
                        start_indx = off

                    if sel_start is None:
                        # this is the first range with some overlap
                        sel_start = local_off

                    n_set = None
                    if comm_offset + n_comm > off + n:
                        n_set = off + n - start_indx
                    else:
                        n_set = comm_offset + n_comm - start_indx

                    n_sel += n_set

                    buffer_selected.append(
                        np.arange(buf_off, buf_off + n_set, 1, dtype=np.int64)
                    )
                    local_view = local[buf_off : buf_off + n_set]
                    local_view[:] = self._mpicomm.rank
                    local_view[
                        self.local_flags[local_off : local_off + n_set] != 0
                    ] = self._mpicomm.size
                    range_off += n

                local_selected = slice(sel_start, sel_start + n_sel, 1)
                buffer_selected = np.concatenate(buffer_selected)

            elif self._local_indices is not None:
                local_selected = np.logical_and(
                    np.logical_and(
                        self._local_indices >= comm_offset,
                        self._local_indices < comm_offset + n_comm,
                    ),
                    self.local_flags == 0,
                )
                buffer_selected = self._local_indices[local_selected] - comm_offset
                local[buffer_selected] = self._mpicomm.rank
            else:
                raise RuntimeError(
                    "should never get here- non-full, disjoint data requires no sync"
                )

        self._mpicomm.Allreduce(local, assigned, op=MPI.MIN)

        if (
            (self._global_last >= comm_offset)
            and self.local is not None
            and (self._global_first < comm_offset + n_comm)
        ):
            # Compute local dot product of just our assigned, unflagged elements
            local_result += np.dot(
                np.where(
                    np.logical_and(
                        self.local_flags[local_selected] == 0,
                        assigned[buffer_selected] == self._mpicomm.rank,
                    ),
                    self.local[local_selected],
                    0,
                ),
                np.where(
                    np.logical_and(
                        other.local_flags[local_selected] == 0,
                        assigned[buffer_selected] == self._mpicomm.rank,
                    ),
                    other.local[local_selected],
                    0,
                ),
            )

        comm_offset += n_comm

    result = self._mpicomm.allreduce(local_result, op=MPI.SUM)

    del local
    del assigned
    local_raw.clear()
    assigned_raw.clear()
    del local_raw
    del assigned_raw

    return result

duplicate()

Return a copy of the data.

Source code in toast/templates/amplitudes.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def duplicate(self):
    """Return a copy of the data."""
    ret = Amplitudes(
        self._comm,
        self._n_global,
        self._n_local,
        local_indices=self._local_indices,
        local_ranges=self._local_ranges,
        dtype=self._dtype,
        use_group=self._use_group,
    )
    if self.accel_exists():
        ret.accel_create(self._accel_name)
    restore = False
    if self.accel_in_use():
        # We have no good way to copy between device buffers,
        # so do this on the host.  The duplicate() method is
        # not used inside the solver loop.
        self.accel_update_host()
        restore = True
    if self.local is not None:
        ret.local[:] = self.local
    if self.local_flags is not None:
        ret.local_flags[:] = self.local_flags
    if restore:
        self.accel_update_device()
        ret.accel_update_device()
    return ret

reset()

Set all amplitude values to zero.

Source code in toast/templates/amplitudes.py
288
289
290
291
292
293
294
def reset(self):
    """Set all amplitude values to zero."""
    if self.local is None:
        return
    self.local[:] = 0
    if self.accel_exists():
        self._accel_reset_local()

reset_flags()

Set all flag values to zero.

Source code in toast/templates/amplitudes.py
296
297
298
299
300
301
302
def reset_flags(self):
    """Set all flag values to zero."""
    if self.local_flags is None:
        return
    self.local_flags[:] = 0
    if self.accel_exists():
        self._accel_reset_local_flags()

sync(comm_bytes=10000000)

Perform an Allreduce across all processes.

Parameters:

Name Type Description Default
comm_bytes int

The maximum number of bytes to communicate in each call to Allreduce.

10000000

Returns:

Type Description

None

Source code in toast/templates/amplitudes.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def sync(self, comm_bytes=10000000):
    """Perform an Allreduce across all processes.

    Args:
        comm_bytes (int):  The maximum number of bytes to communicate in each
            call to Allreduce.

    Returns:
        None

    """
    if self._mpicomm is None:
        # Nothing to do
        return

    if not self._full and (
        self._local_indices is None and self._local_ranges is None
    ):
        # Disjoint set of amplitudes, no communication needed.
        return

    log = Logger.get()

    n_comm = int(comm_bytes / self._itemsize)
    n_total = self._n_global
    if n_comm > n_total:
        n_comm = n_total

    # Create persistent buffers for the reduction

    send_raw = self._storage_class.zeros(n_comm)
    send_buffer = send_raw.array()
    recv_raw = self._storage_class.zeros(n_comm)
    recv_buffer = recv_raw.array()

    # Buffered Allreduce

    # For each buffer, the local indices of relevant data
    local_selected = None

    # For each buffer, the indices of relevant data in the buffer
    buffer_selected = None

    comm_offset = 0
    while comm_offset < n_total:
        if comm_offset + n_comm > n_total:
            n_comm = n_total - comm_offset

        if self._full:
            # Shortcut if we have all global amplitudes locally
            send_buffer[:n_comm] = self.local[comm_offset : comm_offset + n_comm]
            bad = self.local_flags[comm_offset : comm_offset + n_comm] != 0
            send_buffer[:n_comm][bad] = 0
        else:
            # Need to compute our overlap with the global amplitude range.
            send_buffer[:] = 0
            if (
                (self._global_last >= comm_offset)
                and self.local is not None
                and (self._global_first < comm_offset + n_comm)
            ):
                # We have some overlap
                if self._local_ranges is not None:
                    sel_start = None
                    n_sel = 0

                    # current local offset of the range
                    range_off = 0

                    # build up the corresponding buffer indices
                    buffer_selected = list()

                    for off, n in self._local_ranges:
                        if off >= comm_offset + n_comm:
                            range_off += n
                            continue
                        if off + n <= comm_offset:
                            range_off += n
                            continue
                        # This range has some overlap...

                        # This is the starting local memory offset of this range:
                        local_off = range_off

                        # Copy offset into the buffer
                        buf_off = 0

                        # The global starting index of the copy
                        start_indx = None

                        if comm_offset > off:
                            local_off += comm_offset - off
                            start_indx = comm_offset
                        else:
                            buf_off = off - comm_offset
                            start_indx = off

                        if sel_start is None:
                            # this is the first range with some overlap
                            sel_start = local_off

                        n_copy = None
                        if comm_offset + n_comm > off + n:
                            n_copy = off + n - start_indx
                        else:
                            n_copy = comm_offset + n_comm - start_indx

                        n_sel += n_copy

                        buffer_selected.append(
                            np.arange(buf_off, buf_off + n_copy, 1, dtype=np.int64)
                        )
                        send_view = send_buffer[buf_off : buf_off + n_copy]
                        send_view[:] = self.local[local_off : local_off + n_copy]
                        send_view[
                            self.local_flags[local_off : local_off + n_copy] != 0
                        ] = 0
                        range_off += n

                    local_selected = slice(sel_start, sel_start + n_sel, 1)
                    buffer_selected = np.concatenate(buffer_selected)

                elif self._local_indices is not None:
                    local_selected = np.logical_and(
                        np.logical_and(
                            self._local_indices >= comm_offset,
                            self._local_indices < comm_offset + n_comm,
                        ),
                        self.local_flags == 0,
                    )
                    buffer_selected = (
                        self._local_indices[local_selected] - comm_offset
                    )
                    send_buffer[buffer_selected] = self.local[local_selected]
                else:
                    raise RuntimeError(
                        "should never get here- non-full, disjoint data requires no sync"
                    )

        self._mpicomm.Allreduce(send_buffer, recv_buffer, op=MPI.SUM)

        if self._full:
            # Shortcut if we have all global amplitudes locally
            self.local[comm_offset : comm_offset + n_comm] = recv_buffer[:n_comm]
        else:
            if (
                (self._global_last >= comm_offset)
                and self.local is not None
                and (self._global_first < comm_offset + n_comm)
            ):
                self.local[local_selected] = recv_buffer[buffer_selected]

        comm_offset += n_comm

    # Cleanup
    del send_buffer
    del recv_buffer
    send_raw.clear()
    recv_raw.clear()
    del send_raw
    del recv_raw

toast.templates.AmplitudesMap

Bases: MutableMapping, AcceleratorObject

Helper class to provide arithmetic operations on a collection of Amplitudes.

This simply provides syntactic sugar to reduce duplicated code when working with a collection of Amplitudes in the map making.

Source code in toast/templates/amplitudes.py
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
class AmplitudesMap(MutableMapping, AcceleratorObject):
    """Helper class to provide arithmetic operations on a collection of Amplitudes.

    This simply provides syntactic sugar to reduce duplicated code when working with
    a collection of Amplitudes in the map making.

    """

    def __init__(self):
        super().__init__()
        self._internal = dict()

    # Mapping methods

    def __getitem__(self, key):
        return self._internal[key]

    def __delitem__(self, key):
        del self._internal[key]

    def __setitem__(self, key, value):
        if not isinstance(value, Amplitudes):
            raise RuntimeError(
                "Only Amplitudes objects may be assigned to an AmplitudesMap"
            )
        self._internal[key] = value

    def __iter__(self):
        return iter(self._internal)

    def __len__(self):
        return len(self._internal)

    def __repr__(self):
        val = "<AmplitudesMap"
        for k, v in self._internal.items():
            val += "\n  {} = {}".format(k, v)
        val += "\n>"
        return val

    # Arithmetic.  These operations are done between corresponding Amplitude keys.

    def _check_other(self, other):
        log = Logger.get()
        if sorted(self._internal.keys()) != sorted(other._internal.keys()):
            msg = "Arithmetic between AmplitudesMap objects requires identical keys"
            log.error(msg)
            raise RuntimeError(msg)
        for k, v in self._internal.items():
            if v.n_global != other[k].n_global:
                msg = "Number of global amplitudes not equal for key '{}'".format(k)
                log.error(msg)
                raise RuntimeError(msg)
            if v.n_local != other[k].n_local:
                msg = "Number of local amplitudes not equal for key '{}'".format(k)
                log.error(msg)
                raise RuntimeError(msg)

    def __eq__(self, value):
        if isinstance(value, AmplitudesMap):
            self._check_other(value)
            for k, v in self._internal.items():
                if v != value[k]:
                    return False
            return True
        else:
            for k, v in self._internal.items():
                if v != value:
                    return False
            return True

    def __iadd__(self, other):
        if isinstance(other, AmplitudesMap):
            self._check_other(other)
            for k, v in self._internal.items():
                v += other[k]
        else:
            for k, v in self._internal.items():
                v += other
        return self

    def __isub__(self, other):
        if isinstance(other, AmplitudesMap):
            self._check_other(other)
            for k, v in self._internal.items():
                v -= other[k]
        else:
            for k, v in self._internal.items():
                v -= other
        return self

    def __imul__(self, other):
        if isinstance(other, AmplitudesMap):
            self._check_other(other)
            for k, v in self._internal.items():
                v *= other[k]
        else:
            for k, v in self._internal.items():
                v *= other
        return self

    def __itruediv__(self, other):
        if isinstance(other, AmplitudesMap):
            self._check_other(other)
            for k, v in self._internal.items():
                v /= other[k]
        else:
            for k, v in self._internal.items():
                v /= other
        return self

    def __add__(self, other):
        result = self.duplicate()
        result += other
        return result

    def __sub__(self, other):
        result = self.duplicate()
        result -= other
        return result

    def __mul__(self, other):
        result = self.duplicate()
        result *= other
        return result

    def __truediv__(self, other):
        result = self.duplicate()
        result /= other
        return result

    def reset(self):
        """Set all amplitude values to zero."""
        for k, v in self._internal.items():
            v.reset()

    def reset_flags(self):
        """Set all flag values to zero."""
        for k, v in self._internal.items():
            v.reset_flags()

    def duplicate(self):
        """Return a copy of the data."""
        ret = AmplitudesMap()
        for k, v in self._internal.items():
            ret[k] = v.duplicate()
        return ret

    def dot(self, other):
        """Dot product of all corresponding Amplitudes.

        Args:
            other (AmplitudesMap):  The other instance.

        Result:
            (float):  The dot product.

        """
        log = Logger.get()
        if not isinstance(other, AmplitudesMap):
            msg = "dot product must be with another AmplitudesMap object"
            log.error(msg)
            raise RuntimeError(msg)
        self._check_other(other)
        result = 0.0
        for k, v in self._internal.items():
            result += v.dot(other[k])
        return result

    def accel_used(self, state):
        super().accel_used(state)
        for k, v in self._internal.items():
            v.accel_used(state)

    def _accel_exists(self):
        if not accel_enabled():
            return False
        result = 0
        for k, v in self._internal.items():
            if v.accel_exists():
                result += 1
        if result == 0:
            return False
        elif result != len(self._internal):
            log = Logger.get()
            msg = f"Only some of the Amplitudes exist on device"
            log.error(msg)
            raise RuntimeError(msg)
        return True

    def _accel_create(self, zero_out=False):
        if not accel_enabled():
            return
        for k, v in self._internal.items():
            v.accel_create(f"{self._accel_name}_{k}", zero_out=zero_out)

    def _accel_update_device(self):
        if not accel_enabled():
            return
        for k, v in self._internal.items():
            v.accel_update_device()

    def _accel_update_host(self):
        if not accel_enabled():
            return
        for k, v in self._internal.items():
            v.accel_update_host()

    def _accel_delete(self):
        if not accel_enabled():
            return
        for k, v in self._internal.items():
            v.accel_delete()

    def _accel_reset(self):
        if not accel_enabled():
            return
        for k, v in self._internal.items():
            v.accel_reset()

_internal = dict() instance-attribute

__add__(other)

Source code in toast/templates/amplitudes.py
918
919
920
921
def __add__(self, other):
    result = self.duplicate()
    result += other
    return result

__delitem__(key)

Source code in toast/templates/amplitudes.py
824
825
def __delitem__(self, key):
    del self._internal[key]

__eq__(value)

Source code in toast/templates/amplitudes.py
865
866
867
868
869
870
871
872
873
874
875
876
def __eq__(self, value):
    if isinstance(value, AmplitudesMap):
        self._check_other(value)
        for k, v in self._internal.items():
            if v != value[k]:
                return False
        return True
    else:
        for k, v in self._internal.items():
            if v != value:
                return False
        return True

__getitem__(key)

Source code in toast/templates/amplitudes.py
821
822
def __getitem__(self, key):
    return self._internal[key]

__iadd__(other)

Source code in toast/templates/amplitudes.py
878
879
880
881
882
883
884
885
886
def __iadd__(self, other):
    if isinstance(other, AmplitudesMap):
        self._check_other(other)
        for k, v in self._internal.items():
            v += other[k]
    else:
        for k, v in self._internal.items():
            v += other
    return self

__imul__(other)

Source code in toast/templates/amplitudes.py
898
899
900
901
902
903
904
905
906
def __imul__(self, other):
    if isinstance(other, AmplitudesMap):
        self._check_other(other)
        for k, v in self._internal.items():
            v *= other[k]
    else:
        for k, v in self._internal.items():
            v *= other
    return self

__init__()

Source code in toast/templates/amplitudes.py
815
816
817
def __init__(self):
    super().__init__()
    self._internal = dict()

__isub__(other)

Source code in toast/templates/amplitudes.py
888
889
890
891
892
893
894
895
896
def __isub__(self, other):
    if isinstance(other, AmplitudesMap):
        self._check_other(other)
        for k, v in self._internal.items():
            v -= other[k]
    else:
        for k, v in self._internal.items():
            v -= other
    return self

__iter__()

Source code in toast/templates/amplitudes.py
834
835
def __iter__(self):
    return iter(self._internal)

__itruediv__(other)

Source code in toast/templates/amplitudes.py
908
909
910
911
912
913
914
915
916
def __itruediv__(self, other):
    if isinstance(other, AmplitudesMap):
        self._check_other(other)
        for k, v in self._internal.items():
            v /= other[k]
    else:
        for k, v in self._internal.items():
            v /= other
    return self

__len__()

Source code in toast/templates/amplitudes.py
837
838
def __len__(self):
    return len(self._internal)

__mul__(other)

Source code in toast/templates/amplitudes.py
928
929
930
931
def __mul__(self, other):
    result = self.duplicate()
    result *= other
    return result

__repr__()

Source code in toast/templates/amplitudes.py
840
841
842
843
844
845
def __repr__(self):
    val = "<AmplitudesMap"
    for k, v in self._internal.items():
        val += "\n  {} = {}".format(k, v)
    val += "\n>"
    return val

__setitem__(key, value)

Source code in toast/templates/amplitudes.py
827
828
829
830
831
832
def __setitem__(self, key, value):
    if not isinstance(value, Amplitudes):
        raise RuntimeError(
            "Only Amplitudes objects may be assigned to an AmplitudesMap"
        )
    self._internal[key] = value

__sub__(other)

Source code in toast/templates/amplitudes.py
923
924
925
926
def __sub__(self, other):
    result = self.duplicate()
    result -= other
    return result

__truediv__(other)

Source code in toast/templates/amplitudes.py
933
934
935
936
def __truediv__(self, other):
    result = self.duplicate()
    result /= other
    return result

_accel_create(zero_out=False)

Source code in toast/templates/amplitudes.py
 997
 998
 999
1000
1001
def _accel_create(self, zero_out=False):
    if not accel_enabled():
        return
    for k, v in self._internal.items():
        v.accel_create(f"{self._accel_name}_{k}", zero_out=zero_out)

_accel_delete()

Source code in toast/templates/amplitudes.py
1015
1016
1017
1018
1019
def _accel_delete(self):
    if not accel_enabled():
        return
    for k, v in self._internal.items():
        v.accel_delete()

_accel_exists()

Source code in toast/templates/amplitudes.py
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
def _accel_exists(self):
    if not accel_enabled():
        return False
    result = 0
    for k, v in self._internal.items():
        if v.accel_exists():
            result += 1
    if result == 0:
        return False
    elif result != len(self._internal):
        log = Logger.get()
        msg = f"Only some of the Amplitudes exist on device"
        log.error(msg)
        raise RuntimeError(msg)
    return True

_accel_reset()

Source code in toast/templates/amplitudes.py
1021
1022
1023
1024
1025
def _accel_reset(self):
    if not accel_enabled():
        return
    for k, v in self._internal.items():
        v.accel_reset()

_accel_update_device()

Source code in toast/templates/amplitudes.py
1003
1004
1005
1006
1007
def _accel_update_device(self):
    if not accel_enabled():
        return
    for k, v in self._internal.items():
        v.accel_update_device()

_accel_update_host()

Source code in toast/templates/amplitudes.py
1009
1010
1011
1012
1013
def _accel_update_host(self):
    if not accel_enabled():
        return
    for k, v in self._internal.items():
        v.accel_update_host()

_check_other(other)

Source code in toast/templates/amplitudes.py
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
def _check_other(self, other):
    log = Logger.get()
    if sorted(self._internal.keys()) != sorted(other._internal.keys()):
        msg = "Arithmetic between AmplitudesMap objects requires identical keys"
        log.error(msg)
        raise RuntimeError(msg)
    for k, v in self._internal.items():
        if v.n_global != other[k].n_global:
            msg = "Number of global amplitudes not equal for key '{}'".format(k)
            log.error(msg)
            raise RuntimeError(msg)
        if v.n_local != other[k].n_local:
            msg = "Number of local amplitudes not equal for key '{}'".format(k)
            log.error(msg)
            raise RuntimeError(msg)

accel_used(state)

Source code in toast/templates/amplitudes.py
976
977
978
979
def accel_used(self, state):
    super().accel_used(state)
    for k, v in self._internal.items():
        v.accel_used(state)

dot(other)

Dot product of all corresponding Amplitudes.

Parameters:

Name Type Description Default
other AmplitudesMap

The other instance.

required
Result

(float): The dot product.

Source code in toast/templates/amplitudes.py
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
def dot(self, other):
    """Dot product of all corresponding Amplitudes.

    Args:
        other (AmplitudesMap):  The other instance.

    Result:
        (float):  The dot product.

    """
    log = Logger.get()
    if not isinstance(other, AmplitudesMap):
        msg = "dot product must be with another AmplitudesMap object"
        log.error(msg)
        raise RuntimeError(msg)
    self._check_other(other)
    result = 0.0
    for k, v in self._internal.items():
        result += v.dot(other[k])
    return result

duplicate()

Return a copy of the data.

Source code in toast/templates/amplitudes.py
948
949
950
951
952
953
def duplicate(self):
    """Return a copy of the data."""
    ret = AmplitudesMap()
    for k, v in self._internal.items():
        ret[k] = v.duplicate()
    return ret

reset()

Set all amplitude values to zero.

Source code in toast/templates/amplitudes.py
938
939
940
941
def reset(self):
    """Set all amplitude values to zero."""
    for k, v in self._internal.items():
        v.reset()

reset_flags()

Set all flag values to zero.

Source code in toast/templates/amplitudes.py
943
944
945
946
def reset_flags(self):
    """Set all flag values to zero."""
    for k, v in self._internal.items():
        v.reset_flags()

toast.templates.Template

Bases: TraitConfig

Base class for timestream templates.

A template defines a mapping to / from timestream values to a set of template amplitudes. These amplitudes are usually quantities being solved as part of the map-making. Examples of templates might be destriping baseline offsets, azimuthally binned ground pickup, etc.

The template amplitude data may be distributed in a variety of ways. For some types of templates, every process may have their own unique set of amplitudes based on the data that they have locally. In other cases, every process may have a full local copy of all template amplitudes. There might also be cases where each process has a non-unique subset of amplitude values (similar to the way that pixel domain quantities are distributed).

Source code in toast/templates/template.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
class Template(TraitConfig):
    """Base class for timestream templates.

    A template defines a mapping to / from timestream values to a set of template
    amplitudes.  These amplitudes are usually quantities being solved as part of the
    map-making.  Examples of templates might be destriping baseline offsets,
    azimuthally binned ground pickup, etc.

    The template amplitude data may be distributed in a variety of ways.  For some
    types of templates, every process may have their own unique set of amplitudes based
    on the data that they have locally.  In other cases, every process may have a full
    local copy of all template amplitudes.  There might also be cases where each
    process has a non-unique subset of amplitude values (similar to the way that
    pixel domain quantities are distributed).

    """

    # Note:  The TraitConfig base class defines a "name" attribute.

    data = Instance(
        klass=Data,
        allow_none=True,
        help="This must be an instance of a Data class (or None)",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_data = Unicode(
        defaults.det_data,
        allow_none=True,
        help="Observation detdata key for the timestream data",
    )

    det_data_units = Unit(
        defaults.det_data_units, help="Desired units of detector data"
    )

    det_mask = Int(
        defaults.det_mask_invalid,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for solver flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience, help="Bit mask value for solver flags"
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("data")
    def _check_data(self, proposal):
        dat = proposal["value"]
        if dat is not None:
            if not isinstance(dat, Data):
                raise traitlets.TraitError("data should be a Data instance")
        return dat

    @traitlets.observe("data")
    def initialize(self, change):
        newdata = change["new"]
        if newdata is not None:
            self._initialize(newdata)

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _initialize(self, new_data):
        # Derived classes should implement this method to do any set up (like
        # computing the number of amplitudes) whenever the data changes.
        raise NotImplementedError("Derived class must implement _initialize()")

    def _check_enabled(self, use_accel=None):
        if self.data is None:
            raise RuntimeError(
                "You must set the data trait before calling template methods"
            )
        if self.enabled:
            if use_accel and not self.supports_accel():
                msg = f"Template {self.name} does not support accelerator, "
                msg += "cannot specify use_accel=True"
                raise RuntimeError(msg)
            return True
        else:
            log = Logger.get()
            if self.data.comm.world_rank == 0:
                msg = f"Template {self.name} is disabled, skipping calls to all methods"
                log.debug(msg)
            return False

    def _detectors(self):
        # Derived classes should return the list of detectors they support.
        raise NotImplementedError("Derived class must implement _detectors()")

    def detectors(self):
        """Return a list of detectors supported by the template.

        This list will change whenever the `data` trait is set, which initializes
        the template.

        Returns:
            (list):  The detectors with local amplitudes across all observations.

        """
        if self._check_enabled():
            return self._detectors()

    def _zeros(self):
        raise NotImplementedError("Derived class must implement _zeros()")

    def zeros(self):
        """Return an Amplitudes object filled with zeros.

        This returns an Amplitudes instance with appropriate dimensions for this
        template.  This will raise an exception if called before the `data` trait
        is set.

        Returns:
            (Amplitudes):  Zero amplitudes.

        """
        if self._check_enabled():
            return self._zeros()

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        raise NotImplementedError("Derived class must implement _add_to_signal()")

    @function_timer_stackskip
    def add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs):
        """Accumulate the projected amplitudes to a timestream.

        This performs the operation:

        .. math::
            s += F \\cdot a

        Where `s` is the det_data signal, `F` is the template and `a` is the amplitudes.

        Args:
            detector (str):  The detector name.
            amplitudes (Amplitudes):  The Amplitude values for this template.

        Returns:
            None

        """
        if self._check_enabled(use_accel=use_accel):
            self._add_to_signal(detector, amplitudes, use_accel=use_accel, **kwargs)

    def _project_signal(self, detector, amplitudes, **kwargs):
        raise NotImplementedError("Derived class must implement _project_signal()")

    @function_timer_stackskip
    def project_signal(self, detector, amplitudes, use_accel=None, **kwargs):
        """Project a timestream into template amplitudes.

        This performs:

        .. math::
            a += F^T \\cdot s

        Where `s` is the det_data signal, `F` is the template and `a` is the amplitudes.

        Args:
            detector (str):  The detector name.
            amplitudes (Amplitudes):  The Amplitude values for this template.

        Returns:
            None

        """
        if self._check_enabled(use_accel=use_accel):
            self._project_signal(detector, amplitudes, use_accel=use_accel, **kwargs)

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        # Not all Templates implement the prior
        return

    @function_timer_stackskip
    def add_prior(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
        """Apply the inverse amplitude covariance as a prior.

        This performs:

        .. math::
            a' += {C_a}^{-1} \\cdot a

        Args:
            amplitudes_in (Amplitudes):  The input Amplitude values for this template.
            amplitudes_out (Amplitudes):  The input Amplitude values for this template.

        Returns:
            None

        """
        if self._check_enabled(use_accel=use_accel):
            self._add_prior(
                amplitudes_in, amplitudes_out, use_accel=use_accel, **kwargs
            )

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        raise NotImplementedError("Derived class must implement _apply_precond()")

    @function_timer_stackskip
    def apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
        """Apply the template preconditioner.

        Formally, the preconditioner "M" is an approximation to the "design matrix"
        (the "A" matrix in "Ax = b").  This function applies the inverse preconditioner
        to the template amplitudes:

        .. math::
            a' += M^{-1} \\cdot a

        Args:
            amplitudes_in (Amplitudes):  The input Amplitude values for this template.
            amplitudes_out (Amplitudes):  The input Amplitude values for this template.

        Returns:
            None

        """
        if self._check_enabled(use_accel=use_accel):
            self._apply_precond(
                amplitudes_in, amplitudes_out, use_accel=use_accel, **kwargs
            )

    @classmethod
    def get_class_config_path(cls):
        return "/templates/{}".format(cls.__qualname__)

    def get_config_path(self):
        if self.name is None:
            return None
        return "/templates/{}".format(self.name)

    @classmethod
    def get_class_config(cls, input=None):
        """Return a dictionary of the default traits of an Template class.

        This returns a new or appended dictionary.  The class instance properties are
        contained in a dictionary found in result["templates"][cls.name].

        If the specified named location in the input config already exists then an
        exception is raised.

        Args:
            input (dict):  The optional input dictionary to update.

        Returns:
            (dict):  The created or updated dictionary.

        """
        return super().get_class_config(section="templates", input=input)

    def get_config(self, input=None):
        """Return a dictionary of the current traits of a Template *instance*.

        This returns a new or appended dictionary.  The operator instance properties are
        contained in a dictionary found in result["templates"][self.name].

        If the specified named location in the input config already exists then an
        exception is raised.

        Args:
            input (dict):  The optional input dictionary to update.

        Returns:
            (dict):  The created or updated dictionary.

        """
        return super().get_config(section="templates", input=input)

    @classmethod
    def translate(cls, props):
        """Given a config dictionary, modify it to match the current API."""
        # For templates, the derived classes should implement this method as needed
        # and then call super().translate(props) to trigger this method.  Here we strip
        # the 'API' key from the config.
        props = super().translate(props)
        if "API" in props:
            del props["API"]
        return props

data = Instance(klass=Data, allow_none=True, help='This must be an instance of a Data class (or None)') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, allow_none=True, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Desired units of detector data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for solver flags') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for solver flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_invalid, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/templates/template.py
98
99
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/template.py
208
209
210
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    # Not all Templates implement the prior
    return

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/template.py
158
159
def _add_to_signal(self, detector, amplitudes, **kwargs):
    raise NotImplementedError("Derived class must implement _add_to_signal()")

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/template.py
234
235
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    raise NotImplementedError("Derived class must implement _apply_precond()")

_check_data(proposal)

Source code in toast/templates/template.py
84
85
86
87
88
89
90
@traitlets.validate("data")
def _check_data(self, proposal):
    dat = proposal["value"]
    if dat is not None:
        if not isinstance(dat, Data):
            raise traitlets.TraitError("data should be a Data instance")
    return dat

_check_det_mask(proposal)

Source code in toast/templates/template.py
70
71
72
73
74
75
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_enabled(use_accel=None)

Source code in toast/templates/template.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def _check_enabled(self, use_accel=None):
    if self.data is None:
        raise RuntimeError(
            "You must set the data trait before calling template methods"
        )
    if self.enabled:
        if use_accel and not self.supports_accel():
            msg = f"Template {self.name} does not support accelerator, "
            msg += "cannot specify use_accel=True"
            raise RuntimeError(msg)
        return True
    else:
        log = Logger.get()
        if self.data.comm.world_rank == 0:
            msg = f"Template {self.name} is disabled, skipping calls to all methods"
            log.debug(msg)
        return False

_check_flag_mask(proposal)

Source code in toast/templates/template.py
77
78
79
80
81
82
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_detectors()

Source code in toast/templates/template.py
124
125
126
def _detectors(self):
    # Derived classes should return the list of detectors they support.
    raise NotImplementedError("Derived class must implement _detectors()")

_initialize(new_data)

Source code in toast/templates/template.py
101
102
103
104
def _initialize(self, new_data):
    # Derived classes should implement this method to do any set up (like
    # computing the number of amplitudes) whenever the data changes.
    raise NotImplementedError("Derived class must implement _initialize()")

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/template.py
183
184
def _project_signal(self, detector, amplitudes, **kwargs):
    raise NotImplementedError("Derived class must implement _project_signal()")

_zeros()

Source code in toast/templates/template.py
141
142
def _zeros(self):
    raise NotImplementedError("Derived class must implement _zeros()")

add_prior(amplitudes_in, amplitudes_out, use_accel=None, **kwargs)

Apply the inverse amplitude covariance as a prior.

This performs:

.. math:: a' += {C_a}^{-1} \cdot a

Parameters:

Name Type Description Default
amplitudes_in Amplitudes

The input Amplitude values for this template.

required
amplitudes_out Amplitudes

The input Amplitude values for this template.

required

Returns:

Type Description

None

Source code in toast/templates/template.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@function_timer_stackskip
def add_prior(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
    """Apply the inverse amplitude covariance as a prior.

    This performs:

    .. math::
        a' += {C_a}^{-1} \\cdot a

    Args:
        amplitudes_in (Amplitudes):  The input Amplitude values for this template.
        amplitudes_out (Amplitudes):  The input Amplitude values for this template.

    Returns:
        None

    """
    if self._check_enabled(use_accel=use_accel):
        self._add_prior(
            amplitudes_in, amplitudes_out, use_accel=use_accel, **kwargs
        )

add_to_signal(detector, amplitudes, use_accel=None, **kwargs)

Accumulate the projected amplitudes to a timestream.

This performs the operation:

.. math:: s += F \cdot a

Where s is the det_data signal, F is the template and a is the amplitudes.

Parameters:

Name Type Description Default
detector str

The detector name.

required
amplitudes Amplitudes

The Amplitude values for this template.

required

Returns:

Type Description

None

Source code in toast/templates/template.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@function_timer_stackskip
def add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs):
    """Accumulate the projected amplitudes to a timestream.

    This performs the operation:

    .. math::
        s += F \\cdot a

    Where `s` is the det_data signal, `F` is the template and `a` is the amplitudes.

    Args:
        detector (str):  The detector name.
        amplitudes (Amplitudes):  The Amplitude values for this template.

    Returns:
        None

    """
    if self._check_enabled(use_accel=use_accel):
        self._add_to_signal(detector, amplitudes, use_accel=use_accel, **kwargs)

apply_precond(amplitudes_in, amplitudes_out, use_accel=None, **kwargs)

Apply the template preconditioner.

Formally, the preconditioner "M" is an approximation to the "design matrix" (the "A" matrix in "Ax = b"). This function applies the inverse preconditioner to the template amplitudes:

.. math:: a' += M^{-1} \cdot a

Parameters:

Name Type Description Default
amplitudes_in Amplitudes

The input Amplitude values for this template.

required
amplitudes_out Amplitudes

The input Amplitude values for this template.

required

Returns:

Type Description

None

Source code in toast/templates/template.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
@function_timer_stackskip
def apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
    """Apply the template preconditioner.

    Formally, the preconditioner "M" is an approximation to the "design matrix"
    (the "A" matrix in "Ax = b").  This function applies the inverse preconditioner
    to the template amplitudes:

    .. math::
        a' += M^{-1} \\cdot a

    Args:
        amplitudes_in (Amplitudes):  The input Amplitude values for this template.
        amplitudes_out (Amplitudes):  The input Amplitude values for this template.

    Returns:
        None

    """
    if self._check_enabled(use_accel=use_accel):
        self._apply_precond(
            amplitudes_in, amplitudes_out, use_accel=use_accel, **kwargs
        )

detectors()

Return a list of detectors supported by the template.

This list will change whenever the data trait is set, which initializes the template.

Returns:

Type Description
list

The detectors with local amplitudes across all observations.

Source code in toast/templates/template.py
128
129
130
131
132
133
134
135
136
137
138
139
def detectors(self):
    """Return a list of detectors supported by the template.

    This list will change whenever the `data` trait is set, which initializes
    the template.

    Returns:
        (list):  The detectors with local amplitudes across all observations.

    """
    if self._check_enabled():
        return self._detectors()

get_class_config(input=None) classmethod

Return a dictionary of the default traits of an Template class.

This returns a new or appended dictionary. The class instance properties are contained in a dictionary found in result["templates"][cls.name].

If the specified named location in the input config already exists then an exception is raised.

Parameters:

Name Type Description Default
input dict

The optional input dictionary to update.

None

Returns:

Type Description
dict

The created or updated dictionary.

Source code in toast/templates/template.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
@classmethod
def get_class_config(cls, input=None):
    """Return a dictionary of the default traits of an Template class.

    This returns a new or appended dictionary.  The class instance properties are
    contained in a dictionary found in result["templates"][cls.name].

    If the specified named location in the input config already exists then an
    exception is raised.

    Args:
        input (dict):  The optional input dictionary to update.

    Returns:
        (dict):  The created or updated dictionary.

    """
    return super().get_class_config(section="templates", input=input)

get_class_config_path() classmethod

Source code in toast/templates/template.py
261
262
263
@classmethod
def get_class_config_path(cls):
    return "/templates/{}".format(cls.__qualname__)

get_config(input=None)

Return a dictionary of the current traits of a Template instance.

This returns a new or appended dictionary. The operator instance properties are contained in a dictionary found in result["templates"][self.name].

If the specified named location in the input config already exists then an exception is raised.

Parameters:

Name Type Description Default
input dict

The optional input dictionary to update.

None

Returns:

Type Description
dict

The created or updated dictionary.

Source code in toast/templates/template.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def get_config(self, input=None):
    """Return a dictionary of the current traits of a Template *instance*.

    This returns a new or appended dictionary.  The operator instance properties are
    contained in a dictionary found in result["templates"][self.name].

    If the specified named location in the input config already exists then an
    exception is raised.

    Args:
        input (dict):  The optional input dictionary to update.

    Returns:
        (dict):  The created or updated dictionary.

    """
    return super().get_config(section="templates", input=input)

get_config_path()

Source code in toast/templates/template.py
265
266
267
268
def get_config_path(self):
    if self.name is None:
        return None
    return "/templates/{}".format(self.name)

initialize(change)

Source code in toast/templates/template.py
92
93
94
95
96
@traitlets.observe("data")
def initialize(self, change):
    newdata = change["new"]
    if newdata is not None:
        self._initialize(newdata)

project_signal(detector, amplitudes, use_accel=None, **kwargs)

Project a timestream into template amplitudes.

This performs:

.. math:: a += F^T \cdot s

Where s is the det_data signal, F is the template and a is the amplitudes.

Parameters:

Name Type Description Default
detector str

The detector name.

required
amplitudes Amplitudes

The Amplitude values for this template.

required

Returns:

Type Description

None

Source code in toast/templates/template.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@function_timer_stackskip
def project_signal(self, detector, amplitudes, use_accel=None, **kwargs):
    """Project a timestream into template amplitudes.

    This performs:

    .. math::
        a += F^T \\cdot s

    Where `s` is the det_data signal, `F` is the template and `a` is the amplitudes.

    Args:
        detector (str):  The detector name.
        amplitudes (Amplitudes):  The Amplitude values for this template.

    Returns:
        None

    """
    if self._check_enabled(use_accel=use_accel):
        self._project_signal(detector, amplitudes, use_accel=use_accel, **kwargs)

translate(props) classmethod

Given a config dictionary, modify it to match the current API.

Source code in toast/templates/template.py
307
308
309
310
311
312
313
314
315
316
@classmethod
def translate(cls, props):
    """Given a config dictionary, modify it to match the current API."""
    # For templates, the derived classes should implement this method as needed
    # and then call super().translate(props) to trigger this method.  Here we strip
    # the 'API' key from the config.
    props = super().translate(props)
    if "API" in props:
        del props["API"]
    return props

zeros()

Return an Amplitudes object filled with zeros.

This returns an Amplitudes instance with appropriate dimensions for this template. This will raise an exception if called before the data trait is set.

Returns:

Type Description
Amplitudes

Zero amplitudes.

Source code in toast/templates/template.py
144
145
146
147
148
149
150
151
152
153
154
155
156
def zeros(self):
    """Return an Amplitudes object filled with zeros.

    This returns an Amplitudes instance with appropriate dimensions for this
    template.  This will raise an exception if called before the `data` trait
    is set.

    Returns:
        (Amplitudes):  Zero amplitudes.

    """
    if self._check_enabled():
        return self._zeros()

toast.templates.Offset

Bases: Template

This class represents noise fluctuations as a step function.

Every process stores the amplitudes for its local data, which is disjoint from the amplitudes on other processes. We project amplitudes one detector at a time, and so we arrange our template amplitudes in "detector major" order and store offsets into this for each observation.

Source code in toast/templates/offset/offset.py
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
@trait_docs
class Offset(Template):
    """This class represents noise fluctuations as a step function.

    Every process stores the amplitudes for its local data, which is disjoint from the
    amplitudes on other processes.  We project amplitudes one detector at a time, and
    so we arrange our template amplitudes in "detector major" order and store offsets
    into this for each observation.

    """

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    step_time = Quantity(10000.0 * u.second, help="Time per baseline step")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    noise_model = Unicode(
        None,
        allow_none=True,
        help="Observation key containing the optional noise model",
    )

    good_fraction = Float(
        0.5,
        help="Fraction of unflagged samples needed to keep a given offset amplitude",
    )

    use_noise_prior = Bool(
        False,
        help="Use detector PSDs to build the noise prior and preconditioner",
    )

    precond_width = Int(20, help="Preconditioner width in terms of offsets / baselines")

    debug_plots = Unicode(
        None,
        allow_none=True,
        help="If not None, make debugging plots in this directory",
    )

    @traitlets.validate("precond_width")
    def _check_precond_width(self, proposal):
        w = proposal["value"]
        if w < 1:
            raise traitlets.TraitError("Preconditioner width should be >= 1")
        return w

    @traitlets.validate("good_fraction")
    def _check_good_fraction(self, proposal):
        f = proposal["value"]
        if f < 0.0 or f > 1.0:
            raise traitlets.TraitError("good_fraction should be a value from 0 to 1")
        return f

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def clear(self):
        """Delete the underlying C-allocated memory."""
        if hasattr(self, "_offsetvar"):
            del self._offsetvar
        if hasattr(self, "_offsetvar_raw"):
            self._offsetvar_raw.clear()
            del self._offsetvar_raw

    def __del__(self):
        self.clear()

    @function_timer
    def _initialize(self, new_data):
        log = Logger.get()
        # This function is called whenever a new data trait is assigned to the template.
        # Clear any C-allocated buffers from previous uses.
        self.clear()

        # Compute the step boundaries for every observation and the number of
        # amplitude values on this process.  Every process only stores amplitudes
        # for its locally assigned data.

        if self.use_noise_prior and self.noise_model is None:
            raise RuntimeError("cannot use noise prior without specifying noise_model")

        # Units for inverse variance weighting
        detnoise_units = 1.0 / self.det_data_units**2

        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Amplitude lengths of all views for each obs
        self._obs_views = dict()

        # Sample rate for each obs.
        self._obs_rate = dict()

        # Frequency bins for the noise prior for each obs.
        self._freq = dict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        for iob, ob in enumerate(new_data.obs):
            # Compute sample rate from timestamps
            (rate, dt, dt_min, dt_max, dt_std) = rate_from_times(ob.shared[self.times])
            self._obs_rate[iob] = rate

            # The step length for this observation
            step_length = self._step_length(
                self.step_time.to_value(u.second), self._obs_rate[iob]
            )

            # Track number of offset amplitudes per view, per det.
            ob_views = list()
            for view_slice in ob.view[self.view]:
                view_len = None
                if view_slice.start < 0:
                    # This is a view of the whole obs
                    view_len = ob.n_local_samples
                else:
                    view_len = view_slice.stop - view_slice.start
                view_n_amp = view_len // step_length
                if view_n_amp * step_length < view_len:
                    view_n_amp += 1
                ob_views.append(view_n_amp)
            self._obs_views[iob] = np.array(ob_views, dtype=np.int64)

            # The noise model.
            if self.noise_model is not None:
                if self.noise_model not in ob:
                    msg = "Observation {}:  noise model {} does not exist".format(
                        ob.name, self.noise_model
                    )
                    log.error(msg)
                    raise RuntimeError(msg)

                # Determine the binning for the noise prior
                if self.use_noise_prior:
                    obstime = ob.shared[self.times][-1] - ob.shared[self.times][0]
                    tbase = self.step_time.to_value(u.second)
                    powmin = np.floor(np.log10(1 / obstime)) - 1
                    powmax = min(
                        np.ceil(np.log10(1 / tbase)) + 2, np.log10(self._obs_rate[iob])
                    )
                    self._freq[iob] = np.logspace(powmin, powmax, 1000)

            # Build up detector list
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None

        self._all_dets = list(all_dets.keys())

        # Go through the data one local detector at a time and compute the offsets
        # into the amplitudes.

        self._det_start = dict()

        offset = 0
        for det in self._all_dets:
            self._det_start[det] = offset
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                offset += np.sum(self._obs_views[iob])

        # Now we know the total number of amplitudes.

        self._n_local = offset
        if new_data.comm.comm_world is None:
            self._n_global = self._n_local
        else:
            self._n_global = new_data.comm.comm_world.allreduce(
                self._n_local, op=MPI.SUM
            )

        # Now that we know the number of amplitudes, we go through the solver flags
        # and determine what amplitudes, if any, are poorly constrained.  These are
        # stored internally as a bool array, and used when constructing a new
        # Amplitudes object.  We also compute and store the variance of each amplitude,
        # based on the noise weight of the detector and the number of flagged samples.

        # Boolean flags
        self._amp_flags = np.zeros(self._n_local, dtype=bool)

        # Here we track the variance of the offsets based on the detector noise weights
        # and the number of unflagged / good samples per offset.
        if self._n_local == 0:
            self._offsetvar_raw = None
            self._offsetvar = None
        else:
            self._offsetvar_raw = AlignedF64.zeros(self._n_local)
            self._offsetvar = self._offsetvar_raw.array()

        offset = 0
        for det in self._all_dets:
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue

                # "Noise weight" (time-domain inverse variance)
                detnoise = 1.0
                if self.noise_model is not None:
                    detnoise = (
                        ob[self.noise_model]
                        .detector_weight(det)
                        .to_value(detnoise_units)
                    )

                # The step length for this observation
                step_length = self._step_length(
                    self.step_time.to_value(u.second), self._obs_rate[iob]
                )

                # Loop over views
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    view_samples = None
                    if vw.start < 0:
                        # This is a view of the whole obs
                        view_samples = ob.n_local_samples
                    else:
                        view_samples = vw.stop - vw.start
                    n_amp_view = self._obs_views[iob][ivw]

                    # Move this loop to compiled code if it is slow...
                    # Note:  we are building the offset amplitude *variance*, which is
                    # why the "noise weight" (inverse variance) is in the denominator.
                    if detnoise <= 0:
                        # This detector is cut in the noise model
                        for amp in range(n_amp_view):
                            self._offsetvar[offset + amp] = 0.0
                            self._amp_flags[offset + amp] = True
                    else:
                        if self.det_flags is None:
                            voff = 0
                            for amp in range(n_amp_view):
                                amplen = step_length
                                if amp == n_amp_view - 1:
                                    amplen = view_samples - voff
                                self._offsetvar[offset + amp] = 1.0 / (
                                    detnoise * amplen
                                )
                                voff += step_length
                        else:
                            flags = views.detdata[self.det_flags][ivw]
                            voff = 0
                            for amp in range(n_amp_view):
                                amplen = step_length
                                if amp == n_amp_view - 1:
                                    amplen = view_samples - voff
                                n_good = amplen - np.count_nonzero(
                                    flags[det][voff : voff + amplen]
                                    & self.det_flag_mask
                                )
                                if (n_good / amplen) <= self.good_fraction:
                                    # This detector is cut or too many samples flagged
                                    self._offsetvar[offset + amp] = 0.0
                                    self._amp_flags[offset + amp] = True
                                else:
                                    # Keep this
                                    self._offsetvar[offset + amp] = 1.0 / (
                                        detnoise * n_good
                                    )
                                voff += step_length
                    offset += n_amp_view

        # Compute the amplitude noise filter and preconditioner for each detector
        # and each view.  The "noise filter" is the real-space inverse amplitude
        # covariance, which is constructed from the Fourier domain amplitude PSD.
        #
        # The preconditioner is either a diagonal one using the amplitude variance,
        # or is a banded one using the amplitude covariance plus the diagonal term.

        self._filters = dict()
        self._precond = dict()

        if self.use_noise_prior:
            offset = 0
            for det in self._all_dets:
                for iob, ob in enumerate(new_data.obs):
                    if det not in self._obs_dets[iob]:
                        continue
                    if iob not in self._filters:
                        self._filters[iob] = dict()
                        self._precond[iob] = dict()

                    offset_psd = self._get_offset_psd(
                        ob[self.noise_model],
                        self._freq[iob],
                        self.step_time.to_value(u.second),
                        det,
                    )

                    if self.debug_plots is not None:
                        set_matplotlib_backend()
                        import matplotlib.pyplot as plt

                        fname = os.path.join(
                            self.debug_plots, f"{self.name}_{det}_{ob.name}_psd.pdf"
                        )
                        psdfreq = ob[self.noise_model].freq(det).to_value(u.Hz)
                        psd = (
                            ob[self.noise_model]
                            .psd(det)
                            .to_value(self.det_data_units**2 * u.second)
                        )
                        corrpsd = self._remove_white_noise(psdfreq, psd)

                        fig = plt.figure(figsize=[12, 12])
                        ax = fig.add_subplot(2, 1, 1)
                        ax.loglog(
                            psdfreq,
                            psd,
                            color="black",
                            label="Original PSD",
                        )
                        ax.loglog(
                            psdfreq,
                            corrpsd,
                            color="red",
                            label="Correlated PSD",
                        )
                        ax.set_xlabel("Frequency [Hz]")
                        ax.set_ylabel("PSD [K$^2$ / Hz]")
                        ax.legend(loc="best")

                        ax = fig.add_subplot(2, 1, 2)
                        ax.loglog(
                            self._freq[iob],
                            offset_psd,
                            label=f"Offset PSD",
                        )
                        ax.set_xlabel("Frequency [Hz]")
                        ax.set_ylabel("PSD [K$^2$ / Hz]")
                        ax.legend(loc="best")
                        fig.savefig(fname)
                        plt.close(fig)

                    # "Noise weight" (time-domain inverse variance)
                    detnoise = (
                        ob[self.noise_model]
                        .detector_weight(det)
                        .to_value(detnoise_units)
                    )

                    # Log version of offset PSD and its inverse for interpolation
                    logfreq = np.log(self._freq[iob])
                    logpsd = np.log(offset_psd)
                    logfilter = np.log(1.0 / offset_psd)

                    # Compute the list of filters and preconditioners (one per view)
                    # For this detector.

                    self._filters[iob][det] = list()
                    self._precond[iob][det] = list()

                    if self.debug_plots is not None:
                        ffilter = os.path.join(
                            self.debug_plots, f"{self.name}_{det}_{ob.name}_filters.pdf"
                        )
                        fprec = os.path.join(
                            self.debug_plots, f"{self.name}_{det}_{ob.name}_prec.pdf"
                        )
                        figfilter = plt.figure(figsize=[12, 8])
                        axfilter = figfilter.add_subplot(1, 1, 1)
                        figprec = plt.figure(figsize=[12, 8])
                        axprec = figprec.add_subplot(1, 1, 1)

                    # Loop over views
                    views = ob.view[self.view]
                    for ivw, vw in enumerate(views):
                        view_samples = None
                        if vw.start < 0:
                            # This is a view of the whole obs
                            view_samples = ob.n_local_samples
                        else:
                            view_samples = vw.stop - vw.start
                        n_amp_view = self._obs_views[iob][ivw]
                        offsetvar_slice = self._offsetvar[offset : offset + n_amp_view]

                        filterlen = 2
                        while filterlen < 2 * n_amp_view:
                            filterlen *= 2
                        filterfreq = np.fft.rfftfreq(
                            filterlen, self.step_time.to_value(u.second)
                        )

                        # Recall that the "noise filter" is the inverse amplitude
                        # covariance, which is why we are using 1/PSD.  Also note that
                        # the truncate function shifts the filter to be symmetric about
                        # the center, which is needed for use with scipy.signal.convolve
                        # If we move this application back to compiled FFT based
                        # methods, we should instead keep this filter in the fourier
                        # domain.

                        noisefilter = self._truncate(
                            np.fft.irfft(
                                self._interpolate_psd(filterfreq, logfreq, logfilter)
                            )
                        )

                        self._filters[iob][det].append(noisefilter)

                        if self.debug_plots is not None:
                            axfilter.plot(
                                np.arange(len(noisefilter)),
                                noisefilter,
                                label=f"Noise filter {ivw}",
                            )

                        # Build the preconditioner
                        lower = None
                        preconditioner = None

                        if self.precond_width == 1:
                            # We are using a Toeplitz preconditioner.  The first row
                            # of the matrix is the inverse FFT of the offset PSD,
                            # with an added zero-lag component from the detector
                            # weight.  NOTE:  the truncate function shifts the real
                            # space filter to the center of the vector.
                            preconditioner = self._truncate(
                                np.fft.irfft(
                                    self._interpolate_psd(filterfreq, logfreq, logpsd)
                                )
                            )
                            icenter = preconditioner.size // 2
                            if detnoise != 0:
                                preconditioner[icenter] += 1.0 / detnoise
                            if self.debug_plots is not None:
                                axprec.plot(
                                    np.arange(len(preconditioner)),
                                    preconditioner,
                                    label=f"Toeplitz preconditioner {ivw}",
                                )
                        else:
                            # We are using a banded matrix for the preconditioner.
                            # This contains a Toeplitz component from the inverse
                            # offset variance in the LHS, and another diagonal term
                            # from the individual offset variance.
                            #
                            # NOTE:  Instead of directly solving x = M^{-1} b, we do
                            # not invert "M" and solve M x = b using the Cholesky
                            # decomposition of M (*not* M^{-1}).
                            icenter = noisefilter.size // 2
                            wband = min(self.precond_width, icenter)
                            precond_width = max(
                                wband, min(self.precond_width, n_amp_view)
                            )
                            preconditioner = np.zeros(
                                [precond_width, n_amp_view], dtype=np.float64
                            )
                            if detnoise != 0:
                                preconditioner[0, :] = 1.0 / offsetvar_slice
                            preconditioner[:wband, :] += np.repeat(
                                noisefilter[icenter : icenter + wband, np.newaxis],
                                n_amp_view,
                                1,
                            )
                            lower = True
                            preconditioner = scipy.linalg.cholesky_banded(
                                preconditioner,
                                overwrite_ab=True,
                                lower=lower,
                                check_finite=True,
                            )
                            if self.debug_plots is not None:
                                axprec.plot(
                                    np.arange(len(preconditioner)),
                                    preconditioner,
                                    label=f"Banded preconditioner {ivw}",
                                )
                        self._precond[iob][det].append((preconditioner, lower))
                        offset += n_amp_view

                    if self.debug_plots is not None:
                        axfilter.set_xlabel("Sample Lag")
                        axfilter.set_ylabel("Amplitude")
                        axfilter.legend(loc="best")
                        figfilter.savefig(ffilter)
                        axprec.set_xlabel("Sample Lag")
                        axprec.set_ylabel("Amplitude")
                        axprec.legend(loc="best")
                        figprec.savefig(fprec)
                        plt.close(figfilter)
                        plt.close(figprec)

        log.verbose(f"Offset variance = {self._offsetvar}")
        return

    # Helper functions for noise / preconditioner calculations

    def _interpolate_psd(self, x, lfreq, lpsd):
        # Threshold for zero frequency
        thresh = 1.0e-6
        lowf = x < thresh
        good = np.logical_not(lowf)

        logx = np.empty_like(x)
        logx[lowf] = np.log(thresh)
        logx[good] = np.log(x[good])
        logresult = np.interp(logx, lfreq, lpsd)
        result = np.exp(logresult)
        return result

    def _truncate(self, noisefilter, lim=1e-4):
        icenter = noisefilter.size // 2
        ind = np.abs(noisefilter[:icenter]) > np.abs(noisefilter[0]) * lim
        icut = np.argwhere(ind)[-1][0]
        if icut % 2 == 0:
            icut += 1
        noisefilter = np.roll(noisefilter, icenter)
        noisefilter = noisefilter[icenter - icut : icenter + icut + 1]
        return noisefilter

    def _remove_white_noise(self, freq, psd):
        """Remove the white noise component of the PSD."""
        corrpsd = psd.copy()
        n_corrpsd = len(corrpsd)
        plat_off = int(0.8 * n_corrpsd)
        if n_corrpsd - plat_off < 10:
            if n_corrpsd < 10:
                # Crazy spectrum...
                plat_off = 0
            else:
                plat_off = n_corrpsd - 10

        cfreq = np.log(freq[plat_off:])
        cdata = np.log(corrpsd[plat_off:])

        def lin_func(x, a, b, c):
            # Line
            return a * (x - b) + c

        params, params_cov = scipy.optimize.curve_fit(
            lin_func, cfreq, cdata, p0=[0.0, cfreq[-1], cdata[-1]]
        )

        cdata = lin_func(cfreq, params[0], params[1], params[2])
        cdata = np.exp(cdata)
        plat = cdata[-1]

        # Given the range between the white noise plateau and the maximum
        # values of the PSD, we set a minimum value for any spectral bins
        # that are small or negative.
        corrmax = np.amax(corrpsd)
        corrthresh = 1.0e-10 * corrmax - plat
        corrpsd -= plat
        corrpsd[corrpsd < corrthresh] = corrthresh
        return corrpsd

    def _get_offset_psd(self, noise, freq, step_time, det):
        """Compute the PSD of the baseline offsets."""
        psdfreq = noise.freq(det).to_value(u.Hz)
        psd = noise.psd(det).to_value(self.det_data_units**2 * u.second)
        rate = noise.rate(det).to_value(u.Hz)

        # Remove the white noise component from the PSD
        psd = self._remove_white_noise(psdfreq, psd)

        # Log PSD for interpolation
        logfreq = np.log(psdfreq)
        logpsd = np.log(psd)

        # The calculation of `offset_psd` is based on Keihänen, E. et al:
        # "Making CMB temperature and polarization maps with Madam",
        # A&A 510:A57, 2010, with a small algebra correction.

        m_max = 5
        tbase = step_time
        fbase = 1.0 / tbase

        def g(f, m):
            # The frequencies are constructed without the zero frequency,
            # so we do not need to handle it here.
            # result = np.sin(np.pi * f * tbase) ** 2 / (np.pi * (f * tbase + m)) ** 2
            x = np.pi * (f * tbase + m)
            bad = np.abs(x) < 1.0e-30
            good = np.logical_not(bad)
            result = np.empty_like(x)
            result[bad] = 1.0
            result[good] = np.sin(x[good]) ** 2 / x[good] ** 2
            return result

        offset_psd = np.zeros_like(freq)

        # The m = 0 term
        offset_psd = self._interpolate_psd(freq, logfreq, logpsd) * g(freq, 0)

        # The remaining terms
        for m in range(1, m_max):
            # Positive m
            offset_psd[:] += self._interpolate_psd(
                freq + m * fbase, logfreq, logpsd
            ) * g(freq, m)
            # Negative m
            offset_psd[:] += self._interpolate_psd(
                freq - m * fbase, logfreq, logpsd
            ) * g(freq, -m)

        offset_psd *= fbase
        return offset_psd

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        z = Amplitudes(self.data.comm, self._n_global, self._n_local)
        if z.local_flags is not None:
            z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _step_length(self, stime, rate):
        return int(stime * rate + 0.5)

    @function_timer
    def _add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs):
        log = Logger.get()

        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        amp_offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            det_indx = ob.detdata[self.det_data].indices([detector])
            # The step length for this observation
            step_length = self._step_length(
                self.step_time.to_value(u.second), self._obs_rate[iob]
            )

            # The number of amplitudes in each view
            n_amp_views = self._obs_views[iob]

            # # DEBUGGING
            # restore_dev = False
            # prefix = "HOST"
            # if amplitudes.accel_in_use():
            #     amplitudes.accel_update_host()
            #     restore_dev = True
            #     prefix = "DEVICE"
            # print(
            #     f"{prefix} Add to signal input:  {amp_offset}, {n_amp_views}, {amplitudes.local}",
            #     flush=True,
            # )
            # if restore_dev:
            #     amplitudes.accel_update_device()

            # # DEBUGGING
            # restore_dev = False
            # prefix = "HOST"
            # if ob.detdata[self.det_data].accel_in_use():
            #     ob.detdata[self.det_data].accel_update_host()
            #     restore_dev = True
            #     prefix = "DEVICE"
            # tod_min = np.amin(ob.detdata[self.det_data])
            # tod_max = np.amax(ob.detdata[self.det_data])
            # print(
            #     f"{prefix} Add to signal starting TOD output:  {ob.detdata[self.det_data]}, min={tod_min}, max={tod_max}",
            #     flush=True,
            # )
            # if (np.absolute(tod_min) < 1.0e-15) and (np.absolute(tod_max) < 1.0e-15):
            #     ob.detdata[self.det_data][:] = 0
            # if restore_dev:
            #     ob.detdata[self.det_data].accel_update_device()

            offset_add_to_signal(
                step_length,
                amp_offset,
                n_amp_views,
                amplitudes.local,
                amplitudes.local_flags,
                det_indx[0],
                ob.detdata[self.det_data].data,
                ob.intervals[self.view].data,
                impl=implementation,
                use_accel=use_accel,
            )

            # # DEBUGGING
            # restore_dev = False
            # prefix = "HOST"
            # if ob.detdata[self.det_data].accel_in_use():
            #     ob.detdata[self.det_data].accel_update_host()
            #     restore_dev = True
            #     prefix = "DEVICE"
            # print(
            #     f"{prefix} Add to signal output:  {ob.detdata[self.det_data]}",
            #     flush=True,
            # )
            # if restore_dev:
            #     ob.detdata[self.det_data].accel_update_device()

            amp_offset += np.sum(n_amp_views)

    @function_timer
    def _project_signal(self, detector, amplitudes, use_accel=None, **kwargs):
        log = Logger.get()

        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        amp_offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            det_indx = ob.detdata[self.det_data].indices([detector])
            if self.det_flags is not None:
                flag_indx = ob.detdata[self.det_flags].indices([detector])
                flag_data = ob.detdata[self.det_flags].data
            else:
                flag_indx = np.array([-1], dtype=np.int32)
                flag_data = np.zeros(1, dtype=np.uint8)
            # The step length for this observation
            step_length = self._step_length(
                self.step_time.to_value(u.second), self._obs_rate[iob]
            )

            # The number of amplitudes in each view
            n_amp_views = self._obs_views[iob]

            # # DEBUGGING
            # restore_dev = False
            # prefix="HOST"
            # if ob.detdata[self.det_data].accel_in_use():
            #     ob.detdata[self.det_data].accel_update_host()
            #     restore_dev = True
            #     prefix="DEVICE"
            # print(f"{prefix} Project signal input:  {ob.detdata[self.det_data]}", flush=True)
            # if restore_dev:
            #     ob.detdata[self.det_data].accel_update_device()

            offset_project_signal(
                det_indx[0],
                ob.detdata[self.det_data].data,
                flag_indx[0],
                flag_data,
                self.det_flag_mask,
                step_length,
                amp_offset,
                n_amp_views,
                amplitudes.local,
                amplitudes.local_flags,
                ob.intervals[self.view].data,
                impl=implementation,
                use_accel=use_accel,
            )

            # restore_dev = False
            # prefix="HOST"
            # if amplitudes.accel_in_use():
            #     amplitudes.accel_update_host()
            #     restore_dev = True
            #     prefix="DEVICE"
            # print(f"{prefix} Project signal output:  {amp_offset}, {n_amp_views}, {amplitudes.local}", flush=True)
            # if restore_dev:
            #     amplitudes.accel_update_device()

            amp_offset += np.sum(n_amp_views)

    @function_timer
    def _add_prior(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
        if not self.use_noise_prior:
            # Not using the noise prior term, nothing to accumulate to output.
            return
        if use_accel:
            raise NotImplementedError(
                "offset template add_prior on accelerator not implemented"
            )
        if self.debug_plots is not None:
            set_matplotlib_backend()
            import matplotlib.pyplot as plt

        for det in self._all_dets:
            offset = self._det_start[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                    n_amp_view = self._obs_views[iob][ivw]
                    amp_slice = slice(offset, offset + n_amp_view, 1)
                    amps_in = amplitudes_in.local[amp_slice]
                    amp_flags_in = amplitudes_in.local_flags[amp_slice]
                    amps_out = amplitudes_out.local[amp_slice]
                    if det in self._filters[iob]:
                        # There is some contribution from this detector
                        amps_out[:] += scipy.signal.convolve(
                            amps_in,
                            self._filters[iob][det][ivw],
                            mode="same",
                            method="direct",
                        )

                        if self.debug_plots is not None:
                            # Find the first unused file name in the sequence
                            iter = -1
                            while iter < 0 or os.path.isfile(fname):
                                iter += 1
                                fname = os.path.join(
                                    self.debug_plots,
                                    f"{self.name}_{det}_{ob.name}_prior_{ivw}_{iter}.pdf",
                                )
                            fig = plt.figure(figsize=[12, 8])
                            ax = fig.add_subplot(1, 1, 1)
                            ax.plot(
                                np.arange(len(amps_in)),
                                amps_in,
                                color="black",
                                label="Input Amplitudes",
                            )
                            ax.plot(
                                np.arange(len(amps_in)),
                                amps_out,
                                color="red",
                                label="Output Amplitudes",
                            )

                        amps_out[amp_flags_in != 0] = 0.0

                        if self.debug_plots is not None:
                            ax.plot(
                                np.arange(len(amps_in)),
                                amps_out,
                                color="green",
                                label="Output Amplitudes (flagged)",
                            )
                            ax.set_xlabel("Amplitude Index")
                            ax.set_ylabel("Value")
                            ax.legend(loc="best")
                            fig.savefig(fname)
                            plt.close(fig)

                    else:
                        amps_out[:] = 0.0
                    offset += n_amp_view

    @function_timer
    def _apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
        if self.use_noise_prior:
            if use_accel:
                raise NotImplementedError(
                    "offset template precond on accelerator not implemented"
                )
            # Our design matrix includes a term with the inverse offset covariance.
            # This means that our preconditioner should include this term as well.
            for det in self._all_dets:
                offset = self._det_start[det]
                for iob, ob in enumerate(self.data.obs):
                    if det not in self._obs_dets[iob]:
                        continue
                    # Loop over views
                    views = ob.view[self.view]
                    for ivw, vw in enumerate(views):
                        view_samples = None
                        if vw.start < 0:
                            # This is a view of the whole obs
                            view_samples = ob.n_local_samples
                        else:
                            view_samples = vw.stop - vw.start

                        n_amp_view = self._obs_views[iob][ivw]
                        amp_slice = slice(offset, offset + n_amp_view, 1)

                        amps_in = amplitudes_in.local[amp_slice]
                        amp_flags_in = amplitudes_in.local_flags[amp_slice]
                        amps_out = None
                        if det in self._precond[iob]:
                            # We have a contribution from this detector
                            if self.precond_width <= 1:
                                # We are using a Toeplitz preconditioner.
                                # scipy.signal.convolve will use either `convolve` or
                                # `fftconvolve` depending on the size of the inputs
                                amps_out = scipy.signal.convolve(
                                    amps_in,
                                    self._precond[iob][det][ivw][0],
                                    mode="same",
                                )
                            else:
                                # Use pre-computed Cholesky decomposition.  Note that this
                                # is the decomposition of the actual preconditioner (not
                                # its inverse), since we are solving Mx=b.
                                amps_out = scipy.linalg.cho_solve_banded(
                                    self._precond[iob][det][ivw],
                                    amps_in,
                                    overwrite_b=False,
                                    check_finite=True,
                                )
                            amps_out[amp_flags_in != 0] = 0.0
                        else:
                            # This detector is cut
                            amps_out = np.zeros_like(amps_in)
                        amplitudes_out.local[amp_slice] = amps_out
                        offset += n_amp_view
        else:
            # Since we do not have a noise filter term in our LHS, our diagonal
            # preconditioner is just the application of offset variance.

            # Kernel selection
            implementation, use_accel = self.select_kernels(use_accel=use_accel)

            offset_apply_diag_precond(
                self._offsetvar,
                amplitudes_in.local,
                amplitudes_in.local_flags,
                amplitudes_out.local,
                impl=implementation,
                use_accel=use_accel,
            )
        return

    def _implementations(self):
        return [
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        ]

    def _supports_accel(self):
        return True

    @function_timer
    def write(self, amplitudes, out):
        """Write out amplitude values.

        This stores the amplitudes to files for debugging / plotting.  Since the
        Offset amplitudes are unique on each process, we open one file per process
        group and each process in the group communicates their amplitudes to one
        writer.

        Since this function is used mainly for debugging, we are a bit wasteful
        and duplicate the amplitudes in order to make things easier.

        Args:
            amplitudes (Amplitudes):  The amplitude data.
            out (str):  The output file root.

        Returns:
            None

        """

        # Copy of the amplitudes, organized by observation and detector
        obs_det_amps = dict()

        for det in self._all_dets:
            amp_offset = self._det_start[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if not ob.is_distributed_by_detector:
                    raise NotImplementedError(
                        "Only observations distributed by detector are supported"
                    )
                # The step length for this observation
                step_length = self._step_length(
                    self.step_time.to_value(u.second), self._obs_rate[iob]
                )

                if ob.name not in obs_det_amps:
                    # First time with this observation, store info about
                    # the offset spans
                    obs_det_amps[ob.name] = dict()
                    props = dict()
                    props["step_length"] = step_length
                    amp_first = list()
                    amp_last = list()
                    amp_start = list()
                    amp_stop = list()
                    for ivw, vw in enumerate(ob.intervals[self.view]):
                        n_amp_view = self._obs_views[iob][ivw]
                        for istep in range(n_amp_view):
                            istart = vw.first + istep * step_length
                            amp_first.append(istart)
                            amp_start.append(ob.shared[self.times].data[istart])
                            if istep == n_amp_view - 1:
                                istop = vw.last
                            else:
                                istop = vw.first + (istep + 1) * step_length
                            amp_last.append(istop)
                            amp_stop.append(ob.shared[self.times].data[istop - 1])
                    props["amp_first"] = np.array(amp_first, dtype=np.int64)
                    props["amp_last"] = np.array(amp_last, dtype=np.int64)
                    props["amp_start"] = np.array(amp_start, dtype=np.float64)
                    props["amp_stop"] = np.array(amp_stop, dtype=np.float64)
                    obs_det_amps[ob.name]["bounds"] = props

                # Loop over views and extract per-detector amplitudes and flags
                det_amps = list()
                det_flags = list()
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    n_amp_view = self._obs_views[iob][ivw]
                    amp_slice = slice(amp_offset, amp_offset + n_amp_view, 1)
                    det_amps.append(amplitudes.local[amp_slice])
                    det_flags.append(amplitudes.local_flags[amp_slice])
                    amp_offset += n_amp_view
                det_amps = np.concatenate(det_amps, dtype=np.float64)
                det_flags = np.concatenate(det_flags, dtype=np.uint8)
                obs_det_amps[ob.name][det] = {
                    "amps": det_amps,
                    "flags": det_flags,
                }

        # Each group writes out its amplitudes.

        # NOTE:  If/when we want to support arbitrary data distributions when
        # writing, we would need to take the data from each process and align
        # them in time rather than just extracting detector data and writing
        # to the datasets.

        for iob, ob in enumerate(self.data.obs):
            obs_local_amps = obs_det_amps[ob.name]
            if self.data.comm.group_size == 1:
                all_obs_amps = [obs_local_amps]
            else:
                all_obs_amps = self.data.comm.comm_group.gather(obs_local_amps, root=0)

            if self.data.comm.group_rank == 0:
                out_file = f"{out}_{ob.name}.h5"
                det_names = set()
                for pdata in all_obs_amps:
                    for k in pdata.keys():
                        if k != "bounds":
                            det_names.add(k)
                det_names = list(sorted(det_names))
                n_det = len(det_names)
                amp_first = all_obs_amps[0]["bounds"]["amp_first"]
                amp_last = all_obs_amps[0]["bounds"]["amp_last"]
                amp_start = all_obs_amps[0]["bounds"]["amp_start"]
                amp_stop = all_obs_amps[0]["bounds"]["amp_stop"]
                n_amp = len(amp_first)
                det_to_row = {y: x for x, y in enumerate(det_names)}
                with h5py.File(out_file, "w") as hf:
                    hf.attrs["step_length"] = all_obs_amps[0]["bounds"]["step_length"]
                    hf.attrs["detectors"] = json.dumps(det_names)
                    hamp_first = hf.create_dataset("amp_first", data=amp_first)
                    hamp_last = hf.create_dataset("amp_last", data=amp_last)
                    hamp_start = hf.create_dataset("amp_start", data=amp_start)
                    hamp_stop = hf.create_dataset("amp_stop", data=amp_stop)
                    hamps = hf.create_dataset(
                        "amplitudes",
                        (n_det, n_amp),
                        dtype=np.float64,
                    )
                    hflags = hf.create_dataset(
                        "flags",
                        (n_det, n_amp),
                        dtype=np.uint8,
                    )
                    for pdata in all_obs_amps:
                        for k, v in pdata.items():
                            if k == "bounds":
                                continue
                            row = det_to_row[k]
                            hslice = (slice(row, row + 1, 1), slice(0, n_amp, 1))
                            dslice = (slice(0, n_amp, 1),)
                            hamps.write_direct(v["amps"], dslice, hslice)
                            hflags.write_direct(v["flags"], dslice, hslice)

debug_plots = Unicode(None, allow_none=True, help='If not None, make debugging plots in this directory') class-attribute instance-attribute

good_fraction = Float(0.5, help='Fraction of unflagged samples needed to keep a given offset amplitude') class-attribute instance-attribute

noise_model = Unicode(None, allow_none=True, help='Observation key containing the optional noise model') class-attribute instance-attribute

precond_width = Int(20, help='Preconditioner width in terms of offsets / baselines') class-attribute instance-attribute

step_time = Quantity(10000.0 * u.second, help='Time per baseline step') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

use_noise_prior = Bool(False, help='Use detector PSDs to build the noise prior and preconditioner') class-attribute instance-attribute

__del__()

Source code in toast/templates/offset/offset.py
107
108
def __del__(self):
    self.clear()

__init__(**kwargs)

Source code in toast/templates/offset/offset.py
96
97
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, use_accel=None, **kwargs)

Source code in toast/templates/offset/offset.py
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
@function_timer
def _add_prior(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
    if not self.use_noise_prior:
        # Not using the noise prior term, nothing to accumulate to output.
        return
    if use_accel:
        raise NotImplementedError(
            "offset template add_prior on accelerator not implemented"
        )
    if self.debug_plots is not None:
        set_matplotlib_backend()
        import matplotlib.pyplot as plt

    for det in self._all_dets:
        offset = self._det_start[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                n_amp_view = self._obs_views[iob][ivw]
                amp_slice = slice(offset, offset + n_amp_view, 1)
                amps_in = amplitudes_in.local[amp_slice]
                amp_flags_in = amplitudes_in.local_flags[amp_slice]
                amps_out = amplitudes_out.local[amp_slice]
                if det in self._filters[iob]:
                    # There is some contribution from this detector
                    amps_out[:] += scipy.signal.convolve(
                        amps_in,
                        self._filters[iob][det][ivw],
                        mode="same",
                        method="direct",
                    )

                    if self.debug_plots is not None:
                        # Find the first unused file name in the sequence
                        iter = -1
                        while iter < 0 or os.path.isfile(fname):
                            iter += 1
                            fname = os.path.join(
                                self.debug_plots,
                                f"{self.name}_{det}_{ob.name}_prior_{ivw}_{iter}.pdf",
                            )
                        fig = plt.figure(figsize=[12, 8])
                        ax = fig.add_subplot(1, 1, 1)
                        ax.plot(
                            np.arange(len(amps_in)),
                            amps_in,
                            color="black",
                            label="Input Amplitudes",
                        )
                        ax.plot(
                            np.arange(len(amps_in)),
                            amps_out,
                            color="red",
                            label="Output Amplitudes",
                        )

                    amps_out[amp_flags_in != 0] = 0.0

                    if self.debug_plots is not None:
                        ax.plot(
                            np.arange(len(amps_in)),
                            amps_out,
                            color="green",
                            label="Output Amplitudes (flagged)",
                        )
                        ax.set_xlabel("Amplitude Index")
                        ax.set_ylabel("Value")
                        ax.legend(loc="best")
                        fig.savefig(fname)
                        plt.close(fig)

                else:
                    amps_out[:] = 0.0
                offset += n_amp_view

_add_to_signal(detector, amplitudes, use_accel=None, **kwargs)

Source code in toast/templates/offset/offset.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
@function_timer
def _add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs):
    log = Logger.get()

    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    amp_offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        det_indx = ob.detdata[self.det_data].indices([detector])
        # The step length for this observation
        step_length = self._step_length(
            self.step_time.to_value(u.second), self._obs_rate[iob]
        )

        # The number of amplitudes in each view
        n_amp_views = self._obs_views[iob]

        # # DEBUGGING
        # restore_dev = False
        # prefix = "HOST"
        # if amplitudes.accel_in_use():
        #     amplitudes.accel_update_host()
        #     restore_dev = True
        #     prefix = "DEVICE"
        # print(
        #     f"{prefix} Add to signal input:  {amp_offset}, {n_amp_views}, {amplitudes.local}",
        #     flush=True,
        # )
        # if restore_dev:
        #     amplitudes.accel_update_device()

        # # DEBUGGING
        # restore_dev = False
        # prefix = "HOST"
        # if ob.detdata[self.det_data].accel_in_use():
        #     ob.detdata[self.det_data].accel_update_host()
        #     restore_dev = True
        #     prefix = "DEVICE"
        # tod_min = np.amin(ob.detdata[self.det_data])
        # tod_max = np.amax(ob.detdata[self.det_data])
        # print(
        #     f"{prefix} Add to signal starting TOD output:  {ob.detdata[self.det_data]}, min={tod_min}, max={tod_max}",
        #     flush=True,
        # )
        # if (np.absolute(tod_min) < 1.0e-15) and (np.absolute(tod_max) < 1.0e-15):
        #     ob.detdata[self.det_data][:] = 0
        # if restore_dev:
        #     ob.detdata[self.det_data].accel_update_device()

        offset_add_to_signal(
            step_length,
            amp_offset,
            n_amp_views,
            amplitudes.local,
            amplitudes.local_flags,
            det_indx[0],
            ob.detdata[self.det_data].data,
            ob.intervals[self.view].data,
            impl=implementation,
            use_accel=use_accel,
        )

        # # DEBUGGING
        # restore_dev = False
        # prefix = "HOST"
        # if ob.detdata[self.det_data].accel_in_use():
        #     ob.detdata[self.det_data].accel_update_host()
        #     restore_dev = True
        #     prefix = "DEVICE"
        # print(
        #     f"{prefix} Add to signal output:  {ob.detdata[self.det_data]}",
        #     flush=True,
        # )
        # if restore_dev:
        #     ob.detdata[self.det_data].accel_update_device()

        amp_offset += np.sum(n_amp_views)

_apply_precond(amplitudes_in, amplitudes_out, use_accel=None, **kwargs)

Source code in toast/templates/offset/offset.py
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
@function_timer
def _apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
    if self.use_noise_prior:
        if use_accel:
            raise NotImplementedError(
                "offset template precond on accelerator not implemented"
            )
        # Our design matrix includes a term with the inverse offset covariance.
        # This means that our preconditioner should include this term as well.
        for det in self._all_dets:
            offset = self._det_start[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                # Loop over views
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    view_samples = None
                    if vw.start < 0:
                        # This is a view of the whole obs
                        view_samples = ob.n_local_samples
                    else:
                        view_samples = vw.stop - vw.start

                    n_amp_view = self._obs_views[iob][ivw]
                    amp_slice = slice(offset, offset + n_amp_view, 1)

                    amps_in = amplitudes_in.local[amp_slice]
                    amp_flags_in = amplitudes_in.local_flags[amp_slice]
                    amps_out = None
                    if det in self._precond[iob]:
                        # We have a contribution from this detector
                        if self.precond_width <= 1:
                            # We are using a Toeplitz preconditioner.
                            # scipy.signal.convolve will use either `convolve` or
                            # `fftconvolve` depending on the size of the inputs
                            amps_out = scipy.signal.convolve(
                                amps_in,
                                self._precond[iob][det][ivw][0],
                                mode="same",
                            )
                        else:
                            # Use pre-computed Cholesky decomposition.  Note that this
                            # is the decomposition of the actual preconditioner (not
                            # its inverse), since we are solving Mx=b.
                            amps_out = scipy.linalg.cho_solve_banded(
                                self._precond[iob][det][ivw],
                                amps_in,
                                overwrite_b=False,
                                check_finite=True,
                            )
                        amps_out[amp_flags_in != 0] = 0.0
                    else:
                        # This detector is cut
                        amps_out = np.zeros_like(amps_in)
                    amplitudes_out.local[amp_slice] = amps_out
                    offset += n_amp_view
    else:
        # Since we do not have a noise filter term in our LHS, our diagonal
        # preconditioner is just the application of offset variance.

        # Kernel selection
        implementation, use_accel = self.select_kernels(use_accel=use_accel)

        offset_apply_diag_precond(
            self._offsetvar,
            amplitudes_in.local,
            amplitudes_in.local_flags,
            amplitudes_out.local,
            impl=implementation,
            use_accel=use_accel,
        )
    return

_check_good_fraction(proposal)

Source code in toast/templates/offset/offset.py
89
90
91
92
93
94
@traitlets.validate("good_fraction")
def _check_good_fraction(self, proposal):
    f = proposal["value"]
    if f < 0.0 or f > 1.0:
        raise traitlets.TraitError("good_fraction should be a value from 0 to 1")
    return f

_check_precond_width(proposal)

Source code in toast/templates/offset/offset.py
82
83
84
85
86
87
@traitlets.validate("precond_width")
def _check_precond_width(self, proposal):
    w = proposal["value"]
    if w < 1:
        raise traitlets.TraitError("Preconditioner width should be >= 1")
    return w

_detectors()

Source code in toast/templates/offset/offset.py
648
649
def _detectors(self):
    return self._all_dets

_get_offset_psd(noise, freq, step_time, det)

Compute the PSD of the baseline offsets.

Source code in toast/templates/offset/offset.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def _get_offset_psd(self, noise, freq, step_time, det):
    """Compute the PSD of the baseline offsets."""
    psdfreq = noise.freq(det).to_value(u.Hz)
    psd = noise.psd(det).to_value(self.det_data_units**2 * u.second)
    rate = noise.rate(det).to_value(u.Hz)

    # Remove the white noise component from the PSD
    psd = self._remove_white_noise(psdfreq, psd)

    # Log PSD for interpolation
    logfreq = np.log(psdfreq)
    logpsd = np.log(psd)

    # The calculation of `offset_psd` is based on Keihänen, E. et al:
    # "Making CMB temperature and polarization maps with Madam",
    # A&A 510:A57, 2010, with a small algebra correction.

    m_max = 5
    tbase = step_time
    fbase = 1.0 / tbase

    def g(f, m):
        # The frequencies are constructed without the zero frequency,
        # so we do not need to handle it here.
        # result = np.sin(np.pi * f * tbase) ** 2 / (np.pi * (f * tbase + m)) ** 2
        x = np.pi * (f * tbase + m)
        bad = np.abs(x) < 1.0e-30
        good = np.logical_not(bad)
        result = np.empty_like(x)
        result[bad] = 1.0
        result[good] = np.sin(x[good]) ** 2 / x[good] ** 2
        return result

    offset_psd = np.zeros_like(freq)

    # The m = 0 term
    offset_psd = self._interpolate_psd(freq, logfreq, logpsd) * g(freq, 0)

    # The remaining terms
    for m in range(1, m_max):
        # Positive m
        offset_psd[:] += self._interpolate_psd(
            freq + m * fbase, logfreq, logpsd
        ) * g(freq, m)
        # Negative m
        offset_psd[:] += self._interpolate_psd(
            freq - m * fbase, logfreq, logpsd
        ) * g(freq, -m)

    offset_psd *= fbase
    return offset_psd

_implementations()

Source code in toast/templates/offset/offset.py
964
965
966
967
968
969
970
def _implementations(self):
    return [
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    ]

_initialize(new_data)

Source code in toast/templates/offset/offset.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
@function_timer
def _initialize(self, new_data):
    log = Logger.get()
    # This function is called whenever a new data trait is assigned to the template.
    # Clear any C-allocated buffers from previous uses.
    self.clear()

    # Compute the step boundaries for every observation and the number of
    # amplitude values on this process.  Every process only stores amplitudes
    # for its locally assigned data.

    if self.use_noise_prior and self.noise_model is None:
        raise RuntimeError("cannot use noise prior without specifying noise_model")

    # Units for inverse variance weighting
    detnoise_units = 1.0 / self.det_data_units**2

    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Amplitude lengths of all views for each obs
    self._obs_views = dict()

    # Sample rate for each obs.
    self._obs_rate = dict()

    # Frequency bins for the noise prior for each obs.
    self._freq = dict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    for iob, ob in enumerate(new_data.obs):
        # Compute sample rate from timestamps
        (rate, dt, dt_min, dt_max, dt_std) = rate_from_times(ob.shared[self.times])
        self._obs_rate[iob] = rate

        # The step length for this observation
        step_length = self._step_length(
            self.step_time.to_value(u.second), self._obs_rate[iob]
        )

        # Track number of offset amplitudes per view, per det.
        ob_views = list()
        for view_slice in ob.view[self.view]:
            view_len = None
            if view_slice.start < 0:
                # This is a view of the whole obs
                view_len = ob.n_local_samples
            else:
                view_len = view_slice.stop - view_slice.start
            view_n_amp = view_len // step_length
            if view_n_amp * step_length < view_len:
                view_n_amp += 1
            ob_views.append(view_n_amp)
        self._obs_views[iob] = np.array(ob_views, dtype=np.int64)

        # The noise model.
        if self.noise_model is not None:
            if self.noise_model not in ob:
                msg = "Observation {}:  noise model {} does not exist".format(
                    ob.name, self.noise_model
                )
                log.error(msg)
                raise RuntimeError(msg)

            # Determine the binning for the noise prior
            if self.use_noise_prior:
                obstime = ob.shared[self.times][-1] - ob.shared[self.times][0]
                tbase = self.step_time.to_value(u.second)
                powmin = np.floor(np.log10(1 / obstime)) - 1
                powmax = min(
                    np.ceil(np.log10(1 / tbase)) + 2, np.log10(self._obs_rate[iob])
                )
                self._freq[iob] = np.logspace(powmin, powmax, 1000)

        # Build up detector list
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None

    self._all_dets = list(all_dets.keys())

    # Go through the data one local detector at a time and compute the offsets
    # into the amplitudes.

    self._det_start = dict()

    offset = 0
    for det in self._all_dets:
        self._det_start[det] = offset
        for iob, ob in enumerate(new_data.obs):
            if det not in self._obs_dets[iob]:
                continue
            offset += np.sum(self._obs_views[iob])

    # Now we know the total number of amplitudes.

    self._n_local = offset
    if new_data.comm.comm_world is None:
        self._n_global = self._n_local
    else:
        self._n_global = new_data.comm.comm_world.allreduce(
            self._n_local, op=MPI.SUM
        )

    # Now that we know the number of amplitudes, we go through the solver flags
    # and determine what amplitudes, if any, are poorly constrained.  These are
    # stored internally as a bool array, and used when constructing a new
    # Amplitudes object.  We also compute and store the variance of each amplitude,
    # based on the noise weight of the detector and the number of flagged samples.

    # Boolean flags
    self._amp_flags = np.zeros(self._n_local, dtype=bool)

    # Here we track the variance of the offsets based on the detector noise weights
    # and the number of unflagged / good samples per offset.
    if self._n_local == 0:
        self._offsetvar_raw = None
        self._offsetvar = None
    else:
        self._offsetvar_raw = AlignedF64.zeros(self._n_local)
        self._offsetvar = self._offsetvar_raw.array()

    offset = 0
    for det in self._all_dets:
        for iob, ob in enumerate(new_data.obs):
            if det not in self._obs_dets[iob]:
                continue

            # "Noise weight" (time-domain inverse variance)
            detnoise = 1.0
            if self.noise_model is not None:
                detnoise = (
                    ob[self.noise_model]
                    .detector_weight(det)
                    .to_value(detnoise_units)
                )

            # The step length for this observation
            step_length = self._step_length(
                self.step_time.to_value(u.second), self._obs_rate[iob]
            )

            # Loop over views
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                view_samples = None
                if vw.start < 0:
                    # This is a view of the whole obs
                    view_samples = ob.n_local_samples
                else:
                    view_samples = vw.stop - vw.start
                n_amp_view = self._obs_views[iob][ivw]

                # Move this loop to compiled code if it is slow...
                # Note:  we are building the offset amplitude *variance*, which is
                # why the "noise weight" (inverse variance) is in the denominator.
                if detnoise <= 0:
                    # This detector is cut in the noise model
                    for amp in range(n_amp_view):
                        self._offsetvar[offset + amp] = 0.0
                        self._amp_flags[offset + amp] = True
                else:
                    if self.det_flags is None:
                        voff = 0
                        for amp in range(n_amp_view):
                            amplen = step_length
                            if amp == n_amp_view - 1:
                                amplen = view_samples - voff
                            self._offsetvar[offset + amp] = 1.0 / (
                                detnoise * amplen
                            )
                            voff += step_length
                    else:
                        flags = views.detdata[self.det_flags][ivw]
                        voff = 0
                        for amp in range(n_amp_view):
                            amplen = step_length
                            if amp == n_amp_view - 1:
                                amplen = view_samples - voff
                            n_good = amplen - np.count_nonzero(
                                flags[det][voff : voff + amplen]
                                & self.det_flag_mask
                            )
                            if (n_good / amplen) <= self.good_fraction:
                                # This detector is cut or too many samples flagged
                                self._offsetvar[offset + amp] = 0.0
                                self._amp_flags[offset + amp] = True
                            else:
                                # Keep this
                                self._offsetvar[offset + amp] = 1.0 / (
                                    detnoise * n_good
                                )
                            voff += step_length
                offset += n_amp_view

    # Compute the amplitude noise filter and preconditioner for each detector
    # and each view.  The "noise filter" is the real-space inverse amplitude
    # covariance, which is constructed from the Fourier domain amplitude PSD.
    #
    # The preconditioner is either a diagonal one using the amplitude variance,
    # or is a banded one using the amplitude covariance plus the diagonal term.

    self._filters = dict()
    self._precond = dict()

    if self.use_noise_prior:
        offset = 0
        for det in self._all_dets:
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if iob not in self._filters:
                    self._filters[iob] = dict()
                    self._precond[iob] = dict()

                offset_psd = self._get_offset_psd(
                    ob[self.noise_model],
                    self._freq[iob],
                    self.step_time.to_value(u.second),
                    det,
                )

                if self.debug_plots is not None:
                    set_matplotlib_backend()
                    import matplotlib.pyplot as plt

                    fname = os.path.join(
                        self.debug_plots, f"{self.name}_{det}_{ob.name}_psd.pdf"
                    )
                    psdfreq = ob[self.noise_model].freq(det).to_value(u.Hz)
                    psd = (
                        ob[self.noise_model]
                        .psd(det)
                        .to_value(self.det_data_units**2 * u.second)
                    )
                    corrpsd = self._remove_white_noise(psdfreq, psd)

                    fig = plt.figure(figsize=[12, 12])
                    ax = fig.add_subplot(2, 1, 1)
                    ax.loglog(
                        psdfreq,
                        psd,
                        color="black",
                        label="Original PSD",
                    )
                    ax.loglog(
                        psdfreq,
                        corrpsd,
                        color="red",
                        label="Correlated PSD",
                    )
                    ax.set_xlabel("Frequency [Hz]")
                    ax.set_ylabel("PSD [K$^2$ / Hz]")
                    ax.legend(loc="best")

                    ax = fig.add_subplot(2, 1, 2)
                    ax.loglog(
                        self._freq[iob],
                        offset_psd,
                        label=f"Offset PSD",
                    )
                    ax.set_xlabel("Frequency [Hz]")
                    ax.set_ylabel("PSD [K$^2$ / Hz]")
                    ax.legend(loc="best")
                    fig.savefig(fname)
                    plt.close(fig)

                # "Noise weight" (time-domain inverse variance)
                detnoise = (
                    ob[self.noise_model]
                    .detector_weight(det)
                    .to_value(detnoise_units)
                )

                # Log version of offset PSD and its inverse for interpolation
                logfreq = np.log(self._freq[iob])
                logpsd = np.log(offset_psd)
                logfilter = np.log(1.0 / offset_psd)

                # Compute the list of filters and preconditioners (one per view)
                # For this detector.

                self._filters[iob][det] = list()
                self._precond[iob][det] = list()

                if self.debug_plots is not None:
                    ffilter = os.path.join(
                        self.debug_plots, f"{self.name}_{det}_{ob.name}_filters.pdf"
                    )
                    fprec = os.path.join(
                        self.debug_plots, f"{self.name}_{det}_{ob.name}_prec.pdf"
                    )
                    figfilter = plt.figure(figsize=[12, 8])
                    axfilter = figfilter.add_subplot(1, 1, 1)
                    figprec = plt.figure(figsize=[12, 8])
                    axprec = figprec.add_subplot(1, 1, 1)

                # Loop over views
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    view_samples = None
                    if vw.start < 0:
                        # This is a view of the whole obs
                        view_samples = ob.n_local_samples
                    else:
                        view_samples = vw.stop - vw.start
                    n_amp_view = self._obs_views[iob][ivw]
                    offsetvar_slice = self._offsetvar[offset : offset + n_amp_view]

                    filterlen = 2
                    while filterlen < 2 * n_amp_view:
                        filterlen *= 2
                    filterfreq = np.fft.rfftfreq(
                        filterlen, self.step_time.to_value(u.second)
                    )

                    # Recall that the "noise filter" is the inverse amplitude
                    # covariance, which is why we are using 1/PSD.  Also note that
                    # the truncate function shifts the filter to be symmetric about
                    # the center, which is needed for use with scipy.signal.convolve
                    # If we move this application back to compiled FFT based
                    # methods, we should instead keep this filter in the fourier
                    # domain.

                    noisefilter = self._truncate(
                        np.fft.irfft(
                            self._interpolate_psd(filterfreq, logfreq, logfilter)
                        )
                    )

                    self._filters[iob][det].append(noisefilter)

                    if self.debug_plots is not None:
                        axfilter.plot(
                            np.arange(len(noisefilter)),
                            noisefilter,
                            label=f"Noise filter {ivw}",
                        )

                    # Build the preconditioner
                    lower = None
                    preconditioner = None

                    if self.precond_width == 1:
                        # We are using a Toeplitz preconditioner.  The first row
                        # of the matrix is the inverse FFT of the offset PSD,
                        # with an added zero-lag component from the detector
                        # weight.  NOTE:  the truncate function shifts the real
                        # space filter to the center of the vector.
                        preconditioner = self._truncate(
                            np.fft.irfft(
                                self._interpolate_psd(filterfreq, logfreq, logpsd)
                            )
                        )
                        icenter = preconditioner.size // 2
                        if detnoise != 0:
                            preconditioner[icenter] += 1.0 / detnoise
                        if self.debug_plots is not None:
                            axprec.plot(
                                np.arange(len(preconditioner)),
                                preconditioner,
                                label=f"Toeplitz preconditioner {ivw}",
                            )
                    else:
                        # We are using a banded matrix for the preconditioner.
                        # This contains a Toeplitz component from the inverse
                        # offset variance in the LHS, and another diagonal term
                        # from the individual offset variance.
                        #
                        # NOTE:  Instead of directly solving x = M^{-1} b, we do
                        # not invert "M" and solve M x = b using the Cholesky
                        # decomposition of M (*not* M^{-1}).
                        icenter = noisefilter.size // 2
                        wband = min(self.precond_width, icenter)
                        precond_width = max(
                            wband, min(self.precond_width, n_amp_view)
                        )
                        preconditioner = np.zeros(
                            [precond_width, n_amp_view], dtype=np.float64
                        )
                        if detnoise != 0:
                            preconditioner[0, :] = 1.0 / offsetvar_slice
                        preconditioner[:wband, :] += np.repeat(
                            noisefilter[icenter : icenter + wband, np.newaxis],
                            n_amp_view,
                            1,
                        )
                        lower = True
                        preconditioner = scipy.linalg.cholesky_banded(
                            preconditioner,
                            overwrite_ab=True,
                            lower=lower,
                            check_finite=True,
                        )
                        if self.debug_plots is not None:
                            axprec.plot(
                                np.arange(len(preconditioner)),
                                preconditioner,
                                label=f"Banded preconditioner {ivw}",
                            )
                    self._precond[iob][det].append((preconditioner, lower))
                    offset += n_amp_view

                if self.debug_plots is not None:
                    axfilter.set_xlabel("Sample Lag")
                    axfilter.set_ylabel("Amplitude")
                    axfilter.legend(loc="best")
                    figfilter.savefig(ffilter)
                    axprec.set_xlabel("Sample Lag")
                    axprec.set_ylabel("Amplitude")
                    axprec.legend(loc="best")
                    figprec.savefig(fprec)
                    plt.close(figfilter)
                    plt.close(figprec)

    log.verbose(f"Offset variance = {self._offsetvar}")
    return

_interpolate_psd(x, lfreq, lpsd)

Source code in toast/templates/offset/offset.py
537
538
539
540
541
542
543
544
545
546
547
548
def _interpolate_psd(self, x, lfreq, lpsd):
    # Threshold for zero frequency
    thresh = 1.0e-6
    lowf = x < thresh
    good = np.logical_not(lowf)

    logx = np.empty_like(x)
    logx[lowf] = np.log(thresh)
    logx[good] = np.log(x[good])
    logresult = np.interp(logx, lfreq, lpsd)
    result = np.exp(logresult)
    return result

_project_signal(detector, amplitudes, use_accel=None, **kwargs)

Source code in toast/templates/offset/offset.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
@function_timer
def _project_signal(self, detector, amplitudes, use_accel=None, **kwargs):
    log = Logger.get()

    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return

    # Kernel selection
    implementation, use_accel = self.select_kernels(use_accel=use_accel)

    amp_offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        det_indx = ob.detdata[self.det_data].indices([detector])
        if self.det_flags is not None:
            flag_indx = ob.detdata[self.det_flags].indices([detector])
            flag_data = ob.detdata[self.det_flags].data
        else:
            flag_indx = np.array([-1], dtype=np.int32)
            flag_data = np.zeros(1, dtype=np.uint8)
        # The step length for this observation
        step_length = self._step_length(
            self.step_time.to_value(u.second), self._obs_rate[iob]
        )

        # The number of amplitudes in each view
        n_amp_views = self._obs_views[iob]

        # # DEBUGGING
        # restore_dev = False
        # prefix="HOST"
        # if ob.detdata[self.det_data].accel_in_use():
        #     ob.detdata[self.det_data].accel_update_host()
        #     restore_dev = True
        #     prefix="DEVICE"
        # print(f"{prefix} Project signal input:  {ob.detdata[self.det_data]}", flush=True)
        # if restore_dev:
        #     ob.detdata[self.det_data].accel_update_device()

        offset_project_signal(
            det_indx[0],
            ob.detdata[self.det_data].data,
            flag_indx[0],
            flag_data,
            self.det_flag_mask,
            step_length,
            amp_offset,
            n_amp_views,
            amplitudes.local,
            amplitudes.local_flags,
            ob.intervals[self.view].data,
            impl=implementation,
            use_accel=use_accel,
        )

        # restore_dev = False
        # prefix="HOST"
        # if amplitudes.accel_in_use():
        #     amplitudes.accel_update_host()
        #     restore_dev = True
        #     prefix="DEVICE"
        # print(f"{prefix} Project signal output:  {amp_offset}, {n_amp_views}, {amplitudes.local}", flush=True)
        # if restore_dev:
        #     amplitudes.accel_update_device()

        amp_offset += np.sum(n_amp_views)

_remove_white_noise(freq, psd)

Remove the white noise component of the PSD.

Source code in toast/templates/offset/offset.py
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
def _remove_white_noise(self, freq, psd):
    """Remove the white noise component of the PSD."""
    corrpsd = psd.copy()
    n_corrpsd = len(corrpsd)
    plat_off = int(0.8 * n_corrpsd)
    if n_corrpsd - plat_off < 10:
        if n_corrpsd < 10:
            # Crazy spectrum...
            plat_off = 0
        else:
            plat_off = n_corrpsd - 10

    cfreq = np.log(freq[plat_off:])
    cdata = np.log(corrpsd[plat_off:])

    def lin_func(x, a, b, c):
        # Line
        return a * (x - b) + c

    params, params_cov = scipy.optimize.curve_fit(
        lin_func, cfreq, cdata, p0=[0.0, cfreq[-1], cdata[-1]]
    )

    cdata = lin_func(cfreq, params[0], params[1], params[2])
    cdata = np.exp(cdata)
    plat = cdata[-1]

    # Given the range between the white noise plateau and the maximum
    # values of the PSD, we set a minimum value for any spectral bins
    # that are small or negative.
    corrmax = np.amax(corrpsd)
    corrthresh = 1.0e-10 * corrmax - plat
    corrpsd -= plat
    corrpsd[corrpsd < corrthresh] = corrthresh
    return corrpsd

_step_length(stime, rate)

Source code in toast/templates/offset/offset.py
657
658
def _step_length(self, stime, rate):
    return int(stime * rate + 0.5)

_supports_accel()

Source code in toast/templates/offset/offset.py
972
973
def _supports_accel(self):
    return True

_truncate(noisefilter, lim=0.0001)

Source code in toast/templates/offset/offset.py
550
551
552
553
554
555
556
557
558
def _truncate(self, noisefilter, lim=1e-4):
    icenter = noisefilter.size // 2
    ind = np.abs(noisefilter[:icenter]) > np.abs(noisefilter[0]) * lim
    icut = np.argwhere(ind)[-1][0]
    if icut % 2 == 0:
        icut += 1
    noisefilter = np.roll(noisefilter, icenter)
    noisefilter = noisefilter[icenter - icut : icenter + icut + 1]
    return noisefilter

_zeros()

Source code in toast/templates/offset/offset.py
651
652
653
654
655
def _zeros(self):
    z = Amplitudes(self.data.comm, self._n_global, self._n_local)
    if z.local_flags is not None:
        z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

clear()

Delete the underlying C-allocated memory.

Source code in toast/templates/offset/offset.py
 99
100
101
102
103
104
105
def clear(self):
    """Delete the underlying C-allocated memory."""
    if hasattr(self, "_offsetvar"):
        del self._offsetvar
    if hasattr(self, "_offsetvar_raw"):
        self._offsetvar_raw.clear()
        del self._offsetvar_raw

write(amplitudes, out)

Write out amplitude values.

This stores the amplitudes to files for debugging / plotting. Since the Offset amplitudes are unique on each process, we open one file per process group and each process in the group communicates their amplitudes to one writer.

Since this function is used mainly for debugging, we are a bit wasteful and duplicate the amplitudes in order to make things easier.

Parameters:

Name Type Description Default
amplitudes Amplitudes

The amplitude data.

required
out str

The output file root.

required

Returns:

Type Description

None

Source code in toast/templates/offset/offset.py
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
@function_timer
def write(self, amplitudes, out):
    """Write out amplitude values.

    This stores the amplitudes to files for debugging / plotting.  Since the
    Offset amplitudes are unique on each process, we open one file per process
    group and each process in the group communicates their amplitudes to one
    writer.

    Since this function is used mainly for debugging, we are a bit wasteful
    and duplicate the amplitudes in order to make things easier.

    Args:
        amplitudes (Amplitudes):  The amplitude data.
        out (str):  The output file root.

    Returns:
        None

    """

    # Copy of the amplitudes, organized by observation and detector
    obs_det_amps = dict()

    for det in self._all_dets:
        amp_offset = self._det_start[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            if not ob.is_distributed_by_detector:
                raise NotImplementedError(
                    "Only observations distributed by detector are supported"
                )
            # The step length for this observation
            step_length = self._step_length(
                self.step_time.to_value(u.second), self._obs_rate[iob]
            )

            if ob.name not in obs_det_amps:
                # First time with this observation, store info about
                # the offset spans
                obs_det_amps[ob.name] = dict()
                props = dict()
                props["step_length"] = step_length
                amp_first = list()
                amp_last = list()
                amp_start = list()
                amp_stop = list()
                for ivw, vw in enumerate(ob.intervals[self.view]):
                    n_amp_view = self._obs_views[iob][ivw]
                    for istep in range(n_amp_view):
                        istart = vw.first + istep * step_length
                        amp_first.append(istart)
                        amp_start.append(ob.shared[self.times].data[istart])
                        if istep == n_amp_view - 1:
                            istop = vw.last
                        else:
                            istop = vw.first + (istep + 1) * step_length
                        amp_last.append(istop)
                        amp_stop.append(ob.shared[self.times].data[istop - 1])
                props["amp_first"] = np.array(amp_first, dtype=np.int64)
                props["amp_last"] = np.array(amp_last, dtype=np.int64)
                props["amp_start"] = np.array(amp_start, dtype=np.float64)
                props["amp_stop"] = np.array(amp_stop, dtype=np.float64)
                obs_det_amps[ob.name]["bounds"] = props

            # Loop over views and extract per-detector amplitudes and flags
            det_amps = list()
            det_flags = list()
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                n_amp_view = self._obs_views[iob][ivw]
                amp_slice = slice(amp_offset, amp_offset + n_amp_view, 1)
                det_amps.append(amplitudes.local[amp_slice])
                det_flags.append(amplitudes.local_flags[amp_slice])
                amp_offset += n_amp_view
            det_amps = np.concatenate(det_amps, dtype=np.float64)
            det_flags = np.concatenate(det_flags, dtype=np.uint8)
            obs_det_amps[ob.name][det] = {
                "amps": det_amps,
                "flags": det_flags,
            }

    # Each group writes out its amplitudes.

    # NOTE:  If/when we want to support arbitrary data distributions when
    # writing, we would need to take the data from each process and align
    # them in time rather than just extracting detector data and writing
    # to the datasets.

    for iob, ob in enumerate(self.data.obs):
        obs_local_amps = obs_det_amps[ob.name]
        if self.data.comm.group_size == 1:
            all_obs_amps = [obs_local_amps]
        else:
            all_obs_amps = self.data.comm.comm_group.gather(obs_local_amps, root=0)

        if self.data.comm.group_rank == 0:
            out_file = f"{out}_{ob.name}.h5"
            det_names = set()
            for pdata in all_obs_amps:
                for k in pdata.keys():
                    if k != "bounds":
                        det_names.add(k)
            det_names = list(sorted(det_names))
            n_det = len(det_names)
            amp_first = all_obs_amps[0]["bounds"]["amp_first"]
            amp_last = all_obs_amps[0]["bounds"]["amp_last"]
            amp_start = all_obs_amps[0]["bounds"]["amp_start"]
            amp_stop = all_obs_amps[0]["bounds"]["amp_stop"]
            n_amp = len(amp_first)
            det_to_row = {y: x for x, y in enumerate(det_names)}
            with h5py.File(out_file, "w") as hf:
                hf.attrs["step_length"] = all_obs_amps[0]["bounds"]["step_length"]
                hf.attrs["detectors"] = json.dumps(det_names)
                hamp_first = hf.create_dataset("amp_first", data=amp_first)
                hamp_last = hf.create_dataset("amp_last", data=amp_last)
                hamp_start = hf.create_dataset("amp_start", data=amp_start)
                hamp_stop = hf.create_dataset("amp_stop", data=amp_stop)
                hamps = hf.create_dataset(
                    "amplitudes",
                    (n_det, n_amp),
                    dtype=np.float64,
                )
                hflags = hf.create_dataset(
                    "flags",
                    (n_det, n_amp),
                    dtype=np.uint8,
                )
                for pdata in all_obs_amps:
                    for k, v in pdata.items():
                        if k == "bounds":
                            continue
                        row = det_to_row[k]
                        hslice = (slice(row, row + 1, 1), slice(0, n_amp, 1))
                        dslice = (slice(0, n_amp, 1),)
                        hamps.write_direct(v["amps"], dslice, hslice)
                        hflags.write_direct(v["flags"], dslice, hslice)

toast.templates.Periodic

Bases: Template

This template represents amplitudes which are periodic in time.

The template amplitudes are modeled as a value for each detector of each observation in a "bin" of the specified data values. The min / max values of the periodic data are computed for each observation, and the binning between these min / max values is set by either the n_bins trait or by specifying the increment in the value to use for each bin.

Although the data values used do not have to be strictly periodic, this template works best if the values are varying in a regular way such that each bin has approximately the same number of hits.

The periodic quantity to consider can be either a shared or detdata field.

Source code in toast/templates/periodic.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
@trait_docs
class Periodic(Template):
    """This template represents amplitudes which are periodic in time.

    The template amplitudes are modeled as a value for each detector of each
    observation in a "bin" of the specified data values.  The min / max values
    of the periodic data are computed for each observation, and the binning
    between these min / max values is set by either the n_bins trait or by
    specifying the increment in the value to use for each bin.

    Although the data values used do not have to be strictly periodic, this
    template works best if the values are varying in a regular way such that
    each bin has approximately the same number of hits.

    The periodic quantity to consider can be either a shared or detdata field.

    """

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    is_detdata_key = Bool(
        False,
        help="If True, the periodic data and flags are detector fields, not shared",
    )

    key = Unicode(
        None, allow_none=True, help="Observation data key for the periodic quantity"
    )

    flags = Unicode(
        None,
        allow_none=True,
        help="Observation data key for flags to use",
    )

    flag_mask = Int(0, help="Bit mask value for flags")

    bins = Int(
        10,
        allow_none=True,
        help="Number of bins between min / max values of data key",
    )

    increment = Float(
        None,
        allow_none=True,
        help="The increment of the data key for each bin",
    )

    minimum_bin_hits = Int(3, help="Minimum number of samples per amplitude bin")

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _initialize(self, new_data):
        log = Logger.get()
        if self.key is None:
            msg = "You must set key before initializing"
            raise RuntimeError(msg)

        if self.bins is not None and self.increment is not None:
            msg = "Only one of bins and increment can be specified"
            raise RuntimeError(msg)

        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        # Find the binning for each observation and the total detectors on this
        # process.
        self._obs_min = list()
        self._obs_max = list()
        self._obs_incr = list()
        self._obs_nbins = list()
        total_bins = 0
        for iob, ob in enumerate(new_data.obs):
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            omin = None
            omax = None
            for vw in ob.intervals[self.view].data:
                vw_slc = slice(vw.first, vw.last, 1)
                good = slice(None)
                if self.is_detdata_key:
                    vw_data = ob.detdata[self.key].data[vw_slc]
                    if self.flags is not None:
                        # We have some flags
                        bad = ob.detdata[self.flags].data[vw_slc] & self.flag_mask
                        good = np.logical_not(bad)
                else:
                    vw_data = ob.shared[self.key].data[vw_slc]
                    if self.flags is not None:
                        # We have some flags
                        bad = ob.shared[self.flags].data[vw_slc] & self.flag_mask
                        good = np.logical_not(bad)
                vmin = np.amin(vw_data[good])
                vmax = np.amax(vw_data[good])
                if omin is None:
                    omin = vmin
                    omax = vmax
                else:
                    omin = min(omin, vmin)
                    omax = max(omax, vmax)

            if omin == omax:
                msg = f"Periodic data {self.key} is constant for observation "
                msg += f"{ob.name}"
                raise RuntimeError(msg)
            self._obs_min.append(omin)
            self._obs_max.append(omax)
            if self.bins is not None:
                obins = int(self.bins)
                oincr = (omax - omin) / obins
            else:
                oincr = float(self.increment)
                obins = int((omax - omin) / oincr)
            if obins == 0 and ob.comm.group_rank == 0:
                msg = f"Template {self.name}, obs {ob.name} has zero amplitude bins"
                log.warning(msg)
            total_bins += obins
            self._obs_nbins.append(obins)
            self._obs_incr.append(oincr)

            # Build up detector list
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None

        self._all_dets = list(all_dets.keys())

        if total_bins == 0:
            msg = f"Template {self.name} process group {new_data.comm.group}"
            msg += f" has zero amplitude bins- change the binning size."
            raise RuntimeError(msg)

        # During application of the template, we will be looping over detectors
        # in the outer loop.  So we pack the amplitudes by detector and then by
        # observation.  Compute the per-detector offsets into the amplitudes.

        self._det_offset = dict()

        offset = 0
        for det in self._all_dets:
            self._det_offset[det] = offset
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if self.is_detdata_key:
                    if self.key not in ob.detdata:
                        continue
                else:
                    if self.key not in ob.shared:
                        continue
                offset += self._obs_nbins[iob]

        # Now we know the total number of local amplitudes.

        if offset == 0:
            # This means that no observations included the shared key
            # we are using.
            msg = f"Data has no observations with key '{self.key}'."
            msg += "  You should disable this template."
            log.error(msg)
            raise RuntimeError(msg)

        self._n_local = offset
        if new_data.comm.comm_world is None:
            self._n_global = self._n_local
        else:
            self._n_global = new_data.comm.comm_world.allreduce(
                self._n_local, op=MPI.SUM
            )

        # Go through all the data and compute the number of hits per amplitude
        # bin and the flagging of bins.

        # Boolean flags
        if self._n_local == 0:
            self._amp_flags = None
        else:
            self._amp_flags = np.zeros(self._n_local, dtype=bool)

        # Hits
        if self._n_local == 0:
            self._amp_hits = None
        else:
            self._amp_hits = np.zeros(self._n_local, dtype=np.int32)

        self._obs_bin_hits = list()
        for det in self._all_dets:
            amp_offset = self._det_offset[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if self.is_detdata_key:
                    if self.key not in ob.detdata:
                        continue
                else:
                    if self.key not in ob.shared:
                        continue
                nbins = self._obs_nbins[iob]
                det_indx = ob.detdata[self.det_data].indices([det])[0]
                amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
                amp_flags = self._amp_flags[amp_offset : amp_offset + nbins]
                if self.det_flags is not None:
                    flag_indx = ob.detdata[self.det_flags].indices([det])[0]
                else:
                    flag_indx = None
                for vw in ob.intervals[self.view].data:
                    vw_slc = slice(vw.first, vw.last, 1)
                    if self.is_detdata_key:
                        vw_data = ob.detdata[self.key].data[vw_slc]
                    else:
                        vw_data = ob.shared[self.key].data[vw_slc]
                    good, amp_indx = self._view_flags_and_index(
                        det_indx,
                        iob,
                        ob,
                        vw,
                        flag_indx=flag_indx,
                        det_flags=True,
                    )
                    np.add.at(
                        amp_hits,
                        amp_indx,
                        np.ones(len(vw_data[good]), dtype=np.int32),
                    )
                    flag_thresh = amp_hits < self.minimum_bin_hits
                    amp_flags[flag_thresh] = True
                amp_offset += nbins
        return

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        z = Amplitudes(self.data.comm, self._n_global, self._n_local)
        if z.local_flags is not None:
            z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _view_flags_and_index(
        self, det_indx, ob_indx, ob, view, flag_indx=None, det_flags=False
    ):
        """Get the flags and amplitude indices for one detector and view."""
        vw_slc = slice(view.first, view.last, 1)
        vw_len = view.last - view.first
        incr = self._obs_incr[ob_indx]
        # Determine good samples
        if self.is_detdata_key:
            vw_data = ob.detdata[self.key].data[vw_slc]
            if self.flags is not None:
                # We have some flags
                bad = ob.detdata[self.flags].data[vw_slc] & self.flag_mask
            else:
                bad = np.zeros(vw_len, dtype=np.uint8)
        else:
            vw_data = ob.shared[self.key].data[vw_slc]
            if self.flags is not None:
                # We have some flags
                bad = ob.shared[self.flags].data[vw_slc] & self.flag_mask
            else:
                bad = np.zeros(vw_len, dtype=np.uint8)
        if det_flags and self.det_flags is not None:
            # We have some det flags
            bad |= ob.detdata[self.det_flags][flag_indx, vw_slc] & self.det_flag_mask
        good = np.logical_not(bad)

        # Find the amplitude index for every good sample
        amp_indx = np.array(
            ((vw_data[good] - self._obs_min[ob_indx]) / incr),
            dtype=np.int32,
        )
        overflow = amp_indx >= self._obs_nbins[ob_indx]
        amp_indx[overflow] = self._obs_nbins[ob_indx] - 1

        return good, amp_indx

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return

        amp_offset = self._det_offset[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            nbins = self._obs_nbins[iob]
            det_indx = ob.detdata[self.det_data].indices([detector])[0]
            amps = amplitudes.local[amp_offset : amp_offset + nbins]
            for vw in ob.intervals[self.view].data:
                vw_slc = slice(vw.first, vw.last, 1)
                good, amp_indx = self._view_flags_and_index(
                    det_indx,
                    iob,
                    ob,
                    vw,
                    det_flags=False,
                )
                # Accumulate to timestream
                ob.detdata[self.det_data][det_indx, vw_slc][good] += amps[amp_indx]
            amp_offset += nbins

    def _project_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return

        amp_offset = self._det_offset[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            nbins = self._obs_nbins[iob]
            det_indx = ob.detdata[self.det_data].indices([detector])[0]
            amps = amplitudes.local[amp_offset : amp_offset + nbins]
            if self.det_flags is not None:
                flag_indx = ob.detdata[self.det_flags].indices([detector])[0]
            else:
                flag_indx = None
            for vw in ob.intervals[self.view].data:
                vw_slc = slice(vw.first, vw.last, 1)
                good, amp_indx = self._view_flags_and_index(
                    det_indx,
                    iob,
                    ob,
                    vw,
                    flag_indx=flag_indx,
                    det_flags=True,
                )
                # Accumulate to amplitudes
                np.add.at(
                    amps,
                    amp_indx,
                    ob.detdata[self.det_data][det_indx, vw_slc][good],
                )
            amp_offset += nbins

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        # No prior for this template, nothing to accumulate to output.
        return

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        # Apply weights based on the number of samples hitting each
        # amplitude bin.
        for det in self._all_dets:
            amp_offset = self._det_offset[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if self.is_detdata_key:
                    if self.key not in ob.detdata:
                        continue
                else:
                    if self.key not in ob.shared:
                        continue
                nbins = self._obs_nbins[iob]
                amps_in = amplitudes_in.local[amp_offset : amp_offset + nbins]
                amps_out = amplitudes_out.local[amp_offset : amp_offset + nbins]
                amp_flags = amplitudes_in.local_flags[amp_offset : amp_offset + nbins]
                amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
                amp_good = amp_flags == 0

                amps_out[amp_good] = amps_in[amp_good] * amp_hits[amp_good]

                amp_offset += nbins

    @function_timer
    def write(self, amplitudes, out):
        """Write out amplitude values.

        This stores the amplitudes to a file for debugging / plotting.

        Args:
            amplitudes (Amplitudes):  The amplitude data.
            out (str):  The output file.

        Returns:
            None

        """
        # By definition, when solving for something that is periodic for
        # a given detector in a single observation, we (should) have many
        # fewer template amplitudes than timestream samples.  Because of
        # this we assume we can make several copies for extracting the
        # amplitudes and gathering them for writing.

        obs_det_amps = dict()

        for det in self._all_dets:
            amp_offset = self._det_offset[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                if self.is_detdata_key:
                    if self.key not in ob.detdata:
                        continue
                else:
                    if self.key not in ob.shared:
                        continue
                if ob.name not in obs_det_amps:
                    obs_det_amps[ob.name] = dict()
                nbins = self._obs_nbins[iob]
                amp_data = amplitudes.local[amp_offset : amp_offset + nbins]
                amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
                amp_flags = self._amp_flags[amp_offset : amp_offset + nbins]
                obs_det_amps[ob.name][det] = {
                    "amps": amp_data,
                    "hits": amp_hits,
                    "flags": amp_flags,
                    "min": self._obs_min[iob],
                    "max": self._obs_max[iob],
                    "incr": self._obs_incr[iob],
                }
                amp_offset += nbins

        if self.data.comm.world_size == 1:
            all_obs_dets_amps = [obs_det_amps]
        else:
            all_obs_dets_amps = self.data.comm.comm_world.gather(obs_det_amps, root=0)

        if self.data.comm.world_rank == 0:
            obs_det_amps = dict()
            for pdata in all_obs_dets_amps:
                for obname in pdata.keys():
                    if obname not in obs_det_amps:
                        obs_det_amps[obname] = dict()
                    obs_det_amps[obname].update(pdata[obname])
            del all_obs_dets_amps
            with h5py.File(out, "w") as hf:
                for obname, obamps in obs_det_amps.items():
                    n_det = len(obamps)
                    det_list = list(sorted(obamps.keys()))
                    det_indx = {y: x for x, y in enumerate(det_list)}
                    indx_to_det = {det_indx[x]: x for x in det_list}
                    n_amp = len(obamps[det_list[0]]["amps"])
                    amp_min = [obamps[x]["min"] for x in det_list]
                    amp_max = [obamps[x]["max"] for x in det_list]
                    amp_incr = [obamps[x]["incr"] for x in det_list]

                    # Create datasets for this observation
                    hg = hf.create_group(obname)
                    hg.attrs["detectors"] = json.dumps(det_list)
                    hg.attrs["min"] = json.dumps(amp_min)
                    hg.attrs["max"] = json.dumps(amp_max)
                    hg.attrs["incr"] = json.dumps(amp_incr)
                    hamps = hg.create_dataset(
                        "amplitudes",
                        (n_det, n_amp),
                        dtype=np.float64,
                    )
                    hhits = hg.create_dataset(
                        "hits",
                        (n_det, n_amp),
                        dtype=np.int32,
                    )
                    hflags = hg.create_dataset(
                        "flags",
                        (n_det, n_amp),
                        dtype=np.uint8,
                    )

                    # Write data
                    for idet in range(n_det):
                        det = indx_to_det[idet]
                        dprops = obamps[det]
                        hslice = (slice(idet, idet + 1, 1), slice(0, n_amp, 1))
                        dslice = (slice(0, n_amp, 1),)
                        hamps.write_direct(dprops["amps"], dslice, hslice)
                        hhits.write_direct(dprops["hits"], dslice, hslice)
                        hflags.write_direct(dprops["flags"], dslice, hslice)

bins = Int(10, allow_none=True, help='Number of bins between min / max values of data key') class-attribute instance-attribute

flag_mask = Int(0, help='Bit mask value for flags') class-attribute instance-attribute

flags = Unicode(None, allow_none=True, help='Observation data key for flags to use') class-attribute instance-attribute

increment = Float(None, allow_none=True, help='The increment of the data key for each bin') class-attribute instance-attribute

is_detdata_key = Bool(False, help='If True, the periodic data and flags are detector fields, not shared') class-attribute instance-attribute

key = Unicode(None, allow_none=True, help='Observation data key for the periodic quantity') class-attribute instance-attribute

minimum_bin_hits = Int(3, help='Minimum number of samples per amplitude bin') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/templates/periodic.py
83
84
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/periodic.py
392
393
394
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    # No prior for this template, nothing to accumulate to output.
    return

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/periodic.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def _add_to_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return

    amp_offset = self._det_offset[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        if self.is_detdata_key:
            if self.key not in ob.detdata:
                continue
        else:
            if self.key not in ob.shared:
                continue
        nbins = self._obs_nbins[iob]
        det_indx = ob.detdata[self.det_data].indices([detector])[0]
        amps = amplitudes.local[amp_offset : amp_offset + nbins]
        for vw in ob.intervals[self.view].data:
            vw_slc = slice(vw.first, vw.last, 1)
            good, amp_indx = self._view_flags_and_index(
                det_indx,
                iob,
                ob,
                vw,
                det_flags=False,
            )
            # Accumulate to timestream
            ob.detdata[self.det_data][det_indx, vw_slc][good] += amps[amp_indx]
        amp_offset += nbins

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/periodic.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    # Apply weights based on the number of samples hitting each
    # amplitude bin.
    for det in self._all_dets:
        amp_offset = self._det_offset[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            nbins = self._obs_nbins[iob]
            amps_in = amplitudes_in.local[amp_offset : amp_offset + nbins]
            amps_out = amplitudes_out.local[amp_offset : amp_offset + nbins]
            amp_flags = amplitudes_in.local_flags[amp_offset : amp_offset + nbins]
            amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
            amp_good = amp_flags == 0

            amps_out[amp_good] = amps_in[amp_good] * amp_hits[amp_good]

            amp_offset += nbins

_detectors()

Source code in toast/templates/periodic.py
275
276
def _detectors(self):
    return self._all_dets

_initialize(new_data)

Source code in toast/templates/periodic.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def _initialize(self, new_data):
    log = Logger.get()
    if self.key is None:
        msg = "You must set key before initializing"
        raise RuntimeError(msg)

    if self.bins is not None and self.increment is not None:
        msg = "Only one of bins and increment can be specified"
        raise RuntimeError(msg)

    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    # Find the binning for each observation and the total detectors on this
    # process.
    self._obs_min = list()
    self._obs_max = list()
    self._obs_incr = list()
    self._obs_nbins = list()
    total_bins = 0
    for iob, ob in enumerate(new_data.obs):
        if self.is_detdata_key:
            if self.key not in ob.detdata:
                continue
        else:
            if self.key not in ob.shared:
                continue
        omin = None
        omax = None
        for vw in ob.intervals[self.view].data:
            vw_slc = slice(vw.first, vw.last, 1)
            good = slice(None)
            if self.is_detdata_key:
                vw_data = ob.detdata[self.key].data[vw_slc]
                if self.flags is not None:
                    # We have some flags
                    bad = ob.detdata[self.flags].data[vw_slc] & self.flag_mask
                    good = np.logical_not(bad)
            else:
                vw_data = ob.shared[self.key].data[vw_slc]
                if self.flags is not None:
                    # We have some flags
                    bad = ob.shared[self.flags].data[vw_slc] & self.flag_mask
                    good = np.logical_not(bad)
            vmin = np.amin(vw_data[good])
            vmax = np.amax(vw_data[good])
            if omin is None:
                omin = vmin
                omax = vmax
            else:
                omin = min(omin, vmin)
                omax = max(omax, vmax)

        if omin == omax:
            msg = f"Periodic data {self.key} is constant for observation "
            msg += f"{ob.name}"
            raise RuntimeError(msg)
        self._obs_min.append(omin)
        self._obs_max.append(omax)
        if self.bins is not None:
            obins = int(self.bins)
            oincr = (omax - omin) / obins
        else:
            oincr = float(self.increment)
            obins = int((omax - omin) / oincr)
        if obins == 0 and ob.comm.group_rank == 0:
            msg = f"Template {self.name}, obs {ob.name} has zero amplitude bins"
            log.warning(msg)
        total_bins += obins
        self._obs_nbins.append(obins)
        self._obs_incr.append(oincr)

        # Build up detector list
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None

    self._all_dets = list(all_dets.keys())

    if total_bins == 0:
        msg = f"Template {self.name} process group {new_data.comm.group}"
        msg += f" has zero amplitude bins- change the binning size."
        raise RuntimeError(msg)

    # During application of the template, we will be looping over detectors
    # in the outer loop.  So we pack the amplitudes by detector and then by
    # observation.  Compute the per-detector offsets into the amplitudes.

    self._det_offset = dict()

    offset = 0
    for det in self._all_dets:
        self._det_offset[det] = offset
        for iob, ob in enumerate(new_data.obs):
            if det not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            offset += self._obs_nbins[iob]

    # Now we know the total number of local amplitudes.

    if offset == 0:
        # This means that no observations included the shared key
        # we are using.
        msg = f"Data has no observations with key '{self.key}'."
        msg += "  You should disable this template."
        log.error(msg)
        raise RuntimeError(msg)

    self._n_local = offset
    if new_data.comm.comm_world is None:
        self._n_global = self._n_local
    else:
        self._n_global = new_data.comm.comm_world.allreduce(
            self._n_local, op=MPI.SUM
        )

    # Go through all the data and compute the number of hits per amplitude
    # bin and the flagging of bins.

    # Boolean flags
    if self._n_local == 0:
        self._amp_flags = None
    else:
        self._amp_flags = np.zeros(self._n_local, dtype=bool)

    # Hits
    if self._n_local == 0:
        self._amp_hits = None
    else:
        self._amp_hits = np.zeros(self._n_local, dtype=np.int32)

    self._obs_bin_hits = list()
    for det in self._all_dets:
        amp_offset = self._det_offset[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            nbins = self._obs_nbins[iob]
            det_indx = ob.detdata[self.det_data].indices([det])[0]
            amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
            amp_flags = self._amp_flags[amp_offset : amp_offset + nbins]
            if self.det_flags is not None:
                flag_indx = ob.detdata[self.det_flags].indices([det])[0]
            else:
                flag_indx = None
            for vw in ob.intervals[self.view].data:
                vw_slc = slice(vw.first, vw.last, 1)
                if self.is_detdata_key:
                    vw_data = ob.detdata[self.key].data[vw_slc]
                else:
                    vw_data = ob.shared[self.key].data[vw_slc]
                good, amp_indx = self._view_flags_and_index(
                    det_indx,
                    iob,
                    ob,
                    vw,
                    flag_indx=flag_indx,
                    det_flags=True,
                )
                np.add.at(
                    amp_hits,
                    amp_indx,
                    np.ones(len(vw_data[good]), dtype=np.int32),
                )
                flag_thresh = amp_hits < self.minimum_bin_hits
                amp_flags[flag_thresh] = True
            amp_offset += nbins
    return

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/periodic.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def _project_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return

    amp_offset = self._det_offset[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        if self.is_detdata_key:
            if self.key not in ob.detdata:
                continue
        else:
            if self.key not in ob.shared:
                continue
        nbins = self._obs_nbins[iob]
        det_indx = ob.detdata[self.det_data].indices([detector])[0]
        amps = amplitudes.local[amp_offset : amp_offset + nbins]
        if self.det_flags is not None:
            flag_indx = ob.detdata[self.det_flags].indices([detector])[0]
        else:
            flag_indx = None
        for vw in ob.intervals[self.view].data:
            vw_slc = slice(vw.first, vw.last, 1)
            good, amp_indx = self._view_flags_and_index(
                det_indx,
                iob,
                ob,
                vw,
                flag_indx=flag_indx,
                det_flags=True,
            )
            # Accumulate to amplitudes
            np.add.at(
                amps,
                amp_indx,
                ob.detdata[self.det_data][det_indx, vw_slc][good],
            )
        amp_offset += nbins

_view_flags_and_index(det_indx, ob_indx, ob, view, flag_indx=None, det_flags=False)

Get the flags and amplitude indices for one detector and view.

Source code in toast/templates/periodic.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def _view_flags_and_index(
    self, det_indx, ob_indx, ob, view, flag_indx=None, det_flags=False
):
    """Get the flags and amplitude indices for one detector and view."""
    vw_slc = slice(view.first, view.last, 1)
    vw_len = view.last - view.first
    incr = self._obs_incr[ob_indx]
    # Determine good samples
    if self.is_detdata_key:
        vw_data = ob.detdata[self.key].data[vw_slc]
        if self.flags is not None:
            # We have some flags
            bad = ob.detdata[self.flags].data[vw_slc] & self.flag_mask
        else:
            bad = np.zeros(vw_len, dtype=np.uint8)
    else:
        vw_data = ob.shared[self.key].data[vw_slc]
        if self.flags is not None:
            # We have some flags
            bad = ob.shared[self.flags].data[vw_slc] & self.flag_mask
        else:
            bad = np.zeros(vw_len, dtype=np.uint8)
    if det_flags and self.det_flags is not None:
        # We have some det flags
        bad |= ob.detdata[self.det_flags][flag_indx, vw_slc] & self.det_flag_mask
    good = np.logical_not(bad)

    # Find the amplitude index for every good sample
    amp_indx = np.array(
        ((vw_data[good] - self._obs_min[ob_indx]) / incr),
        dtype=np.int32,
    )
    overflow = amp_indx >= self._obs_nbins[ob_indx]
    amp_indx[overflow] = self._obs_nbins[ob_indx] - 1

    return good, amp_indx

_zeros()

Source code in toast/templates/periodic.py
278
279
280
281
282
def _zeros(self):
    z = Amplitudes(self.data.comm, self._n_global, self._n_local)
    if z.local_flags is not None:
        z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

write(amplitudes, out)

Write out amplitude values.

This stores the amplitudes to a file for debugging / plotting.

Parameters:

Name Type Description Default
amplitudes Amplitudes

The amplitude data.

required
out str

The output file.

required

Returns:

Type Description

None

Source code in toast/templates/periodic.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
@function_timer
def write(self, amplitudes, out):
    """Write out amplitude values.

    This stores the amplitudes to a file for debugging / plotting.

    Args:
        amplitudes (Amplitudes):  The amplitude data.
        out (str):  The output file.

    Returns:
        None

    """
    # By definition, when solving for something that is periodic for
    # a given detector in a single observation, we (should) have many
    # fewer template amplitudes than timestream samples.  Because of
    # this we assume we can make several copies for extracting the
    # amplitudes and gathering them for writing.

    obs_det_amps = dict()

    for det in self._all_dets:
        amp_offset = self._det_offset[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            if self.is_detdata_key:
                if self.key not in ob.detdata:
                    continue
            else:
                if self.key not in ob.shared:
                    continue
            if ob.name not in obs_det_amps:
                obs_det_amps[ob.name] = dict()
            nbins = self._obs_nbins[iob]
            amp_data = amplitudes.local[amp_offset : amp_offset + nbins]
            amp_hits = self._amp_hits[amp_offset : amp_offset + nbins]
            amp_flags = self._amp_flags[amp_offset : amp_offset + nbins]
            obs_det_amps[ob.name][det] = {
                "amps": amp_data,
                "hits": amp_hits,
                "flags": amp_flags,
                "min": self._obs_min[iob],
                "max": self._obs_max[iob],
                "incr": self._obs_incr[iob],
            }
            amp_offset += nbins

    if self.data.comm.world_size == 1:
        all_obs_dets_amps = [obs_det_amps]
    else:
        all_obs_dets_amps = self.data.comm.comm_world.gather(obs_det_amps, root=0)

    if self.data.comm.world_rank == 0:
        obs_det_amps = dict()
        for pdata in all_obs_dets_amps:
            for obname in pdata.keys():
                if obname not in obs_det_amps:
                    obs_det_amps[obname] = dict()
                obs_det_amps[obname].update(pdata[obname])
        del all_obs_dets_amps
        with h5py.File(out, "w") as hf:
            for obname, obamps in obs_det_amps.items():
                n_det = len(obamps)
                det_list = list(sorted(obamps.keys()))
                det_indx = {y: x for x, y in enumerate(det_list)}
                indx_to_det = {det_indx[x]: x for x in det_list}
                n_amp = len(obamps[det_list[0]]["amps"])
                amp_min = [obamps[x]["min"] for x in det_list]
                amp_max = [obamps[x]["max"] for x in det_list]
                amp_incr = [obamps[x]["incr"] for x in det_list]

                # Create datasets for this observation
                hg = hf.create_group(obname)
                hg.attrs["detectors"] = json.dumps(det_list)
                hg.attrs["min"] = json.dumps(amp_min)
                hg.attrs["max"] = json.dumps(amp_max)
                hg.attrs["incr"] = json.dumps(amp_incr)
                hamps = hg.create_dataset(
                    "amplitudes",
                    (n_det, n_amp),
                    dtype=np.float64,
                )
                hhits = hg.create_dataset(
                    "hits",
                    (n_det, n_amp),
                    dtype=np.int32,
                )
                hflags = hg.create_dataset(
                    "flags",
                    (n_det, n_amp),
                    dtype=np.uint8,
                )

                # Write data
                for idet in range(n_det):
                    det = indx_to_det[idet]
                    dprops = obamps[det]
                    hslice = (slice(idet, idet + 1, 1), slice(0, n_amp, 1))
                    dslice = (slice(0, n_amp, 1),)
                    hamps.write_direct(dprops["amps"], dslice, hslice)
                    hhits.write_direct(dprops["hits"], dslice, hslice)
                    hflags.write_direct(dprops["flags"], dslice, hslice)

toast.templates.SubHarmonic

Bases: Template

This class represents sub-harmonic noise fluctuations.

Sub-harmonic means that the characteristic frequency of the noise modes is lower than 1/T where T is the length of the interval being fitted.

Every process stores the amplitudes for its local data, which is disjoint from the amplitudes on other processes. We project amplitudes one detector at a time, and so we arrange our template amplitudes in "detector major" order and store offsets into this for each observation.

Source code in toast/templates/subharmonic.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@trait_docs
class SubHarmonic(Template):
    """This class represents sub-harmonic noise fluctuations.

    Sub-harmonic means that the characteristic frequency of the noise
    modes is lower than 1/T where T is the length of the interval
    being fitted.

    Every process stores the amplitudes for its local data, which is disjoint from the
    amplitudes on other processes.  We project amplitudes one detector at a time, and
    so we arrange our template amplitudes in "detector major" order and store offsets
    into this for each observation.

    """

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    order = Int(1, help="The filter order")

    noise_model = Unicode(
        None,
        allow_none=True,
        help="Observation key containing the optional noise model",
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _initialize(self, new_data):
        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        # Build up detector list
        for iob, ob in enumerate(new_data.obs):
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None

        self._all_dets = list(all_dets.keys())

        # The inverse variance units
        invvar_units = 1.0 / (self.det_data_units**2)

        # Go through the data one local detector at a time and compute the offsets into
        # the amplitudes.

        # The starting amplitude for each detector within the local amplitude data.
        self._det_start = dict()

        offset = 0
        for det in self._all_dets:
            self._det_start[det] = offset
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                # We have one set of amplitudes for each detector in each view
                offset += len(ob.view[self.view]) * (self.order + 1)

        # Now we know the total number of amplitudes.

        self._n_local = offset
        if new_data.comm.comm_world is None:
            self._n_global = self._n_local
        else:
            self._n_global = new_data.comm.comm_world.allreduce(
                self._n_local, op=MPI.SUM
            )

        # The templates for each view of each obs
        self._templates = dict()

        # The preconditioner for each obs / view / detector
        self._precond = dict()

        # We are not constructing any data objects that are in the same order as the
        # amplitudes (we are just building dictionaries for lookups).  In this case,
        # it is easier to just build these by looping in observation order rather than
        # detector order.

        for iob, ob in enumerate(new_data.obs):
            # Build the templates and preconditioners for every view.
            self._templates[iob] = list()
            self._precond[iob] = dict()
            norder = self.order + 1

            noise = None
            if self.noise_model in ob:
                noise = ob[self.noise_model]

            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                view_len = None
                if vw.start is None:
                    # This is a view of the whole obs
                    view_len = ob.n_local_samples
                else:
                    view_len = vw.stop - vw.start

                templates = np.zeros((norder, view_len), dtype=np.float64)
                r = np.linspace(-1.0, 1.0, view_len)
                for order in range(norder):
                    if order == 0:
                        templates[order] = 1.0
                    elif order == 1:
                        templates[order] = r
                    else:
                        templates[order] = (
                            (2 * order - 1) * r * templates[order - 1]
                            - (order - 1) * templates[order - 2]
                        ) / order
                self._templates[iob].append(templates)

                self._precond[iob][ivw] = dict()
                for det in ob.local_detectors:
                    if det not in self._obs_dets[iob]:
                        continue
                    detweight = 1.0
                    if noise is not None:
                        detweight = noise.detector_weight(det).to_value(invvar_units)

                    good = slice(0, view_len, 1)
                    if self.det_flags is not None:
                        flags = views.detdata[self.det_flags][ivw][det]
                        good = (flags & self.det_flag_mask) == 0

                    prec = np.zeros((norder, norder), dtype=np.float64)
                    for row in range(norder):
                        for col in range(row, norder):
                            prec[row, col] = np.dot(
                                templates[row][good], templates[col][good]
                            )
                            prec[row, col] *= detweight
                            if row != col:
                                prec[col, row] = prec[row, col]
                    self._precond[iob][ivw][det] = np.linalg.inv(prec)

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        z = Amplitudes(self.data.comm, self._n_global, self._n_local)
        # No explicit flagging of amplitudes in this template...
        # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        norder = self.order + 1
        offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                amp_view = amplitudes.local[offset : offset + norder]
                for order in range(norder):
                    vw[detector] += self._templates[iob][ivw][order] * amp_view[order]
                offset += norder

    def _project_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        norder = self.order + 1
        offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                amp_view = amplitudes.local[offset : offset + norder]
                for order, template in enumerate(self._templates[iob][ivw]):
                    amp_view[order] = np.dot(vw[detector], template)
                offset += norder

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        # No prior for this template, nothing to accumulate to output.
        return

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        norder = self.order + 1
        for det in self._all_dets:
            offset = self._det_start[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    amps_in = amplitudes_in.local[offset : offset + norder]
                    amps_out = amplitudes_out.local[offset : offset + norder]
                    amps_out[:] = np.dot(self._precond[iob][ivw][det], amps_in)
                    offset += norder

noise_model = Unicode(None, allow_none=True, help='Observation key containing the optional noise model') class-attribute instance-attribute

order = Int(1, help='The filter order') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/templates/subharmonic.py
54
55
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/subharmonic.py
212
213
214
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    # No prior for this template, nothing to accumulate to output.
    return

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/subharmonic.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def _add_to_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    norder = self.order + 1
    offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
            amp_view = amplitudes.local[offset : offset + norder]
            for order in range(norder):
                vw[detector] += self._templates[iob][ivw][order] * amp_view[order]
            offset += norder

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/subharmonic.py
216
217
218
219
220
221
222
223
224
225
226
227
228
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    norder = self.order + 1
    for det in self._all_dets:
        offset = self._det_start[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                amps_in = amplitudes_in.local[offset : offset + norder]
                amps_out = amplitudes_out.local[offset : offset + norder]
                amps_out[:] = np.dot(self._precond[iob][ivw][det], amps_in)
                offset += norder

_detectors()

Source code in toast/templates/subharmonic.py
173
174
def _detectors(self):
    return self._all_dets

_initialize(new_data)

Source code in toast/templates/subharmonic.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def _initialize(self, new_data):
    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    # Build up detector list
    for iob, ob in enumerate(new_data.obs):
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None

    self._all_dets = list(all_dets.keys())

    # The inverse variance units
    invvar_units = 1.0 / (self.det_data_units**2)

    # Go through the data one local detector at a time and compute the offsets into
    # the amplitudes.

    # The starting amplitude for each detector within the local amplitude data.
    self._det_start = dict()

    offset = 0
    for det in self._all_dets:
        self._det_start[det] = offset
        for iob, ob in enumerate(new_data.obs):
            if det not in self._obs_dets[iob]:
                continue
            # We have one set of amplitudes for each detector in each view
            offset += len(ob.view[self.view]) * (self.order + 1)

    # Now we know the total number of amplitudes.

    self._n_local = offset
    if new_data.comm.comm_world is None:
        self._n_global = self._n_local
    else:
        self._n_global = new_data.comm.comm_world.allreduce(
            self._n_local, op=MPI.SUM
        )

    # The templates for each view of each obs
    self._templates = dict()

    # The preconditioner for each obs / view / detector
    self._precond = dict()

    # We are not constructing any data objects that are in the same order as the
    # amplitudes (we are just building dictionaries for lookups).  In this case,
    # it is easier to just build these by looping in observation order rather than
    # detector order.

    for iob, ob in enumerate(new_data.obs):
        # Build the templates and preconditioners for every view.
        self._templates[iob] = list()
        self._precond[iob] = dict()
        norder = self.order + 1

        noise = None
        if self.noise_model in ob:
            noise = ob[self.noise_model]

        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            view_len = None
            if vw.start is None:
                # This is a view of the whole obs
                view_len = ob.n_local_samples
            else:
                view_len = vw.stop - vw.start

            templates = np.zeros((norder, view_len), dtype=np.float64)
            r = np.linspace(-1.0, 1.0, view_len)
            for order in range(norder):
                if order == 0:
                    templates[order] = 1.0
                elif order == 1:
                    templates[order] = r
                else:
                    templates[order] = (
                        (2 * order - 1) * r * templates[order - 1]
                        - (order - 1) * templates[order - 2]
                    ) / order
            self._templates[iob].append(templates)

            self._precond[iob][ivw] = dict()
            for det in ob.local_detectors:
                if det not in self._obs_dets[iob]:
                    continue
                detweight = 1.0
                if noise is not None:
                    detweight = noise.detector_weight(det).to_value(invvar_units)

                good = slice(0, view_len, 1)
                if self.det_flags is not None:
                    flags = views.detdata[self.det_flags][ivw][det]
                    good = (flags & self.det_flag_mask) == 0

                prec = np.zeros((norder, norder), dtype=np.float64)
                for row in range(norder):
                    for col in range(row, norder):
                        prec[row, col] = np.dot(
                            templates[row][good], templates[col][good]
                        )
                        prec[row, col] *= detweight
                        if row != col:
                            prec[col, row] = prec[row, col]
                self._precond[iob][ivw][det] = np.linalg.inv(prec)

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/subharmonic.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def _project_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    norder = self.order + 1
    offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
            amp_view = amplitudes.local[offset : offset + norder]
            for order, template in enumerate(self._templates[iob][ivw]):
                amp_view[order] = np.dot(vw[detector], template)
            offset += norder

_zeros()

Source code in toast/templates/subharmonic.py
176
177
178
179
180
def _zeros(self):
    z = Amplitudes(self.data.comm, self._n_global, self._n_local)
    # No explicit flagging of amplitudes in this template...
    # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

toast.templates.Fourier2D

Bases: Template

This class models 2D Fourier modes across the focalplane.

Since the modes are shared across detectors, our amplitudes are organized by observation and views within each observation. Each detector projection will traverse all the local amplitudes.

Source code in toast/templates/fourier2d.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
@trait_docs
class Fourier2D(Template):
    """This class models 2D Fourier modes across the focalplane.

    Since the modes are shared across detectors, our amplitudes are organized by
    observation and views within each observation.  Each detector projection
    will traverse all the local amplitudes.

    """

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    correlation_length = Quantity(10.0 * u.second, help="Correlation length in time")

    correlation_amplitude = Float(10.0, help="Scale factor of the filter")

    order = Int(1, help="The filter order")

    fit_subharmonics = Bool(True, help="If True, fit subharmonics")

    noise_model = Unicode(
        None,
        allow_none=True,
        help="Observation key containing the optional noise model",
    )

    debug_plots = Unicode(
        None,
        allow_none=True,
        help="If not None, make debugging plots in this directory",
    )

    @traitlets.validate("order")
    def _check_order(self, proposal):
        od = proposal["value"]
        if od < 1:
            raise traitlets.TraitError("Filter order should be >= 1")
        return od

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def clear(self):
        """Delete the underlying C-allocated memory."""
        if hasattr(self, "_norms"):
            del self._norms
        if hasattr(self, "_norms_raw"):
            self._norms_raw.clear()
            del self._norms_raw

    def __del__(self):
        self.clear()

    def _initialize(self, new_data):
        zaxis = np.array([0.0, 0.0, 1.0])

        # This function is called whenever a new data trait is assigned to the template.
        # Clear any C-allocated buffers from previous uses.
        self.clear()

        self._norder = self.order + 1
        self._nmode = (2 * self.order) ** 2 + 1
        if self.fit_subharmonics:
            self._nmode += 2

        # The inverse variance units
        invvar_units = 1.0 / (self.det_data_units**2)

        # Every process determines their local amplitude ranges.

        # The local ranges of amplitudes (in terms of global indices)
        self._local_ranges = list()

        # Starting local amplitude for each view within each obs
        self._obs_view_local_offset = dict()

        # Starting global amplitude for each view within each obs
        self._obs_view_global_offset = dict()

        # Number of amplitudes in each local view for each obs
        self._obs_view_namp = dict()

        # This is the total number of amplitudes for each observation, across all
        # views.
        self._obs_total_namp = dict()

        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        local_offset = 0
        global_offset = 0

        for iob, ob in enumerate(new_data.obs):
            self._obs_view_namp[iob] = list()
            self._obs_view_local_offset[iob] = list()
            self._obs_view_global_offset[iob] = list()

            # Build up detector list
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None

            obs_n_amp = 0

            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                # First obs sample of this view
                obs_start = ob.local_index_offset

                view_len = None
                if vw.start is None:
                    # This is a view of the whole obs
                    view_len = ob.n_local_samples
                else:
                    view_len = vw.stop - vw.start
                    obs_start += vw.start

                obs_offset = obs_start * self._nmode

                self._obs_view_local_offset[iob].append(local_offset)
                self._obs_view_global_offset[iob].append(global_offset + obs_offset)

                view_n_amp = view_len * self._nmode
                obs_n_amp += view_n_amp
                self._obs_view_namp[iob].append(view_n_amp)

                self._local_ranges.append((global_offset + obs_offset, view_n_amp))

                local_offset += view_n_amp

            # To get the total number of amplitudes in this observation, we must
            # accumulate across the grid row communicator.
            if ob.comm_row is not None:
                obs_n_amp = ob.comm_row.allreduce(obs_n_amp)
            self._obs_total_namp[iob] = obs_n_amp
            global_offset += obs_n_amp

        self._all_dets = list(all_dets.keys())

        # The global number of amplitudes for our process group and our local process.
        # Since different groups have different observations, their amplitude values
        # are completely disjoint.  We create Amplitudes with the `use_group` option
        # and so only have to consider the full set if we are doing things like I/O
        # (i.e. nothing needed by this class).

        self._n_global = np.sum(
            [self._obs_total_namp[x] for x, y in enumerate(new_data.obs)]
        )

        self._n_local = np.sum([x[1] for x in self._local_ranges])

        # Allocate norms.  This data is the same size as a set of amplitudes,
        # so we allocate it in C memory.

        if self._n_local == 0:
            self._norms_raw = None
            self._norms = None
        else:
            self._norms_raw = AlignedF64.zeros(self._n_local)
            self._norms = self._norms_raw.array()

        def evaluate_template(theta, phi, radius):
            """Helper function to get the template values for a detector."""
            values = np.zeros(self._nmode)
            values[0] = 1
            offset = 1
            if self.fit_subharmonics:
                values[1:3] = theta / radius, phi / radius
                offset += 2
            if self.order > 0:
                rinv = np.pi / radius
                orders = np.arange(self.order) + 1
                thetavec = np.zeros(self.order * 2)
                phivec = np.zeros(self.order * 2)
                thetavec[::2] = np.cos(orders * theta * rinv)
                thetavec[1::2] = np.sin(orders * theta * rinv)
                phivec[::2] = np.cos(orders * phi * rinv)
                phivec[1::2] = np.sin(orders * phi * rinv)
                values[offset:] = np.outer(thetavec, phivec).ravel()
            return values

        # The detector templates for each observation
        self._templates = dict()

        # The noise filter for each observation
        self._filters = dict()

        for iob, ob in enumerate(new_data.obs):
            # Focalplane for this observation
            fp = ob.telescope.focalplane

            # Focalplane radius
            radius = 0.5 * fp.field_of_view.to_value(u.radian)

            noise = None
            if self.noise_model in ob:
                noise = ob[self.noise_model]

            self._templates[iob] = list()
            self._filters[iob] = list()

            obs_local_namp = 0

            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                view_len = None
                if vw.start is None:
                    # This is a view of the whole obs
                    view_len = ob.n_local_samples
                else:
                    view_len = vw.stop - vw.start

                # Build the filter for this view

                corr_len = self.correlation_length.to_value(u.second)
                times = views.shared[self.times][ivw]
                corr = (
                    np.exp((times[0] - times) / corr_len) * self.correlation_amplitude
                )
                ihalf = times.size // 2
                if times.size % 2 == 0:
                    corr[ihalf:] = corr[ihalf - 1 :: -1]
                else:
                    corr[ihalf + 1 :] = corr[ihalf - 1 :: -1]

                fcorr = np.fft.rfft(corr)
                fcorr_orig = np.copy(fcorr)
                too_small = fcorr < (1.0e-6 * self.correlation_amplitude)
                fcorr[too_small] = 1.0e-6 * self.correlation_amplitude
                invcorr = np.fft.irfft(1 / fcorr)
                self._filters[iob].append(invcorr)

                if self.debug_plots is not None and ob.comm.group_rank == 0:
                    os.makedirs(self.debug_plots, exist_ok=True)
                    set_matplotlib_backend(backend="pdf")

                    import matplotlib.pyplot as plt

                    figdpi = 100
                    plotfile = os.path.join(
                        self.debug_plots, f"f2d_{ob.name}_{ivw}_filter.pdf"
                    )
                    fig = plt.figure(dpi=figdpi, figsize=(8, 12))
                    xdata = np.arange(len(corr))
                    fxdata = np.arange(len(fcorr))
                    ax = fig.add_subplot(3, 1, 1)
                    ax.plot(xdata, corr, label=f"Input Corr")
                    ax.set_yscale("log")
                    ax.set_xlabel("Sample")
                    ax.set_ylabel("Amplitude")
                    ax.legend(loc="best")
                    ax = fig.add_subplot(3, 1, 2)
                    ax.plot(fxdata, fcorr_orig, label=f"Fcorr original")
                    ax.plot(fxdata, fcorr, label=f"Fcorr cut")
                    ax.set_yscale("log")
                    ax.set_xlabel("Frequency")
                    ax.set_ylabel("Amplitude")
                    ax.legend(loc="best")
                    ax = fig.add_subplot(3, 1, 3)
                    ax.plot(xdata, invcorr, label=f"Inverse Corr")
                    ax.set_xlabel("Sample")
                    ax.set_ylabel("Amplitude")
                    ax.legend(loc="best")
                    plt.savefig(plotfile, dpi=figdpi, bbox_inches="tight", format="pdf")
                    plt.close()

                # Now compute templates and norm for this view

                view_templates = dict()

                good = np.empty(view_len, dtype=np.float64)
                norm_slice = slice(
                    self._obs_view_local_offset[iob][ivw],
                    self._obs_view_local_offset[iob][ivw]
                    + self._obs_view_namp[iob][ivw],
                    1,
                )
                norms_view = self._norms[norm_slice].reshape((-1, self._nmode))

                for det in ob.local_detectors:
                    if det not in self._obs_dets[iob]:
                        continue
                    detweight = 1.0
                    if noise is not None:
                        detweight = noise.detector_weight(det).to_value(invvar_units)
                    det_quat = fp[det]["quat"]
                    x, y, z = qa.rotate(det_quat, zaxis)
                    theta, phi = np.arcsin([x, y])
                    view_templates[det] = evaluate_template(theta, phi, radius)

                    good[:] = 1.0
                    if self.det_flags is not None:
                        flags = views.detdata[self.det_flags][ivw][det]
                        good[(flags & self.det_flag_mask) != 0] = 0
                    norms_view += np.outer(good, view_templates[det] ** 2 * detweight)

                obs_local_namp += self._obs_view_namp[iob][ivw]
                self._templates[iob].append(view_templates)

            # Reduce norm values across the process grid column
            norm_slice = slice(
                self._obs_view_local_offset[iob][0],
                self._obs_view_local_offset[iob][0] + obs_local_namp,
                1,
            )
            norms_view = self._norms[norm_slice]
            if ob.comm_col is not None:
                temp = np.array(norms_view)
                ob.comm_col.Allreduce(temp, norms_view, op=MPI.SUM)
                del temp

            # Invert norms
            good = norms_view != 0
            norms_view[good] = 1.0 / norms_view[good]

        # Set the filter scale by the prescribed correlation strength
        # and the number of modes at each angular scale
        self._filter_scale = np.zeros(self._nmode)
        self._filter_scale[0] = 1
        offset = 1
        if self.fit_subharmonics:
            self._filter_scale[1:3] = 2
            offset += 2
        self._filter_scale[offset:] = 4
        self._filter_scale *= self.correlation_amplitude

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        # Return amplitudes distributed over the group communicator and using our
        # local ranges.
        z = Amplitudes(
            self.data.comm,
            self._n_global,
            self._n_local,
            local_ranges=self._local_ranges,
        )
        # Amplitude flags are not used by this template- if some samples are flagged
        # across all detectors then they will just not contribute to the projection.
        # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                amp_slice = slice(
                    self._obs_view_local_offset[iob][ivw],
                    self._obs_view_local_offset[iob][ivw]
                    + self._obs_view_namp[iob][ivw],
                    1,
                )
                views.detdata[self.det_data][ivw][detector] += np.sum(
                    amplitudes.local[amp_slice].reshape((-1, self._nmode))
                    * self._templates[iob][ivw][detector],
                    1,
                )

    def _project_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                amp_slice = slice(
                    self._obs_view_local_offset[iob][ivw],
                    self._obs_view_local_offset[iob][ivw]
                    + self._obs_view_namp[iob][ivw],
                    1,
                )
                amp_view = amplitudes.local[amp_slice].reshape((-1, self._nmode))
                amp_view[:] += np.outer(
                    views.detdata[self.det_data][ivw][detector],
                    self._templates[iob][ivw][detector],
                )

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        for iob, ob in enumerate(self.data.obs):
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                amp_slice = slice(
                    self._obs_view_local_offset[iob][ivw],
                    self._obs_view_local_offset[iob][ivw]
                    + self._obs_view_namp[iob][ivw],
                    1,
                )
                in_view = amplitudes_in.local[amp_slice].reshape((-1, self._nmode))
                out_view = amplitudes_out.local[amp_slice].reshape((-1, self._nmode))
                for mode in range(self._nmode):
                    scale = self._filter_scale[mode]
                    out_view[:, mode] += scipy.signal.convolve(
                        in_view[:, mode],
                        self._filters[iob][ivw] * scale,
                        mode="same",
                    )

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        amplitudes_out.local[:] = amplitudes_in.local
        amplitudes_out.local *= self._norms

correlation_amplitude = Float(10.0, help='Scale factor of the filter') class-attribute instance-attribute

correlation_length = Quantity(10.0 * u.second, help='Correlation length in time') class-attribute instance-attribute

debug_plots = Unicode(None, allow_none=True, help='If not None, make debugging plots in this directory') class-attribute instance-attribute

fit_subharmonics = Bool(True, help='If True, fit subharmonics') class-attribute instance-attribute

noise_model = Unicode(None, allow_none=True, help='Observation key containing the optional noise model') class-attribute instance-attribute

order = Int(1, help='The filter order') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__del__()

Source code in toast/templates/fourier2d.py
86
87
def __del__(self):
    self.clear()

__init__(**kwargs)

Source code in toast/templates/fourier2d.py
75
76
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/fourier2d.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    for iob, ob in enumerate(self.data.obs):
        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            amp_slice = slice(
                self._obs_view_local_offset[iob][ivw],
                self._obs_view_local_offset[iob][ivw]
                + self._obs_view_namp[iob][ivw],
                1,
            )
            in_view = amplitudes_in.local[amp_slice].reshape((-1, self._nmode))
            out_view = amplitudes_out.local[amp_slice].reshape((-1, self._nmode))
            for mode in range(self._nmode):
                scale = self._filter_scale[mode]
                out_view[:, mode] += scipy.signal.convolve(
                    in_view[:, mode],
                    self._filters[iob][ivw] * scale,
                    mode="same",
                )

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/fourier2d.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def _add_to_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            amp_slice = slice(
                self._obs_view_local_offset[iob][ivw],
                self._obs_view_local_offset[iob][ivw]
                + self._obs_view_namp[iob][ivw],
                1,
            )
            views.detdata[self.det_data][ivw][detector] += np.sum(
                amplitudes.local[amp_slice].reshape((-1, self._nmode))
                * self._templates[iob][ivw][detector],
                1,
            )

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/fourier2d.py
449
450
451
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    amplitudes_out.local[:] = amplitudes_in.local
    amplitudes_out.local *= self._norms

_check_order(proposal)

Source code in toast/templates/fourier2d.py
68
69
70
71
72
73
@traitlets.validate("order")
def _check_order(self, proposal):
    od = proposal["value"]
    if od < 1:
        raise traitlets.TraitError("Filter order should be >= 1")
    return od

_detectors()

Source code in toast/templates/fourier2d.py
370
371
def _detectors(self):
    return self._all_dets

_initialize(new_data)

Source code in toast/templates/fourier2d.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def _initialize(self, new_data):
    zaxis = np.array([0.0, 0.0, 1.0])

    # This function is called whenever a new data trait is assigned to the template.
    # Clear any C-allocated buffers from previous uses.
    self.clear()

    self._norder = self.order + 1
    self._nmode = (2 * self.order) ** 2 + 1
    if self.fit_subharmonics:
        self._nmode += 2

    # The inverse variance units
    invvar_units = 1.0 / (self.det_data_units**2)

    # Every process determines their local amplitude ranges.

    # The local ranges of amplitudes (in terms of global indices)
    self._local_ranges = list()

    # Starting local amplitude for each view within each obs
    self._obs_view_local_offset = dict()

    # Starting global amplitude for each view within each obs
    self._obs_view_global_offset = dict()

    # Number of amplitudes in each local view for each obs
    self._obs_view_namp = dict()

    # This is the total number of amplitudes for each observation, across all
    # views.
    self._obs_total_namp = dict()

    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    local_offset = 0
    global_offset = 0

    for iob, ob in enumerate(new_data.obs):
        self._obs_view_namp[iob] = list()
        self._obs_view_local_offset[iob] = list()
        self._obs_view_global_offset[iob] = list()

        # Build up detector list
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None

        obs_n_amp = 0

        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            # First obs sample of this view
            obs_start = ob.local_index_offset

            view_len = None
            if vw.start is None:
                # This is a view of the whole obs
                view_len = ob.n_local_samples
            else:
                view_len = vw.stop - vw.start
                obs_start += vw.start

            obs_offset = obs_start * self._nmode

            self._obs_view_local_offset[iob].append(local_offset)
            self._obs_view_global_offset[iob].append(global_offset + obs_offset)

            view_n_amp = view_len * self._nmode
            obs_n_amp += view_n_amp
            self._obs_view_namp[iob].append(view_n_amp)

            self._local_ranges.append((global_offset + obs_offset, view_n_amp))

            local_offset += view_n_amp

        # To get the total number of amplitudes in this observation, we must
        # accumulate across the grid row communicator.
        if ob.comm_row is not None:
            obs_n_amp = ob.comm_row.allreduce(obs_n_amp)
        self._obs_total_namp[iob] = obs_n_amp
        global_offset += obs_n_amp

    self._all_dets = list(all_dets.keys())

    # The global number of amplitudes for our process group and our local process.
    # Since different groups have different observations, their amplitude values
    # are completely disjoint.  We create Amplitudes with the `use_group` option
    # and so only have to consider the full set if we are doing things like I/O
    # (i.e. nothing needed by this class).

    self._n_global = np.sum(
        [self._obs_total_namp[x] for x, y in enumerate(new_data.obs)]
    )

    self._n_local = np.sum([x[1] for x in self._local_ranges])

    # Allocate norms.  This data is the same size as a set of amplitudes,
    # so we allocate it in C memory.

    if self._n_local == 0:
        self._norms_raw = None
        self._norms = None
    else:
        self._norms_raw = AlignedF64.zeros(self._n_local)
        self._norms = self._norms_raw.array()

    def evaluate_template(theta, phi, radius):
        """Helper function to get the template values for a detector."""
        values = np.zeros(self._nmode)
        values[0] = 1
        offset = 1
        if self.fit_subharmonics:
            values[1:3] = theta / radius, phi / radius
            offset += 2
        if self.order > 0:
            rinv = np.pi / radius
            orders = np.arange(self.order) + 1
            thetavec = np.zeros(self.order * 2)
            phivec = np.zeros(self.order * 2)
            thetavec[::2] = np.cos(orders * theta * rinv)
            thetavec[1::2] = np.sin(orders * theta * rinv)
            phivec[::2] = np.cos(orders * phi * rinv)
            phivec[1::2] = np.sin(orders * phi * rinv)
            values[offset:] = np.outer(thetavec, phivec).ravel()
        return values

    # The detector templates for each observation
    self._templates = dict()

    # The noise filter for each observation
    self._filters = dict()

    for iob, ob in enumerate(new_data.obs):
        # Focalplane for this observation
        fp = ob.telescope.focalplane

        # Focalplane radius
        radius = 0.5 * fp.field_of_view.to_value(u.radian)

        noise = None
        if self.noise_model in ob:
            noise = ob[self.noise_model]

        self._templates[iob] = list()
        self._filters[iob] = list()

        obs_local_namp = 0

        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            view_len = None
            if vw.start is None:
                # This is a view of the whole obs
                view_len = ob.n_local_samples
            else:
                view_len = vw.stop - vw.start

            # Build the filter for this view

            corr_len = self.correlation_length.to_value(u.second)
            times = views.shared[self.times][ivw]
            corr = (
                np.exp((times[0] - times) / corr_len) * self.correlation_amplitude
            )
            ihalf = times.size // 2
            if times.size % 2 == 0:
                corr[ihalf:] = corr[ihalf - 1 :: -1]
            else:
                corr[ihalf + 1 :] = corr[ihalf - 1 :: -1]

            fcorr = np.fft.rfft(corr)
            fcorr_orig = np.copy(fcorr)
            too_small = fcorr < (1.0e-6 * self.correlation_amplitude)
            fcorr[too_small] = 1.0e-6 * self.correlation_amplitude
            invcorr = np.fft.irfft(1 / fcorr)
            self._filters[iob].append(invcorr)

            if self.debug_plots is not None and ob.comm.group_rank == 0:
                os.makedirs(self.debug_plots, exist_ok=True)
                set_matplotlib_backend(backend="pdf")

                import matplotlib.pyplot as plt

                figdpi = 100
                plotfile = os.path.join(
                    self.debug_plots, f"f2d_{ob.name}_{ivw}_filter.pdf"
                )
                fig = plt.figure(dpi=figdpi, figsize=(8, 12))
                xdata = np.arange(len(corr))
                fxdata = np.arange(len(fcorr))
                ax = fig.add_subplot(3, 1, 1)
                ax.plot(xdata, corr, label=f"Input Corr")
                ax.set_yscale("log")
                ax.set_xlabel("Sample")
                ax.set_ylabel("Amplitude")
                ax.legend(loc="best")
                ax = fig.add_subplot(3, 1, 2)
                ax.plot(fxdata, fcorr_orig, label=f"Fcorr original")
                ax.plot(fxdata, fcorr, label=f"Fcorr cut")
                ax.set_yscale("log")
                ax.set_xlabel("Frequency")
                ax.set_ylabel("Amplitude")
                ax.legend(loc="best")
                ax = fig.add_subplot(3, 1, 3)
                ax.plot(xdata, invcorr, label=f"Inverse Corr")
                ax.set_xlabel("Sample")
                ax.set_ylabel("Amplitude")
                ax.legend(loc="best")
                plt.savefig(plotfile, dpi=figdpi, bbox_inches="tight", format="pdf")
                plt.close()

            # Now compute templates and norm for this view

            view_templates = dict()

            good = np.empty(view_len, dtype=np.float64)
            norm_slice = slice(
                self._obs_view_local_offset[iob][ivw],
                self._obs_view_local_offset[iob][ivw]
                + self._obs_view_namp[iob][ivw],
                1,
            )
            norms_view = self._norms[norm_slice].reshape((-1, self._nmode))

            for det in ob.local_detectors:
                if det not in self._obs_dets[iob]:
                    continue
                detweight = 1.0
                if noise is not None:
                    detweight = noise.detector_weight(det).to_value(invvar_units)
                det_quat = fp[det]["quat"]
                x, y, z = qa.rotate(det_quat, zaxis)
                theta, phi = np.arcsin([x, y])
                view_templates[det] = evaluate_template(theta, phi, radius)

                good[:] = 1.0
                if self.det_flags is not None:
                    flags = views.detdata[self.det_flags][ivw][det]
                    good[(flags & self.det_flag_mask) != 0] = 0
                norms_view += np.outer(good, view_templates[det] ** 2 * detweight)

            obs_local_namp += self._obs_view_namp[iob][ivw]
            self._templates[iob].append(view_templates)

        # Reduce norm values across the process grid column
        norm_slice = slice(
            self._obs_view_local_offset[iob][0],
            self._obs_view_local_offset[iob][0] + obs_local_namp,
            1,
        )
        norms_view = self._norms[norm_slice]
        if ob.comm_col is not None:
            temp = np.array(norms_view)
            ob.comm_col.Allreduce(temp, norms_view, op=MPI.SUM)
            del temp

        # Invert norms
        good = norms_view != 0
        norms_view[good] = 1.0 / norms_view[good]

    # Set the filter scale by the prescribed correlation strength
    # and the number of modes at each angular scale
    self._filter_scale = np.zeros(self._nmode)
    self._filter_scale[0] = 1
    offset = 1
    if self.fit_subharmonics:
        self._filter_scale[1:3] = 2
        offset += 2
    self._filter_scale[offset:] = 4
    self._filter_scale *= self.correlation_amplitude

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/fourier2d.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def _project_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            amp_slice = slice(
                self._obs_view_local_offset[iob][ivw],
                self._obs_view_local_offset[iob][ivw]
                + self._obs_view_namp[iob][ivw],
                1,
            )
            amp_view = amplitudes.local[amp_slice].reshape((-1, self._nmode))
            amp_view[:] += np.outer(
                views.detdata[self.det_data][ivw][detector],
                self._templates[iob][ivw][detector],
            )

_zeros()

Source code in toast/templates/fourier2d.py
373
374
375
376
377
378
379
380
381
382
383
384
385
def _zeros(self):
    # Return amplitudes distributed over the group communicator and using our
    # local ranges.
    z = Amplitudes(
        self.data.comm,
        self._n_global,
        self._n_local,
        local_ranges=self._local_ranges,
    )
    # Amplitude flags are not used by this template- if some samples are flagged
    # across all detectors then they will just not contribute to the projection.
    # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

clear()

Delete the underlying C-allocated memory.

Source code in toast/templates/fourier2d.py
78
79
80
81
82
83
84
def clear(self):
    """Delete the underlying C-allocated memory."""
    if hasattr(self, "_norms"):
        del self._norms
    if hasattr(self, "_norms_raw"):
        self._norms_raw.clear()
        del self._norms_raw

toast.templates.Hwpss

Bases: Template

This template represents the HWP synchronous signal.

Source code in toast/templates/hwpss.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@trait_docs
class Hwpss(Template):
    """This template represents the HWP synchronous signal."""

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    hwp_angle = Unicode(
        defaults.hwp_angle, allow_none=True, help="Observation shared key for HWP angle"
    )

    hwp_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for HWP flags",
    )

    hwp_flag_mask = Int(
        defaults.shared_mask_invalid,
        help="Bit mask to use when considering valid HWP angle values.",
    )

    harmonics = Int(9, help="Number of harmonics to consider in the expansion")

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    debug_plots = Unicode(
        None,
        allow_none=True,
        help="If not None, make debugging plots in this directory",
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _initialize(self, new_data):
        log = Logger.get()

        # For each harmonic with the sin, cos, and time drift terms for each.
        self._n_coeff = 4 * self.harmonics

        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        # Build up detector list
        for iob, ob in enumerate(new_data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None
        self._all_dets = list(all_dets.keys())

        # During application of the template, we will be looping over detectors
        # in the outer loop.  So we pack the amplitudes by detector and then by
        # observation.  For each observation, we precompute the quantities that
        # are common to all detectors.

        self._det_offset = dict()
        self._obs_reltime = dict()
        self._obs_sincos = dict()
        self._obs_cov = dict()
        self._obs_outview = dict()

        offset = 0
        for det in self._all_dets:
            self._det_offset[det] = offset
            for iob, ob in enumerate(new_data.obs):
                if self.hwp_angle not in ob.shared:
                    continue
                if det not in self._obs_dets[iob]:
                    continue
                if iob not in self._obs_reltime:
                    # First time we are considering this observation
                    if self.view is not None:
                        self._obs_outview[iob] = ~ob.intervals[self.view]
                    times = np.array(ob.shared[self.times].data, copy=True)
                    time_offset = times[0]
                    times -= time_offset
                    self._obs_reltime[iob] = times.astype(np.float32)
                    if self.hwp_flags is None:
                        flags = np.zeros(len(times), dtype=np.uint8)
                    else:
                        flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
                    self._obs_sincos[iob] = hwpss_sincos_buffer(
                        ob.shared[self.hwp_angle].data,
                        flags,
                        self.harmonics,
                        comm=ob.comm.comm_group,
                    )
                    self._obs_cov[iob] = hwpss_compute_coeff_covariance(
                        self._obs_sincos[iob],
                        flags,
                        comm=ob.comm.comm_group,
                        times=self._obs_reltime[iob],
                        time_drift=True,
                    )
                offset += self._n_coeff

        # Now we know the total number of local amplitudes.
        if offset == 0:
            # This means that no observations included a HWP angle
            msg = f"Data has no observations with HWP angle '{self.hwp_angle}'."
            msg += "  You should disable this template."
            log.error(msg)
            raise RuntimeError(msg)

        self._n_local = offset
        if new_data.comm.comm_world is None:
            self._n_global = self._n_local
        else:
            self._n_global = new_data.comm.comm_world.allreduce(
                self._n_local, op=MPI.SUM
            )

        # Boolean flags
        if self._n_local == 0:
            self._amp_flags = None
        else:
            self._amp_flags = np.zeros(self._n_local, dtype=bool)

        for det in self._all_dets:
            amp_offset = self._det_offset[det]
            for iob, ob in enumerate(new_data.obs):
                if self.hwp_angle not in ob.shared:
                    continue
                if det not in self._obs_dets[iob]:
                    continue
                if self._obs_cov[iob] is None:
                    # This observation has poorly conditioned covariance
                    self._amp_flags[amp_offset : amp_offset + self._n_coeff] = True
                amp_offset += self._n_coeff

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        z = Amplitudes(self.data.comm, self._n_global, self._n_local)
        if z.local_flags is not None:
            z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        amp_offset = self._det_offset[detector]
        for iob, ob in enumerate(self.data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            if detector not in self._obs_dets[iob]:
                continue
            if self.hwp_flags is None:
                flags = np.zeros(ob.n_local_samples, dtype=np.uint8)
            else:
                flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
            if self.view is not None:
                # Flag samples outside the valid intervals
                for vw in self._obs_outview[iob]:
                    vw_slc = slice(vw.first, vw.last, 1)
                    flags[vw_slc] = 1
            coeff = amplitudes.local[amp_offset : amp_offset + self._n_coeff]
            model = hwpss_build_model(
                self._obs_sincos[iob],
                flags,
                coeff,
                times=self._obs_reltime[iob],
                time_drift=True,
            )
            good = flags == 0
            # Accumulate to timestream
            ob.detdata[self.det_data][detector][good] += model[good]
            amp_offset += self._n_coeff

    def _project_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        amp_offset = self._det_offset[detector]
        for iob, ob in enumerate(self.data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            if detector not in self._obs_dets[iob]:
                continue
            if self.hwp_flags is None:
                flags = np.zeros(ob.n_local_samples, dtype=np.uint8)
            else:
                flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
            if self.view is not None:
                # Flag samples outside the valid intervals
                for vw in self._obs_outview[iob]:
                    vw_slc = slice(vw.first, vw.last, 1)
                    flags[vw_slc] = 1
            if self.det_flags is not None:
                flags |= ob.detdata[self.det_flags][detector] & self.det_flag_mask
            if self._obs_cov[iob] is None:
                # Flagged
                amplitudes.local[amp_offset : amp_offset + self._n_coeff] = 0
            else:
                coeff = hwpss_compute_coeff(
                    self._obs_sincos[iob],
                    ob.detdata[self.det_data][detector],
                    flags,
                    self._obs_cov[iob][0],
                    self._obs_cov[iob][1],
                    times=self._obs_reltime[iob],
                    time_drift=True,
                )
                amplitudes.local[amp_offset : amp_offset + self._n_coeff] = coeff
            amp_offset += self._n_coeff

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        # No prior for this template, nothing to accumulate to output.
        return

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        # Just the identity matrix
        amplitudes_out.local[:] = amplitudes_in.local
        return

    @function_timer
    def write(self, amplitudes, out):
        """Write out amplitude values.

        This stores the amplitudes to a file for debugging / plotting.
        WARNING: currently this only works for data distributed by
        detector.

        Args:
            amplitudes (Amplitudes):  The amplitude data.
            out (str):  The output file.

        Returns:
            None

        """
        obs_det_amps = dict()
        obs_reltime = dict()
        obs_hwpang = dict()

        for det in self._all_dets:
            amp_offset = self._det_offset[det]
            for iob, ob in enumerate(self.data.obs):
                if self.hwp_angle not in ob.shared:
                    continue
                if det not in self._obs_dets[iob]:
                    continue
                if ob.name not in obs_det_amps:
                    obs_det_amps[ob.name] = dict()
                if ob.comm.group_rank == 0:
                    if ob.name not in obs_reltime:
                        obs_reltime[ob.name] = np.array(self._obs_reltime[iob])
                        if self.hwp_flags is not None:
                            flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
                            # Set flagged samples to a negative time value, to
                            # communicate the flags to downstream code.
                            obs_reltime[ob.name][flags != 0] = -1.0
                        obs_hwpang[ob.name] = np.array(
                            ob.shared[self.hwp_angle].data, dtype=np.float32
                        )
                obs_det_amps[ob.name][det] = amplitudes.local[
                    amp_offset : amp_offset + self._n_coeff
                ]
                amp_offset += self._n_coeff

        if self.data.comm.world_size == 1:
            all_obs_dets_amps = [obs_det_amps]
            all_obs_reltime = [obs_reltime]
            all_obs_hwpang = [obs_hwpang]
        else:
            all_obs_dets_amps = self.data.comm.comm_world.gather(obs_det_amps, root=0)
            all_obs_reltime = self.data.comm.comm_world.gather(obs_reltime, root=0)
            all_obs_hwpang = self.data.comm.comm_world.gather(obs_hwpang, root=0)

        if self.data.comm.world_rank == 0:
            obs_det_amps = dict()
            for pdata in all_obs_dets_amps:
                for obname in pdata.keys():
                    if obname not in obs_det_amps:
                        obs_det_amps[obname] = dict()
                    obs_det_amps[obname].update(pdata[obname])
            del all_obs_dets_amps

            obs_reltime = dict()
            for pdata in all_obs_reltime:
                if len(pdata) == 0:
                    continue
                for obname in pdata.keys():
                    if obname not in obs_reltime:
                        obs_reltime[obname] = pdata[obname]
            del all_obs_reltime

            obs_hwpang = dict()
            for pdata in all_obs_hwpang:
                if len(pdata) == 0:
                    continue
                for obname in pdata.keys():
                    if obname not in obs_hwpang:
                        obs_hwpang[obname] = pdata[obname]
            del all_obs_hwpang

            with h5py.File(out, "w") as hf:
                for obname, obamps in obs_det_amps.items():
                    n_det = len(obamps)
                    det_list = list(sorted(obamps.keys()))
                    det_indx = {y: x for x, y in enumerate(det_list)}
                    indx_to_det = {det_indx[x]: x for x in det_list}
                    n_amp = len(obamps[det_list[0]])
                    n_samp = len(obs_reltime[obname])

                    # Create datasets for this observation
                    hg = hf.create_group(obname)
                    hg.attrs["detectors"] = json.dumps(det_list)
                    hamps = hg.create_dataset(
                        "amplitudes",
                        (n_det, n_amp),
                        dtype=np.float64,
                    )
                    htime = hg.create_dataset(
                        "reltime",
                        (n_samp,),
                        dtype=np.float32,
                    )
                    hang = hg.create_dataset(
                        "hwpangle",
                        (n_samp,),
                        dtype=np.float32,
                    )

                    # Write data
                    samp_slice = (slice(0, n_samp, 1),)
                    htime.write_direct(obs_reltime[obname], samp_slice, samp_slice)
                    hang.write_direct(obs_hwpang[obname], samp_slice, samp_slice)
                    for idet in range(n_det):
                        det = indx_to_det[idet]
                        hslice = (slice(idet, idet + 1, 1), slice(0, n_amp, 1))
                        dslice = (slice(0, n_amp, 1),)
                        hamps.write_direct(obamps[det], dslice, hslice)

debug_plots = Unicode(None, allow_none=True, help='If not None, make debugging plots in this directory') class-attribute instance-attribute

harmonics = Int(9, help='Number of harmonics to consider in the expansion') class-attribute instance-attribute

hwp_angle = Unicode(defaults.hwp_angle, allow_none=True, help='Observation shared key for HWP angle') class-attribute instance-attribute

hwp_flag_mask = Int(defaults.shared_mask_invalid, help='Bit mask to use when considering valid HWP angle values.') class-attribute instance-attribute

hwp_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for HWP flags') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/templates/hwpss.py
70
71
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/hwpss.py
257
258
259
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    # No prior for this template, nothing to accumulate to output.
    return

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/hwpss.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def _add_to_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    amp_offset = self._det_offset[detector]
    for iob, ob in enumerate(self.data.obs):
        if self.hwp_angle not in ob.shared:
            continue
        if detector not in self._obs_dets[iob]:
            continue
        if self.hwp_flags is None:
            flags = np.zeros(ob.n_local_samples, dtype=np.uint8)
        else:
            flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
        if self.view is not None:
            # Flag samples outside the valid intervals
            for vw in self._obs_outview[iob]:
                vw_slc = slice(vw.first, vw.last, 1)
                flags[vw_slc] = 1
        coeff = amplitudes.local[amp_offset : amp_offset + self._n_coeff]
        model = hwpss_build_model(
            self._obs_sincos[iob],
            flags,
            coeff,
            times=self._obs_reltime[iob],
            time_drift=True,
        )
        good = flags == 0
        # Accumulate to timestream
        ob.detdata[self.det_data][detector][good] += model[good]
        amp_offset += self._n_coeff

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/hwpss.py
261
262
263
264
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    # Just the identity matrix
    amplitudes_out.local[:] = amplitudes_in.local
    return

_detectors()

Source code in toast/templates/hwpss.py
179
180
def _detectors(self):
    return self._all_dets

_initialize(new_data)

Source code in toast/templates/hwpss.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _initialize(self, new_data):
    log = Logger.get()

    # For each harmonic with the sin, cos, and time drift terms for each.
    self._n_coeff = 4 * self.harmonics

    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    # Build up detector list
    for iob, ob in enumerate(new_data.obs):
        if self.hwp_angle not in ob.shared:
            continue
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None
    self._all_dets = list(all_dets.keys())

    # During application of the template, we will be looping over detectors
    # in the outer loop.  So we pack the amplitudes by detector and then by
    # observation.  For each observation, we precompute the quantities that
    # are common to all detectors.

    self._det_offset = dict()
    self._obs_reltime = dict()
    self._obs_sincos = dict()
    self._obs_cov = dict()
    self._obs_outview = dict()

    offset = 0
    for det in self._all_dets:
        self._det_offset[det] = offset
        for iob, ob in enumerate(new_data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            if det not in self._obs_dets[iob]:
                continue
            if iob not in self._obs_reltime:
                # First time we are considering this observation
                if self.view is not None:
                    self._obs_outview[iob] = ~ob.intervals[self.view]
                times = np.array(ob.shared[self.times].data, copy=True)
                time_offset = times[0]
                times -= time_offset
                self._obs_reltime[iob] = times.astype(np.float32)
                if self.hwp_flags is None:
                    flags = np.zeros(len(times), dtype=np.uint8)
                else:
                    flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
                self._obs_sincos[iob] = hwpss_sincos_buffer(
                    ob.shared[self.hwp_angle].data,
                    flags,
                    self.harmonics,
                    comm=ob.comm.comm_group,
                )
                self._obs_cov[iob] = hwpss_compute_coeff_covariance(
                    self._obs_sincos[iob],
                    flags,
                    comm=ob.comm.comm_group,
                    times=self._obs_reltime[iob],
                    time_drift=True,
                )
            offset += self._n_coeff

    # Now we know the total number of local amplitudes.
    if offset == 0:
        # This means that no observations included a HWP angle
        msg = f"Data has no observations with HWP angle '{self.hwp_angle}'."
        msg += "  You should disable this template."
        log.error(msg)
        raise RuntimeError(msg)

    self._n_local = offset
    if new_data.comm.comm_world is None:
        self._n_global = self._n_local
    else:
        self._n_global = new_data.comm.comm_world.allreduce(
            self._n_local, op=MPI.SUM
        )

    # Boolean flags
    if self._n_local == 0:
        self._amp_flags = None
    else:
        self._amp_flags = np.zeros(self._n_local, dtype=bool)

    for det in self._all_dets:
        amp_offset = self._det_offset[det]
        for iob, ob in enumerate(new_data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            if det not in self._obs_dets[iob]:
                continue
            if self._obs_cov[iob] is None:
                # This observation has poorly conditioned covariance
                self._amp_flags[amp_offset : amp_offset + self._n_coeff] = True
            amp_offset += self._n_coeff

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/hwpss.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def _project_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    amp_offset = self._det_offset[detector]
    for iob, ob in enumerate(self.data.obs):
        if self.hwp_angle not in ob.shared:
            continue
        if detector not in self._obs_dets[iob]:
            continue
        if self.hwp_flags is None:
            flags = np.zeros(ob.n_local_samples, dtype=np.uint8)
        else:
            flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
        if self.view is not None:
            # Flag samples outside the valid intervals
            for vw in self._obs_outview[iob]:
                vw_slc = slice(vw.first, vw.last, 1)
                flags[vw_slc] = 1
        if self.det_flags is not None:
            flags |= ob.detdata[self.det_flags][detector] & self.det_flag_mask
        if self._obs_cov[iob] is None:
            # Flagged
            amplitudes.local[amp_offset : amp_offset + self._n_coeff] = 0
        else:
            coeff = hwpss_compute_coeff(
                self._obs_sincos[iob],
                ob.detdata[self.det_data][detector],
                flags,
                self._obs_cov[iob][0],
                self._obs_cov[iob][1],
                times=self._obs_reltime[iob],
                time_drift=True,
            )
            amplitudes.local[amp_offset : amp_offset + self._n_coeff] = coeff
        amp_offset += self._n_coeff

_zeros()

Source code in toast/templates/hwpss.py
182
183
184
185
186
def _zeros(self):
    z = Amplitudes(self.data.comm, self._n_global, self._n_local)
    if z.local_flags is not None:
        z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

write(amplitudes, out)

Write out amplitude values.

This stores the amplitudes to a file for debugging / plotting. WARNING: currently this only works for data distributed by detector.

Parameters:

Name Type Description Default
amplitudes Amplitudes

The amplitude data.

required
out str

The output file.

required

Returns:

Type Description

None

Source code in toast/templates/hwpss.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@function_timer
def write(self, amplitudes, out):
    """Write out amplitude values.

    This stores the amplitudes to a file for debugging / plotting.
    WARNING: currently this only works for data distributed by
    detector.

    Args:
        amplitudes (Amplitudes):  The amplitude data.
        out (str):  The output file.

    Returns:
        None

    """
    obs_det_amps = dict()
    obs_reltime = dict()
    obs_hwpang = dict()

    for det in self._all_dets:
        amp_offset = self._det_offset[det]
        for iob, ob in enumerate(self.data.obs):
            if self.hwp_angle not in ob.shared:
                continue
            if det not in self._obs_dets[iob]:
                continue
            if ob.name not in obs_det_amps:
                obs_det_amps[ob.name] = dict()
            if ob.comm.group_rank == 0:
                if ob.name not in obs_reltime:
                    obs_reltime[ob.name] = np.array(self._obs_reltime[iob])
                    if self.hwp_flags is not None:
                        flags = ob.shared[self.hwp_flags].data & self.hwp_flag_mask
                        # Set flagged samples to a negative time value, to
                        # communicate the flags to downstream code.
                        obs_reltime[ob.name][flags != 0] = -1.0
                    obs_hwpang[ob.name] = np.array(
                        ob.shared[self.hwp_angle].data, dtype=np.float32
                    )
            obs_det_amps[ob.name][det] = amplitudes.local[
                amp_offset : amp_offset + self._n_coeff
            ]
            amp_offset += self._n_coeff

    if self.data.comm.world_size == 1:
        all_obs_dets_amps = [obs_det_amps]
        all_obs_reltime = [obs_reltime]
        all_obs_hwpang = [obs_hwpang]
    else:
        all_obs_dets_amps = self.data.comm.comm_world.gather(obs_det_amps, root=0)
        all_obs_reltime = self.data.comm.comm_world.gather(obs_reltime, root=0)
        all_obs_hwpang = self.data.comm.comm_world.gather(obs_hwpang, root=0)

    if self.data.comm.world_rank == 0:
        obs_det_amps = dict()
        for pdata in all_obs_dets_amps:
            for obname in pdata.keys():
                if obname not in obs_det_amps:
                    obs_det_amps[obname] = dict()
                obs_det_amps[obname].update(pdata[obname])
        del all_obs_dets_amps

        obs_reltime = dict()
        for pdata in all_obs_reltime:
            if len(pdata) == 0:
                continue
            for obname in pdata.keys():
                if obname not in obs_reltime:
                    obs_reltime[obname] = pdata[obname]
        del all_obs_reltime

        obs_hwpang = dict()
        for pdata in all_obs_hwpang:
            if len(pdata) == 0:
                continue
            for obname in pdata.keys():
                if obname not in obs_hwpang:
                    obs_hwpang[obname] = pdata[obname]
        del all_obs_hwpang

        with h5py.File(out, "w") as hf:
            for obname, obamps in obs_det_amps.items():
                n_det = len(obamps)
                det_list = list(sorted(obamps.keys()))
                det_indx = {y: x for x, y in enumerate(det_list)}
                indx_to_det = {det_indx[x]: x for x in det_list}
                n_amp = len(obamps[det_list[0]])
                n_samp = len(obs_reltime[obname])

                # Create datasets for this observation
                hg = hf.create_group(obname)
                hg.attrs["detectors"] = json.dumps(det_list)
                hamps = hg.create_dataset(
                    "amplitudes",
                    (n_det, n_amp),
                    dtype=np.float64,
                )
                htime = hg.create_dataset(
                    "reltime",
                    (n_samp,),
                    dtype=np.float32,
                )
                hang = hg.create_dataset(
                    "hwpangle",
                    (n_samp,),
                    dtype=np.float32,
                )

                # Write data
                samp_slice = (slice(0, n_samp, 1),)
                htime.write_direct(obs_reltime[obname], samp_slice, samp_slice)
                hang.write_direct(obs_hwpang[obname], samp_slice, samp_slice)
                for idet in range(n_det):
                    det = indx_to_det[idet]
                    hslice = (slice(idet, idet + 1, 1), slice(0, n_amp, 1))
                    dslice = (slice(0, n_amp, 1),)
                    hamps.write_direct(obamps[det], dslice, hslice)

toast.templates.GainTemplate

Bases: Template

This class aims at fitting and mitigating gain fluctuations in the data. The fluctuations are modeled as a linear combination of Legendre polynomials (up to a given order, commonly n<5 ) weighted by the so called gain amplitudes. The gain template is therefore obtained by estimating the polynomial amplitudes by assuming a signal estimate provided by a template map (encoding a coarse estimate of the underlying signal.)

Source code in toast/templates/gaintemplate.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@trait_docs
class GainTemplate(Template):
    """This class aims at fitting and mitigating gain fluctuations in the data.
    The fluctuations  are modeled as a linear combination of Legendre polynomials (up
    to a given order, commonly `n<5` ) weighted by the so called _gain amplitudes_.
    The gain template is therefore obtained by estimating the polynomial amplitudes by
    assuming a _signal estimate_ provided by a template map (encoding a coarse estimate
    of the underlying signal.)

    """

    # Notes:  The TraitConfig base class defines a "name" attribute.  The Template
    # class (derived from TraitConfig) defines the following traits already:
    #    data             : The Data instance we are working with
    #    view             : The timestream view we are using
    #    det_data         : The detector data key with the timestreams
    #    det_data_units   : The units of the detector data
    #    det_mask         : Bitmask for per-detector flagging
    #    det_flags        : Optional detector solver flags
    #    det_flag_mask    : Bit mask for detector solver flags
    #

    order = Int(1, help="The order of Legendre polynomials to fit the gain amplitudes ")

    template_name = Unicode(
        None,
        allow_none=True,
        help="detdata key encoding the signal estimate to fit the gain amplitudes",
    )

    noise_model = Unicode(
        None,
        allow_none=True,
        help="Observation key containing the   noise model ",
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _get_polynomials(self, N):
        norder = self.order + 1
        L = np.zeros((N, norder), dtype=np.float64)
        x = np.linspace(-1.0, 1.0, num=N, endpoint=True)
        for i in range(norder):
            L[:, i] = scipy.special.legendre(i)(x)
        return L

    def _initialize(self, new_data):
        self.norder = self.order + 1
        # Use this as an "Ordered Set".  We want the unique detectors on this process,
        # but sorted in order of occurrence.
        all_dets = OrderedDict()

        # Good detectors to use for each observation
        self._obs_dets = dict()

        # Build up detector list
        for iob, ob in enumerate(new_data.obs):
            self._obs_dets[iob] = set()
            for d in ob.select_local_detectors(flagmask=self.det_mask):
                if d not in ob.detdata[self.det_data].detectors:
                    continue
                self._obs_dets[iob].add(d)
                if d not in all_dets:
                    all_dets[d] = None

        self._all_dets = list(all_dets.keys())

        # The inverse variance units
        invvar_units = 1.0 / (self.det_data_units**2)

        # Go through the data one local detector at a time and compute the offsets into
        # the amplitudes.

        # The starting amplitude for each detector within the local amplitude data.
        self._det_start = dict()

        offset = 0
        for det in self._all_dets:
            self._det_start[det] = offset
            for iob, ob in enumerate(new_data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                # We have one set of amplitudes for each detector in each view
                offset += len(ob.view[self.view]) * (self.order + 1)

        # Now we know the total number of amplitudes.

        self._n_local = offset
        if new_data.comm.comm_world is None:
            self._n_global = self._n_local
        else:
            self._n_global = new_data.comm.comm_world.allreduce(
                self._n_local, op=MPI.SUM
            )

        # The preconditioner for each obs / view / detector
        self._precond = dict()
        self._templates = dict()

        # Build the preconditioner .
        # it is easier to just build these by looping in observation order rather than
        # detector order.

        for iob, ob in enumerate(new_data.obs):
            # Build the templates and preconditioners for every view.
            self._templates[iob] = list()
            self._precond[iob] = dict()
            norder = self.order + 1

            noise = None
            if self.noise_model in ob:
                noise = ob[self.noise_model]
            # import pdb; pdb.set_trace()
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                view_len = None
                if vw.start is None:
                    # This is a view of the whole obs
                    view_len = ob.n_local_samples
                else:
                    view_len = vw.stop - vw.start
                # get legendre polynomials
                L = self._get_polynomials(view_len)
                # store them in the template dictionary
                self._templates[iob].append(L)

                self._precond[iob][ivw] = dict()
                for det in ob.local_detectors:
                    if det not in self._obs_dets[iob]:
                        continue
                    detweight = 1.0
                    if noise is not None:
                        detweight = noise.detector_weight(det).to_value(invvar_units)

                    good = slice(0, view_len, 1)
                    if self.det_flags is not None:
                        flags = views.detdata[self.det_flags][ivw][det]
                        good = (flags & self.det_flag_mask) == 0

                    prec = np.zeros((norder, norder), dtype=np.float64)
                    T = ob.detdata[self.template_name][det]

                    LT = L.T.copy()
                    for row in LT:
                        row *= T * np.sqrt(detweight)
                    M = LT.dot(LT.T)
                    self._precond[iob][ivw][det] = np.linalg.inv(M)

    def _detectors(self):
        return self._all_dets

    def _zeros(self):
        z = Amplitudes(self.data.comm, self._n_global, self._n_local)
        # No explicit flagging of amplitudes in this template...
        # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
        return z

    def _add_to_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        norder = self.order + 1
        offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                legendre_poly = self._templates[iob][ivw]
                poly_amps = amplitudes.local[offset : offset + norder]
                delta_gain = legendre_poly.dot(poly_amps)
                signal_estimate = ob.detdata[self.template_name][detector]
                gain_fluctuation = signal_estimate * delta_gain
                vw[detector] += gain_fluctuation

    def _project_signal(self, detector, amplitudes, **kwargs):
        if detector not in self._all_dets:
            # This must have been cut by per-detector flags during initialization
            return
        norder = self.order + 1
        offset = self._det_start[detector]
        for iob, ob in enumerate(self.data.obs):
            if detector not in self._obs_dets[iob]:
                continue
            for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
                legendre_poly = self._templates[iob][ivw]
                signal_estimate = ob.detdata[self.template_name][detector]
                if self.det_flags is not None:
                    flagview = ob.view[self.view].detdata[self.det_flags][ivw]
                    mask = (flagview[detector] & self.det_flag_mask) == 0
                else:
                    mask = 1
                LT = legendre_poly.T.copy()
                for row in LT:
                    row *= signal_estimate
                poly_amps = amplitudes.local[offset : offset + norder]
                poly_amps += np.dot(LT, vw[detector] * mask)

    def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
        # No prior for this template, nothing to accumulate to output.
        return

    def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
        norder = self.order + 1
        for det in self._all_dets:
            offset = self._det_start[det]
            for iob, ob in enumerate(self.data.obs):
                if det not in self._obs_dets[iob]:
                    continue
                views = ob.view[self.view]
                for ivw, vw in enumerate(views):
                    amps_in = amplitudes_in.local[offset : offset + norder]
                    amps_out = amplitudes_out.local[offset : offset + norder]
                    amps_out[:] = np.dot(self._precond[iob][ivw][det], amps_in)
                    offset += norder

noise_model = Unicode(None, allow_none=True, help='Observation key containing the noise model ') class-attribute instance-attribute

order = Int(1, help='The order of Legendre polynomials to fit the gain amplitudes ') class-attribute instance-attribute

template_name = Unicode(None, allow_none=True, help='detdata key encoding the signal estimate to fit the gain amplitudes') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/templates/gaintemplate.py
52
53
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_add_prior(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/gaintemplate.py
214
215
216
def _add_prior(self, amplitudes_in, amplitudes_out, **kwargs):
    # No prior for this template, nothing to accumulate to output.
    return

_add_to_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/gaintemplate.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def _add_to_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    norder = self.order + 1
    offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
            legendre_poly = self._templates[iob][ivw]
            poly_amps = amplitudes.local[offset : offset + norder]
            delta_gain = legendre_poly.dot(poly_amps)
            signal_estimate = ob.detdata[self.template_name][detector]
            gain_fluctuation = signal_estimate * delta_gain
            vw[detector] += gain_fluctuation

_apply_precond(amplitudes_in, amplitudes_out, **kwargs)

Source code in toast/templates/gaintemplate.py
218
219
220
221
222
223
224
225
226
227
228
229
230
def _apply_precond(self, amplitudes_in, amplitudes_out, **kwargs):
    norder = self.order + 1
    for det in self._all_dets:
        offset = self._det_start[det]
        for iob, ob in enumerate(self.data.obs):
            if det not in self._obs_dets[iob]:
                continue
            views = ob.view[self.view]
            for ivw, vw in enumerate(views):
                amps_in = amplitudes_in.local[offset : offset + norder]
                amps_out = amplitudes_out.local[offset : offset + norder]
                amps_out[:] = np.dot(self._precond[iob][ivw][det], amps_in)
                offset += norder

_detectors()

Source code in toast/templates/gaintemplate.py
165
166
def _detectors(self):
    return self._all_dets

_get_polynomials(N)

Source code in toast/templates/gaintemplate.py
55
56
57
58
59
60
61
def _get_polynomials(self, N):
    norder = self.order + 1
    L = np.zeros((N, norder), dtype=np.float64)
    x = np.linspace(-1.0, 1.0, num=N, endpoint=True)
    for i in range(norder):
        L[:, i] = scipy.special.legendre(i)(x)
    return L

_initialize(new_data)

Source code in toast/templates/gaintemplate.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def _initialize(self, new_data):
    self.norder = self.order + 1
    # Use this as an "Ordered Set".  We want the unique detectors on this process,
    # but sorted in order of occurrence.
    all_dets = OrderedDict()

    # Good detectors to use for each observation
    self._obs_dets = dict()

    # Build up detector list
    for iob, ob in enumerate(new_data.obs):
        self._obs_dets[iob] = set()
        for d in ob.select_local_detectors(flagmask=self.det_mask):
            if d not in ob.detdata[self.det_data].detectors:
                continue
            self._obs_dets[iob].add(d)
            if d not in all_dets:
                all_dets[d] = None

    self._all_dets = list(all_dets.keys())

    # The inverse variance units
    invvar_units = 1.0 / (self.det_data_units**2)

    # Go through the data one local detector at a time and compute the offsets into
    # the amplitudes.

    # The starting amplitude for each detector within the local amplitude data.
    self._det_start = dict()

    offset = 0
    for det in self._all_dets:
        self._det_start[det] = offset
        for iob, ob in enumerate(new_data.obs):
            if det not in self._obs_dets[iob]:
                continue
            # We have one set of amplitudes for each detector in each view
            offset += len(ob.view[self.view]) * (self.order + 1)

    # Now we know the total number of amplitudes.

    self._n_local = offset
    if new_data.comm.comm_world is None:
        self._n_global = self._n_local
    else:
        self._n_global = new_data.comm.comm_world.allreduce(
            self._n_local, op=MPI.SUM
        )

    # The preconditioner for each obs / view / detector
    self._precond = dict()
    self._templates = dict()

    # Build the preconditioner .
    # it is easier to just build these by looping in observation order rather than
    # detector order.

    for iob, ob in enumerate(new_data.obs):
        # Build the templates and preconditioners for every view.
        self._templates[iob] = list()
        self._precond[iob] = dict()
        norder = self.order + 1

        noise = None
        if self.noise_model in ob:
            noise = ob[self.noise_model]
        # import pdb; pdb.set_trace()
        views = ob.view[self.view]
        for ivw, vw in enumerate(views):
            view_len = None
            if vw.start is None:
                # This is a view of the whole obs
                view_len = ob.n_local_samples
            else:
                view_len = vw.stop - vw.start
            # get legendre polynomials
            L = self._get_polynomials(view_len)
            # store them in the template dictionary
            self._templates[iob].append(L)

            self._precond[iob][ivw] = dict()
            for det in ob.local_detectors:
                if det not in self._obs_dets[iob]:
                    continue
                detweight = 1.0
                if noise is not None:
                    detweight = noise.detector_weight(det).to_value(invvar_units)

                good = slice(0, view_len, 1)
                if self.det_flags is not None:
                    flags = views.detdata[self.det_flags][ivw][det]
                    good = (flags & self.det_flag_mask) == 0

                prec = np.zeros((norder, norder), dtype=np.float64)
                T = ob.detdata[self.template_name][det]

                LT = L.T.copy()
                for row in LT:
                    row *= T * np.sqrt(detweight)
                M = LT.dot(LT.T)
                self._precond[iob][ivw][det] = np.linalg.inv(M)

_project_signal(detector, amplitudes, **kwargs)

Source code in toast/templates/gaintemplate.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def _project_signal(self, detector, amplitudes, **kwargs):
    if detector not in self._all_dets:
        # This must have been cut by per-detector flags during initialization
        return
    norder = self.order + 1
    offset = self._det_start[detector]
    for iob, ob in enumerate(self.data.obs):
        if detector not in self._obs_dets[iob]:
            continue
        for ivw, vw in enumerate(ob.view[self.view].detdata[self.det_data]):
            legendre_poly = self._templates[iob][ivw]
            signal_estimate = ob.detdata[self.template_name][detector]
            if self.det_flags is not None:
                flagview = ob.view[self.view].detdata[self.det_flags][ivw]
                mask = (flagview[detector] & self.det_flag_mask) == 0
            else:
                mask = 1
            LT = legendre_poly.T.copy()
            for row in LT:
                row *= signal_estimate
            poly_amps = amplitudes.local[offset : offset + norder]
            poly_amps += np.dot(LT, vw[detector] * mask)

_zeros()

Source code in toast/templates/gaintemplate.py
168
169
170
171
172
def _zeros(self):
    z = Amplitudes(self.data.comm, self._n_global, self._n_local)
    # No explicit flagging of amplitudes in this template...
    # z.local_flags[:] = np.where(self._amp_flags, 1, 0)
    return z

toast.ops.TemplateMatrix

Bases: Operator

Operator for projecting or accumulating template amplitudes.

Source code in toast/ops/mapmaker_templates.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
@trait_docs
class TemplateMatrix(Operator):
    """Operator for projecting or accumulating template amplitudes."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    templates = List([], help="This should be a list of Template instances")

    amplitudes = Unicode(None, allow_none=True, help="Data key for template amplitudes")

    transpose = Bool(False, help="If True, apply the transpose.")

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_data = Unicode(
        None, allow_none=True, help="Observation detdata key for the timestream data"
    )

    det_data_units = Unit(
        defaults.det_data_units, help="Output units if creating detector data"
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Flag mask should be a positive integer")
        return check

    @traitlets.validate("templates")
    def _check_templates(self, proposal):
        temps = proposal["value"]
        for tp in temps:
            if not isinstance(tp, Template):
                raise traitlets.TraitError(
                    "templates must be a list of Template instances or None"
                )
        return temps

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._initialized = False

    def reset(self):
        """Reset templates to allow re-initialization on a new Data object."""
        self._initialized = False

    def duplicate(self):
        """Make a shallow copy which contains the same list of templates.

        This is useful when we want to use both a template matrix and its transpose
        in the same pipeline.

        Returns:
            (TemplateMatrix):  A new instance with the same templates.

        """
        ret = TemplateMatrix(
            API=self.API,
            templates=self.templates,
            amplitudes=self.amplitudes,
            transpose=self.transpose,
            view=self.view,
            det_data=self.det_data,
            det_data_units=self.det_data_units,
            det_mask=self.det_mask,
            det_flags=self.det_flags,
            det_flag_mask=self.det_flag_mask,
        )
        ret._initialized = self._initialized
        return ret

    def apply_precond(self, amps_in, amps_out, use_accel=None, **kwargs):
        """Apply the preconditioner from all templates to the amplitudes.

        This can only be called after the operator has been used at least once so that
        the templates are initialized.

        Args:
            amps_in (AmplitudesMap):  The input amplitudes.
            amps_out (AmplitudesMap):  The output amplitudes, modified in place.

        Returns:
            None

        """
        if not self._initialized:
            raise RuntimeError(
                "You must call exec() once before applying preconditioners"
            )
        for tmpl in self.templates:
            if tmpl.enabled:
                tmpl.apply_precond(
                    amps_in[tmpl.name],
                    amps_out[tmpl.name],
                    use_accel=use_accel,
                    **kwargs,
                )

    def add_prior(self, amps_in, amps_out, use_accel=None, **kwargs):
        """Apply the noise prior from all templates to the amplitudes.

        This can only be called after the operator has been used at least once so that
        the templates are initialized.

        Args:
            amps_in (AmplitudesMap):  The input amplitudes.
            amps_out (AmplitudesMap):  The output amplitudes, modified in place.

        Returns:
            None

        """
        if not self._initialized:
            raise RuntimeError(
                "You must call exec() once before applying the noise prior"
            )
        for tmpl in self.templates:
            if tmpl.enabled:
                tmpl.add_prior(
                    amps_in[tmpl.name],
                    amps_out[tmpl.name],
                    use_accel=use_accel,
                    **kwargs,
                )

    @property
    def n_enabled_templates(self):
        n_enabled_templates = 0
        for template in self.templates:
            if template.enabled:
                n_enabled_templates += 1
        return n_enabled_templates

    def reset_templates(self):
        """Mark templates to be re-initialized on next call to exec()."""
        self._initialized = False

    @function_timer
    def initialize(self, data, use_accel=False):
        if not self._initialized:
            if use_accel:
                # fail when a user tries to run the initialization pipeline on GPU
                raise RuntimeError(
                    "You cannot currently initialize templates on device (please disable accel for this operator/pipeline)."
                )
            for tmpl in self.templates:
                if not tmpl.enabled:
                    continue
                if tmpl.view is None:
                    tmpl.view = self.view
                tmpl.det_data_units = self.det_data_units
                tmpl.det_mask = self.det_mask
                tmpl.det_flags = self.det_flags
                tmpl.det_flag_mask = self.det_flag_mask
                # This next line will trigger calculation of the number
                # of amplitudes within each template.
                tmpl.data = data
            self._initialized = True

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()

        # Check that the detector data is set
        if self.det_data is None:
            raise RuntimeError("You must set the det_data trait before calling exec()")

        # Check that amplitudes is set
        if self.amplitudes is None:
            raise RuntimeError(
                "You must set the amplitudes trait before calling exec()"
            )

        if len(self.templates) == 0:
            log.debug_rank(
                "No templates in TemplateMatrix, nothing to do",
                comm=data.comm.comm_world,
            )
            return

        # Check that accelerator switch makes sense for this operator
        if use_accel is None:
            use_accel = False
        if use_accel and not self.supports_accel():
            msg = "Template matrix called with use_accel=True, "
            msg += "but does not support accelerators"
            raise RuntimeError(msg)

        # Ensure we have initialized templates with the full set of detectors.
        if not self._initialized:
            raise RuntimeError("You must call initialize() before calling exec()")

        # Set the data we are using for this execution
        for tmpl in self.templates:
            if tmpl.enabled:
                tmpl.det_data = self.det_data

        # We loop over detectors.  Internally, each template loops over observations
        # and ignores observations where the detector does not exist.
        all_dets = data.all_local_detectors(selection=detectors, flagmask=self.det_mask)

        if self.transpose:
            # Check that the incoming detector data in all observations has the correct
            # units.
            input_units = 1.0 / self.det_data_units
            for ob in data.obs:
                if self.det_data not in ob.detdata:
                    continue
                if ob.detdata[self.det_data].units != input_units:
                    msg = f"obs {ob.name} detdata {self.det_data}"
                    msg += f" does not have units of {input_units}"
                    msg += f" before template matrix projection"
                    log.error(msg)
                    raise RuntimeError(msg)

            if self.amplitudes not in data:
                # The output template amplitudes do not yet exist.
                # Create these with all zero values.
                data[self.amplitudes] = AmplitudesMap()
                for tmpl in self.templates:
                    if tmpl.enabled:
                        data[self.amplitudes][tmpl.name] = tmpl.zeros()
                if use_accel:
                    # We are running on the accelerator, so our output data must exist
                    # on the device and will be used there.
                    data[self.amplitudes].accel_create(self.name, zero_out=True)
                    data[self.amplitudes].accel_used(True)
            elif use_accel and not data[self.amplitudes].accel_exists():
                # The output template amplitudes exist on host, but are not yet
                # staged to the accelerator.
                data[self.amplitudes].accel_create(self.name)
                data[self.amplitudes].accel_update_device()

            for d in all_dets:
                for tmpl in self.templates:
                    if tmpl.enabled:
                        log.verbose(
                            f"TemplateMatrix {d} project_signal {tmpl.name} (use_accel={use_accel})"
                        )
                        tmpl.project_signal(
                            d,
                            data[self.amplitudes][tmpl.name],
                            use_accel=use_accel,
                            **kwargs,
                        )
        else:
            if self.amplitudes not in data:
                msg = f"Template amplitudes '{self.amplitudes}' do not exist in data"
                log.error(msg)
                raise RuntimeError(msg)

            # Ensure that our output detector data exists in each observation
            for ob in data.obs:
                # Get the detectors we are using for this observation
                dets = ob.select_local_detectors(
                    selection=detectors,
                    flagmask=self.det_mask,
                )
                if len(dets) == 0:
                    # Nothing to do for this observation
                    continue
                exists = ob.detdata.ensure(
                    self.det_data,
                    detectors=dets,
                    accel=use_accel,
                    create_units=self.det_data_units,
                )
                if exists:
                    # We need to clear our detector TOD before projecting amplitudes
                    # into timestreams.  Note:  in the accelerator case, the reset call
                    # will clear all detectors, not just the current list.  This is
                    # wasteful if det_data has a very large buffer that has been
                    # restricted to be used for a smaller number of detectors.  We
                    # should deal with that corner case eventually.  If the data was
                    # created, then it was already zeroed out.
                    ob.detdata[self.det_data].reset(dets=dets)

                ob.detdata[self.det_data].update_units(self.det_data_units)

            for d in all_dets:
                for tmpl in self.templates:
                    if tmpl.enabled:
                        log.verbose(
                            f"TemplateMatrix {d} add to signal {tmpl.name} (use_accel={use_accel})"
                        )
                        tmpl.add_to_signal(
                            d,
                            data[self.amplitudes][tmpl.name],
                            use_accel=use_accel,
                            **kwargs,
                        )
        return

    def _finalize(self, data, use_accel=None, **kwargs):
        if self.transpose:
            # move amplitudes to host as sync is CPU only
            if use_accel:
                data[self.amplitudes].accel_update_host()
            # Synchronize the result
            for tmpl in self.templates:
                if tmpl.enabled:
                    data[self.amplitudes][tmpl.name].sync()
            # move amplitudes back to GPU as it is NOT finalize's job to move data to host
            if use_accel:
                data[self.amplitudes].accel_update_device()
        # Set the internal initialization to False, so that we are ready to process
        # completely new data sets.
        return

    def _requires(self):
        req = dict()
        req["detdata"] = [self.det_data]
        if self.view is not None:
            req["intervals"].append(self.view)
        if self.transpose:
            if self.det_flags is not None:
                req["detdata"].append(self.det_flags)
        else:
            req["global"] = [self.amplitudes]
        return req

    def _provides(self):
        prov = dict()
        if self.transpose:
            prov["global"] = [self.amplitudes]
        else:
            prov["detdata"] = [self.det_data]
        return prov

    def _implementations(self):
        """
        Find implementations supported by all the templates
        """
        implementations = {
            ImplementationType.DEFAULT,
            ImplementationType.COMPILED,
            ImplementationType.NUMPY,
            ImplementationType.JAX,
        }
        for tmpl in self.templates:
            implementations.intersection_update(tmpl.implementations())
        return list(implementations)

    def _supports_accel(self):
        """
        Returns True if all the templates are GPU compatible.
        """
        for tmpl in self.templates:
            if not tmpl.supports_accel():
                log = Logger.get()
                msg = f"{self} does not support accel because of '{tmpl.name}'"
                log.debug(msg)
                return False
        return True

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

_initialized = False instance-attribute

amplitudes = Unicode(None, allow_none=True, help='Data key for template amplitudes') class-attribute instance-attribute

det_data = Unicode(None, allow_none=True, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_data_units = Unit(defaults.det_data_units, help='Output units if creating detector data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

n_enabled_templates property

templates = List([], help='This should be a list of Template instances') class-attribute instance-attribute

transpose = Bool(False, help='If True, apply the transpose.') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_templates.py
97
98
99
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self._initialized = False

_check_det_mask(proposal)

Source code in toast/ops/mapmaker_templates.py
73
74
75
76
77
78
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_flag_mask(proposal)

Source code in toast/ops/mapmaker_templates.py
80
81
82
83
84
85
@traitlets.validate("det_flag_mask")
def _check_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Flag mask should be a positive integer")
    return check

_check_templates(proposal)

Source code in toast/ops/mapmaker_templates.py
87
88
89
90
91
92
93
94
95
@traitlets.validate("templates")
def _check_templates(self, proposal):
    temps = proposal["value"]
    for tp in temps:
        if not isinstance(tp, Template):
            raise traitlets.TraitError(
                "templates must be a list of Template instances or None"
            )
    return temps

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_templates.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()

    # Check that the detector data is set
    if self.det_data is None:
        raise RuntimeError("You must set the det_data trait before calling exec()")

    # Check that amplitudes is set
    if self.amplitudes is None:
        raise RuntimeError(
            "You must set the amplitudes trait before calling exec()"
        )

    if len(self.templates) == 0:
        log.debug_rank(
            "No templates in TemplateMatrix, nothing to do",
            comm=data.comm.comm_world,
        )
        return

    # Check that accelerator switch makes sense for this operator
    if use_accel is None:
        use_accel = False
    if use_accel and not self.supports_accel():
        msg = "Template matrix called with use_accel=True, "
        msg += "but does not support accelerators"
        raise RuntimeError(msg)

    # Ensure we have initialized templates with the full set of detectors.
    if not self._initialized:
        raise RuntimeError("You must call initialize() before calling exec()")

    # Set the data we are using for this execution
    for tmpl in self.templates:
        if tmpl.enabled:
            tmpl.det_data = self.det_data

    # We loop over detectors.  Internally, each template loops over observations
    # and ignores observations where the detector does not exist.
    all_dets = data.all_local_detectors(selection=detectors, flagmask=self.det_mask)

    if self.transpose:
        # Check that the incoming detector data in all observations has the correct
        # units.
        input_units = 1.0 / self.det_data_units
        for ob in data.obs:
            if self.det_data not in ob.detdata:
                continue
            if ob.detdata[self.det_data].units != input_units:
                msg = f"obs {ob.name} detdata {self.det_data}"
                msg += f" does not have units of {input_units}"
                msg += f" before template matrix projection"
                log.error(msg)
                raise RuntimeError(msg)

        if self.amplitudes not in data:
            # The output template amplitudes do not yet exist.
            # Create these with all zero values.
            data[self.amplitudes] = AmplitudesMap()
            for tmpl in self.templates:
                if tmpl.enabled:
                    data[self.amplitudes][tmpl.name] = tmpl.zeros()
            if use_accel:
                # We are running on the accelerator, so our output data must exist
                # on the device and will be used there.
                data[self.amplitudes].accel_create(self.name, zero_out=True)
                data[self.amplitudes].accel_used(True)
        elif use_accel and not data[self.amplitudes].accel_exists():
            # The output template amplitudes exist on host, but are not yet
            # staged to the accelerator.
            data[self.amplitudes].accel_create(self.name)
            data[self.amplitudes].accel_update_device()

        for d in all_dets:
            for tmpl in self.templates:
                if tmpl.enabled:
                    log.verbose(
                        f"TemplateMatrix {d} project_signal {tmpl.name} (use_accel={use_accel})"
                    )
                    tmpl.project_signal(
                        d,
                        data[self.amplitudes][tmpl.name],
                        use_accel=use_accel,
                        **kwargs,
                    )
    else:
        if self.amplitudes not in data:
            msg = f"Template amplitudes '{self.amplitudes}' do not exist in data"
            log.error(msg)
            raise RuntimeError(msg)

        # Ensure that our output detector data exists in each observation
        for ob in data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                selection=detectors,
                flagmask=self.det_mask,
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            exists = ob.detdata.ensure(
                self.det_data,
                detectors=dets,
                accel=use_accel,
                create_units=self.det_data_units,
            )
            if exists:
                # We need to clear our detector TOD before projecting amplitudes
                # into timestreams.  Note:  in the accelerator case, the reset call
                # will clear all detectors, not just the current list.  This is
                # wasteful if det_data has a very large buffer that has been
                # restricted to be used for a smaller number of detectors.  We
                # should deal with that corner case eventually.  If the data was
                # created, then it was already zeroed out.
                ob.detdata[self.det_data].reset(dets=dets)

            ob.detdata[self.det_data].update_units(self.det_data_units)

        for d in all_dets:
            for tmpl in self.templates:
                if tmpl.enabled:
                    log.verbose(
                        f"TemplateMatrix {d} add to signal {tmpl.name} (use_accel={use_accel})"
                    )
                    tmpl.add_to_signal(
                        d,
                        data[self.amplitudes][tmpl.name],
                        use_accel=use_accel,
                        **kwargs,
                    )
    return

_finalize(data, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_templates.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def _finalize(self, data, use_accel=None, **kwargs):
    if self.transpose:
        # move amplitudes to host as sync is CPU only
        if use_accel:
            data[self.amplitudes].accel_update_host()
        # Synchronize the result
        for tmpl in self.templates:
            if tmpl.enabled:
                data[self.amplitudes][tmpl.name].sync()
        # move amplitudes back to GPU as it is NOT finalize's job to move data to host
        if use_accel:
            data[self.amplitudes].accel_update_device()
    # Set the internal initialization to False, so that we are ready to process
    # completely new data sets.
    return

_implementations()

Find implementations supported by all the templates

Source code in toast/ops/mapmaker_templates.py
388
389
390
391
392
393
394
395
396
397
398
399
400
def _implementations(self):
    """
    Find implementations supported by all the templates
    """
    implementations = {
        ImplementationType.DEFAULT,
        ImplementationType.COMPILED,
        ImplementationType.NUMPY,
        ImplementationType.JAX,
    }
    for tmpl in self.templates:
        implementations.intersection_update(tmpl.implementations())
    return list(implementations)

_provides()

Source code in toast/ops/mapmaker_templates.py
380
381
382
383
384
385
386
def _provides(self):
    prov = dict()
    if self.transpose:
        prov["global"] = [self.amplitudes]
    else:
        prov["detdata"] = [self.det_data]
    return prov

_requires()

Source code in toast/ops/mapmaker_templates.py
368
369
370
371
372
373
374
375
376
377
378
def _requires(self):
    req = dict()
    req["detdata"] = [self.det_data]
    if self.view is not None:
        req["intervals"].append(self.view)
    if self.transpose:
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
    else:
        req["global"] = [self.amplitudes]
    return req

_supports_accel()

Returns True if all the templates are GPU compatible.

Source code in toast/ops/mapmaker_templates.py
402
403
404
405
406
407
408
409
410
411
412
def _supports_accel(self):
    """
    Returns True if all the templates are GPU compatible.
    """
    for tmpl in self.templates:
        if not tmpl.supports_accel():
            log = Logger.get()
            msg = f"{self} does not support accel because of '{tmpl.name}'"
            log.debug(msg)
            return False
    return True

add_prior(amps_in, amps_out, use_accel=None, **kwargs)

Apply the noise prior from all templates to the amplitudes.

This can only be called after the operator has been used at least once so that the templates are initialized.

Parameters:

Name Type Description Default
amps_in AmplitudesMap

The input amplitudes.

required
amps_out AmplitudesMap

The output amplitudes, modified in place.

required

Returns:

Type Description

None

Source code in toast/ops/mapmaker_templates.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def add_prior(self, amps_in, amps_out, use_accel=None, **kwargs):
    """Apply the noise prior from all templates to the amplitudes.

    This can only be called after the operator has been used at least once so that
    the templates are initialized.

    Args:
        amps_in (AmplitudesMap):  The input amplitudes.
        amps_out (AmplitudesMap):  The output amplitudes, modified in place.

    Returns:
        None

    """
    if not self._initialized:
        raise RuntimeError(
            "You must call exec() once before applying the noise prior"
        )
    for tmpl in self.templates:
        if tmpl.enabled:
            tmpl.add_prior(
                amps_in[tmpl.name],
                amps_out[tmpl.name],
                use_accel=use_accel,
                **kwargs,
            )

apply_precond(amps_in, amps_out, use_accel=None, **kwargs)

Apply the preconditioner from all templates to the amplitudes.

This can only be called after the operator has been used at least once so that the templates are initialized.

Parameters:

Name Type Description Default
amps_in AmplitudesMap

The input amplitudes.

required
amps_out AmplitudesMap

The output amplitudes, modified in place.

required

Returns:

Type Description

None

Source code in toast/ops/mapmaker_templates.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def apply_precond(self, amps_in, amps_out, use_accel=None, **kwargs):
    """Apply the preconditioner from all templates to the amplitudes.

    This can only be called after the operator has been used at least once so that
    the templates are initialized.

    Args:
        amps_in (AmplitudesMap):  The input amplitudes.
        amps_out (AmplitudesMap):  The output amplitudes, modified in place.

    Returns:
        None

    """
    if not self._initialized:
        raise RuntimeError(
            "You must call exec() once before applying preconditioners"
        )
    for tmpl in self.templates:
        if tmpl.enabled:
            tmpl.apply_precond(
                amps_in[tmpl.name],
                amps_out[tmpl.name],
                use_accel=use_accel,
                **kwargs,
            )

duplicate()

Make a shallow copy which contains the same list of templates.

This is useful when we want to use both a template matrix and its transpose in the same pipeline.

Returns:

Type Description
TemplateMatrix

A new instance with the same templates.

Source code in toast/ops/mapmaker_templates.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def duplicate(self):
    """Make a shallow copy which contains the same list of templates.

    This is useful when we want to use both a template matrix and its transpose
    in the same pipeline.

    Returns:
        (TemplateMatrix):  A new instance with the same templates.

    """
    ret = TemplateMatrix(
        API=self.API,
        templates=self.templates,
        amplitudes=self.amplitudes,
        transpose=self.transpose,
        view=self.view,
        det_data=self.det_data,
        det_data_units=self.det_data_units,
        det_mask=self.det_mask,
        det_flags=self.det_flags,
        det_flag_mask=self.det_flag_mask,
    )
    ret._initialized = self._initialized
    return ret

initialize(data, use_accel=False)

Source code in toast/ops/mapmaker_templates.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@function_timer
def initialize(self, data, use_accel=False):
    if not self._initialized:
        if use_accel:
            # fail when a user tries to run the initialization pipeline on GPU
            raise RuntimeError(
                "You cannot currently initialize templates on device (please disable accel for this operator/pipeline)."
            )
        for tmpl in self.templates:
            if not tmpl.enabled:
                continue
            if tmpl.view is None:
                tmpl.view = self.view
            tmpl.det_data_units = self.det_data_units
            tmpl.det_mask = self.det_mask
            tmpl.det_flags = self.det_flags
            tmpl.det_flag_mask = self.det_flag_mask
            # This next line will trigger calculation of the number
            # of amplitudes within each template.
            tmpl.data = data
        self._initialized = True

reset()

Reset templates to allow re-initialization on a new Data object.

Source code in toast/ops/mapmaker_templates.py
101
102
103
def reset(self):
    """Reset templates to allow re-initialization on a new Data object."""
    self._initialized = False

reset_templates()

Mark templates to be re-initialized on next call to exec().

Source code in toast/ops/mapmaker_templates.py
192
193
194
def reset_templates(self):
    """Mark templates to be re-initialized on next call to exec()."""
    self._initialized = False

toast.ops.SolveAmplitudes

Bases: Operator

Solve for template amplitudes.

This operator solves for a maximum likelihood set of template amplitudes that model the timestream contributions from noise, systematics, etc:

.. math:: \left[ M^T N^{-1} Z M + M_p ight] a = M^T N^{-1} Z d

Where a are the solved amplitudes and d is the input data. N is the diagonal time domain noise covariance. M is a matrix of templates that project from the amplitudes into the time domain, and the Z operator is given by:

.. math:: Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

or in terms of the binning operation:

.. math:: Z = I - P B

Where P is the pointing matrix. This operator takes one operator for the template matrix M and one operator for the binning, B. It then uses a conjugate gradient solver to solve for the amplitudes.

Source code in toast/ops/mapmaker_templates.py
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
@trait_docs
class SolveAmplitudes(Operator):
    """Solve for template amplitudes.

    This operator solves for a maximum likelihood set of template amplitudes
    that model the timestream contributions from noise, systematics, etc:

    .. math::
        \left[ M^T N^{-1} Z M + M_p \right] a = M^T N^{-1} Z d

    Where `a` are the solved amplitudes and `d` is the input data.  `N` is the
    diagonal time domain noise covariance.  `M` is a matrix of templates that
    project from the amplitudes into the time domain, and the `Z` operator is given
    by:

    .. math::
        Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

    or in terms of the binning operation:

    .. math::
        Z = I - P B

    Where `P` is the pointing matrix.  This operator takes one operator for the
    template matrix `M` and one operator for the binning, `B`.  It then
    uses a conjugate gradient solver to solve for the amplitudes.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    amplitudes = Unicode(None, allow_none=True, help="Data key for output amplitudes")

    convergence = Float(1.0e-12, help="Relative convergence limit")

    iter_min = Int(3, help="Minimum number of iterations")

    iter_max = Int(100, help="Maximum number of iterations")

    solve_rcond_threshold = Float(
        1.0e-8,
        help="When solving, minimum value for inverse pixel condition number cut.",
    )

    map_rcond_threshold = Float(
        1.0e-8,
        help="For final map, minimum value for inverse pixel condition number cut.",
    )

    mask = Unicode(
        None,
        allow_none=True,
        help="Data key for pixel mask to use in solving.  "
        "First bit of pixel values is tested",
    )

    binning = Instance(
        klass=Operator,
        allow_none=True,
        help="Binning operator used for solving template amplitudes",
    )

    template_matrix = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a template matrix operator",
    )

    keep_solver_products = Bool(
        False, help="If True, keep the map domain solver products in data"
    )

    write_solver_products = Bool(False, help="If True, write out solver products")

    write_hdf5 = Bool(
        False, help="If True, outputs are in HDF5 rather than FITS format."
    )

    write_hdf5_serial = Bool(
        False, help="If True, force serial HDF5 write of output maps."
    )

    output_dir = Unicode(
        ".",
        help="Write output data products to this directory",
    )

    mc_mode = Bool(False, help="If True, re-use solver flags, sparse covariances, etc")

    mc_index = Int(None, allow_none=True, help="The Monte-Carlo index")

    reset_pix_dist = Bool(
        False,
        help="Clear any existing pixel distribution.  Useful when applying "
        "repeatedly to different data objects.",
    )

    report_memory = Bool(False, help="Report memory throughout the execution")

    @traitlets.validate("binning")
    def _check_binning(self, proposal):
        bin = proposal["value"]
        if bin is not None:
            if not isinstance(bin, Operator):
                raise traitlets.TraitError("binning should be an Operator instance")
            # Check that this operator has the traits we require
            for trt in [
                "det_data",
                "pixel_dist",
                "pixel_pointing",
                "stokes_weights",
                "binned",
                "covariance",
                "det_flags",
                "det_mask",
                "det_flag_mask",
                "shared_flags",
                "shared_flag_mask",
                "noise_model",
                "full_pointing",
                "sync_type",
            ]:
                if not bin.has_trait(trt):
                    msg = "binning operator should have a '{}' trait".format(trt)
                    raise traitlets.TraitError(msg)
        return bin

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _write_del(self, prod_key):
        """Write and optionally delete map object"""

        # FIXME:  This I/O technique assumes "known" types of pixel representations.
        # Instead, we should associate read / write functions to a particular pixel
        # class.

        is_pix_wcs = hasattr(self.binning.pixel_pointing, "wcs")
        is_hpix_nest = None
        if not is_pix_wcs:
            is_hpix_nest = self.binning.pixel_pointing.nest

        if self.write_solver_products:
            if is_pix_wcs:
                fname = os.path.join(self.output_dir, f"{prod_key}.fits")
                write_wcs_fits(self._data[prod_key], fname)
            else:
                if self.write_hdf5:
                    # Non-standard HDF5 output
                    fname = os.path.join(self.output_dir, f"{prod_key}.h5")
                    write_healpix_hdf5(
                        self._data[prod_key],
                        fname,
                        nest=is_hpix_nest,
                        single_precision=True,
                        force_serial=self.write_hdf5_serial,
                    )
                else:
                    # Standard FITS output
                    fname = os.path.join(self.output_dir, f"{prod_key}.fits")
                    write_healpix_fits(
                        self._data[prod_key],
                        fname,
                        nest=is_hpix_nest,
                        report_memory=self.report_memory,
                    )

        if not self.mc_mode and not self.keep_solver_products:
            if prod_key in self._data:
                self._data[prod_key].clear()
                del self._data[prod_key]

                self._memreport.prefix = f"After writing/deleting {prod_key}"
                self._memreport.apply(self._data, use_accel=self._use_accel)

        return

    @function_timer
    def _setup(self, data, detectors, use_accel):
        """Set up convenience members used in the _exec() method"""

        self._log = Logger.get()
        self._timer = Timer()
        self._log_prefix = "SolveAmplitudes"

        self._data = data
        self._detectors = detectors
        self._use_accel = use_accel
        self._memreport = MemoryCounter()
        if not self.report_memory:
            self._memreport.enabled = False

        # The global communicator we are using (or None)
        self._comm = data.comm.comm_world
        self._rank = data.comm.world_rank

        # Get the units used across the distributed data for our desired
        # input detector data
        self._det_data_units = data.detector_units(self.det_data)

        # We use the input binning operator to define the flags that the user has
        # specified.  We will save the name / bit mask for these and restore them later.
        # Then we will use the binning operator with our solver flags.  These input
        # flags are combined to the first bit (== 1) of the solver flags.

        self._save_det_flags = self.binning.det_flags
        self._save_det_mask = self.binning.det_mask
        self._save_det_flag_mask = self.binning.det_flag_mask
        self._save_shared_flags = self.binning.shared_flags
        self._save_shared_flag_mask = self.binning.shared_flag_mask
        self._save_binned = self.binning.binned
        self._save_covariance = self.binning.covariance

        self._save_tmpl_flags = self.template_matrix.det_flags
        self._save_tmpl_mask = self.template_matrix.det_mask
        self._save_tmpl_det_mask = self.template_matrix.det_flag_mask

        # Use the same data view as the pointing operator in binning
        self._solve_view = self.binning.pixel_pointing.view

        # Output data products, prefixed with the name of the operator and optionally
        # the MC index.

        if self.mc_mode and self.mc_index is not None:
            self._mc_root = "{self.name}_{self.mc_index:05d}"
        else:
            self._mc_root = self.name

        self.solver_flags = f"{self.name}_solve_flags"
        self.solver_hits_name = f"{self.name}_solve_hits"
        self.solver_cov_name = f"{self.name}_solve_cov"
        self.solver_rcond_name = f"{self.name}_solve_rcond"
        self.solver_rcond_mask_name = f"{self.name}_solve_rcond_mask"
        self.solver_rhs = f"{self._mc_root}_solve_rhs"
        self.solver_bin = f"{self._mc_root}_solve_bin"

        if self.amplitudes is None:
            self.amplitudes = f"{self._mc_root}_solve_amplitudes"

        return

    @function_timer
    def _prepare_pixels(self):
        """Optionally destroy existing pixel distributions (useful if calling
        repeatedly with different data objects)
        """

        if self.reset_pix_dist:
            if self.binning.pixel_dist in self._data:
                del self._data[self.binning.pixel_dist]

        self._memreport.prefix = "After resetting pixel distribution"
        self._memreport.apply(self._data)

        # The pointing matrix used for the solve.  The per-detector flags
        # are normally reset when the binner is run, but here we set them
        # explicitly since we will use these pointing matrix operators for
        # setting up the solver flags below.
        solve_pixels = self.binning.pixel_pointing
        solve_weights = self.binning.stokes_weights
        solve_pixels.detector_pointing.det_mask = self._save_det_mask
        solve_pixels.detector_pointing.det_flag_mask = self._save_det_flag_mask
        if hasattr(solve_weights, "detector_pointing"):
            solve_weights.detector_pointing.det_mask = self._save_det_mask
            solve_weights.detector_pointing.det_flag_mask = self._save_det_flag_mask

        # Set up a pipeline to scan processing and condition number masks
        self._scanner = ScanMask(
            det_flags=self.solver_flags,
            det_mask=self._save_det_mask,
            det_flag_mask=self._save_det_flag_mask,
            pixels=solve_pixels.pixels,
            view=self._solve_view,
        )
        if self.binning.full_pointing:
            # We are caching the pointing anyway- run with all detectors
            scan_pipe = Pipeline(
                detector_sets=["ALL"], operators=[solve_pixels, self._scanner]
            )
        else:
            # Pipeline over detectors
            scan_pipe = Pipeline(
                detector_sets=["SINGLE"], operators=[solve_pixels, self._scanner]
            )

        return solve_pixels, solve_weights, scan_pipe

    @function_timer
    def _prepare_flagging_ob(self, ob):
        """Process a single observation, used by _prepare_flagging

        Copies and masks existing flags
        """

        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(self._detectors, flagmask=self._save_det_mask)
        if len(dets) == 0:
            # Nothing to do for this observation
            return

        if self.mc_mode:
            # Shortcut, just verify that our flags exist
            if self.solver_flags not in ob.detdata:
                msg = f"In MC mode, solver flags missing for observation {ob.name}"
                self._log.error(msg)
                raise RuntimeError(msg)
            det_check = set(ob.detdata[self.solver_flags].detectors)
            for d in dets:
                if d not in det_check:
                    msg = "In MC mode, solver flags missing for "
                    msg + f"observation {ob.name}, det {d}"
                    self._log.error(msg)
                    raise RuntimeError(msg)
            return

        # Create the new solver flags
        exists = ob.detdata.ensure(self.solver_flags, dtype=np.uint8, detectors=dets)

        # The data views
        views = ob.view[self._solve_view]
        # For each view...
        for vw in range(len(views)):
            view_samples = None
            if views[vw].start is None:
                # There is one view of the whole obs
                view_samples = ob.n_local_samples
            else:
                view_samples = views[vw].stop - views[vw].start
            starting_flags = np.zeros(view_samples, dtype=np.uint8)
            if self._save_shared_flags is not None:
                starting_flags[:] = np.where(
                    (
                        views.shared[self._save_shared_flags][vw]
                        & self._save_shared_flag_mask
                    )
                    > 0,
                    1,
                    0,
                )
            for d in dets:
                views.detdata[self.solver_flags][vw][d, :] = starting_flags
                if self._save_det_flags is not None:
                    views.detdata[self.solver_flags][vw][d, :] |= np.where(
                        (
                            views.detdata[self._save_det_flags][vw][d]
                            & self._save_det_flag_mask
                        )
                        > 0,
                        1,
                        0,
                    ).astype(views.detdata[self.solver_flags][vw].dtype)

        return

    @function_timer
    def _prepare_flagging(self, scan_pipe):
        """Flagging.  We create a new set of data flags for the solver that includes:
        - one bit for a bitwise OR of all detector / shared flags
        - one bit for any pixel mask, projected to TOD
        - one bit for any poorly conditioned pixels, projected to TOD
        """

        if self.mc_mode:
            msg = f"{self._log_prefix} begin verifying flags for solver"
        else:
            msg = f"{self._log_prefix} begin building flags for solver"
        self._log.info_rank(msg, comm=self._comm)

        for ob in self._data.obs:
            self._prepare_flagging_ob(ob)

        if self.mc_mode:
            # Shortcut, just verified that our flags exist
            self._log.info_rank(
                f"{self._log_prefix} MC mode, reusing flags for solver", comm=comm
            )
            return

        # Now scan any input mask to this same flag field.  We use the second
        # bit (== 2) for these mask flags.  For the input mask bit we check the
        # first bit of the pixel values.  This is noted in the help string for
        # the mask trait.  Note that we explicitly expand the pointing once
        # here and do not save it.  Even if we are eventually saving the
        # pointing, we want to do that later when building the covariance and
        # the pixel distribution.

        if self.mask is not None:
            # We have a mask.  Scan it.
            self._scanner.det_flags_value = 2
            self._scanner.mask_key = self.mask
            scan_pipe.apply(self._data, detectors=self._detectors)

        self._log.info_rank(
            f"{self._log_prefix}  finished flag building in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After building flags"
        self._memreport.apply(self._data)

        return

    def _count_cut_data(self):
        """Collect and report statistics about cut data"""
        local_total = 0
        local_cut = 0
        for ob in self._data.obs:
            # Get the detectors we are using for this observation
            dets = ob.select_local_detectors(
                self._detectors, flagmask=self._save_det_mask
            )
            if len(dets) == 0:
                # Nothing to do for this observation
                continue
            for vw in ob.view[self._solve_view].detdata[self.solver_flags]:
                for d in dets:
                    local_total += len(vw[d])
                    local_cut += np.count_nonzero(vw[d])

        if self._comm is None:
            total = local_total
            cut = local_cut
        else:
            total = self._comm.allreduce(local_total, op=MPI.SUM)
            cut = self._comm.allreduce(local_cut, op=MPI.SUM)

        frac = 100.0 * (cut / total)
        msg = f"Solver flags cut {cut } / {total} = {frac:0.2f}% of samples"
        self._log.info_rank(f"{self._log_prefix} {msg}", comm=self._comm)

        return

    @function_timer
    def _get_pixel_covariance(self, solve_pixels, solve_weights):
        """Construct the noise covariance, hits, and condition number map for
        the solver.
        """

        if self.mc_mode:
            # Shortcut, verify that our covariance and other products exist.
            if self.binning.pixel_dist not in self._data:
                msg = f"MC mode, pixel distribution "
                msg += f"'{self.binning.pixel_dist}' does not exist"
                self._log.error(msg)
                raise RuntimeError(msg)
            if self.solver_cov_name not in self._data:
                msg = f"MC mode, covariance '{self.solver_cov_name}' does not exist"
                self._log.error(msg)
                raise RuntimeError(msg)
            self._log.info_rank(
                f"{self._log_prefix} MC mode, reusing covariance for solver",
                comm=self._comm,
            )
            return

        self._log.info_rank(
            f"{self._log_prefix} begin build of solver covariance",
            comm=self._comm,
        )

        solver_cov = CovarianceAndHits(
            pixel_dist=self.binning.pixel_dist,
            covariance=self.solver_cov_name,
            hits=self.solver_hits_name,
            rcond=self.solver_rcond_name,
            det_data_units=self._det_data_units,
            det_mask=self._save_det_mask,
            det_flags=self.solver_flags,
            det_flag_mask=255,
            pixel_pointing=solve_pixels,
            stokes_weights=solve_weights,
            noise_model=self.binning.noise_model,
            rcond_threshold=self.solve_rcond_threshold,
            sync_type=self.binning.sync_type,
            save_pointing=self.binning.full_pointing,
        )

        solver_cov.apply(self._data, detectors=self._detectors)

        self._memreport.prefix = "After constructing covariance and hits"
        self._memreport.apply(self._data)

        return

    @function_timer
    def _get_rcond_mask(self, scan_pipe):
        """Construct the noise covariance, hits, and condition number mask for
        the solver.
        """

        if self.mc_mode:
            # The flags are already cached
            return

        self._log.info_rank(
            f"{self._log_prefix} begin build of rcond flags",
            comm=self._comm,
        )

        # Translate the rcond map into a mask
        self._data[self.solver_rcond_mask_name] = PixelData(
            self._data[self.binning.pixel_dist], dtype=np.uint8, n_value=1
        )
        rcond = self._data[self.solver_rcond_name].data
        rcond_mask = self._data[self.solver_rcond_mask_name].data
        bad = rcond < self.solve_rcond_threshold
        n_bad = np.count_nonzero(bad)
        n_good = rcond.size - n_bad
        rcond_mask[bad] = 1

        # No more need for the rcond map
        self._write_del(self.solver_rcond_name)

        self._memreport.prefix = "After constructing rcond mask"
        self._memreport.apply(self._data)

        # Re-use our mask scanning pipeline, setting third bit (== 4)
        self._scanner.det_flags_value = 4
        self._scanner.mask_key = self.solver_rcond_mask_name
        scan_pipe.apply(self._data, detectors=self._detectors)

        self._log.info_rank(
            f"{self._log_prefix}  finished build of solver covariance in",
            comm=self._comm,
            timer=self._timer,
        )

        self._count_cut_data()  # Report statistics

        return

    @function_timer
    def _get_rhs(self):
        """Compute the RHS.  Overwrite inputs, either the original or the copy"""

        self._log.info_rank(
            f"{self._log_prefix} begin RHS calculation", comm=self._comm
        )

        # Initialize the template matrix
        self.template_matrix.det_data = self.det_data
        self.template_matrix.det_data_units = self._det_data_units
        self.template_matrix.det_flags = self.solver_flags
        self.template_matrix.det_mask = self._save_det_mask
        self.template_matrix.det_flag_mask = 255
        self.template_matrix.view = self.binning.pixel_pointing.view
        self.template_matrix.initialize(self._data)

        # Set our binning operator to use only our new solver flags
        self.binning.shared_flag_mask = 0
        self.binning.det_flags = self.solver_flags
        self.binning.det_flag_mask = 255
        self.binning.det_data_units = self._det_data_units

        # Set the binning operator to output to temporary map.  This will be
        # overwritten on each iteration of the solver.
        self.binning.binned = self.solver_bin
        self.binning.covariance = self.solver_cov_name

        self.template_matrix.amplitudes = self.solver_rhs

        rhs_calc = SolverRHS(
            name=f"{self.name}_rhs",
            det_data=self.det_data,
            det_data_units=self._det_data_units,
            binning=self.binning,
            template_matrix=self.template_matrix,
        )
        rhs_calc.apply(self._data, detectors=self._detectors)

        self._log.info_rank(
            f"{self._log_prefix}  finished RHS calculation in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After constructing RHS"
        self._memreport.apply(self._data)

        return

    @function_timer
    def _solve_amplitudes(self):
        """Solve the destriping equation"""

        # Set up the LHS operator.

        self._log.info_rank(
            f"{self._log_prefix} begin PCG solver",
            comm=self._comm,
        )

        lhs_calc = SolverLHS(
            name=f"{self.name}_lhs",
            det_data_units=self._det_data_units,
            binning=self.binning,
            template_matrix=self.template_matrix,
        )

        # If we eventually want to support an input starting guess of the
        # amplitudes, we would need to ensure that data[amplitude_key] is set
        # at this point...

        # Solve for amplitudes.
        solve(
            self._data,
            self._detectors,
            lhs_calc,
            self.solver_rhs,
            self.amplitudes,
            convergence=self.convergence,
            n_iter_min=self.iter_min,
            n_iter_max=self.iter_max,
        )

        self._log.info_rank(
            f"{self._log_prefix}  finished solver in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After solving for amplitudes"
        self._memreport.apply(self._data)

        return

    @function_timer
    def _cleanup(self):
        """Clean up convenience members for _exec()"""

        # Restore flag names and masks to binning operator, in case it is being used
        # for the final map making or for other external operations.

        self.binning.det_flags = self._save_det_flags
        self.binning.det_mask = self._save_det_mask
        self.binning.det_flag_mask = self._save_det_flag_mask
        self.binning.shared_flags = self._save_shared_flags
        self.binning.shared_flag_mask = self._save_shared_flag_mask
        self.binning.binned = self._save_binned
        self.binning.covariance = self._save_covariance

        self.template_matrix.det_flags = self._save_tmpl_flags
        self.template_matrix.det_flag_mask = self._save_tmpl_det_mask
        self.template_matrix.det_mask = self._save_tmpl_mask
        # FIXME: this reset does not seem needed
        # if not self.mc_mode:
        #    self.template_matrix.reset_templates()

        del self._solve_view

        # Delete members used by the _exec() method
        del self._log
        del self._timer
        del self._log_prefix

        del self._data
        del self._detectors
        del self._use_accel
        del self._memreport

        del self._comm
        del self._rank

        del self._det_data_units

        del self._mc_root

        del self._scanner

        return

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        # Check if we have any templates
        if (
            self.template_matrix is None
            or self.template_matrix.n_enabled_templates == 0
        ):
            return

        self._setup(data, detectors, use_accel)

        self._memreport.prefix = "Start of amplitude solve"
        self._memreport.apply(self._data)

        solve_pixels, solve_weights, scan_pipe = self._prepare_pixels()

        self._timer.start()

        self._prepare_flagging(scan_pipe)

        self._get_pixel_covariance(solve_pixels, solve_weights)
        self._write_del(self.solver_hits_name)

        self._get_rcond_mask(scan_pipe)
        self._write_del(self.solver_rcond_mask_name)

        self._get_rhs()
        self._solve_amplitudes()

        self._write_del(self.solver_cov_name)
        self._write_del(self.solver_bin)

        if not self.mc_mode and not self.keep_solver_products:
            if self.solver_rhs in self._data:
                self._data[self.solver_rhs].clear()
                del self._data[self.solver_rhs]
            for ob in self._data.obs:
                del ob.detdata[self.solver_flags]

        self._memreport.prefix = "End of amplitude solve"
        self._memreport.apply(self._data)

        self._cleanup()

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # This operator requires everything that its sub-operators need.
        req = self.binning.requires()
        if self.template_matrix is not None:
            req.update(self.template_matrix.requires())
        req["detdata"].append(self.det_data)
        return req

    def _provides(self):
        prov = dict()
        prov["global"] = [self.amplitudes]
        if self.keep_solver_products:
            prov["global"].extend(
                [
                    self.solver_hits_name,
                    self.solver_cov_name,
                    self.solver_rcond_name,
                    self.solver_rcond_mask_name,
                    self.solver_rhs,
                    self.solver_bin,
                ]
            )
            prov["detdata"] = [self.solver_flags]
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

amplitudes = Unicode(None, allow_none=True, help='Data key for output amplitudes') class-attribute instance-attribute

binning = Instance(klass=Operator, allow_none=True, help='Binning operator used for solving template amplitudes') class-attribute instance-attribute

convergence = Float(1e-12, help='Relative convergence limit') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

iter_max = Int(100, help='Maximum number of iterations') class-attribute instance-attribute

iter_min = Int(3, help='Minimum number of iterations') class-attribute instance-attribute

keep_solver_products = Bool(False, help='If True, keep the map domain solver products in data') class-attribute instance-attribute

map_rcond_threshold = Float(1e-08, help='For final map, minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

mask = Unicode(None, allow_none=True, help='Data key for pixel mask to use in solving. First bit of pixel values is tested') class-attribute instance-attribute

mc_index = Int(None, allow_none=True, help='The Monte-Carlo index') class-attribute instance-attribute

mc_mode = Bool(False, help='If True, re-use solver flags, sparse covariances, etc') class-attribute instance-attribute

output_dir = Unicode('.', help='Write output data products to this directory') class-attribute instance-attribute

report_memory = Bool(False, help='Report memory throughout the execution') class-attribute instance-attribute

reset_pix_dist = Bool(False, help='Clear any existing pixel distribution. Useful when applying repeatedly to different data objects.') class-attribute instance-attribute

solve_rcond_threshold = Float(1e-08, help='When solving, minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

template_matrix = Instance(klass=Operator, allow_none=True, help='This must be an instance of a template matrix operator') class-attribute instance-attribute

write_hdf5 = Bool(False, help='If True, outputs are in HDF5 rather than FITS format.') class-attribute instance-attribute

write_hdf5_serial = Bool(False, help='If True, force serial HDF5 write of output maps.') class-attribute instance-attribute

write_solver_products = Bool(False, help='If True, write out solver products') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker_templates.py
548
549
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_check_binning(proposal)

Source code in toast/ops/mapmaker_templates.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
@traitlets.validate("binning")
def _check_binning(self, proposal):
    bin = proposal["value"]
    if bin is not None:
        if not isinstance(bin, Operator):
            raise traitlets.TraitError("binning should be an Operator instance")
        # Check that this operator has the traits we require
        for trt in [
            "det_data",
            "pixel_dist",
            "pixel_pointing",
            "stokes_weights",
            "binned",
            "covariance",
            "det_flags",
            "det_mask",
            "det_flag_mask",
            "shared_flags",
            "shared_flag_mask",
            "noise_model",
            "full_pointing",
            "sync_type",
        ]:
            if not bin.has_trait(trt):
                msg = "binning operator should have a '{}' trait".format(trt)
                raise traitlets.TraitError(msg)
    return bin

_cleanup()

Clean up convenience members for _exec()

Source code in toast/ops/mapmaker_templates.py
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
@function_timer
def _cleanup(self):
    """Clean up convenience members for _exec()"""

    # Restore flag names and masks to binning operator, in case it is being used
    # for the final map making or for other external operations.

    self.binning.det_flags = self._save_det_flags
    self.binning.det_mask = self._save_det_mask
    self.binning.det_flag_mask = self._save_det_flag_mask
    self.binning.shared_flags = self._save_shared_flags
    self.binning.shared_flag_mask = self._save_shared_flag_mask
    self.binning.binned = self._save_binned
    self.binning.covariance = self._save_covariance

    self.template_matrix.det_flags = self._save_tmpl_flags
    self.template_matrix.det_flag_mask = self._save_tmpl_det_mask
    self.template_matrix.det_mask = self._save_tmpl_mask
    # FIXME: this reset does not seem needed
    # if not self.mc_mode:
    #    self.template_matrix.reset_templates()

    del self._solve_view

    # Delete members used by the _exec() method
    del self._log
    del self._timer
    del self._log_prefix

    del self._data
    del self._detectors
    del self._use_accel
    del self._memreport

    del self._comm
    del self._rank

    del self._det_data_units

    del self._mc_root

    del self._scanner

    return

_count_cut_data()

Collect and report statistics about cut data

Source code in toast/ops/mapmaker_templates.py
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def _count_cut_data(self):
    """Collect and report statistics about cut data"""
    local_total = 0
    local_cut = 0
    for ob in self._data.obs:
        # Get the detectors we are using for this observation
        dets = ob.select_local_detectors(
            self._detectors, flagmask=self._save_det_mask
        )
        if len(dets) == 0:
            # Nothing to do for this observation
            continue
        for vw in ob.view[self._solve_view].detdata[self.solver_flags]:
            for d in dets:
                local_total += len(vw[d])
                local_cut += np.count_nonzero(vw[d])

    if self._comm is None:
        total = local_total
        cut = local_cut
    else:
        total = self._comm.allreduce(local_total, op=MPI.SUM)
        cut = self._comm.allreduce(local_cut, op=MPI.SUM)

    frac = 100.0 * (cut / total)
    msg = f"Solver flags cut {cut } / {total} = {frac:0.2f}% of samples"
    self._log.info_rank(f"{self._log_prefix} {msg}", comm=self._comm)

    return

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker_templates.py
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    # Check if we have any templates
    if (
        self.template_matrix is None
        or self.template_matrix.n_enabled_templates == 0
    ):
        return

    self._setup(data, detectors, use_accel)

    self._memreport.prefix = "Start of amplitude solve"
    self._memreport.apply(self._data)

    solve_pixels, solve_weights, scan_pipe = self._prepare_pixels()

    self._timer.start()

    self._prepare_flagging(scan_pipe)

    self._get_pixel_covariance(solve_pixels, solve_weights)
    self._write_del(self.solver_hits_name)

    self._get_rcond_mask(scan_pipe)
    self._write_del(self.solver_rcond_mask_name)

    self._get_rhs()
    self._solve_amplitudes()

    self._write_del(self.solver_cov_name)
    self._write_del(self.solver_bin)

    if not self.mc_mode and not self.keep_solver_products:
        if self.solver_rhs in self._data:
            self._data[self.solver_rhs].clear()
            del self._data[self.solver_rhs]
        for ob in self._data.obs:
            del ob.detdata[self.solver_flags]

    self._memreport.prefix = "End of amplitude solve"
    self._memreport.apply(self._data)

    self._cleanup()

    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker_templates.py
1140
1141
def _finalize(self, data, **kwargs):
    return

_get_pixel_covariance(solve_pixels, solve_weights)

Construct the noise covariance, hits, and condition number map for the solver.

Source code in toast/ops/mapmaker_templates.py
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
@function_timer
def _get_pixel_covariance(self, solve_pixels, solve_weights):
    """Construct the noise covariance, hits, and condition number map for
    the solver.
    """

    if self.mc_mode:
        # Shortcut, verify that our covariance and other products exist.
        if self.binning.pixel_dist not in self._data:
            msg = f"MC mode, pixel distribution "
            msg += f"'{self.binning.pixel_dist}' does not exist"
            self._log.error(msg)
            raise RuntimeError(msg)
        if self.solver_cov_name not in self._data:
            msg = f"MC mode, covariance '{self.solver_cov_name}' does not exist"
            self._log.error(msg)
            raise RuntimeError(msg)
        self._log.info_rank(
            f"{self._log_prefix} MC mode, reusing covariance for solver",
            comm=self._comm,
        )
        return

    self._log.info_rank(
        f"{self._log_prefix} begin build of solver covariance",
        comm=self._comm,
    )

    solver_cov = CovarianceAndHits(
        pixel_dist=self.binning.pixel_dist,
        covariance=self.solver_cov_name,
        hits=self.solver_hits_name,
        rcond=self.solver_rcond_name,
        det_data_units=self._det_data_units,
        det_mask=self._save_det_mask,
        det_flags=self.solver_flags,
        det_flag_mask=255,
        pixel_pointing=solve_pixels,
        stokes_weights=solve_weights,
        noise_model=self.binning.noise_model,
        rcond_threshold=self.solve_rcond_threshold,
        sync_type=self.binning.sync_type,
        save_pointing=self.binning.full_pointing,
    )

    solver_cov.apply(self._data, detectors=self._detectors)

    self._memreport.prefix = "After constructing covariance and hits"
    self._memreport.apply(self._data)

    return

_get_rcond_mask(scan_pipe)

Construct the noise covariance, hits, and condition number mask for the solver.

Source code in toast/ops/mapmaker_templates.py
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
@function_timer
def _get_rcond_mask(self, scan_pipe):
    """Construct the noise covariance, hits, and condition number mask for
    the solver.
    """

    if self.mc_mode:
        # The flags are already cached
        return

    self._log.info_rank(
        f"{self._log_prefix} begin build of rcond flags",
        comm=self._comm,
    )

    # Translate the rcond map into a mask
    self._data[self.solver_rcond_mask_name] = PixelData(
        self._data[self.binning.pixel_dist], dtype=np.uint8, n_value=1
    )
    rcond = self._data[self.solver_rcond_name].data
    rcond_mask = self._data[self.solver_rcond_mask_name].data
    bad = rcond < self.solve_rcond_threshold
    n_bad = np.count_nonzero(bad)
    n_good = rcond.size - n_bad
    rcond_mask[bad] = 1

    # No more need for the rcond map
    self._write_del(self.solver_rcond_name)

    self._memreport.prefix = "After constructing rcond mask"
    self._memreport.apply(self._data)

    # Re-use our mask scanning pipeline, setting third bit (== 4)
    self._scanner.det_flags_value = 4
    self._scanner.mask_key = self.solver_rcond_mask_name
    scan_pipe.apply(self._data, detectors=self._detectors)

    self._log.info_rank(
        f"{self._log_prefix}  finished build of solver covariance in",
        comm=self._comm,
        timer=self._timer,
    )

    self._count_cut_data()  # Report statistics

    return

_get_rhs()

Compute the RHS. Overwrite inputs, either the original or the copy

Source code in toast/ops/mapmaker_templates.py
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
@function_timer
def _get_rhs(self):
    """Compute the RHS.  Overwrite inputs, either the original or the copy"""

    self._log.info_rank(
        f"{self._log_prefix} begin RHS calculation", comm=self._comm
    )

    # Initialize the template matrix
    self.template_matrix.det_data = self.det_data
    self.template_matrix.det_data_units = self._det_data_units
    self.template_matrix.det_flags = self.solver_flags
    self.template_matrix.det_mask = self._save_det_mask
    self.template_matrix.det_flag_mask = 255
    self.template_matrix.view = self.binning.pixel_pointing.view
    self.template_matrix.initialize(self._data)

    # Set our binning operator to use only our new solver flags
    self.binning.shared_flag_mask = 0
    self.binning.det_flags = self.solver_flags
    self.binning.det_flag_mask = 255
    self.binning.det_data_units = self._det_data_units

    # Set the binning operator to output to temporary map.  This will be
    # overwritten on each iteration of the solver.
    self.binning.binned = self.solver_bin
    self.binning.covariance = self.solver_cov_name

    self.template_matrix.amplitudes = self.solver_rhs

    rhs_calc = SolverRHS(
        name=f"{self.name}_rhs",
        det_data=self.det_data,
        det_data_units=self._det_data_units,
        binning=self.binning,
        template_matrix=self.template_matrix,
    )
    rhs_calc.apply(self._data, detectors=self._detectors)

    self._log.info_rank(
        f"{self._log_prefix}  finished RHS calculation in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After constructing RHS"
    self._memreport.apply(self._data)

    return

_prepare_flagging(scan_pipe)

Flagging. We create a new set of data flags for the solver that includes: - one bit for a bitwise OR of all detector / shared flags - one bit for any pixel mask, projected to TOD - one bit for any poorly conditioned pixels, projected to TOD

Source code in toast/ops/mapmaker_templates.py
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
@function_timer
def _prepare_flagging(self, scan_pipe):
    """Flagging.  We create a new set of data flags for the solver that includes:
    - one bit for a bitwise OR of all detector / shared flags
    - one bit for any pixel mask, projected to TOD
    - one bit for any poorly conditioned pixels, projected to TOD
    """

    if self.mc_mode:
        msg = f"{self._log_prefix} begin verifying flags for solver"
    else:
        msg = f"{self._log_prefix} begin building flags for solver"
    self._log.info_rank(msg, comm=self._comm)

    for ob in self._data.obs:
        self._prepare_flagging_ob(ob)

    if self.mc_mode:
        # Shortcut, just verified that our flags exist
        self._log.info_rank(
            f"{self._log_prefix} MC mode, reusing flags for solver", comm=comm
        )
        return

    # Now scan any input mask to this same flag field.  We use the second
    # bit (== 2) for these mask flags.  For the input mask bit we check the
    # first bit of the pixel values.  This is noted in the help string for
    # the mask trait.  Note that we explicitly expand the pointing once
    # here and do not save it.  Even if we are eventually saving the
    # pointing, we want to do that later when building the covariance and
    # the pixel distribution.

    if self.mask is not None:
        # We have a mask.  Scan it.
        self._scanner.det_flags_value = 2
        self._scanner.mask_key = self.mask
        scan_pipe.apply(self._data, detectors=self._detectors)

    self._log.info_rank(
        f"{self._log_prefix}  finished flag building in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After building flags"
    self._memreport.apply(self._data)

    return

_prepare_flagging_ob(ob)

Process a single observation, used by _prepare_flagging

Copies and masks existing flags

Source code in toast/ops/mapmaker_templates.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
@function_timer
def _prepare_flagging_ob(self, ob):
    """Process a single observation, used by _prepare_flagging

    Copies and masks existing flags
    """

    # Get the detectors we are using for this observation
    dets = ob.select_local_detectors(self._detectors, flagmask=self._save_det_mask)
    if len(dets) == 0:
        # Nothing to do for this observation
        return

    if self.mc_mode:
        # Shortcut, just verify that our flags exist
        if self.solver_flags not in ob.detdata:
            msg = f"In MC mode, solver flags missing for observation {ob.name}"
            self._log.error(msg)
            raise RuntimeError(msg)
        det_check = set(ob.detdata[self.solver_flags].detectors)
        for d in dets:
            if d not in det_check:
                msg = "In MC mode, solver flags missing for "
                msg + f"observation {ob.name}, det {d}"
                self._log.error(msg)
                raise RuntimeError(msg)
        return

    # Create the new solver flags
    exists = ob.detdata.ensure(self.solver_flags, dtype=np.uint8, detectors=dets)

    # The data views
    views = ob.view[self._solve_view]
    # For each view...
    for vw in range(len(views)):
        view_samples = None
        if views[vw].start is None:
            # There is one view of the whole obs
            view_samples = ob.n_local_samples
        else:
            view_samples = views[vw].stop - views[vw].start
        starting_flags = np.zeros(view_samples, dtype=np.uint8)
        if self._save_shared_flags is not None:
            starting_flags[:] = np.where(
                (
                    views.shared[self._save_shared_flags][vw]
                    & self._save_shared_flag_mask
                )
                > 0,
                1,
                0,
            )
        for d in dets:
            views.detdata[self.solver_flags][vw][d, :] = starting_flags
            if self._save_det_flags is not None:
                views.detdata[self.solver_flags][vw][d, :] |= np.where(
                    (
                        views.detdata[self._save_det_flags][vw][d]
                        & self._save_det_flag_mask
                    )
                    > 0,
                    1,
                    0,
                ).astype(views.detdata[self.solver_flags][vw].dtype)

    return

_prepare_pixels()

Optionally destroy existing pixel distributions (useful if calling repeatedly with different data objects)

Source code in toast/ops/mapmaker_templates.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@function_timer
def _prepare_pixels(self):
    """Optionally destroy existing pixel distributions (useful if calling
    repeatedly with different data objects)
    """

    if self.reset_pix_dist:
        if self.binning.pixel_dist in self._data:
            del self._data[self.binning.pixel_dist]

    self._memreport.prefix = "After resetting pixel distribution"
    self._memreport.apply(self._data)

    # The pointing matrix used for the solve.  The per-detector flags
    # are normally reset when the binner is run, but here we set them
    # explicitly since we will use these pointing matrix operators for
    # setting up the solver flags below.
    solve_pixels = self.binning.pixel_pointing
    solve_weights = self.binning.stokes_weights
    solve_pixels.detector_pointing.det_mask = self._save_det_mask
    solve_pixels.detector_pointing.det_flag_mask = self._save_det_flag_mask
    if hasattr(solve_weights, "detector_pointing"):
        solve_weights.detector_pointing.det_mask = self._save_det_mask
        solve_weights.detector_pointing.det_flag_mask = self._save_det_flag_mask

    # Set up a pipeline to scan processing and condition number masks
    self._scanner = ScanMask(
        det_flags=self.solver_flags,
        det_mask=self._save_det_mask,
        det_flag_mask=self._save_det_flag_mask,
        pixels=solve_pixels.pixels,
        view=self._solve_view,
    )
    if self.binning.full_pointing:
        # We are caching the pointing anyway- run with all detectors
        scan_pipe = Pipeline(
            detector_sets=["ALL"], operators=[solve_pixels, self._scanner]
        )
    else:
        # Pipeline over detectors
        scan_pipe = Pipeline(
            detector_sets=["SINGLE"], operators=[solve_pixels, self._scanner]
        )

    return solve_pixels, solve_weights, scan_pipe

_provides()

Source code in toast/ops/mapmaker_templates.py
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
def _provides(self):
    prov = dict()
    prov["global"] = [self.amplitudes]
    if self.keep_solver_products:
        prov["global"].extend(
            [
                self.solver_hits_name,
                self.solver_cov_name,
                self.solver_rcond_name,
                self.solver_rcond_mask_name,
                self.solver_rhs,
                self.solver_bin,
            ]
        )
        prov["detdata"] = [self.solver_flags]
    return prov

_requires()

Source code in toast/ops/mapmaker_templates.py
1143
1144
1145
1146
1147
1148
1149
def _requires(self):
    # This operator requires everything that its sub-operators need.
    req = self.binning.requires()
    if self.template_matrix is not None:
        req.update(self.template_matrix.requires())
    req["detdata"].append(self.det_data)
    return req

_setup(data, detectors, use_accel)

Set up convenience members used in the _exec() method

Source code in toast/ops/mapmaker_templates.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
@function_timer
def _setup(self, data, detectors, use_accel):
    """Set up convenience members used in the _exec() method"""

    self._log = Logger.get()
    self._timer = Timer()
    self._log_prefix = "SolveAmplitudes"

    self._data = data
    self._detectors = detectors
    self._use_accel = use_accel
    self._memreport = MemoryCounter()
    if not self.report_memory:
        self._memreport.enabled = False

    # The global communicator we are using (or None)
    self._comm = data.comm.comm_world
    self._rank = data.comm.world_rank

    # Get the units used across the distributed data for our desired
    # input detector data
    self._det_data_units = data.detector_units(self.det_data)

    # We use the input binning operator to define the flags that the user has
    # specified.  We will save the name / bit mask for these and restore them later.
    # Then we will use the binning operator with our solver flags.  These input
    # flags are combined to the first bit (== 1) of the solver flags.

    self._save_det_flags = self.binning.det_flags
    self._save_det_mask = self.binning.det_mask
    self._save_det_flag_mask = self.binning.det_flag_mask
    self._save_shared_flags = self.binning.shared_flags
    self._save_shared_flag_mask = self.binning.shared_flag_mask
    self._save_binned = self.binning.binned
    self._save_covariance = self.binning.covariance

    self._save_tmpl_flags = self.template_matrix.det_flags
    self._save_tmpl_mask = self.template_matrix.det_mask
    self._save_tmpl_det_mask = self.template_matrix.det_flag_mask

    # Use the same data view as the pointing operator in binning
    self._solve_view = self.binning.pixel_pointing.view

    # Output data products, prefixed with the name of the operator and optionally
    # the MC index.

    if self.mc_mode and self.mc_index is not None:
        self._mc_root = "{self.name}_{self.mc_index:05d}"
    else:
        self._mc_root = self.name

    self.solver_flags = f"{self.name}_solve_flags"
    self.solver_hits_name = f"{self.name}_solve_hits"
    self.solver_cov_name = f"{self.name}_solve_cov"
    self.solver_rcond_name = f"{self.name}_solve_rcond"
    self.solver_rcond_mask_name = f"{self.name}_solve_rcond_mask"
    self.solver_rhs = f"{self._mc_root}_solve_rhs"
    self.solver_bin = f"{self._mc_root}_solve_bin"

    if self.amplitudes is None:
        self.amplitudes = f"{self._mc_root}_solve_amplitudes"

    return

_solve_amplitudes()

Solve the destriping equation

Source code in toast/ops/mapmaker_templates.py
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
@function_timer
def _solve_amplitudes(self):
    """Solve the destriping equation"""

    # Set up the LHS operator.

    self._log.info_rank(
        f"{self._log_prefix} begin PCG solver",
        comm=self._comm,
    )

    lhs_calc = SolverLHS(
        name=f"{self.name}_lhs",
        det_data_units=self._det_data_units,
        binning=self.binning,
        template_matrix=self.template_matrix,
    )

    # If we eventually want to support an input starting guess of the
    # amplitudes, we would need to ensure that data[amplitude_key] is set
    # at this point...

    # Solve for amplitudes.
    solve(
        self._data,
        self._detectors,
        lhs_calc,
        self.solver_rhs,
        self.amplitudes,
        convergence=self.convergence,
        n_iter_min=self.iter_min,
        n_iter_max=self.iter_max,
    )

    self._log.info_rank(
        f"{self._log_prefix}  finished solver in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After solving for amplitudes"
    self._memreport.apply(self._data)

    return

_write_del(prod_key)

Write and optionally delete map object

Source code in toast/ops/mapmaker_templates.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
@function_timer
def _write_del(self, prod_key):
    """Write and optionally delete map object"""

    # FIXME:  This I/O technique assumes "known" types of pixel representations.
    # Instead, we should associate read / write functions to a particular pixel
    # class.

    is_pix_wcs = hasattr(self.binning.pixel_pointing, "wcs")
    is_hpix_nest = None
    if not is_pix_wcs:
        is_hpix_nest = self.binning.pixel_pointing.nest

    if self.write_solver_products:
        if is_pix_wcs:
            fname = os.path.join(self.output_dir, f"{prod_key}.fits")
            write_wcs_fits(self._data[prod_key], fname)
        else:
            if self.write_hdf5:
                # Non-standard HDF5 output
                fname = os.path.join(self.output_dir, f"{prod_key}.h5")
                write_healpix_hdf5(
                    self._data[prod_key],
                    fname,
                    nest=is_hpix_nest,
                    single_precision=True,
                    force_serial=self.write_hdf5_serial,
                )
            else:
                # Standard FITS output
                fname = os.path.join(self.output_dir, f"{prod_key}.fits")
                write_healpix_fits(
                    self._data[prod_key],
                    fname,
                    nest=is_hpix_nest,
                    report_memory=self.report_memory,
                )

    if not self.mc_mode and not self.keep_solver_products:
        if prod_key in self._data:
            self._data[prod_key].clear()
            del self._data[prod_key]

            self._memreport.prefix = f"After writing/deleting {prod_key}"
            self._memreport.apply(self._data, use_accel=self._use_accel)

    return

High Level Tools

toast.ops.Calibrate

Bases: Operator

Operator for calibrating timestreams using solved templates.

This operator first solves for a maximum likelihood set of template amplitudes that model the timestream contributions from noise, systematics, etc:

.. math:: \left[ M^T N^{-1} Z M + M_p ight] a = M^T N^{-1} Z d

Where a are the solved amplitudes and d is the input data. N is the diagonal time domain noise covariance. M is a matrix of templates that project from the amplitudes into the time domain, and the Z operator is given by:

.. math:: Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

or in terms of the binning operation:

.. math:: Z = I - P B

Where P is the pointing matrix. This operator takes one operator for the template matrix M and one operator for the binning, B. It then uses a conjugate gradient solver to solve for the amplitudes.

After solving for the template amplitudes, they are projected into the time domain and the input data is element-wise divided by this.

If the result trait is not set, then the input is overwritten.

Source code in toast/ops/mapmaker.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
@trait_docs
class Calibrate(Operator):
    """Operator for calibrating timestreams using solved templates.

    This operator first solves for a maximum likelihood set of template amplitudes
    that model the timestream contributions from noise, systematics, etc:

    .. math::
        \left[ M^T N^{-1} Z M + M_p \right] a = M^T N^{-1} Z d

    Where `a` are the solved amplitudes and `d` is the input data.  `N` is the
    diagonal time domain noise covariance.  `M` is a matrix of templates that
    project from the amplitudes into the time domain, and the `Z` operator is given
    by:

    .. math::
        Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

    or in terms of the binning operation:

    .. math::
        Z = I - P B

    Where `P` is the pointing matrix.  This operator takes one operator for the
    template matrix `M` and one operator for the binning, `B`.  It then
    uses a conjugate gradient solver to solve for the amplitudes.

    After solving for the template amplitudes, they are projected into the time
    domain and the input data is element-wise divided by this.

    If the result trait is not set, then the input is overwritten.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    result = Unicode(
        None, allow_none=True, help="Observation detdata key for the output"
    )

    convergence = Float(1.0e-12, help="Relative convergence limit")

    iter_min = Int(3, help="Minimum number of iterations")

    iter_max = Int(100, help="Maximum number of iterations")

    solve_rcond_threshold = Float(
        1.0e-8,
        help="When solving, minimum value for inverse pixel condition number cut.",
    )

    mask = Unicode(
        None,
        allow_none=True,
        help="Data key for pixel mask to use in solving.  First bit of pixel values is tested",
    )

    binning = Instance(
        klass=Operator,
        allow_none=True,
        help="Binning operator used for solving template amplitudes",
    )

    template_matrix = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a template matrix operator",
    )

    keep_solver_products = Bool(
        False, help="If True, keep the map domain solver products in data"
    )

    mc_mode = Bool(False, help="If True, re-use solver flags, sparse covariances, etc")

    mc_index = Int(None, allow_none=True, help="The Monte-Carlo index")

    mc_root = Unicode(None, allow_node=True, help="Root name for Monte Carlo products")

    reset_pix_dist = Bool(
        False,
        help="Clear any existing pixel distribution.  Useful when applying "
        "repeatedly to different data objects.",
    )

    report_memory = Bool(False, help="Report memory throughout the execution")

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):
        log = Logger.get()
        timer = Timer()
        log_prefix = "Calibrate"

        memreport = MemoryCounter()
        if not self.report_memory:
            memreport.enabled = False

        memreport.prefix = "Start of calibration"
        memreport.apply(data, use_accel=use_accel)

        # The global communicator we are using (or None)
        comm = data.comm.comm_world
        rank = data.comm.world_rank

        timer.start()

        # Solve for template amplitudes
        amplitudes_solve = SolveAmplitudes(
            name=self.name,
            det_data=self.det_data,
            convergence=self.convergence,
            iter_min=self.iter_min,
            iter_max=self.iter_max,
            solve_rcond_threshold=self.solve_rcond_threshold,
            mask=self.mask,
            binning=self.binning,
            template_matrix=self.template_matrix,
            keep_solver_products=self.keep_solver_products,
            mc_mode=self.mc_mode,
            mc_index=self.mc_index,
            reset_pix_dist=self.reset_pix_dist,
            report_memory=self.report_memory,
        )
        amplitudes_solve.apply(data, detectors=detectors, use_accel=use_accel)

        log.info_rank(
            f"{log_prefix}  finished template amplitude solve in",
            comm=comm,
            timer=timer,
        )

        # Apply (divide) solved amplitudes.

        log.info_rank(
            f"{log_prefix} begin apply template amplitudes",
            comm=comm,
        )

        out_calib = self.det_data
        if self.result is not None:
            # We are writing out calibrated timestreams to a new set of detector
            # data rather than overwriting the inputs.  Here we create these output
            # timestreams if they do not exist.  We do this by copying the inputs,
            # since the application of the amplitudes below will zero these
            out_calib = self.result
            Copy(detdata=[(self.det_data, self.result)]).apply(
                data, use_accel=use_accel
            )

        amplitudes_apply = ApplyAmplitudes(
            op="divide",
            det_data=self.det_data,
            amplitudes=amplitudes_solve.amplitudes,
            template_matrix=self.template_matrix,
            output=out_calib,
        )
        amplitudes_apply.apply(data, detectors=detectors, use_accel=use_accel)

        log.info_rank(
            f"{log_prefix}  finished apply template amplitudes in",
            comm=comm,
            timer=timer,
        )

        memreport.prefix = "After calibration"
        memreport.apply(data, use_accel=use_accel)

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # This operator requires everything that its sub-operators needs.
        req = self.binning.requires()
        if self.template_matrix is not None:
            req.update(self.template_matrix.requires())
        req["detdata"].append(self.det_data)
        return req

    def _provides(self):
        prov = dict()
        prov["global"] = [self.binning.binned]
        if self.result is not None:
            prov["detdata"] = [self.result]
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

binning = Instance(klass=Operator, allow_none=True, help='Binning operator used for solving template amplitudes') class-attribute instance-attribute

convergence = Float(1e-12, help='Relative convergence limit') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

iter_max = Int(100, help='Maximum number of iterations') class-attribute instance-attribute

iter_min = Int(3, help='Minimum number of iterations') class-attribute instance-attribute

keep_solver_products = Bool(False, help='If True, keep the map domain solver products in data') class-attribute instance-attribute

mask = Unicode(None, allow_none=True, help='Data key for pixel mask to use in solving. First bit of pixel values is tested') class-attribute instance-attribute

mc_index = Int(None, allow_none=True, help='The Monte-Carlo index') class-attribute instance-attribute

mc_mode = Bool(False, help='If True, re-use solver flags, sparse covariances, etc') class-attribute instance-attribute

mc_root = Unicode(None, allow_node=True, help='Root name for Monte Carlo products') class-attribute instance-attribute

report_memory = Bool(False, help='Report memory throughout the execution') class-attribute instance-attribute

reset_pix_dist = Bool(False, help='Clear any existing pixel distribution. Useful when applying repeatedly to different data objects.') class-attribute instance-attribute

result = Unicode(None, allow_none=True, help='Observation detdata key for the output') class-attribute instance-attribute

solve_rcond_threshold = Float(1e-08, help='When solving, minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

template_matrix = Instance(klass=Operator, allow_none=True, help='This must be an instance of a template matrix operator') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker.py
810
811
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker.py
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):
    log = Logger.get()
    timer = Timer()
    log_prefix = "Calibrate"

    memreport = MemoryCounter()
    if not self.report_memory:
        memreport.enabled = False

    memreport.prefix = "Start of calibration"
    memreport.apply(data, use_accel=use_accel)

    # The global communicator we are using (or None)
    comm = data.comm.comm_world
    rank = data.comm.world_rank

    timer.start()

    # Solve for template amplitudes
    amplitudes_solve = SolveAmplitudes(
        name=self.name,
        det_data=self.det_data,
        convergence=self.convergence,
        iter_min=self.iter_min,
        iter_max=self.iter_max,
        solve_rcond_threshold=self.solve_rcond_threshold,
        mask=self.mask,
        binning=self.binning,
        template_matrix=self.template_matrix,
        keep_solver_products=self.keep_solver_products,
        mc_mode=self.mc_mode,
        mc_index=self.mc_index,
        reset_pix_dist=self.reset_pix_dist,
        report_memory=self.report_memory,
    )
    amplitudes_solve.apply(data, detectors=detectors, use_accel=use_accel)

    log.info_rank(
        f"{log_prefix}  finished template amplitude solve in",
        comm=comm,
        timer=timer,
    )

    # Apply (divide) solved amplitudes.

    log.info_rank(
        f"{log_prefix} begin apply template amplitudes",
        comm=comm,
    )

    out_calib = self.det_data
    if self.result is not None:
        # We are writing out calibrated timestreams to a new set of detector
        # data rather than overwriting the inputs.  Here we create these output
        # timestreams if they do not exist.  We do this by copying the inputs,
        # since the application of the amplitudes below will zero these
        out_calib = self.result
        Copy(detdata=[(self.det_data, self.result)]).apply(
            data, use_accel=use_accel
        )

    amplitudes_apply = ApplyAmplitudes(
        op="divide",
        det_data=self.det_data,
        amplitudes=amplitudes_solve.amplitudes,
        template_matrix=self.template_matrix,
        output=out_calib,
    )
    amplitudes_apply.apply(data, detectors=detectors, use_accel=use_accel)

    log.info_rank(
        f"{log_prefix}  finished apply template amplitudes in",
        comm=comm,
        timer=timer,
    )

    memreport.prefix = "After calibration"
    memreport.apply(data, use_accel=use_accel)

    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker.py
895
896
def _finalize(self, data, **kwargs):
    return

_provides()

Source code in toast/ops/mapmaker.py
906
907
908
909
910
911
def _provides(self):
    prov = dict()
    prov["global"] = [self.binning.binned]
    if self.result is not None:
        prov["detdata"] = [self.result]
    return prov

_requires()

Source code in toast/ops/mapmaker.py
898
899
900
901
902
903
904
def _requires(self):
    # This operator requires everything that its sub-operators needs.
    req = self.binning.requires()
    if self.template_matrix is not None:
        req.update(self.template_matrix.requires())
    req["detdata"].append(self.det_data)
    return req

toast.ops.MapMaker

Bases: Operator

Operator for making maps.

This operator first solves for a maximum likelihood set of template amplitudes that model the timestream contributions from noise, systematics, etc:

.. math:: \left[ M^T N^{-1} Z M + M_p ight] a = M^T N^{-1} Z d

Where a are the solved amplitudes and d is the input data. N is the diagonal time domain noise covariance. M is a matrix of templates that project from the amplitudes into the time domain, and the Z operator is given by:

.. math:: Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

or in terms of the binning operation:

.. math:: Z = I - P B

Where P is the pointing matrix. This operator takes one operator for the template matrix M and one operator for the binning, B. It then uses a conjugate gradient solver to solve for the amplitudes.

After solving for the template amplitudes, a final map of the signal estimate is computed using a simple binning:

.. math:: MAP = ({P'}^T N^{-1} P')^{-1} {P'}^T N^{-1} (y - M a)

Where the "prime" indicates that this final map might be computed using a different pointing matrix than the one used to solve for the template amplitudes.

The template-subtracted detector timestreams are saved either in the input det_data key of each observation, or (if overwrite == False) in an obs.detdata key based on the name of this class instance.

Source code in toast/ops/mapmaker.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
@trait_docs
class MapMaker(Operator):
    """Operator for making maps.

    This operator first solves for a maximum likelihood set of template amplitudes
    that model the timestream contributions from noise, systematics, etc:

    .. math::
        \left[ M^T N^{-1} Z M + M_p \right] a = M^T N^{-1} Z d

    Where `a` are the solved amplitudes and `d` is the input data.  `N` is the
    diagonal time domain noise covariance.  `M` is a matrix of templates that
    project from the amplitudes into the time domain, and the `Z` operator is given
    by:

    .. math::
        Z = I - P (P^T N^{-1} P)^{-1} P^T N^{-1}

    or in terms of the binning operation:

    .. math::
        Z = I - P B

    Where `P` is the pointing matrix.  This operator takes one operator for the
    template matrix `M` and one operator for the binning, `B`.  It then
    uses a conjugate gradient solver to solve for the amplitudes.

    After solving for the template amplitudes, a final map of the signal estimate is
    computed using a simple binning:

    .. math::
        MAP = ({P'}^T N^{-1} P')^{-1} {P'}^T N^{-1} (y - M a)

    Where the "prime" indicates that this final map might be computed using a different
    pointing matrix than the one used to solve for the template amplitudes.

    The template-subtracted detector timestreams are saved either in the input
    `det_data` key of each observation, or (if overwrite == False) in an obs.detdata
    key based on the name of this class instance.

    """

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    convergence = Float(1.0e-12, help="Relative convergence limit")

    iter_min = Int(3, help="Minimum number of iterations")

    iter_max = Int(100, help="Maximum number of iterations")

    solve_rcond_threshold = Float(
        1.0e-8,
        help="When solving, minimum value for inverse pixel condition number cut.",
    )

    map_rcond_threshold = Float(
        1.0e-8,
        help="For final map, minimum value for inverse pixel condition number cut.",
    )

    mask = Unicode(
        None,
        allow_none=True,
        help="Data key for pixel mask to use in solving.  First bit of pixel values is tested",
    )

    binning = Instance(
        klass=Operator,
        allow_none=True,
        help="Binning operator used for solving template amplitudes",
    )

    template_matrix = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a template matrix operator",
    )

    map_binning = Instance(
        klass=Operator,
        allow_none=True,
        help="Binning operator for final map making.  Default is same as solver",
    )

    write_binmap = Bool(
        True, help="If True, write the projected map *before* template subtraction"
    )

    write_map = Bool(True, help="If True, write the projected map")

    write_hdf5 = Bool(
        False, help="If True, outputs are in HDF5 rather than FITS format."
    )

    write_hdf5_serial = Bool(
        False, help="If True, force serial HDF5 write of output maps."
    )

    write_noiseweighted_map = Bool(
        False,
        help="If True, write the noise-weighted map",
    )

    write_hits = Bool(True, help="If True, write the hits map")

    write_cov = Bool(True, help="If True, write the white noise covariance matrices.")

    write_invcov = Bool(
        False,
        help="If True, write the inverse white noise covariance matrices.",
    )

    write_rcond = Bool(True, help="If True, write the reciprocal condition numbers.")

    write_solver_products = Bool(
        False, help="If True, write out equivalent solver products."
    )

    keep_solver_products = Bool(
        False, help="If True, keep the map domain solver products in data"
    )

    keep_final_products = Bool(
        False, help="If True, keep the map domain products in data after write"
    )

    mc_mode = Bool(False, help="If True, re-use solver flags, sparse covariances, etc")

    mc_index = Int(None, allow_none=True, help="The Monte-Carlo index")

    save_cleaned = Bool(
        False, help="If True, save the template-subtracted detector timestreams"
    )

    overwrite_cleaned = Bool(
        False, help="If True and save_cleaned is True, overwrite the input data"
    )

    reset_pix_dist = Bool(
        False,
        help="Clear any existing pixel distribution.  Useful when applying "
        "repeatedly to different data objects.",
    )

    output_dir = Unicode(
        ".",
        help="Write output data products to this directory",
    )

    report_memory = Bool(False, help="Report memory throughout the execution")

    @traitlets.validate("map_binning")
    def _check_map_binning(self, proposal):
        bin = proposal["value"]
        if bin is not None:
            if not isinstance(bin, Operator):
                raise traitlets.TraitError("map_binning should be an Operator instance")
            # Check that this operator has the traits we expect
            for trt in [
                "det_data",
                "pixel_dist",
                "pixel_pointing",
                "stokes_weights",
                "binned",
                "covariance",
                "det_mask",
                "det_flags",
                "det_flag_mask",
                "shared_flags",
                "shared_flag_mask",
                "noise_model",
                "full_pointing",
                "sync_type",
            ]:
                if not bin.has_trait(trt):
                    msg = "map_binning operator should have a '{}' trait".format(trt)
                    raise traitlets.TraitError(msg)
        return bin

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @function_timer
    def _write_del(self, prod_key, prod_write, force, rootname):
        """Write data object to file and delete it from cache"""
        log = Logger.get()

        # FIXME:  This I/O technique assumes "known" types of pixel representations.
        # Instead, we should associate read / write functions to a particular pixel
        # class.

        if self.map_binning is not None and self.map_binning.enabled:
            map_binning = self.map_binning
        else:
            map_binning = self.binning

        if hasattr(map_binning.pixel_pointing, "wcs"):
            is_pix_wcs = True
        else:
            is_pix_wcs = False
            is_hpix_nest = map_binning.pixel_pointing.nest

        wtimer = Timer()
        wtimer.start()
        product = prod_key.replace(f"{self.name}_", "")
        if prod_write:
            if is_pix_wcs:
                fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits")
                if self.mc_mode and not force and os.path.isfile(fname):
                    log.info_rank(f"Skipping existing file: {fname}", comm=self._comm)
                else:
                    write_wcs_fits(self._data[prod_key], fname)
            else:
                if self.write_hdf5:
                    # Non-standard HDF5 output
                    fname = os.path.join(self.output_dir, f"{rootname}_{product}.h5")
                    if self.mc_mode and not force and os.path.isfile(fname):
                        log.info_rank(
                            f"Skipping existing file: {fname}", comm=self._comm
                        )
                    else:
                        write_healpix_hdf5(
                            self._data[prod_key],
                            fname,
                            nest=is_hpix_nest,
                            single_precision=True,
                            force_serial=self.write_hdf5_serial,
                        )
                else:
                    # Standard FITS output
                    fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits")
                    if self.mc_mode and not force and os.path.isfile(fname):
                        log.info_rank(
                            f"Skipping existing file: {fname}", comm=self._comm
                        )
                    else:
                        write_healpix_fits(
                            self._data[prod_key],
                            fname,
                            nest=is_hpix_nest,
                            report_memory=self.report_memory,
                        )
            log.info_rank(f"Wrote {fname} in", comm=self._comm, timer=wtimer)

        if not self.keep_final_products and not self.mc_mode:
            if prod_key in self._data:
                self._data[prod_key].clear()
                del self._data[prod_key]

        self._memreport.prefix = f"After writing/deleting {prod_key}"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        return

    @function_timer
    def _setup(self, data, detectors, use_accel):
        """Set up convenience members used in the _exec() method"""

        self._log = Logger.get()
        self._timer = Timer()
        self._log_prefix = "MapMaker"

        self._mc_root = self.name
        if self.mc_mode:
            if self.mc_root is not None:
                self._mc_root += f"_{self.mc_root}"
            if self.mc_index is not None:
                self._mc_root += f"_{self.mc_index:05d}"

        self._data = data
        self._detectors = detectors
        self._use_accel = use_accel
        self._memreport = MemoryCounter()
        if not self.report_memory:
            self._memreport.enabled = False

        # The global communicator we are using (or None)

        self._comm = data.comm.comm_world
        self._rank = data.comm.world_rank

        # Data names of outputs

        self.hits_name = f"{self.name}_hits"
        self.cov_name = f"{self.name}_cov"
        self.invcov_name = f"{self.name}_invcov"
        self.rcond_name = f"{self.name}_rcond"
        self.det_flag_name = f"{self.name}_flags"

        self.clean_name = f"{self.name}_cleaned"
        self.binmap_name = f"{self.name}_binmap"
        self.map_name = f"{self.name}_map"
        self.noiseweighted_map_name = f"{self.name}_noiseweighted_map"

        self._timer.start()

        return

    @function_timer
    def _fit_templates(self):
        """Solve for template amplitudes"""

        amplitudes_solve = SolveAmplitudes(
            name=self.name,
            det_data=self.det_data,
            convergence=self.convergence,
            iter_min=self.iter_min,
            iter_max=self.iter_max,
            solve_rcond_threshold=self.solve_rcond_threshold,
            mask=self.mask,
            binning=self.binning,
            template_matrix=self.template_matrix,
            keep_solver_products=self.keep_solver_products,
            write_solver_products=self.write_solver_products,
            write_hdf5=self.write_hdf5,
            write_hdf5_serial=self.write_hdf5_serial,
            output_dir=self.output_dir,
            mc_mode=self.mc_mode,
            mc_index=self.mc_index,
            reset_pix_dist=self.reset_pix_dist,
            report_memory=self.report_memory,
        )
        amplitudes_solve.apply(
            self._data, detectors=self._detectors, use_accel=self._use_accel
        )
        template_amplitudes = amplitudes_solve.amplitudes

        self._log.info_rank(
            f"{self._log_prefix}  finished template amplitude solve in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After solving amplitudes"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        return template_amplitudes

    @function_timer
    def _prepare_binning(self):
        """Set up the final map binning"""

        # Map binning operator
        if self.map_binning is not None and self.map_binning.enabled:
            map_binning = self.map_binning
        else:
            # Use the same binning used in the solver.
            map_binning = self.binning
        map_binning.pre_process = None
        map_binning.covariance = self.cov_name

        # Pixel distribution
        if self.reset_pix_dist:
            # Purge any stale products from previous runs
            for name in [
                self.hits_name,
                self.cov_name,
                self.invcov_name,
                self.rcond_name,
                self.clean_name,
                self.binmap_name,
                self.map_name,
                self.noiseweighted_map_name,
                map_binning.pixel_dist,
                map_binning.covariance,
            ]:
                if name in self._data:
                    del self._data[name]

        if map_binning.pixel_dist not in self._data:
            self._log.info_rank(
                f"{self._log_prefix} Caching pixel distribution",
                comm=self._comm,
            )
            pix_dist = BuildPixelDistribution(
                pixel_dist=map_binning.pixel_dist,
                pixel_pointing=map_binning.pixel_pointing,
                save_pointing=map_binning.full_pointing,
            )
            pix_dist.apply(self._data, use_accel=self._use_accel)
            self._log.info_rank(
                f"{self._log_prefix}  finished build of pixel distribution in",
                comm=self._comm,
                timer=self._timer,
            )

            self._memreport.prefix = "After pixel distribution"
            self._memreport.apply(self._data, use_accel=self._use_accel)

        return map_binning

    @function_timer
    def _build_pixel_covariance(self, map_binning):
        """Accumulate hits and pixel covariance"""

        if map_binning.covariance in self._data and self.mc_mode:
            # Covariance is already cached
            return

        # Construct the noise covariance, hits, and condition number
        # mask for the final binned map.

        self._log.info_rank(
            f"{self._log_prefix} begin build of final binning covariance",
            comm=self._comm,
        )

        final_cov = CovarianceAndHits(
            pixel_dist=map_binning.pixel_dist,
            covariance=map_binning.covariance,
            inverse_covariance=self.invcov_name,
            hits=self.hits_name,
            rcond=self.rcond_name,
            det_mask=map_binning.det_mask,
            det_flags=map_binning.det_flags,
            det_flag_mask=map_binning.det_flag_mask,
            det_data_units=map_binning.det_data_units,
            shared_flags=map_binning.shared_flags,
            shared_flag_mask=map_binning.shared_flag_mask,
            pixel_pointing=map_binning.pixel_pointing,
            stokes_weights=map_binning.stokes_weights,
            noise_model=map_binning.noise_model,
            rcond_threshold=self.map_rcond_threshold,
            sync_type=map_binning.sync_type,
            save_pointing=map_binning.full_pointing,
        )

        final_cov.apply(
            self._data, detectors=self._detectors, use_accel=self._use_accel
        )

        self._log.info_rank(
            f"{self._log_prefix}  finished build of final covariance in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After constructing final covariance and hits"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        # These data products are not needed later so they can be
        # written out and purged

        self._write_del(self.hits_name, self.write_hits, False, self.name)
        self._write_del(self.rcond_name, self.write_rcond, False, self.name)
        self._write_del(self.invcov_name, self.write_invcov, False, self.name)

        return

    @function_timer
    def _bin_and_write_raw_signal(self, map_binning):
        """Optionally bin and save an undestriped map"""

        if not self.write_binmap:
            return

        map_binning.det_data = self.det_data
        map_binning.binned = self.binmap_name
        map_binning.noiseweighted = None
        self._log.info_rank(
            f"{self._log_prefix} begin map binning",
            comm=self._comm,
        )
        map_binning.apply(
            self._data, detectors=self._detectors, use_accel=self._use_accel
        )
        self._log.info_rank(
            f"{self._log_prefix}  finished binning in",
            comm=self._comm,
            timer=self._timer,
        )
        self._write_del(self.binmap_name, self.write_binmap, True, self._mc_root)

        self._memreport.prefix = "After binning final map"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        return

    @function_timer
    def _clean_signal(self, template_amplitudes):
        if (
            self.template_matrix is None
            or self.template_matrix.n_enabled_templates == 0
        ):
            # No templates to subtract, bin the input signal
            out_cleaned = self.det_data
        else:
            # Apply (subtract) solved amplitudes.

            self._log.info_rank(
                f"{self._log_prefix} begin apply template amplitudes",
                comm=self._comm,
            )

            out_cleaned = self.clean_name
            if self.save_cleaned and self.overwrite_cleaned:
                # Modify data in place
                out_cleaned = None

            amplitudes_apply = ApplyAmplitudes(
                op="subtract",
                det_data=self.det_data,
                amplitudes=template_amplitudes,
                template_matrix=self.template_matrix,
                output=out_cleaned,
            )
            amplitudes_apply.apply(
                self._data, detectors=self._detectors, use_accel=self._use_accel
            )

            if not self.keep_solver_products:
                del self._data[template_amplitudes]

            self._log.info_rank(
                f"{self._log_prefix}  finished apply template amplitudes in",
                comm=self._comm,
                timer=self._timer,
            )

            self._memreport.prefix = "After subtracting templates"
            self._memreport.apply(self._data, use_accel=self._use_accel)

        return out_cleaned

    @function_timer
    def _bin_cleaned_signal(self, map_binning, out_cleaned):
        """Bin and save a map of the destriped signal"""

        self._log.info_rank(
            f"{self._log_prefix} begin final map binning",
            comm=self._comm,
        )

        if out_cleaned is None:
            map_binning.det_data = self.det_data
        else:
            map_binning.det_data = out_cleaned
        if self.write_noiseweighted_map or self.keep_final_products:
            map_binning.noiseweighted = self.noiseweighted_map_name
        map_binning.binned = self.map_name

        # Do the final binning
        map_binning.apply(
            self._data, detectors=self._detectors, use_accel=self._use_accel
        )

        self._log.info_rank(
            f"{self._log_prefix}  finished final binning in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After binning final map"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        return

    @function_timer
    def _purge_cleaned_tod(self):
        """If the cleaned TOD is not being returned, purge it"""

        if self.save_cleaned:
            return

        del_tod = Delete(detdata=[self.clean_name])
        del_tod.apply(self._data, use_accel=self._use_accel)

        self._memreport.prefix = "After purging cleaned TOD"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        return

    @function_timer
    def _write_maps(self):
        """Write and delete the outputs"""

        self._write_del(
            self.noiseweighted_map_name,
            self.write_noiseweighted_map,
            True,
            self._mc_root,
        )
        self._write_del(self.map_name, self.write_map, True, self._mc_root)
        self._write_del(self.cov_name, self.write_cov, False, self.name)

        self._log.info_rank(
            f"{self._log_prefix}  finished output write in",
            comm=self._comm,
            timer=self._timer,
        )

        return

    @function_timer
    def _closeout(self):
        """Explicitly delete members used by the _exec() method"""

        del self._log
        del self._timer
        del self._log_prefix
        del self._mc_root
        del self._data
        del self._detectors
        del self._use_accel
        del self._memreport
        del self._comm
        del self._rank

        return

    @function_timer
    def _exec(self, data, detectors=None, use_accel=None, **kwargs):

        # First confirm that there is at least one valid detector

        if self.map_binning is not None and self.map_binning.enabled:
            map_binning = self.map_binning
        else:
            # Use the same binning used in the solver.
            map_binning = self.binning
        all_local_dets = data.all_local_detectors(
            selection=detectors, flagmask=map_binning.det_mask
        )
        ndet = len(all_local_dets)
        if data.comm.comm_world is not None:
            ndet = data.comm.comm_world.allreduce(ndet, op=MPI.SUM)
        if ndet == 0:
            # No valid detectors, no mapmaking
            return

        # Destripe data and make maps

        self._setup(data, detectors, use_accel)

        self._memreport.prefix = "Start of mapmaking"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        template_amplitudes = self._fit_templates()

        map_binning = self._prepare_binning()

        self._build_pixel_covariance(map_binning)

        self._bin_and_write_raw_signal(map_binning)

        out_cleaned = self._clean_signal(template_amplitudes)

        if self.write_noiseweighted_map or self.write_map or self.keep_final_products:
            self._bin_cleaned_signal(map_binning, out_cleaned)

        self._purge_cleaned_tod()  # Potentially frees memory for writing maps

        self._write_maps()

        self._memreport.prefix = "End of mapmaking"
        self._memreport.apply(self._data, use_accel=self._use_accel)

        self._closeout()

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        # This operator requires everything that its sub-operators needs.
        req = self.binning.requires()
        if self.template_matrix is not None:
            req.update(self.template_matrix.requires())
        if self.map_binning is not None:
            req.update(self.map_binning.requires())
        req["detdata"].append(self.det_data)
        return req

    def _provides(self):
        prov = dict()
        if self.map_binning is not None:
            prov["global"] = [self.map_binning.binned]
        else:
            prov["global"] = [self.binning.binned]
        return prov

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

binning = Instance(klass=Operator, allow_none=True, help='Binning operator used for solving template amplitudes') class-attribute instance-attribute

convergence = Float(1e-12, help='Relative convergence limit') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

iter_max = Int(100, help='Maximum number of iterations') class-attribute instance-attribute

iter_min = Int(3, help='Minimum number of iterations') class-attribute instance-attribute

keep_final_products = Bool(False, help='If True, keep the map domain products in data after write') class-attribute instance-attribute

keep_solver_products = Bool(False, help='If True, keep the map domain solver products in data') class-attribute instance-attribute

map_binning = Instance(klass=Operator, allow_none=True, help='Binning operator for final map making. Default is same as solver') class-attribute instance-attribute

map_rcond_threshold = Float(1e-08, help='For final map, minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

mask = Unicode(None, allow_none=True, help='Data key for pixel mask to use in solving. First bit of pixel values is tested') class-attribute instance-attribute

mc_index = Int(None, allow_none=True, help='The Monte-Carlo index') class-attribute instance-attribute

mc_mode = Bool(False, help='If True, re-use solver flags, sparse covariances, etc') class-attribute instance-attribute

output_dir = Unicode('.', help='Write output data products to this directory') class-attribute instance-attribute

overwrite_cleaned = Bool(False, help='If True and save_cleaned is True, overwrite the input data') class-attribute instance-attribute

report_memory = Bool(False, help='Report memory throughout the execution') class-attribute instance-attribute

reset_pix_dist = Bool(False, help='Clear any existing pixel distribution. Useful when applying repeatedly to different data objects.') class-attribute instance-attribute

save_cleaned = Bool(False, help='If True, save the template-subtracted detector timestreams') class-attribute instance-attribute

solve_rcond_threshold = Float(1e-08, help='When solving, minimum value for inverse pixel condition number cut.') class-attribute instance-attribute

template_matrix = Instance(klass=Operator, allow_none=True, help='This must be an instance of a template matrix operator') class-attribute instance-attribute

write_binmap = Bool(True, help='If True, write the projected map *before* template subtraction') class-attribute instance-attribute

write_cov = Bool(True, help='If True, write the white noise covariance matrices.') class-attribute instance-attribute

write_hdf5 = Bool(False, help='If True, outputs are in HDF5 rather than FITS format.') class-attribute instance-attribute

write_hdf5_serial = Bool(False, help='If True, force serial HDF5 write of output maps.') class-attribute instance-attribute

write_hits = Bool(True, help='If True, write the hits map') class-attribute instance-attribute

write_invcov = Bool(False, help='If True, write the inverse white noise covariance matrices.') class-attribute instance-attribute

write_map = Bool(True, help='If True, write the projected map') class-attribute instance-attribute

write_noiseweighted_map = Bool(False, help='If True, write the noise-weighted map') class-attribute instance-attribute

write_rcond = Bool(True, help='If True, write the reciprocal condition numbers.') class-attribute instance-attribute

write_solver_products = Bool(False, help='If True, write out equivalent solver products.') class-attribute instance-attribute

__init__(**kwargs)

Source code in toast/ops/mapmaker.py
213
214
def __init__(self, **kwargs):
    super().__init__(**kwargs)

_bin_and_write_raw_signal(map_binning)

Optionally bin and save an undestriped map

Source code in toast/ops/mapmaker.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
@function_timer
def _bin_and_write_raw_signal(self, map_binning):
    """Optionally bin and save an undestriped map"""

    if not self.write_binmap:
        return

    map_binning.det_data = self.det_data
    map_binning.binned = self.binmap_name
    map_binning.noiseweighted = None
    self._log.info_rank(
        f"{self._log_prefix} begin map binning",
        comm=self._comm,
    )
    map_binning.apply(
        self._data, detectors=self._detectors, use_accel=self._use_accel
    )
    self._log.info_rank(
        f"{self._log_prefix}  finished binning in",
        comm=self._comm,
        timer=self._timer,
    )
    self._write_del(self.binmap_name, self.write_binmap, True, self._mc_root)

    self._memreport.prefix = "After binning final map"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    return

_bin_cleaned_signal(map_binning, out_cleaned)

Bin and save a map of the destriped signal

Source code in toast/ops/mapmaker.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
@function_timer
def _bin_cleaned_signal(self, map_binning, out_cleaned):
    """Bin and save a map of the destriped signal"""

    self._log.info_rank(
        f"{self._log_prefix} begin final map binning",
        comm=self._comm,
    )

    if out_cleaned is None:
        map_binning.det_data = self.det_data
    else:
        map_binning.det_data = out_cleaned
    if self.write_noiseweighted_map or self.keep_final_products:
        map_binning.noiseweighted = self.noiseweighted_map_name
    map_binning.binned = self.map_name

    # Do the final binning
    map_binning.apply(
        self._data, detectors=self._detectors, use_accel=self._use_accel
    )

    self._log.info_rank(
        f"{self._log_prefix}  finished final binning in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After binning final map"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    return

_build_pixel_covariance(map_binning)

Accumulate hits and pixel covariance

Source code in toast/ops/mapmaker.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
@function_timer
def _build_pixel_covariance(self, map_binning):
    """Accumulate hits and pixel covariance"""

    if map_binning.covariance in self._data and self.mc_mode:
        # Covariance is already cached
        return

    # Construct the noise covariance, hits, and condition number
    # mask for the final binned map.

    self._log.info_rank(
        f"{self._log_prefix} begin build of final binning covariance",
        comm=self._comm,
    )

    final_cov = CovarianceAndHits(
        pixel_dist=map_binning.pixel_dist,
        covariance=map_binning.covariance,
        inverse_covariance=self.invcov_name,
        hits=self.hits_name,
        rcond=self.rcond_name,
        det_mask=map_binning.det_mask,
        det_flags=map_binning.det_flags,
        det_flag_mask=map_binning.det_flag_mask,
        det_data_units=map_binning.det_data_units,
        shared_flags=map_binning.shared_flags,
        shared_flag_mask=map_binning.shared_flag_mask,
        pixel_pointing=map_binning.pixel_pointing,
        stokes_weights=map_binning.stokes_weights,
        noise_model=map_binning.noise_model,
        rcond_threshold=self.map_rcond_threshold,
        sync_type=map_binning.sync_type,
        save_pointing=map_binning.full_pointing,
    )

    final_cov.apply(
        self._data, detectors=self._detectors, use_accel=self._use_accel
    )

    self._log.info_rank(
        f"{self._log_prefix}  finished build of final covariance in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After constructing final covariance and hits"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    # These data products are not needed later so they can be
    # written out and purged

    self._write_del(self.hits_name, self.write_hits, False, self.name)
    self._write_del(self.rcond_name, self.write_rcond, False, self.name)
    self._write_del(self.invcov_name, self.write_invcov, False, self.name)

    return

_check_map_binning(proposal)

Source code in toast/ops/mapmaker.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@traitlets.validate("map_binning")
def _check_map_binning(self, proposal):
    bin = proposal["value"]
    if bin is not None:
        if not isinstance(bin, Operator):
            raise traitlets.TraitError("map_binning should be an Operator instance")
        # Check that this operator has the traits we expect
        for trt in [
            "det_data",
            "pixel_dist",
            "pixel_pointing",
            "stokes_weights",
            "binned",
            "covariance",
            "det_mask",
            "det_flags",
            "det_flag_mask",
            "shared_flags",
            "shared_flag_mask",
            "noise_model",
            "full_pointing",
            "sync_type",
        ]:
            if not bin.has_trait(trt):
                msg = "map_binning operator should have a '{}' trait".format(trt)
                raise traitlets.TraitError(msg)
    return bin

_clean_signal(template_amplitudes)

Source code in toast/ops/mapmaker.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
@function_timer
def _clean_signal(self, template_amplitudes):
    if (
        self.template_matrix is None
        or self.template_matrix.n_enabled_templates == 0
    ):
        # No templates to subtract, bin the input signal
        out_cleaned = self.det_data
    else:
        # Apply (subtract) solved amplitudes.

        self._log.info_rank(
            f"{self._log_prefix} begin apply template amplitudes",
            comm=self._comm,
        )

        out_cleaned = self.clean_name
        if self.save_cleaned and self.overwrite_cleaned:
            # Modify data in place
            out_cleaned = None

        amplitudes_apply = ApplyAmplitudes(
            op="subtract",
            det_data=self.det_data,
            amplitudes=template_amplitudes,
            template_matrix=self.template_matrix,
            output=out_cleaned,
        )
        amplitudes_apply.apply(
            self._data, detectors=self._detectors, use_accel=self._use_accel
        )

        if not self.keep_solver_products:
            del self._data[template_amplitudes]

        self._log.info_rank(
            f"{self._log_prefix}  finished apply template amplitudes in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After subtracting templates"
        self._memreport.apply(self._data, use_accel=self._use_accel)

    return out_cleaned

_closeout()

Explicitly delete members used by the _exec() method

Source code in toast/ops/mapmaker.py
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
@function_timer
def _closeout(self):
    """Explicitly delete members used by the _exec() method"""

    del self._log
    del self._timer
    del self._log_prefix
    del self._mc_root
    del self._data
    del self._detectors
    del self._use_accel
    del self._memreport
    del self._comm
    del self._rank

    return

_exec(data, detectors=None, use_accel=None, **kwargs)

Source code in toast/ops/mapmaker.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
@function_timer
def _exec(self, data, detectors=None, use_accel=None, **kwargs):

    # First confirm that there is at least one valid detector

    if self.map_binning is not None and self.map_binning.enabled:
        map_binning = self.map_binning
    else:
        # Use the same binning used in the solver.
        map_binning = self.binning
    all_local_dets = data.all_local_detectors(
        selection=detectors, flagmask=map_binning.det_mask
    )
    ndet = len(all_local_dets)
    if data.comm.comm_world is not None:
        ndet = data.comm.comm_world.allreduce(ndet, op=MPI.SUM)
    if ndet == 0:
        # No valid detectors, no mapmaking
        return

    # Destripe data and make maps

    self._setup(data, detectors, use_accel)

    self._memreport.prefix = "Start of mapmaking"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    template_amplitudes = self._fit_templates()

    map_binning = self._prepare_binning()

    self._build_pixel_covariance(map_binning)

    self._bin_and_write_raw_signal(map_binning)

    out_cleaned = self._clean_signal(template_amplitudes)

    if self.write_noiseweighted_map or self.write_map or self.keep_final_products:
        self._bin_cleaned_signal(map_binning, out_cleaned)

    self._purge_cleaned_tod()  # Potentially frees memory for writing maps

    self._write_maps()

    self._memreport.prefix = "End of mapmaking"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    self._closeout()

    return

_finalize(data, **kwargs)

Source code in toast/ops/mapmaker.py
695
696
def _finalize(self, data, **kwargs):
    return

_fit_templates()

Solve for template amplitudes

Source code in toast/ops/mapmaker.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
@function_timer
def _fit_templates(self):
    """Solve for template amplitudes"""

    amplitudes_solve = SolveAmplitudes(
        name=self.name,
        det_data=self.det_data,
        convergence=self.convergence,
        iter_min=self.iter_min,
        iter_max=self.iter_max,
        solve_rcond_threshold=self.solve_rcond_threshold,
        mask=self.mask,
        binning=self.binning,
        template_matrix=self.template_matrix,
        keep_solver_products=self.keep_solver_products,
        write_solver_products=self.write_solver_products,
        write_hdf5=self.write_hdf5,
        write_hdf5_serial=self.write_hdf5_serial,
        output_dir=self.output_dir,
        mc_mode=self.mc_mode,
        mc_index=self.mc_index,
        reset_pix_dist=self.reset_pix_dist,
        report_memory=self.report_memory,
    )
    amplitudes_solve.apply(
        self._data, detectors=self._detectors, use_accel=self._use_accel
    )
    template_amplitudes = amplitudes_solve.amplitudes

    self._log.info_rank(
        f"{self._log_prefix}  finished template amplitude solve in",
        comm=self._comm,
        timer=self._timer,
    )

    self._memreport.prefix = "After solving amplitudes"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    return template_amplitudes

_prepare_binning()

Set up the final map binning

Source code in toast/ops/mapmaker.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
@function_timer
def _prepare_binning(self):
    """Set up the final map binning"""

    # Map binning operator
    if self.map_binning is not None and self.map_binning.enabled:
        map_binning = self.map_binning
    else:
        # Use the same binning used in the solver.
        map_binning = self.binning
    map_binning.pre_process = None
    map_binning.covariance = self.cov_name

    # Pixel distribution
    if self.reset_pix_dist:
        # Purge any stale products from previous runs
        for name in [
            self.hits_name,
            self.cov_name,
            self.invcov_name,
            self.rcond_name,
            self.clean_name,
            self.binmap_name,
            self.map_name,
            self.noiseweighted_map_name,
            map_binning.pixel_dist,
            map_binning.covariance,
        ]:
            if name in self._data:
                del self._data[name]

    if map_binning.pixel_dist not in self._data:
        self._log.info_rank(
            f"{self._log_prefix} Caching pixel distribution",
            comm=self._comm,
        )
        pix_dist = BuildPixelDistribution(
            pixel_dist=map_binning.pixel_dist,
            pixel_pointing=map_binning.pixel_pointing,
            save_pointing=map_binning.full_pointing,
        )
        pix_dist.apply(self._data, use_accel=self._use_accel)
        self._log.info_rank(
            f"{self._log_prefix}  finished build of pixel distribution in",
            comm=self._comm,
            timer=self._timer,
        )

        self._memreport.prefix = "After pixel distribution"
        self._memreport.apply(self._data, use_accel=self._use_accel)

    return map_binning

_provides()

Source code in toast/ops/mapmaker.py
708
709
710
711
712
713
714
def _provides(self):
    prov = dict()
    if self.map_binning is not None:
        prov["global"] = [self.map_binning.binned]
    else:
        prov["global"] = [self.binning.binned]
    return prov

_purge_cleaned_tod()

If the cleaned TOD is not being returned, purge it

Source code in toast/ops/mapmaker.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
@function_timer
def _purge_cleaned_tod(self):
    """If the cleaned TOD is not being returned, purge it"""

    if self.save_cleaned:
        return

    del_tod = Delete(detdata=[self.clean_name])
    del_tod.apply(self._data, use_accel=self._use_accel)

    self._memreport.prefix = "After purging cleaned TOD"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    return

_requires()

Source code in toast/ops/mapmaker.py
698
699
700
701
702
703
704
705
706
def _requires(self):
    # This operator requires everything that its sub-operators needs.
    req = self.binning.requires()
    if self.template_matrix is not None:
        req.update(self.template_matrix.requires())
    if self.map_binning is not None:
        req.update(self.map_binning.requires())
    req["detdata"].append(self.det_data)
    return req

_setup(data, detectors, use_accel)

Set up convenience members used in the _exec() method

Source code in toast/ops/mapmaker.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
@function_timer
def _setup(self, data, detectors, use_accel):
    """Set up convenience members used in the _exec() method"""

    self._log = Logger.get()
    self._timer = Timer()
    self._log_prefix = "MapMaker"

    self._mc_root = self.name
    if self.mc_mode:
        if self.mc_root is not None:
            self._mc_root += f"_{self.mc_root}"
        if self.mc_index is not None:
            self._mc_root += f"_{self.mc_index:05d}"

    self._data = data
    self._detectors = detectors
    self._use_accel = use_accel
    self._memreport = MemoryCounter()
    if not self.report_memory:
        self._memreport.enabled = False

    # The global communicator we are using (or None)

    self._comm = data.comm.comm_world
    self._rank = data.comm.world_rank

    # Data names of outputs

    self.hits_name = f"{self.name}_hits"
    self.cov_name = f"{self.name}_cov"
    self.invcov_name = f"{self.name}_invcov"
    self.rcond_name = f"{self.name}_rcond"
    self.det_flag_name = f"{self.name}_flags"

    self.clean_name = f"{self.name}_cleaned"
    self.binmap_name = f"{self.name}_binmap"
    self.map_name = f"{self.name}_map"
    self.noiseweighted_map_name = f"{self.name}_noiseweighted_map"

    self._timer.start()

    return

_write_del(prod_key, prod_write, force, rootname)

Write data object to file and delete it from cache

Source code in toast/ops/mapmaker.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@function_timer
def _write_del(self, prod_key, prod_write, force, rootname):
    """Write data object to file and delete it from cache"""
    log = Logger.get()

    # FIXME:  This I/O technique assumes "known" types of pixel representations.
    # Instead, we should associate read / write functions to a particular pixel
    # class.

    if self.map_binning is not None and self.map_binning.enabled:
        map_binning = self.map_binning
    else:
        map_binning = self.binning

    if hasattr(map_binning.pixel_pointing, "wcs"):
        is_pix_wcs = True
    else:
        is_pix_wcs = False
        is_hpix_nest = map_binning.pixel_pointing.nest

    wtimer = Timer()
    wtimer.start()
    product = prod_key.replace(f"{self.name}_", "")
    if prod_write:
        if is_pix_wcs:
            fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits")
            if self.mc_mode and not force and os.path.isfile(fname):
                log.info_rank(f"Skipping existing file: {fname}", comm=self._comm)
            else:
                write_wcs_fits(self._data[prod_key], fname)
        else:
            if self.write_hdf5:
                # Non-standard HDF5 output
                fname = os.path.join(self.output_dir, f"{rootname}_{product}.h5")
                if self.mc_mode and not force and os.path.isfile(fname):
                    log.info_rank(
                        f"Skipping existing file: {fname}", comm=self._comm
                    )
                else:
                    write_healpix_hdf5(
                        self._data[prod_key],
                        fname,
                        nest=is_hpix_nest,
                        single_precision=True,
                        force_serial=self.write_hdf5_serial,
                    )
            else:
                # Standard FITS output
                fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits")
                if self.mc_mode and not force and os.path.isfile(fname):
                    log.info_rank(
                        f"Skipping existing file: {fname}", comm=self._comm
                    )
                else:
                    write_healpix_fits(
                        self._data[prod_key],
                        fname,
                        nest=is_hpix_nest,
                        report_memory=self.report_memory,
                    )
        log.info_rank(f"Wrote {fname} in", comm=self._comm, timer=wtimer)

    if not self.keep_final_products and not self.mc_mode:
        if prod_key in self._data:
            self._data[prod_key].clear()
            del self._data[prod_key]

    self._memreport.prefix = f"After writing/deleting {prod_key}"
    self._memreport.apply(self._data, use_accel=self._use_accel)

    return

_write_maps()

Write and delete the outputs

Source code in toast/ops/mapmaker.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
@function_timer
def _write_maps(self):
    """Write and delete the outputs"""

    self._write_del(
        self.noiseweighted_map_name,
        self.write_noiseweighted_map,
        True,
        self._mc_root,
    )
    self._write_del(self.map_name, self.write_map, True, self._mc_root)
    self._write_del(self.cov_name, self.write_cov, False, self.name)

    self._log.info_rank(
        f"{self._log_prefix}  finished output write in",
        comm=self._comm,
        timer=self._timer,
    )

    return

External Tools

toast.ops.Madam

Bases: Operator

Operator which passes data to libmadam for map-making.

Source code in toast/ops/madam.py
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
@trait_docs
class Madam(Operator):
    """Operator which passes data to libmadam for map-making."""

    # Class traits

    API = Int(0, help="Internal interface version for this operator")

    params = Dict({}, help="Parameters to pass to madam")

    paramfile = Unicode(
        None, allow_none=True, help="Read madam parameters from this file"
    )

    times = Unicode(defaults.times, help="Observation shared key for timestamps")

    det_data = Unicode(
        defaults.det_data, help="Observation detdata key for the timestream data"
    )

    det_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for per-detector flagging",
    )

    det_flags = Unicode(
        defaults.det_flags,
        allow_none=True,
        help="Observation detdata key for flags to use",
    )

    det_flag_mask = Int(
        defaults.det_mask_nonscience,
        help="Bit mask value for detector sample flagging",
    )

    shared_flags = Unicode(
        defaults.shared_flags,
        allow_none=True,
        help="Observation shared key for telescope flags to use",
    )

    shared_flag_mask = Int(
        defaults.shared_mask_nonscience,
        help="Bit mask value for optional shared flagging",
    )

    pixel_pointing = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a pixel pointing operator",
    )

    stokes_weights = Instance(
        klass=Operator,
        allow_none=True,
        help="This must be an instance of a Stokes weights operator",
    )

    view = Unicode(
        None, allow_none=True, help="Use this view of the data in all observations"
    )

    det_out = Unicode(
        None,
        allow_none=True,
        help="Observation detdata key for output destriped timestreams",
    )

    noise_model = Unicode(
        "noise_model", help="Observation key containing the noise model"
    )

    purge_det_data = Bool(
        False,
        help="If True, clear all observation detector data after copying to madam buffers",
    )

    restore_det_data = Bool(
        False,
        help="If True, restore detector data to observations on completion",
    )

    mcmode = Bool(
        False,
        help="If true, Madam will store auxiliary information such as pixel matrices and noise filter.",
    )

    copy_groups = Int(
        1,
        help="The processes on each node are split into this number of groups to copy data in turns",
    )

    translate_timestamps = Bool(
        False, help="Translate timestamps to enforce monotonity."
    )

    noise_scale = Unicode(
        "noise_scale",
        help="Observation key with optional scaling factor for noise PSDs",
    )

    mem_report = Bool(
        False, help="Print system memory use while staging / unstaging data."
    )

    @traitlets.validate("det_mask")
    def _check_det_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det mask should be a positive integer")
        return check

    @traitlets.validate("shared_flag_mask")
    def _check_shared_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Shared flag mask should be a positive integer")
        return check

    @traitlets.validate("det_flag_mask")
    def _check_det_flag_mask(self, proposal):
        check = proposal["value"]
        if check < 0:
            raise traitlets.TraitError("Det flag mask should be a positive integer")
        return check

    @traitlets.validate("restore_det_data")
    def _check_restore_det_data(self, proposal):
        check = proposal["value"]
        if check and not self.purge_det_data:
            raise traitlets.TraitError(
                "Cannot set restore_det_data since purge_det_data is False"
            )
        if check and self.det_out is not None:
            raise traitlets.TraitError(
                "Cannot set restore_det_data since det_out is not None"
            )
        return check

    @traitlets.validate("det_out")
    def _check_det_out(self, proposal):
        check = proposal["value"]
        if check is not None and self.restore_det_data:
            raise traitlets.TraitError(
                "If det_out is not None, restore_det_data should be False"
            )
        return check

    @traitlets.validate("pixel_pointing")
    def _check_pixel_pointing(self, proposal):
        pixels = proposal["value"]
        if pixels is not None:
            if not isinstance(pixels, Operator):
                raise traitlets.TraitError(
                    "pixel_pointing should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["pixels", "create_dist", "view"]:
                if not pixels.has_trait(trt):
                    msg = f"pixel_pointing operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return pixels

    @traitlets.validate("stokes_weights")
    def _check_stokes_weights(self, proposal):
        weights = proposal["value"]
        if weights is not None:
            if not isinstance(weights, Operator):
                raise traitlets.TraitError(
                    "stokes_weights should be an Operator instance"
                )
            # Check that this operator has the traits we expect
            for trt in ["weights", "view"]:
                if not weights.has_trait(trt):
                    msg = f"stokes_weights operator should have a '{trt}' trait"
                    raise traitlets.TraitError(msg)
        return weights

    @traitlets.validate("params")
    def _check_params(self, proposal):
        check = proposal["value"]
        if "info" not in check:
            # The user did not specify the info level- set it from the toast loglevel
            env = Environment.get()
            level = env.log_level()
            if level == "DEBUG":
                check["info"] = 2
            elif level == "VERBOSE":
                check["info"] = 3
            else:
                check["info"] = 1
        return check

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._cached = False
        self._logprefix = "Madam:"

    def clear(self):
        """Delete the underlying memory.

        This will forcibly delete the C-allocated memory and invalidate all python
        references to the buffers.

        """
        if self._cached:
            madam.clear_caches()
            self._cached = False
        for atr in ["timestamps", "signal", "pixels", "pixweights"]:
            atrname = "_madam_{}".format(atr)
            rawname = "{}_raw".format(atrname)
            if hasattr(self, atrname):
                delattr(self, atrname)
                raw = getattr(self, rawname)
                if raw is not None:
                    raw.clear()
                setattr(self, rawname, None)
                setattr(self, atrname, None)

    def __del__(self):
        self.clear()

    @function_timer
    def _exec(self, data, detectors=None, **kwargs):
        log = Logger.get()
        timer = Timer()
        timer.start()

        if not available():
            raise RuntimeError("Madam is either not installed or MPI is disabled")

        if len(data.obs) == 0:
            raise RuntimeError(
                "Madam requires every supplied data object to "
                "contain at least one observation"
            )

        for trait in "det_data", "pixel_pointing", "stokes_weights":
            if getattr(self, trait) is None:
                msg = f"You must set the '{trait}' trait before calling exec()"
                raise RuntimeError(msg)

        # Combine parameters from an external file and other parameters passed in

        params = dict()
        repeat_keys = ["detset", "detset_nopol", "survey"]

        if self.paramfile is not None:
            if data.comm.world_rank == 0:
                line_pat = re.compile(r"(\S+)\s+=\s+(\S+)")
                comment_pat = re.compile(r"^\s*\#.*")
                with open(self.paramfile, "r") as f:
                    for line in f:
                        if comment_pat.match(line) is None:
                            line_mat = line_pat.match(line)
                            if line_mat is not None:
                                k = line_mat.group(1)
                                v = line_mat.group(2)
                                if k in repeat_keys:
                                    if k not in params:
                                        params[k] = [v]
                                    else:
                                        params[k].append(v)
                                else:
                                    params[k] = v
            if data.comm.world_comm is not None:
                params = data.comm.world_comm.bcast(params, root=0)
            for k, v in self.params.items():
                if k in repeat_keys:
                    if k not in params:
                        params[k] = [v]
                    else:
                        params[k].append(v)
                else:
                    params[k] = v

        if self.params is not None:
            params.update(self.params)

        if "fsample" not in params:
            params["fsample"] = data.obs[0].telescope.focalplane.sample_rate.to_value(
                u.Hz
            )

        # Set madam parameters that depend on our traits
        if self.mcmode:
            params["mcmode"] = True
        else:
            params["mcmode"] = False

        if self.det_out is not None:
            params["write_tod"] = True
        else:
            params["write_tod"] = False

        # Check input parameters and compute the sizes of Madam data objects
        if data.comm.world_rank == 0:
            msg = "{} Computing data sizes".format(self._logprefix)
            log.info(msg)
        (
            all_dets,
            nsamp,
            nnz,
            nnz_full,
            nnz_stride,
            interval_starts,
            psd_freqs,
        ) = self._prepare(params, data, detectors)

        log.info_rank(
            f"{self._logprefix} Parsed parameters in",
            comm=data.comm.comm_world,
            timer=timer,
        )

        if data.comm.world_rank == 0:
            msg = "{} Copying toast data to buffers".format(self._logprefix)
            log.info(msg)
        psdinfo, signal_dtype = self._stage_data(
            params,
            data,
            all_dets,
            nsamp,
            nnz,
            nnz_full,
            nnz_stride,
            interval_starts,
            psd_freqs,
        )

        log.info_rank(
            f"{self._logprefix} Staged data in",
            comm=data.comm.comm_world,
            timer=timer,
        )

        if data.comm.world_rank == 0:
            msg = "{} Destriping data".format(self._logprefix)
            log.info(msg)
        self._destripe(params, data, all_dets, interval_starts, psdinfo)

        log.info_rank(
            f"{self._logprefix} Destriped data in",
            comm=data.comm.comm_world,
            timer=timer,
        )

        if data.comm.world_rank == 0:
            msg = "{} Copying buffers back to toast data".format(self._logprefix)
            log.info(msg)
        self._unstage_data(
            params,
            data,
            all_dets,
            nsamp,
            nnz,
            nnz_full,
            interval_starts,
            signal_dtype,
        )

        log.info_rank(
            f"{self._logprefix} Unstaged data in",
            comm=data.comm.comm_world,
            timer=timer,
        )

        return

    def _finalize(self, data, **kwargs):
        return

    def _requires(self):
        req = {
            "meta": [self.noise_model],
            "shared": [
                self.times,
            ],
            "detdata": [self.det_data],
            "intervals": list(),
        }
        if self.view is not None:
            req["intervals"].append(self.view)
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        prov = {"detdata": list()}
        if self.det_out is not None:
            prov["detdata"].append(self.det_out)
        return prov

    @function_timer
    def _prepare(self, params, data, detectors):
        """Examine the data and determine quantities needed to set up Madam buffers"""
        log = Logger.get()
        timer = Timer()
        timer.start()

        params["nside_map"] = self.pixel_pointing.nside

        # Madam requires a fixed set of detectors and pointing matrix non-zeros.
        # Here we find the superset of local detectors used, and also the number
        # of pointing matrix elements.

        nsamp = 0

        # Madam uses monolithic data buffers and specifies contiguous data intervals
        # in that buffer.  The starting sample index is used to mark the transition
        # between data intervals.
        interval_starts = list()

        # This quantity is only used for printing the fraction of samples in valid
        # ranges specified by the View.  Only samples actually in the view are copied
        # to Madam buffers.
        nsamp_valid = 0

        all_dets = set()
        nnz_full = None
        psd_freqs = None

        for ob in data.obs:
            # Get the detectors we are using for this observation
            local_dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
            # if ob.comm.comm_group is not None:
            #     pdets = ob.comm.comm_group.gather(local_dets, root=0)
            #     obs_dets = None
            #     if ob.comm.group_rank == 0:
            #         obs_dets = set()
            #         for plocal in pdets:
            #             for d in plocal:
            #                 obs_dets.add(d)
            #     obs_dets = ob.comm.comm_group.bcast(obs_dets, root=0)
            # else:
            #     obs_dets = set(local_dets)
            # all_dets.update(obs_dets)
            all_dets.update(set(local_dets))

            # Check that the timestamps exist.
            if self.times not in ob.shared:
                msg = (
                    "Shared timestamps '{}' does not exist in observation '{}'".format(
                        self.times, ob.name
                    )
                )
                raise RuntimeError(msg)

            # Check that the detector data and pointing exists in the observation
            if self.det_data not in ob.detdata:
                msg = "Detector data '{}' does not exist in observation '{}'".format(
                    self.det_data, ob.name
                )
                raise RuntimeError(msg)

            # Check that the noise model exists, and that the PSD frequencies are the
            # same across all observations (required by Madam).
            if self.noise_model not in ob:
                msg = "Noise model '{}' not in observation '{}'".format(
                    self.noise_model, ob.name
                )
                raise RuntimeError(msg)
            if psd_freqs is None:
                psd_freqs = np.array(
                    ob[self.noise_model].freq(ob.local_detectors[0]).to_value(u.Hz),
                    dtype=np.float64,
                )
            else:
                check_freqs = (
                    ob[self.noise_model].freq(ob.local_detectors[0]).to_value(u.Hz)
                )
                if not np.allclose(psd_freqs, check_freqs):
                    raise RuntimeError(
                        "All PSDs passed to Madam must have the same frequency binning."
                    )

            # Are we using a view of the data?  If so, we will only be copying data in
            # those valid intervals.
            if self.view is not None:
                if self.view not in ob.intervals:
                    msg = "View '{}' does not exist in observation {}".format(
                        self.view, ob.name
                    )
                    raise RuntimeError(msg)
                # Go through all the intervals that will be used for our data view
                # and accumulate the number of samples.
                for intvw in ob.intervals[self.view]:
                    interval_starts.append(nsamp_valid)
                    nsamp_valid += intvw.last - intvw.first
            else:
                interval_starts.append(nsamp_valid)
                nsamp_valid += ob.n_local_samples
            nsamp += ob.n_local_samples

        if data.comm.world_rank == 0:
            log.info(
                "{}{:.2f} % of samples are included in valid intervals.".format(
                    self._logprefix, nsamp_valid * 100.0 / nsamp
                )
            )

        nsamp = nsamp_valid

        interval_starts = np.array(interval_starts, dtype=np.int64)
        all_dets = list(sorted(all_dets))
        ndet = len(all_dets)

        nnz_full = len(self.stokes_weights.mode)
        nnz_stride = None
        if "temperature_only" in params and params["temperature_only"] in [
            "T",
            "True",
            "TRUE",
            "true",
            True,
        ]:
            # User has requested a temperature-only map.
            if nnz_full not in [1, 3]:
                raise RuntimeError(
                    "Madam cannot make a temperature map with nnz == {}".format(
                        nnz_full
                    )
                )
            nnz = 1
            nnz_stride = nnz_full
        else:
            nnz = nnz_full
            nnz_stride = 1

        if data.comm.world_rank == 0 and "path_output" in params:
            os.makedirs(params["path_output"], exist_ok=True)

        # Inspect the valid intervals across all observations to
        # determine the number of samples per detector

        data.comm.comm_world.barrier()
        timer.stop()
        if data.comm.world_rank == 0:
            msg = "{}  Compute data dimensions: {:0.1f} s".format(
                self._logprefix, timer.seconds()
            )
            log.debug(msg)

        return (
            all_dets,
            nsamp,
            nnz,
            nnz_full,
            nnz_stride,
            interval_starts,
            psd_freqs,
        )

    @function_timer
    def _stage_data(
        self,
        params,
        data,
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        nnz_stride,
        interval_starts,
        psd_freqs,
    ):
        """Create madam-compatible buffers.

        Collect the data into Madam buffers.  If we are purging TOAST data to save
        memory, then optionally limit the number of processes that are copying at once.

        """
        log = Logger.get()
        timer = Timer()

        nodecomm = data.comm.comm_group_node

        # Determine how many processes per node should copy at once.
        n_copy_groups = 1
        if self.purge_det_data:
            # We will be purging some data- see if we should reduce the number of
            # processes copying in parallel (if we are not purging data, there
            # is no benefit to staggering the copy).
            if self.copy_groups > 0:
                n_copy_groups = min(self.copy_groups, nodecomm.size)

        if not self._cached:
            # Only do this if we have not cached the data yet.
            log_time_memory(
                data,
                prefix=self._logprefix,
                mem_msg="Before staging",
                full_mem=self.mem_report,
            )

        # Copy timestamps and PSDs all at once, since they are never purged.

        psds = dict()

        timer.start()

        if not self._cached:
            timestamp_storage, _ = dtype_to_aligned(madam.TIMESTAMP_TYPE)
            self._madam_timestamps_raw = timestamp_storage.zeros(nsamp)
            self._madam_timestamps = self._madam_timestamps_raw.array()

            interval = 0
            time_offset = 0.0

            for ob in data.obs:
                for vw in ob.view[self.view].shared[self.times]:
                    offset = interval_starts[interval]
                    slc = slice(offset, offset + len(vw), 1)
                    self._madam_timestamps[slc] = vw
                    if self.translate_timestamps:
                        off = self._madam_timestamps[offset] - time_offset
                        self._madam_timestamps[slc] -= off
                        time_offset = self._madam_timestamps[slc][-1] + 1.0
                    interval += 1

                # Get the noise object for this observation and create new
                # entries in the dictionary when the PSD actually changes.  The detector
                # weights are obtained from the noise model.

                nse = ob[self.noise_model]
                nse_scale = 1.0
                if self.noise_scale is not None:
                    if self.noise_scale in ob:
                        nse_scale = float(ob[self.noise_scale])

                local_dets = set(ob.select_local_detectors(flagmask=self.det_mask))
                for det in all_dets:
                    if det not in local_dets:
                        continue
                    psd = nse.psd(det).to_value(u.K**2 * u.second) * nse_scale**2
                    detw = nse.detector_weight(det).to_value(1.0 / u.K**2)
                    if det not in psds:
                        psds[det] = [(0.0, psd, detw)]
                    else:
                        if not np.allclose(psds[det][-1][1], psd):
                            psds[det] += [(ob.shared[self.times][0], psd, detw)]

            log_time_memory(
                data,
                timer=timer,
                timer_msg="Copy timestamps and PSDs",
                prefix=self._logprefix,
                mem_msg="After timestamp staging",
                full_mem=self.mem_report,
            )

        # Copy the signal.  We always need to do this, even if we are running MCs.

        signal_dtype = data.obs[0].detdata[self.det_data].dtype

        if self._cached:
            # We have previously created the madam buffers.  We just need to fill
            # them from the toast data.  Since both already exist we just copy the
            # contents.
            stage_local(
                data,
                nsamp,
                self.view,
                all_dets,
                self.det_data,
                self._madam_signal,
                interval_starts,
                1,
                1,
                self.det_mask,
                None,
                None,
                None,
                self.det_flag_mask,
                do_purge=False,
            )
        else:
            # Signal buffers do not yet exist
            if self.purge_det_data:
                # Allocate in a staggered way.
                self._madam_signal_raw, self._madam_signal = stage_in_turns(
                    data,
                    nodecomm,
                    n_copy_groups,
                    nsamp,
                    self.view,
                    all_dets,
                    self.det_data,
                    madam.SIGNAL_TYPE,
                    interval_starts,
                    1,
                    1,
                    self.det_mask,
                    None,
                    None,
                    None,
                    self.det_flag_mask,
                )
            else:
                # Allocate and copy all at once.
                storage, _ = dtype_to_aligned(madam.SIGNAL_TYPE)
                self._madam_signal_raw = storage.zeros(nsamp * len(all_dets))
                self._madam_signal = self._madam_signal_raw.array()

                stage_local(
                    data,
                    nsamp,
                    self.view,
                    all_dets,
                    self.det_data,
                    self._madam_signal,
                    interval_starts,
                    1,
                    1,
                    self.det_mask,
                    None,
                    None,
                    None,
                    self.det_flag_mask,
                    do_purge=False,
                )

        log_time_memory(
            data,
            timer=timer,
            timer_msg="Copy signal",
            prefix=self._logprefix,
            mem_msg="After signal staging",
            full_mem=self.mem_report,
        )

        # Copy the pointing

        nested_pointing = self.pixel_pointing.nest
        if not nested_pointing:
            # Any existing pixel numbers are in the wrong ordering
            Delete(detdata=[self.pixel_pointing.pixels]).apply(data)
            self.pixel_pointing.nest = True

        if not self._cached:
            # We do not have the pointing yet.
            self._madam_pixels_raw, self._madam_pixels = stage_in_turns(
                data,
                nodecomm,
                n_copy_groups,
                nsamp,
                self.view,
                all_dets,
                self.pixel_pointing.pixels,
                madam.PIXEL_TYPE,
                interval_starts,
                1,
                1,
                self.det_mask,
                self.shared_flags,
                self.shared_flag_mask,
                self.det_flags,
                self.det_flag_mask,
                operator=self.pixel_pointing,
            )

            self._madam_pixweights_raw, self._madam_pixweights = stage_in_turns(
                data,
                nodecomm,
                n_copy_groups,
                nsamp,
                self.view,
                all_dets,
                self.stokes_weights.weights,
                madam.WEIGHT_TYPE,
                interval_starts,
                nnz,
                nnz_stride,
                self.det_mask,
                None,
                None,
                None,
                self.det_flag_mask,
                operator=self.stokes_weights,
            )

            log_time_memory(
                data,
                timer=timer,
                timer_msg="Copy pointing",
                prefix=self._logprefix,
                mem_msg="After pointing staging",
                full_mem=self.mem_report,
            )

        if not nested_pointing:
            # Any existing pixel numbers are in the wrong ordering
            Delete(detdata=[self.pixel_pointing.pixels]).apply(data)
            self.pixel_pointing.nest = False

        psdinfo = None

        if not self._cached:
            # Detectors weights.  Madam assumes a single noise weight for each detector
            # that is constant.  We set this based on the first observation or else use
            # uniform weighting.

            ndet = len(all_dets)
            detweights = np.ones(ndet, dtype=np.float64)

            if len(psds) > 0:
                npsdbin = len(psd_freqs)
                npsd = np.zeros(ndet, dtype=np.int64)
                psdstarts = []
                psdvals = []
                for idet, det in enumerate(all_dets):
                    if det not in psds:
                        raise RuntimeError("Every detector must have at least one PSD")
                    psdlist = psds[det]
                    npsd[idet] = len(psdlist)
                    for psdstart, psd, detw in psdlist:
                        psdstarts.append(psdstart)
                        psdvals.append(psd)
                    detweights[idet] = psdlist[0][2]
                npsdtot = np.sum(npsd)
                psdstarts = np.array(psdstarts, dtype=np.float64)
                psdvals = np.hstack(psdvals).astype(madam.PSD_TYPE)
                npsdval = psdvals.size
            else:
                # Uniform weighting
                npsd = np.ones(ndet, dtype=np.int64)
                npsdtot = np.sum(npsd)
                psdstarts = np.zeros(npsdtot)
                npsdbin = 10
                fsample = 10.0
                psd_freqs = np.arange(npsdbin) * fsample / npsdbin
                npsdval = npsdbin * npsdtot
                psdvals = np.ones(npsdval)

            psdinfo = (detweights, npsd, psdstarts, psd_freqs, psdvals)

            log_time_memory(
                data,
                timer=timer,
                timer_msg="Collect PSD info",
                prefix=self._logprefix,
            )
        timer.stop()

        return psdinfo, signal_dtype

    @function_timer
    def _unstage_data(
        self,
        params,
        data,
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        interval_starts,
        signal_dtype,
    ):
        """
        Restore data to TOAST observations.

        Optionally copy the signal and pointing back to TOAST if we previously
        purged it to save memory.  Also copy the destriped timestreams if desired.

        """
        log = Logger.get()
        timer = Timer()

        nodecomm = data.comm.comm_group_node

        # Determine how many processes per node should copy at once.
        n_copy_groups = 1
        if self.purge_det_data:
            # We MAY be restoring some data- see if we should reduce the number of
            # processes copying in parallel (if we are not purging data, there
            # is no benefit to staggering the copy).
            if self.copy_groups > 0:
                n_copy_groups = min(self.copy_groups, nodecomm.size)

        log_time_memory(
            data,
            prefix=self._logprefix,
            mem_msg="Before un-staging",
            full_mem=self.mem_report,
        )

        # Copy the signal

        timer.start()

        out_name = self.det_data
        if self.det_out is not None:
            out_name = self.det_out

        if self.det_out is not None or (self.purge_det_data and self.restore_det_data):
            # We are copying some kind of signal back
            if not self.mcmode:
                # We are not running multiple realizations, so delete as we copy.
                restore_in_turns(
                    data,
                    nodecomm,
                    n_copy_groups,
                    nsamp,
                    self.view,
                    all_dets,
                    out_name,
                    signal_dtype,
                    self._madam_signal,
                    self._madam_signal_raw,
                    interval_starts,
                    1,
                    self.det_mask,
                )
                del self._madam_signal
                del self._madam_signal_raw
            else:
                # We want to re-use the signal buffer, just copy.
                restore_local(
                    data,
                    nsamp,
                    self.view,
                    all_dets,
                    out_name,
                    signal_dtype,
                    self._madam_signal,
                    interval_starts,
                    1,
                    self.det_mask,
                )

            log_time_memory(
                data,
                timer=timer,
                timer_msg="Copy signal",
                prefix=self._logprefix,
                mem_msg="After restoring signal",
                full_mem=self.mem_report,
            )

        # Copy the pointing

        if not self.mcmode:
            # We can clear the cached pointing
            del self._madam_pixels
            del self._madam_pixels_raw
            del self._madam_pixweights
            del self._madam_pixweights_raw
        return

    @function_timer
    def _destripe(self, params, data, dets, interval_starts, psdinfo):
        """Destripe the buffered data"""
        log_time_memory(
            data,
            prefix=self._logprefix,
            mem_msg="Just before libmadam.destripe",
            full_mem=self.mem_report,
        )

        if self._cached:
            # destripe
            outpath = ""
            if "path_output" in params:
                outpath = params["path_output"]
            outpath = outpath.encode("ascii")
            madam.destripe_with_cache(
                data.comm.comm_world,
                self._madam_timestamps,
                self._madam_pixels,
                self._madam_pixweights,
                self._madam_signal,
                outpath,
            )
        else:
            (detweights, npsd, psdstarts, psd_freqs, psdvals) = psdinfo

            # destripe
            madam.destripe(
                data.comm.comm_world,
                params,
                dets,
                detweights,
                self._madam_timestamps,
                self._madam_pixels,
                self._madam_pixweights,
                self._madam_signal,
                interval_starts,
                npsd,
                psdstarts,
                psd_freqs,
                psdvals,
            )
            if self.mcmode:
                self._cached = True
        return

    def _requires(self):
        req = self.pixel_pointing.requires()
        req.update(self.stokes_weights.requires())
        req["meta"].extend([self.noise_model])
        req["detdata"].extend([self.det_data])
        if self.shared_flags is not None:
            req["shared"].append(self.shared_flags)
        if self.det_flags is not None:
            req["detdata"].append(self.det_flags)
        return req

    def _provides(self):
        return dict()

API = Int(0, help='Internal interface version for this operator') class-attribute instance-attribute

_cached = False instance-attribute

_logprefix = 'Madam:' instance-attribute

copy_groups = Int(1, help='The processes on each node are split into this number of groups to copy data in turns') class-attribute instance-attribute

det_data = Unicode(defaults.det_data, help='Observation detdata key for the timestream data') class-attribute instance-attribute

det_flag_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for detector sample flagging') class-attribute instance-attribute

det_flags = Unicode(defaults.det_flags, allow_none=True, help='Observation detdata key for flags to use') class-attribute instance-attribute

det_mask = Int(defaults.det_mask_nonscience, help='Bit mask value for per-detector flagging') class-attribute instance-attribute

det_out = Unicode(None, allow_none=True, help='Observation detdata key for output destriped timestreams') class-attribute instance-attribute

mcmode = Bool(False, help='If true, Madam will store auxiliary information such as pixel matrices and noise filter.') class-attribute instance-attribute

mem_report = Bool(False, help='Print system memory use while staging / unstaging data.') class-attribute instance-attribute

noise_model = Unicode('noise_model', help='Observation key containing the noise model') class-attribute instance-attribute

noise_scale = Unicode('noise_scale', help='Observation key with optional scaling factor for noise PSDs') class-attribute instance-attribute

paramfile = Unicode(None, allow_none=True, help='Read madam parameters from this file') class-attribute instance-attribute

params = Dict({}, help='Parameters to pass to madam') class-attribute instance-attribute

pixel_pointing = Instance(klass=Operator, allow_none=True, help='This must be an instance of a pixel pointing operator') class-attribute instance-attribute

purge_det_data = Bool(False, help='If True, clear all observation detector data after copying to madam buffers') class-attribute instance-attribute

restore_det_data = Bool(False, help='If True, restore detector data to observations on completion') class-attribute instance-attribute

shared_flag_mask = Int(defaults.shared_mask_nonscience, help='Bit mask value for optional shared flagging') class-attribute instance-attribute

shared_flags = Unicode(defaults.shared_flags, allow_none=True, help='Observation shared key for telescope flags to use') class-attribute instance-attribute

stokes_weights = Instance(klass=Operator, allow_none=True, help='This must be an instance of a Stokes weights operator') class-attribute instance-attribute

times = Unicode(defaults.times, help='Observation shared key for timestamps') class-attribute instance-attribute

translate_timestamps = Bool(False, help='Translate timestamps to enforce monotonity.') class-attribute instance-attribute

view = Unicode(None, allow_none=True, help='Use this view of the data in all observations') class-attribute instance-attribute

__del__()

Source code in toast/ops/madam.py
328
329
def __del__(self):
    self.clear()

__init__(**kwargs)

Source code in toast/ops/madam.py
302
303
304
305
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self._cached = False
    self._logprefix = "Madam:"

_check_det_flag_mask(proposal)

Source code in toast/ops/madam.py
228
229
230
231
232
233
@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det flag mask should be a positive integer")
    return check

_check_det_mask(proposal)

Source code in toast/ops/madam.py
214
215
216
217
218
219
@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Det mask should be a positive integer")
    return check

_check_det_out(proposal)

Source code in toast/ops/madam.py
248
249
250
251
252
253
254
255
@traitlets.validate("det_out")
def _check_det_out(self, proposal):
    check = proposal["value"]
    if check is not None and self.restore_det_data:
        raise traitlets.TraitError(
            "If det_out is not None, restore_det_data should be False"
        )
    return check

_check_params(proposal)

Source code in toast/ops/madam.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
@traitlets.validate("params")
def _check_params(self, proposal):
    check = proposal["value"]
    if "info" not in check:
        # The user did not specify the info level- set it from the toast loglevel
        env = Environment.get()
        level = env.log_level()
        if level == "DEBUG":
            check["info"] = 2
        elif level == "VERBOSE":
            check["info"] = 3
        else:
            check["info"] = 1
    return check

_check_pixel_pointing(proposal)

Source code in toast/ops/madam.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@traitlets.validate("pixel_pointing")
def _check_pixel_pointing(self, proposal):
    pixels = proposal["value"]
    if pixels is not None:
        if not isinstance(pixels, Operator):
            raise traitlets.TraitError(
                "pixel_pointing should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["pixels", "create_dist", "view"]:
            if not pixels.has_trait(trt):
                msg = f"pixel_pointing operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return pixels

_check_restore_det_data(proposal)

Source code in toast/ops/madam.py
235
236
237
238
239
240
241
242
243
244
245
246
@traitlets.validate("restore_det_data")
def _check_restore_det_data(self, proposal):
    check = proposal["value"]
    if check and not self.purge_det_data:
        raise traitlets.TraitError(
            "Cannot set restore_det_data since purge_det_data is False"
        )
    if check and self.det_out is not None:
        raise traitlets.TraitError(
            "Cannot set restore_det_data since det_out is not None"
        )
    return check

_check_shared_flag_mask(proposal)

Source code in toast/ops/madam.py
221
222
223
224
225
226
@traitlets.validate("shared_flag_mask")
def _check_shared_flag_mask(self, proposal):
    check = proposal["value"]
    if check < 0:
        raise traitlets.TraitError("Shared flag mask should be a positive integer")
    return check

_check_stokes_weights(proposal)

Source code in toast/ops/madam.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@traitlets.validate("stokes_weights")
def _check_stokes_weights(self, proposal):
    weights = proposal["value"]
    if weights is not None:
        if not isinstance(weights, Operator):
            raise traitlets.TraitError(
                "stokes_weights should be an Operator instance"
            )
        # Check that this operator has the traits we expect
        for trt in ["weights", "view"]:
            if not weights.has_trait(trt):
                msg = f"stokes_weights operator should have a '{trt}' trait"
                raise traitlets.TraitError(msg)
    return weights

_destripe(params, data, dets, interval_starts, psdinfo)

Destripe the buffered data

Source code in toast/ops/madam.py
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
@function_timer
def _destripe(self, params, data, dets, interval_starts, psdinfo):
    """Destripe the buffered data"""
    log_time_memory(
        data,
        prefix=self._logprefix,
        mem_msg="Just before libmadam.destripe",
        full_mem=self.mem_report,
    )

    if self._cached:
        # destripe
        outpath = ""
        if "path_output" in params:
            outpath = params["path_output"]
        outpath = outpath.encode("ascii")
        madam.destripe_with_cache(
            data.comm.comm_world,
            self._madam_timestamps,
            self._madam_pixels,
            self._madam_pixweights,
            self._madam_signal,
            outpath,
        )
    else:
        (detweights, npsd, psdstarts, psd_freqs, psdvals) = psdinfo

        # destripe
        madam.destripe(
            data.comm.comm_world,
            params,
            dets,
            detweights,
            self._madam_timestamps,
            self._madam_pixels,
            self._madam_pixweights,
            self._madam_signal,
            interval_starts,
            npsd,
            psdstarts,
            psd_freqs,
            psdvals,
        )
        if self.mcmode:
            self._cached = True
    return

_exec(data, detectors=None, **kwargs)

Source code in toast/ops/madam.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
@function_timer
def _exec(self, data, detectors=None, **kwargs):
    log = Logger.get()
    timer = Timer()
    timer.start()

    if not available():
        raise RuntimeError("Madam is either not installed or MPI is disabled")

    if len(data.obs) == 0:
        raise RuntimeError(
            "Madam requires every supplied data object to "
            "contain at least one observation"
        )

    for trait in "det_data", "pixel_pointing", "stokes_weights":
        if getattr(self, trait) is None:
            msg = f"You must set the '{trait}' trait before calling exec()"
            raise RuntimeError(msg)

    # Combine parameters from an external file and other parameters passed in

    params = dict()
    repeat_keys = ["detset", "detset_nopol", "survey"]

    if self.paramfile is not None:
        if data.comm.world_rank == 0:
            line_pat = re.compile(r"(\S+)\s+=\s+(\S+)")
            comment_pat = re.compile(r"^\s*\#.*")
            with open(self.paramfile, "r") as f:
                for line in f:
                    if comment_pat.match(line) is None:
                        line_mat = line_pat.match(line)
                        if line_mat is not None:
                            k = line_mat.group(1)
                            v = line_mat.group(2)
                            if k in repeat_keys:
                                if k not in params:
                                    params[k] = [v]
                                else:
                                    params[k].append(v)
                            else:
                                params[k] = v
        if data.comm.world_comm is not None:
            params = data.comm.world_comm.bcast(params, root=0)
        for k, v in self.params.items():
            if k in repeat_keys:
                if k not in params:
                    params[k] = [v]
                else:
                    params[k].append(v)
            else:
                params[k] = v

    if self.params is not None:
        params.update(self.params)

    if "fsample" not in params:
        params["fsample"] = data.obs[0].telescope.focalplane.sample_rate.to_value(
            u.Hz
        )

    # Set madam parameters that depend on our traits
    if self.mcmode:
        params["mcmode"] = True
    else:
        params["mcmode"] = False

    if self.det_out is not None:
        params["write_tod"] = True
    else:
        params["write_tod"] = False

    # Check input parameters and compute the sizes of Madam data objects
    if data.comm.world_rank == 0:
        msg = "{} Computing data sizes".format(self._logprefix)
        log.info(msg)
    (
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        nnz_stride,
        interval_starts,
        psd_freqs,
    ) = self._prepare(params, data, detectors)

    log.info_rank(
        f"{self._logprefix} Parsed parameters in",
        comm=data.comm.comm_world,
        timer=timer,
    )

    if data.comm.world_rank == 0:
        msg = "{} Copying toast data to buffers".format(self._logprefix)
        log.info(msg)
    psdinfo, signal_dtype = self._stage_data(
        params,
        data,
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        nnz_stride,
        interval_starts,
        psd_freqs,
    )

    log.info_rank(
        f"{self._logprefix} Staged data in",
        comm=data.comm.comm_world,
        timer=timer,
    )

    if data.comm.world_rank == 0:
        msg = "{} Destriping data".format(self._logprefix)
        log.info(msg)
    self._destripe(params, data, all_dets, interval_starts, psdinfo)

    log.info_rank(
        f"{self._logprefix} Destriped data in",
        comm=data.comm.comm_world,
        timer=timer,
    )

    if data.comm.world_rank == 0:
        msg = "{} Copying buffers back to toast data".format(self._logprefix)
        log.info(msg)
    self._unstage_data(
        params,
        data,
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        interval_starts,
        signal_dtype,
    )

    log.info_rank(
        f"{self._logprefix} Unstaged data in",
        comm=data.comm.comm_world,
        timer=timer,
    )

    return

_finalize(data, **kwargs)

Source code in toast/ops/madam.py
478
479
def _finalize(self, data, **kwargs):
    return

_prepare(params, data, detectors)

Examine the data and determine quantities needed to set up Madam buffers

Source code in toast/ops/madam.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
@function_timer
def _prepare(self, params, data, detectors):
    """Examine the data and determine quantities needed to set up Madam buffers"""
    log = Logger.get()
    timer = Timer()
    timer.start()

    params["nside_map"] = self.pixel_pointing.nside

    # Madam requires a fixed set of detectors and pointing matrix non-zeros.
    # Here we find the superset of local detectors used, and also the number
    # of pointing matrix elements.

    nsamp = 0

    # Madam uses monolithic data buffers and specifies contiguous data intervals
    # in that buffer.  The starting sample index is used to mark the transition
    # between data intervals.
    interval_starts = list()

    # This quantity is only used for printing the fraction of samples in valid
    # ranges specified by the View.  Only samples actually in the view are copied
    # to Madam buffers.
    nsamp_valid = 0

    all_dets = set()
    nnz_full = None
    psd_freqs = None

    for ob in data.obs:
        # Get the detectors we are using for this observation
        local_dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
        # if ob.comm.comm_group is not None:
        #     pdets = ob.comm.comm_group.gather(local_dets, root=0)
        #     obs_dets = None
        #     if ob.comm.group_rank == 0:
        #         obs_dets = set()
        #         for plocal in pdets:
        #             for d in plocal:
        #                 obs_dets.add(d)
        #     obs_dets = ob.comm.comm_group.bcast(obs_dets, root=0)
        # else:
        #     obs_dets = set(local_dets)
        # all_dets.update(obs_dets)
        all_dets.update(set(local_dets))

        # Check that the timestamps exist.
        if self.times not in ob.shared:
            msg = (
                "Shared timestamps '{}' does not exist in observation '{}'".format(
                    self.times, ob.name
                )
            )
            raise RuntimeError(msg)

        # Check that the detector data and pointing exists in the observation
        if self.det_data not in ob.detdata:
            msg = "Detector data '{}' does not exist in observation '{}'".format(
                self.det_data, ob.name
            )
            raise RuntimeError(msg)

        # Check that the noise model exists, and that the PSD frequencies are the
        # same across all observations (required by Madam).
        if self.noise_model not in ob:
            msg = "Noise model '{}' not in observation '{}'".format(
                self.noise_model, ob.name
            )
            raise RuntimeError(msg)
        if psd_freqs is None:
            psd_freqs = np.array(
                ob[self.noise_model].freq(ob.local_detectors[0]).to_value(u.Hz),
                dtype=np.float64,
            )
        else:
            check_freqs = (
                ob[self.noise_model].freq(ob.local_detectors[0]).to_value(u.Hz)
            )
            if not np.allclose(psd_freqs, check_freqs):
                raise RuntimeError(
                    "All PSDs passed to Madam must have the same frequency binning."
                )

        # Are we using a view of the data?  If so, we will only be copying data in
        # those valid intervals.
        if self.view is not None:
            if self.view not in ob.intervals:
                msg = "View '{}' does not exist in observation {}".format(
                    self.view, ob.name
                )
                raise RuntimeError(msg)
            # Go through all the intervals that will be used for our data view
            # and accumulate the number of samples.
            for intvw in ob.intervals[self.view]:
                interval_starts.append(nsamp_valid)
                nsamp_valid += intvw.last - intvw.first
        else:
            interval_starts.append(nsamp_valid)
            nsamp_valid += ob.n_local_samples
        nsamp += ob.n_local_samples

    if data.comm.world_rank == 0:
        log.info(
            "{}{:.2f} % of samples are included in valid intervals.".format(
                self._logprefix, nsamp_valid * 100.0 / nsamp
            )
        )

    nsamp = nsamp_valid

    interval_starts = np.array(interval_starts, dtype=np.int64)
    all_dets = list(sorted(all_dets))
    ndet = len(all_dets)

    nnz_full = len(self.stokes_weights.mode)
    nnz_stride = None
    if "temperature_only" in params and params["temperature_only"] in [
        "T",
        "True",
        "TRUE",
        "true",
        True,
    ]:
        # User has requested a temperature-only map.
        if nnz_full not in [1, 3]:
            raise RuntimeError(
                "Madam cannot make a temperature map with nnz == {}".format(
                    nnz_full
                )
            )
        nnz = 1
        nnz_stride = nnz_full
    else:
        nnz = nnz_full
        nnz_stride = 1

    if data.comm.world_rank == 0 and "path_output" in params:
        os.makedirs(params["path_output"], exist_ok=True)

    # Inspect the valid intervals across all observations to
    # determine the number of samples per detector

    data.comm.comm_world.barrier()
    timer.stop()
    if data.comm.world_rank == 0:
        msg = "{}  Compute data dimensions: {:0.1f} s".format(
            self._logprefix, timer.seconds()
        )
        log.debug(msg)

    return (
        all_dets,
        nsamp,
        nnz,
        nnz_full,
        nnz_stride,
        interval_starts,
        psd_freqs,
    )

_provides()

Source code in toast/ops/madam.py
1118
1119
def _provides(self):
    return dict()

_requires()

Source code in toast/ops/madam.py
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def _requires(self):
    req = self.pixel_pointing.requires()
    req.update(self.stokes_weights.requires())
    req["meta"].extend([self.noise_model])
    req["detdata"].extend([self.det_data])
    if self.shared_flags is not None:
        req["shared"].append(self.shared_flags)
    if self.det_flags is not None:
        req["detdata"].append(self.det_flags)
    return req

_stage_data(params, data, all_dets, nsamp, nnz, nnz_full, nnz_stride, interval_starts, psd_freqs)

Create madam-compatible buffers.

Collect the data into Madam buffers. If we are purging TOAST data to save memory, then optionally limit the number of processes that are copying at once.

Source code in toast/ops/madam.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
@function_timer
def _stage_data(
    self,
    params,
    data,
    all_dets,
    nsamp,
    nnz,
    nnz_full,
    nnz_stride,
    interval_starts,
    psd_freqs,
):
    """Create madam-compatible buffers.

    Collect the data into Madam buffers.  If we are purging TOAST data to save
    memory, then optionally limit the number of processes that are copying at once.

    """
    log = Logger.get()
    timer = Timer()

    nodecomm = data.comm.comm_group_node

    # Determine how many processes per node should copy at once.
    n_copy_groups = 1
    if self.purge_det_data:
        # We will be purging some data- see if we should reduce the number of
        # processes copying in parallel (if we are not purging data, there
        # is no benefit to staggering the copy).
        if self.copy_groups > 0:
            n_copy_groups = min(self.copy_groups, nodecomm.size)

    if not self._cached:
        # Only do this if we have not cached the data yet.
        log_time_memory(
            data,
            prefix=self._logprefix,
            mem_msg="Before staging",
            full_mem=self.mem_report,
        )

    # Copy timestamps and PSDs all at once, since they are never purged.

    psds = dict()

    timer.start()

    if not self._cached:
        timestamp_storage, _ = dtype_to_aligned(madam.TIMESTAMP_TYPE)
        self._madam_timestamps_raw = timestamp_storage.zeros(nsamp)
        self._madam_timestamps = self._madam_timestamps_raw.array()

        interval = 0
        time_offset = 0.0

        for ob in data.obs:
            for vw in ob.view[self.view].shared[self.times]:
                offset = interval_starts[interval]
                slc = slice(offset, offset + len(vw), 1)
                self._madam_timestamps[slc] = vw
                if self.translate_timestamps:
                    off = self._madam_timestamps[offset] - time_offset
                    self._madam_timestamps[slc] -= off
                    time_offset = self._madam_timestamps[slc][-1] + 1.0
                interval += 1

            # Get the noise object for this observation and create new
            # entries in the dictionary when the PSD actually changes.  The detector
            # weights are obtained from the noise model.

            nse = ob[self.noise_model]
            nse_scale = 1.0
            if self.noise_scale is not None:
                if self.noise_scale in ob:
                    nse_scale = float(ob[self.noise_scale])

            local_dets = set(ob.select_local_detectors(flagmask=self.det_mask))
            for det in all_dets:
                if det not in local_dets:
                    continue
                psd = nse.psd(det).to_value(u.K**2 * u.second) * nse_scale**2
                detw = nse.detector_weight(det).to_value(1.0 / u.K**2)
                if det not in psds:
                    psds[det] = [(0.0, psd, detw)]
                else:
                    if not np.allclose(psds[det][-1][1], psd):
                        psds[det] += [(ob.shared[self.times][0], psd, detw)]

        log_time_memory(
            data,
            timer=timer,
            timer_msg="Copy timestamps and PSDs",
            prefix=self._logprefix,
            mem_msg="After timestamp staging",
            full_mem=self.mem_report,
        )

    # Copy the signal.  We always need to do this, even if we are running MCs.

    signal_dtype = data.obs[0].detdata[self.det_data].dtype

    if self._cached:
        # We have previously created the madam buffers.  We just need to fill
        # them from the toast data.  Since both already exist we just copy the
        # contents.
        stage_local(
            data,
            nsamp,
            self.view,
            all_dets,
            self.det_data,
            self._madam_signal,
            interval_starts,
            1,
            1,
            self.det_mask,
            None,
            None,
            None,
            self.det_flag_mask,
            do_purge=False,
        )
    else:
        # Signal buffers do not yet exist
        if self.purge_det_data:
            # Allocate in a staggered way.
            self._madam_signal_raw, self._madam_signal = stage_in_turns(
                data,
                nodecomm,
                n_copy_groups,
                nsamp,
                self.view,
                all_dets,
                self.det_data,
                madam.SIGNAL_TYPE,
                interval_starts,
                1,
                1,
                self.det_mask,
                None,
                None,
                None,
                self.det_flag_mask,
            )
        else:
            # Allocate and copy all at once.
            storage, _ = dtype_to_aligned(madam.SIGNAL_TYPE)
            self._madam_signal_raw = storage.zeros(nsamp * len(all_dets))
            self._madam_signal = self._madam_signal_raw.array()

            stage_local(
                data,
                nsamp,
                self.view,
                all_dets,
                self.det_data,
                self._madam_signal,
                interval_starts,
                1,
                1,
                self.det_mask,
                None,
                None,
                None,
                self.det_flag_mask,
                do_purge=False,
            )

    log_time_memory(
        data,
        timer=timer,
        timer_msg="Copy signal",
        prefix=self._logprefix,
        mem_msg="After signal staging",
        full_mem=self.mem_report,
    )

    # Copy the pointing

    nested_pointing = self.pixel_pointing.nest
    if not nested_pointing:
        # Any existing pixel numbers are in the wrong ordering
        Delete(detdata=[self.pixel_pointing.pixels]).apply(data)
        self.pixel_pointing.nest = True

    if not self._cached:
        # We do not have the pointing yet.
        self._madam_pixels_raw, self._madam_pixels = stage_in_turns(
            data,
            nodecomm,
            n_copy_groups,
            nsamp,
            self.view,
            all_dets,
            self.pixel_pointing.pixels,
            madam.PIXEL_TYPE,
            interval_starts,
            1,
            1,
            self.det_mask,
            self.shared_flags,
            self.shared_flag_mask,
            self.det_flags,
            self.det_flag_mask,
            operator=self.pixel_pointing,
        )

        self._madam_pixweights_raw, self._madam_pixweights = stage_in_turns(
            data,
            nodecomm,
            n_copy_groups,
            nsamp,
            self.view,
            all_dets,
            self.stokes_weights.weights,
            madam.WEIGHT_TYPE,
            interval_starts,
            nnz,
            nnz_stride,
            self.det_mask,
            None,
            None,
            None,
            self.det_flag_mask,
            operator=self.stokes_weights,
        )

        log_time_memory(
            data,
            timer=timer,
            timer_msg="Copy pointing",
            prefix=self._logprefix,
            mem_msg="After pointing staging",
            full_mem=self.mem_report,
        )

    if not nested_pointing:
        # Any existing pixel numbers are in the wrong ordering
        Delete(detdata=[self.pixel_pointing.pixels]).apply(data)
        self.pixel_pointing.nest = False

    psdinfo = None

    if not self._cached:
        # Detectors weights.  Madam assumes a single noise weight for each detector
        # that is constant.  We set this based on the first observation or else use
        # uniform weighting.

        ndet = len(all_dets)
        detweights = np.ones(ndet, dtype=np.float64)

        if len(psds) > 0:
            npsdbin = len(psd_freqs)
            npsd = np.zeros(ndet, dtype=np.int64)
            psdstarts = []
            psdvals = []
            for idet, det in enumerate(all_dets):
                if det not in psds:
                    raise RuntimeError("Every detector must have at least one PSD")
                psdlist = psds[det]
                npsd[idet] = len(psdlist)
                for psdstart, psd, detw in psdlist:
                    psdstarts.append(psdstart)
                    psdvals.append(psd)
                detweights[idet] = psdlist[0][2]
            npsdtot = np.sum(npsd)
            psdstarts = np.array(psdstarts, dtype=np.float64)
            psdvals = np.hstack(psdvals).astype(madam.PSD_TYPE)
            npsdval = psdvals.size
        else:
            # Uniform weighting
            npsd = np.ones(ndet, dtype=np.int64)
            npsdtot = np.sum(npsd)
            psdstarts = np.zeros(npsdtot)
            npsdbin = 10
            fsample = 10.0
            psd_freqs = np.arange(npsdbin) * fsample / npsdbin
            npsdval = npsdbin * npsdtot
            psdvals = np.ones(npsdval)

        psdinfo = (detweights, npsd, psdstarts, psd_freqs, psdvals)

        log_time_memory(
            data,
            timer=timer,
            timer_msg="Collect PSD info",
            prefix=self._logprefix,
        )
    timer.stop()

    return psdinfo, signal_dtype

_unstage_data(params, data, all_dets, nsamp, nnz, nnz_full, interval_starts, signal_dtype)

Restore data to TOAST observations.

Optionally copy the signal and pointing back to TOAST if we previously purged it to save memory. Also copy the destriped timestreams if desired.

Source code in toast/ops/madam.py
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
@function_timer
def _unstage_data(
    self,
    params,
    data,
    all_dets,
    nsamp,
    nnz,
    nnz_full,
    interval_starts,
    signal_dtype,
):
    """
    Restore data to TOAST observations.

    Optionally copy the signal and pointing back to TOAST if we previously
    purged it to save memory.  Also copy the destriped timestreams if desired.

    """
    log = Logger.get()
    timer = Timer()

    nodecomm = data.comm.comm_group_node

    # Determine how many processes per node should copy at once.
    n_copy_groups = 1
    if self.purge_det_data:
        # We MAY be restoring some data- see if we should reduce the number of
        # processes copying in parallel (if we are not purging data, there
        # is no benefit to staggering the copy).
        if self.copy_groups > 0:
            n_copy_groups = min(self.copy_groups, nodecomm.size)

    log_time_memory(
        data,
        prefix=self._logprefix,
        mem_msg="Before un-staging",
        full_mem=self.mem_report,
    )

    # Copy the signal

    timer.start()

    out_name = self.det_data
    if self.det_out is not None:
        out_name = self.det_out

    if self.det_out is not None or (self.purge_det_data and self.restore_det_data):
        # We are copying some kind of signal back
        if not self.mcmode:
            # We are not running multiple realizations, so delete as we copy.
            restore_in_turns(
                data,
                nodecomm,
                n_copy_groups,
                nsamp,
                self.view,
                all_dets,
                out_name,
                signal_dtype,
                self._madam_signal,
                self._madam_signal_raw,
                interval_starts,
                1,
                self.det_mask,
            )
            del self._madam_signal
            del self._madam_signal_raw
        else:
            # We want to re-use the signal buffer, just copy.
            restore_local(
                data,
                nsamp,
                self.view,
                all_dets,
                out_name,
                signal_dtype,
                self._madam_signal,
                interval_starts,
                1,
                self.det_mask,
            )

        log_time_memory(
            data,
            timer=timer,
            timer_msg="Copy signal",
            prefix=self._logprefix,
            mem_msg="After restoring signal",
            full_mem=self.mem_report,
        )

    # Copy the pointing

    if not self.mcmode:
        # We can clear the cached pointing
        del self._madam_pixels
        del self._madam_pixels_raw
        del self._madam_pixweights
        del self._madam_pixweights_raw
    return

clear()

Delete the underlying memory.

This will forcibly delete the C-allocated memory and invalidate all python references to the buffers.

Source code in toast/ops/madam.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def clear(self):
    """Delete the underlying memory.

    This will forcibly delete the C-allocated memory and invalidate all python
    references to the buffers.

    """
    if self._cached:
        madam.clear_caches()
        self._cached = False
    for atr in ["timestamps", "signal", "pixels", "pixweights"]:
        atrname = "_madam_{}".format(atr)
        rawname = "{}_raw".format(atrname)
        if hasattr(self, atrname):
            delattr(self, atrname)
            raw = getattr(self, rawname)
            if raw is not None:
                raw.clear()
            setattr(self, rawname, None)
            setattr(self, atrname, None)

toast.ops.madam_params_from_mapmaker(mapmaker)

Utility function that configures Madam to match the TOAST mapmaker

Source code in toast/ops/madam.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def madam_params_from_mapmaker(mapmaker):
    """Utility function that configures Madam to match the TOAST mapmaker"""

    if not isinstance(mapmaker, MapMaker):
        raise RuntimeError("Need an instance of MapMaker to configure from")

    destripe_pixels = mapmaker.binning.pixel_pointing
    map_pixels = mapmaker.map_binning.pixel_pointing

    params = {
        "nside_cross": destripe_pixels.nside,
        "nside_map": map_pixels.nside,
        "nside_submap": map_pixels.nside_submap,
        "path_output": mapmaker.output_dir,
        "write_hits": mapmaker.write_hits,
        "write_matrix": mapmaker.write_invcov,
        "write_wcov": mapmaker.write_cov,
        "write_mask": mapmaker.write_rcond,
        "write_binmap": mapmaker.write_binmap,
        "write_map": mapmaker.write_map,
        "info": 3,
        "iter_max": mapmaker.iter_max,
        "pixlim_cross": mapmaker.solve_rcond_threshold,
        "pixlim_map": mapmaker.map_rcond_threshold,
        "cglimit": mapmaker.convergence,
    }
    sync_type = mapmaker.map_binning.sync_type
    if sync_type == "allreduce":
        params["allreduce"] = True
    elif sync_type == "alltoallv":
        params["concatenate_messages"] = True
        params["reassign_submaps"] = True
    else:
        msg = f"Unknown sync_type: {sync_type}"
        raise RuntimeError(msg)

    # Destriping parameters

    for template in mapmaker.template_matrix.templates:
        if isinstance(template, Offset):
            baselines = template
            break
    else:
        baselines = None

    if baselines is None or not baselines.enabled:
        params["kfirst"] = False
        if params["write_map"]:
            params.update({"write_binmap": True, "write_map": False})
    else:
        params.update(
            {
                "kfilter": baselines.use_noise_prior,
                "kfirst": True,
                "base_first": baselines.step_time.to_value(u.s),
                "precond_width_min": baselines.precond_width,
                "precond_width_max": baselines.precond_width,
                "good_baseline_fraction": baselines.good_fraction,
            }
        )

    return params