Skip to content

Utilities

korvax.util

autocorrelate

autocorrelate(x, /, max_size=None, axis=-1)

Compute the autocorrelation of the input array along the specified axis.

Parameters:

Name Type Description Default
x Float[ArrayLike, ...]

Input array.

required
max_size int | None

Maximum size of the autocorrelation lags. If None, uses the full size.

None
axis int

Axis along which to compute the autocorrelation. Default: -1.

-1

Returns:

Type Description
Float[Array, ...]

Autocorrelated array.

Source code in src/korvax/util.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def autocorrelate(
    x: Float[ArrayLike, "..."],
    /,
    max_size: int | None = None,
    axis: int = -1,
) -> Float[Array, "..."]:
    """Compute the autocorrelation of the input array along the specified axis.

    Args:
        x: Input array.
        max_size: Maximum size of the autocorrelation lags. If `None`, uses the full size.
        axis: Axis along which to compute the autocorrelation. Default: `-1`.

    Returns:
        Autocorrelated array.
    """
    x = jnp.asarray(x)
    x = x.swapaxes(-1, axis)
    n_samples = x.shape[-1]
    if max_size is None:
        max_size = n_samples

    with jax.ensure_compile_time_eval():
        n_fft = 2 ** int(jnp.ceil(jnp.log2(2 * (n_samples - 1))))

    X_f = jnp.fft.rfft(x, n=n_fft, axis=-1)
    S_f = jnp.conj(X_f) * X_f
    acf = jnp.fft.irfft(S_f, n=n_fft, axis=-1)
    return acf[..., :max_size].swapaxes(-1, axis)

frame

frame(x, /, frame_length, hop_length)

Slice a JAX array into overlapping frames.

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input array.

required
frame_length int

Length of each frame.

required
hop_length int

Number of samples between adjacent frame starts.

required

Returns:

Type Description
Float[Array, '*channels {frame_length} n_frames=1+(n_samples-{frame_length})//{hop_length}']

Array with the last axis sliced into overlapping frames.

Source code in src/korvax/util.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def frame(
    x: Float[Array, "*channels n_samples"],
    /,
    frame_length: int,
    hop_length: int,
) -> Float[
    Array,
    "*channels {frame_length} n_frames=1+(n_samples-{frame_length})//{hop_length}",
]:
    """Slice a JAX array into overlapping frames.

    Args:
        x: Input array.
        frame_length: Length of each frame.
        hop_length: Number of samples between adjacent frame starts.

    Returns:
        Array with the last axis sliced into overlapping frames.
    """
    n_samples = x.shape[-1]
    n_frames = 1 + (n_samples - frame_length) // hop_length

    return jax.vmap(
        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None), out_axes=-1
    )(x, jnp.arange(n_frames) * hop_length, frame_length, -1)

overlap_and_add

overlap_and_add(x, hop_length)

Construct a signal from overlapping frames with overlap-and-add.

Parameters:

Name Type Description Default
x Float[Array, '*channels frame_length n_frames']

Input array containing overlappinig frames.

required
hop_length int

Number of samples between adjacent frame starts.

required

Returns:

Type Description
Float[Array, '*channels n_samples']

Constructed time-domain signal.

Source code in src/korvax/util.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def overlap_and_add(
    x: Float[Array, "*channels frame_length n_frames"],
    hop_length: int,
) -> Float[Array, "*channels n_samples"]:
    """Construct a signal from overlapping frames with overlap-and-add.

    Args:
        x: Input array containing overlappinig frames.
        hop_length: Number of samples between adjacent frame starts.

    Returns:
        Constructed time-domain signal.
    """
    return jax._src.scipy.signal._overlap_and_add(x.swapaxes(-2, -1), hop_length)

pad_center

pad_center(x, /, size, pad_kwargs=dict())

Pad the input array on both sides to center it in a new array of given size.

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input array.

required
size int

Desired size of the last axis after padding.

required
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to jax.numpy.pad.

dict()

Returns:

Type Description
Float[Array, '*channels {size}']

Array with the last axis center-padded to the desired size.

Source code in src/korvax/util.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def pad_center(
    x: Float[Array, "*channels n_samples"],
    /,
    size: int,
    pad_kwargs: dict[str, Any] = dict(),
) -> Float[Array, "*channels {size}"]:
    """Pad the input array on both sides to center it in a new array of given size.

    Args:
        x: Input array.
        size: Desired size of the last axis after padding.
        pad_kwargs: Additional keyword arguments forwarded to [`jax.numpy.pad`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.pad.html).

    Returns:
        Array with the last axis center-padded to the desired size.
    """
    n_samples = x.shape[-1]

    lpad = int((size - n_samples) // 2)

    lengths = [(0, 0)] * x.ndim
    lengths[-1] = (lpad, int(size - n_samples - lpad))

    return jnp.pad(x, lengths, **pad_kwargs)

fix_length

fix_length(x, /, size, **pad_kwargs)

Fix the length of the input array to a given size by either trimming or padding.

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input array.

required
size int

Desired size of the last axis after fixing length.

required
**pad_kwargs Any

Additional keyword arguments forwarded to jax.numpy.pad.

{}

Returns:

Type Description
Float[Array, '*channels {size}']

Array with the last axis fixed to the desired size.

Source code in src/korvax/util.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def fix_length(
    x: Float[Array, "*channels n_samples"], /, size: int, **pad_kwargs: Any
) -> Float[Array, "*channels {size}"]:
    """Fix the length of the input array to a given size by either trimming or padding.

    Args:
        x: Input array.
        size: Desired size of the last axis after fixing length.
        **pad_kwargs: Additional keyword arguments forwarded to [`jax.numpy.pad`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.pad.html).

    Returns:
        Array with the last axis fixed to the desired size.
    """
    n_samples = x.shape[-1]

    if n_samples < size:
        lengths = [(0, 0)] * x.ndim
        lengths[-1] = (0, size - n_samples)
        return jnp.pad(x, lengths, **pad_kwargs)
    else:
        return x[..., :size]

get_window

get_window(window, Nx=None, fftbins=True, dtype=None)

Return the passed array, or the output of scipy.signal.get_window as a JAX array.

Parameters:

Name Type Description Default
window _WindowSpec

Window specification.

required
Nx int | None

Length of the returned window.

None
fftbins bool

If True, return a periodic window for FFT analysis. If False, return a symmetric window for filter design. Default: True.

True
dtype DTypeLike | None

Desired data type of the returned array. If none, uses the default JAX floating point type, which might be float32 or float64 depending on jax_enable_x64.

None

Returns:

Type Description
Float[Array, ' {Nx}']

The window as a JAX array.

Source code in src/korvax/util.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def get_window(
    window: _WindowSpec,
    Nx: int | None = None,
    fftbins: bool = True,
    dtype: DTypeLike | None = None,
) -> Float[Array, " {Nx}"]:
    """Return the passed array, or the output of [`scipy.signal.get_window`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html) as a JAX array.

    Args:
        window: Window specification.
        Nx: Length of the returned window.
        fftbins: If `True`, return a periodic window for FFT analysis.
            If `False`, return a symmetric window for filter design. Default: `True`.
        dtype: Desired data type of the returned array. If none, uses the default JAX
            floating point type, which might be `float32` or `float64` depending on `jax_enable_x64`.

    Returns:
        The window as a JAX array.
    """
    if is_array(window):
        win = jnp.asarray(window, dtype=dtype)
        if Nx is not None:
            assert len(win) == Nx
        return win
    else:
        assert Nx is not None, "Nx must be specified if window is not an array."
        win = scipy.signal.get_window(window, Nx, fftbins=fftbins)
        return jnp.asarray(win, dtype=dtype)

is_array

is_array(x)

Check if the input is a JAX or NumPy array.

Parameters:

Name Type Description Default
x Any

Input value to check.

required

Returns:

Type Description
TypeGuard[Array | ndarray]

True if the input is a JAX or NumPy array, False otherwise.

Source code in src/korvax/util.py
167
168
169
170
171
172
173
174
175
176
def is_array(x: Any) -> TypeGuard[Array | np.ndarray]:
    """Check if the input is a JAX or NumPy array.

    Args:
        x: Input value to check.

    Returns:
        True if the input is a JAX or NumPy array, False otherwise.
    """
    return isinstance(x, (jax.Array, np.ndarray))

feps

feps(x)

Get the machine epsilon for the data type of the input array.

Parameters:

Name Type Description Default
x Inexact[ArrayLike, ...]

Input array.

required

Returns:

Type Description
float

Machine epsilon as a float.

Source code in src/korvax/util.py
179
180
181
182
183
184
185
186
187
188
def feps(x: Inexact[ArrayLike, "..."]) -> float:
    """Get the machine epsilon for the data type of the input array.

    Args:
        x: Input array.

    Returns:
        Machine epsilon as a float.
    """
    return float(jnp.finfo(jnp.result_type(x)).eps)

normalize

normalize(x, /, ord=None, axis=None, threshold=None)

Normalize an array by its norm along the specified axis.

Parameters:

Name Type Description Default
x Inexact[Array, '*dims']

Input array.

required
ord float | str | None

Order of the norm. See jax.numpy.linalg.norm for options.

None
axis int | tuple[int, ...] | None

Axis or axes along which to compute the norm. If None, normalizes over all axes.

None
threshold float | None

Minimum norm value below which normalization is skipped. If None, uses machine epsilon.

None

Returns:

Type Description
Float[Array, '*dims']

Normalized array.

Source code in src/korvax/util.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def normalize(
    x: Inexact[Array, "*dims"],
    /,
    ord: float | str | None = None,
    axis: int | tuple[int, ...] | None = None,
    threshold: float | None = None,
) -> Float[Array, "*dims"]:
    """Normalize an array by its norm along the specified axis.

    Args:
        x: Input array.
        ord: Order of the norm. See `jax.numpy.linalg.norm` for options.
        axis: Axis or axes along which to compute the norm. If None, normalizes
            over all axes.
        threshold: Minimum norm value below which normalization is skipped.
            If None, uses machine epsilon.

    Returns:
        Normalized array.
    """
    if threshold is None:
        threshold = feps(x)

    x = jnp.abs(x)

    norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True)
    norm = jnp.where(norm < threshold, 1.0, norm)
    return x / norm  # pyright: ignore[reportOperatorIssue]

expand_to

expand_to(x, /, ndim, axes)

Expand the dimensions of an array to a given number of dimensions by adding singleton dimensions at specified axes.

Parameters:

Name Type Description Default
x Shaped[ArrayLike, '*']

Input array.

required
ndim int

Desired number of dimensions after expansion.

required
axes int | tuple[int, ...]

Axes at which to add singleton dimensions.

required

Returns:

Type Description
Shaped[Array, '*expanded_shape']

Expanded array.

Source code in src/korvax/util.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def expand_to(
    x: Shaped[ArrayLike, "*"], /, ndim: int, axes: int | tuple[int, ...]
) -> Shaped[Array, "*expanded_shape"]:
    """Expand the dimensions of an array to a given number of dimensions by adding singleton dimensions at specified axes.

    Args:
        x: Input array.
        ndim: Desired number of dimensions after expansion.
        axes: Axes at which to add singleton dimensions.

    Returns:
        Expanded array.
    """
    x = jnp.asarray(x)
    shape = [1] * ndim
    if isinstance(axes, int):
        shape[axes] = x.shape[0]
    else:
        for i, axis in enumerate(axes):
            shape[axis] = x.shape[i]

    return x.reshape(shape)

parabolic_peak_shifts

parabolic_peak_shifts(x, /, axis)

Compute subpixel peak positions using parabolic interpolation.

Parameters:

Name Type Description Default
x Float[Array, '*dims']

Input array containing peaks.

required
axis int

Axis along which to compute peak shifts.

required

Returns:

Type Description
Float[Array, '*dims']

Array of fractional shifts for each position, where each shift indicates

Float[Array, '*dims']

the subpixel offset from the integer position to the interpolated peak.

Source code in src/korvax/util.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def parabolic_peak_shifts(
    x: Float[Array, "*dims"], /, axis: int
) -> Float[Array, "*dims"]:
    """Compute subpixel peak positions using parabolic interpolation.

    Args:
        x: Input array containing peaks.
        axis: Axis along which to compute peak shifts.

    Returns:
        Array of fractional shifts for each position, where each shift indicates
        the subpixel offset from the integer position to the interpolated peak.
    """
    x = x.swapaxes(-1, axis)
    left_vals = x[..., :-2]
    center_vals = x[..., 1:-1]
    right_vals = x[..., 2:]

    a = right_vals + left_vals - 2 * center_vals
    b = (right_vals - left_vals) / 2

    shifts = -b / (a + feps(x))
    shifts = jnp.where(jnp.abs(b) >= jnp.abs(a), 0.0, shifts)
    shifts = jnp.pad(shifts, [(0, 0)] * (shifts.ndim - 1) + [(1, 1)])

    return shifts.swapaxes(-1, axis)

localmin

localmin(x, /, axis)

Identify local minima in an array along the specified axis.

Parameters:

Name Type Description Default
x Float[Array, '*dims']

Input array.

required
axis int

Axis along which to find local minima.

required

Returns:

Type Description
Bool[Array, '*dims']

Boolean array where True indicates a local minimum.

Source code in src/korvax/util.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def localmin(x: Float[Array, "*dims"], /, axis: int) -> Bool[Array, "*dims"]:
    """Identify local minima in an array along the specified axis.

    Args:
        x: Input array.
        axis: Axis along which to find local minima.

    Returns:
        Boolean array where True indicates a local minimum.
    """
    x = x.swapaxes(-1, axis)
    left_vals = x[..., :-2]
    center_vals = x[..., 1:-1]
    right_vals = x[..., 2:]

    is_min = jnp.logical_and(center_vals < left_vals, center_vals <= right_vals)

    is_min = jnp.pad(
        is_min, [(0, 0)] * (is_min.ndim - 1) + [(1, 1)], constant_values=False
    )

    return is_min.swapaxes(-1, axis)