Skip to content

Transforms

The forward and inverse transforms follow the librosa convention: time-domain signals have the shape (*batch, samples), and time-frequency representations have the shape (*batch, bins, frames).

Forward Transforms

Functions that take in time-domain signals and output time-frequency representations.

korvax.stft

stft(
    x,
    /,
    n_fft=2048,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    pad_kwargs=dict(),
)

Compute the short-time Fourier transform (STFT) of a time-domain signal.

Parameters:

Name Type Description Default
x Float[ArrayLike, '*channels n_samples']

Input signal.

required
n_fft int

FFT size (number of samples per frame).

2048
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, pad the input so that frames are centered on their timestamps.

True
**pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Complex[Array, '*channels {n_fft}//2+1 n_frames']

STFT coefficients.

Source code in src/korvax/transforms/fourier.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def stft(
    x: Float[ArrayLike, "*channels n_samples"],
    /,
    n_fft: int = 2048,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    pad_kwargs: dict[str, Any] = dict(),
) -> Complex[Array, "*channels {n_fft}//2+1 n_frames"]:
    """Compute the short-time Fourier transform (STFT) of a time-domain signal.

    Args:
        x: Input signal.
        n_fft: FFT size (number of samples per frame).
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If True, pad the input so that frames are centered on their timestamps.
        **pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        STFT coefficients.
    """
    if win_length is None:
        win_length = n_fft

    if hop_length is None:
        hop_length = win_length // 4

    x = jnp.asarray(x)

    if center:
        x = util.pad_center(x, size=x.shape[-1] + n_fft, pad_kwargs=pad_kwargs)

    frames = util.frame(x, frame_length=n_fft, hop_length=hop_length)

    fft_window = util.get_window(
        window,
        win_length,
        fftbins=True,
        dtype=frames.dtype,
    )

    if len(fft_window) < n_fft:
        fft_window = util.pad_center(fft_window, n_fft)

    fft_window = util.expand_to(fft_window, frames.ndim, -2)

    return jnp.fft.rfft(frames * fft_window, n=n_fft, axis=-2)

korvax.spectrogram

spectrogram(
    x,
    /,
    n_fft=2048,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    power=2.0,
    pad_kwargs=dict(),
)

Compute the magnitude spectrogram of a time-domain signal.

Parameters:

Name Type Description Default
x Float[ArrayLike, '*channels n_samples']

Input signal.

required
n_fft int

FFT size (number of samples per frame).

2048
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, pad the input so that frames are centered on their timestamps.

True
power float | int | None

Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram. If None, returns complex STFT coefficients.

2.0
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Inexact[Array, '*channels {n_fft}//2+1 n_frames']

Magnitude spectrogram.

Source code in src/korvax/transforms/fourier.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def spectrogram(
    x: Float[ArrayLike, "*channels n_samples"],
    /,
    n_fft: int = 2048,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    power: float | int | None = 2.0,
    pad_kwargs: dict[str, Any] = dict(),
) -> Inexact[Array, "*channels {n_fft}//2+1 n_frames"]:
    """Compute the magnitude spectrogram of a time-domain signal.

    Args:
        x: Input signal.
        n_fft: FFT size (number of samples per frame).
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If True, pad the input so that frames are centered on their timestamps.
        power: Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.
            If None, returns complex STFT coefficients.
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        Magnitude spectrogram.
    """
    x = stft(
        x,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        pad_kwargs=pad_kwargs,
    )

    if power is None:
        return x

    x = x * jnp.conj(x)
    return x.real if power == 2 else x.real ** (power / 2)

korvax.mel_spectrogram

mel_spectrogram(
    x,
    /,
    sr,
    n_fft,
    n_mels=128,
    fmin=0.0,
    fmax=None,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    power=2.0,
    pad_kwargs=dict(),
)

Compute a mel-scaled spectrogram from a time-domain signal.

Parameters:

Name Type Description Default
x Float[ArrayLike, '*channels n_samples']

Input signal.

required
sr float

Sample rate of the audio signal.

required
n_fft int

FFT size (number of samples per frame).

required
n_mels int

Number of mel bands to generate.

128
fmin float

Minimum frequency (Hz).

0.0
fmax float | None

Maximum frequency (Hz). If None, defaults to sr / 2.

None
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, pad the input so that frames are centered on their timestamps.

True
power float | int

Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.

2.0
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Float[Array, '*channels {n_mels} n_frames']

Mel-scale spectrogram.

Source code in src/korvax/transforms/mel.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def mel_spectrogram(
    x: Float[ArrayLike, "*channels n_samples"],
    /,
    sr: float,
    n_fft: int,
    n_mels: int = 128,
    fmin: float = 0.0,
    fmax: float | None = None,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    power: float | int = 2.0,
    pad_kwargs: dict[str, Any] = dict(),
) -> Float[Array, "*channels {n_mels} n_frames"]:
    """Compute a mel-scaled spectrogram from a time-domain signal.

    Args:
        x: Input signal.
        sr: Sample rate of the audio signal.
        n_fft: FFT size (number of samples per frame).
        n_mels: Number of mel bands to generate.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz). If None, defaults to `sr / 2`.
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If True, pad the input so that frames are centered on their timestamps.
        power: Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        Mel-scale spectrogram.
    """
    S = spectrogram(
        x,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        power=power,
        pad_kwargs=pad_kwargs,
    )

    return to_mel_scale(
        S,
        sr=sr,
        n_fft=n_fft,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
    )

korvax.mfcc

mfcc(
    x,
    /,
    sr,
    n_fft,
    n_mfcc=20,
    norm="ortho",
    mag_scale="db",
    lifter=0.0,
    n_mels=128,
    fmin=0.0,
    fmax=None,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    power=2.0,
    pad_kwargs=dict(),
)

Compute mel-frequency cepstral coefficients (MFCCs) from a time-domain signal.

Parameters:

Name Type Description Default
x Float[ArrayLike, '*channels n_samples']

Input signal.

required
sr float

Sample rate of the audio signal.

required
n_fft int

FFT size (number of samples per frame).

required
n_mfcc int

Number of MFCCs to return.

20
norm Literal['backward', 'ortho'] | None

Normalization mode for DCT.

'ortho'
mag_scale Literal['linear', 'log', 'db']

Magnitude scaling to apply before DCT. Options are "linear" (no scaling), "log" (natural logarithm), or "db" (decibels).

'db'
lifter float

If greater than 0, apply liftering (cepstral filtering) with the specified coefficient.

0.0
n_mels int

Number of mel bands to generate.

128
fmin float

Minimum frequency (Hz).

0.0
fmax float | None

Maximum frequency (Hz). If None, defaults to sr / 2.

None
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, pad the input so that frames are centered on their timestamps.

True
power float | int

Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.

2.0
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Float[Array, '*channels {n_mfcc} n_frames']

Mel-frequency cepstral coefficients.

Source code in src/korvax/transforms/mel.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def mfcc(
    x: Float[ArrayLike, "*channels n_samples"],
    /,
    sr: float,
    n_fft: int,
    n_mfcc: int = 20,
    norm: Literal["backward", "ortho"] | None = "ortho",
    mag_scale: Literal["linear", "log", "db"] = "db",
    lifter: float = 0.0,
    n_mels: int = 128,
    fmin: float = 0.0,
    fmax: float | None = None,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    power: float | int = 2.0,
    pad_kwargs: dict[str, Any] = dict(),
) -> Float[Array, "*channels {n_mfcc} n_frames"]:
    """Compute mel-frequency cepstral coefficients (MFCCs) from a time-domain signal.

    Args:
        x: Input signal.
        sr: Sample rate of the audio signal.
        n_fft: FFT size (number of samples per frame).
        n_mfcc: Number of MFCCs to return.
        norm: Normalization mode for DCT.
        mag_scale: Magnitude scaling to apply before DCT. Options are "linear" (no scaling),
            "log" (natural logarithm), or "db" (decibels).
        lifter: If greater than 0, apply liftering (cepstral filtering) with the specified
            coefficient.
        n_mels: Number of mel bands to generate.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz). If None, defaults to `sr / 2`.
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If True, pad the input so that frames are centered on their timestamps.
        power: Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        Mel-frequency cepstral coefficients.
    """
    S = mel_spectrogram(
        x,
        sr=sr,
        n_fft=n_fft,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        power=power,
        pad_kwargs=pad_kwargs,
    )

    return cepstral_coefficients(
        S,
        n_cc=n_mfcc,
        norm=norm,
        mag_scale=mag_scale,
        lifter=lifter,
    )

korvax.cqt

cqt(
    x,
    /,
    sr,
    hop_length=512,
    fmin=32.7,
    fmax=None,
    n_bins=84,
    bins_per_octave=12,
    filter_scale=1.0,
    norm_kernels=1,
    power=2.0,
    window="hann",
    center=True,
    normalization_type="librosa",
    pad_kwargs=dict(),
)

Compute the Constant-Q Transform (CQT) of a time-domain signal.

The CQT is a time-frequency representation with logarithmically-spaced frequency bins, making it well-suited for music analysis. This is a convenience wrapper that calls vqt with gamma=0.

Source code in src/korvax/transforms/_cqt.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def cqt(
    x: Float[Array, " n_samples"],
    /,
    sr: float,
    hop_length=512,
    fmin: float = 32.70,
    fmax: float | None = None,
    n_bins: int = 84,
    bins_per_octave: int = 12,
    filter_scale: float | int = 1.0,
    norm_kernels: float | int = 1,
    power: int | float | None = 2.0,
    window: str | float | tuple = "hann",
    center: bool = True,
    normalization_type: Literal["librosa", "convolutional", "wrap"] = "librosa",
    pad_kwargs=dict(),
) -> Inexact[Array, " n_bins n_frames"]:
    """Compute the Constant-Q Transform (CQT) of a time-domain signal.

    The CQT is a time-frequency representation with logarithmically-spaced frequency bins,
    making it well-suited for music analysis. This is a convenience wrapper that calls
    [vqt][korvax.transforms.vqt] with gamma=0.
    """
    return vqt(
        x,
        sr=sr,
        hop_length=hop_length,
        fmin=fmin,
        fmax=fmax,
        n_bins=n_bins,
        gamma=0.0,
        bins_per_octave=bins_per_octave,
        filter_scale=filter_scale,
        norm_kernels=norm_kernels,
        power=power,
        window=window,
        center=center,
        normalization_type=normalization_type,
        pad_kwargs=pad_kwargs,
    )

korvax.vqt

vqt(
    x,
    /,
    sr,
    hop_length=512,
    fmin=32.7,
    fmax=None,
    n_bins=84,
    gamma=0.0,
    bins_per_octave=12,
    filter_scale=1.0,
    norm_kernels=1,
    power=2.0,
    window="hann",
    center=True,
    normalization_type="librosa",
    pad_kwargs=dict(),
)

Compute the Variable-Q Transform (VQT) of a time-domain signal.

The VQT is a generalization of the Constant-Q Transform (CQT) that allows for variable bandwidth via the gamma parameter. When gamma=0, this is equivalent to the CQT.

Parameters:

Name Type Description Default
x Float[Array, ' n_samples']

Input signal.

required
sr float

Sample rate of the input signal.

required
hop_length

Hop (step) length between adjacent frames.

512
fmin float

Minimum frequency (Hz).

32.7
fmax float | None

Maximum frequency (Hz). If None, determined by n_bins.

None
n_bins int

Number of frequency bins. Ignored if fmax is provided.

84
gamma float

Bandwidth offset parameter. When gamma=0, this reduces to CQT.

0.0
bins_per_octave int

Number of bins per octave.

12
filter_scale float | int

Scale factor for filter bandwidths.

1.0
norm_kernels float | int

Normalization mode for the filter kernels (p-norm to use).

1
power int | float | None

Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram. If None, returns complex VQT coefficients.

2.0
window str | float | tuple

Window specification (see get_window).

'hann'
center bool

If True, pad the input so that frames are centered on their timestamps.

True
normalization_type Literal['librosa', 'convolutional', 'wrap']

Type of normalization to apply ("librosa", "convolutional", or "wrap").

'librosa'
pad_kwargs

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Float[Array, ' n_bins n_frames']

VQT coefficients.

Source code in src/korvax/transforms/_cqt.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def vqt(
    x: Float[Array, " n_samples"],
    /,
    sr: float,
    hop_length=512,
    fmin: float = 32.70,
    fmax: float | None = None,
    n_bins: int = 84,
    gamma: float = 0.0,
    bins_per_octave: int = 12,
    filter_scale: float | int = 1.0,
    norm_kernels: float | int = 1,
    power: int | float | None = 2.0,
    window: str | float | tuple = "hann",
    center: bool = True,
    normalization_type: Literal["librosa", "convolutional", "wrap"] = "librosa",
    pad_kwargs=dict(),
) -> Float[Array, " n_bins n_frames"]:
    """Compute the Variable-Q Transform (VQT) of a time-domain signal.

    The VQT is a generalization of the Constant-Q Transform (CQT) that allows for variable
    bandwidth via the gamma parameter. When gamma=0, this is equivalent to the CQT.

    Args:
        x: Input signal.
        sr: Sample rate of the input signal.
        hop_length: Hop (step) length between adjacent frames.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz). If None, determined by `n_bins`.
        n_bins: Number of frequency bins. Ignored if `fmax` is provided.
        gamma: Bandwidth offset parameter. When gamma=0, this reduces to CQT.
        bins_per_octave: Number of bins per octave.
        filter_scale: Scale factor for filter bandwidths.
        norm_kernels: Normalization mode for the filter kernels (p-norm to use).
        power: Exponent for the magnitude spectrogram. If 2.0, returns power spectrogram.
            If None, returns complex VQT coefficients.
        window: Window specification (see [get_window][korvax.util.get_window]).
        center: If True, pad the input so that frames are centered on their timestamps.
        normalization_type: Type of normalization to apply ("librosa", "convolutional", or "wrap").
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        VQT coefficients.
    """
    with jax.ensure_compile_time_eval():
        Q = float(filter_scale) / (2 ** (1 / bins_per_octave) - 1)
        vqt_kernels, lengths, _ = create_vqt_kernels(
            Q=Q,
            sr=sr,
            fmin=fmin,
            n_bins=n_bins,
            bins_per_octave=bins_per_octave,
            norm=norm_kernels,
            window=window,
            fmax=fmax,
            gamma=gamma,
            dtype=x.dtype,
        )

        n_bins, fft_len = vqt_kernels.shape
        vqt_kernels = jnp.concat(
            [vqt_kernels.real[:, None, :], vqt_kernels.imag[:, None, :]], axis=0
        )

        if normalization_type == "librosa":
            norm_factor = jnp.tile(jnp.sqrt(lengths)[:, None], (2, 1))
        elif normalization_type == "convolutional":
            norm_factor = 1
        elif normalization_type == "wrap":
            norm_factor = 2

    if center:
        x = util.pad_center(x, len(x) + fft_len, **pad_kwargs)

    out = lax.conv_general_dilated(
        lhs=x[None, None, :],
        rhs=vqt_kernels,
        window_strides=(hop_length,),
        padding="VALID",
    ).squeeze(axis=0)

    out = out * norm_factor

    if power is None:
        return out[:n_bins, :] - 1j * out[n_bins:, :]

    elif power == 2:
        return out[:n_bins, :] ** 2 + out[n_bins:, :] ** 2

    return (out[:n_bins, :] ** 2 + out[n_bins:, :] ** 2) ** (power / 2)

Inverse Transforms

Functions that take in time-frequency representations and output time-domain signals.

korvax.istft

istft(
    x,
    /,
    n_fft=None,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    length=None,
)

Compute the inverse short-time Fourier transform (ISTFT).

Parameters:

Name Type Description Default
x Complex[ArrayLike, '*channels n_freqs n_frames']

STFT coefficients.

required
n_fft int | None

FFT size (number of samples per frame).

None
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, frames are assumed to be centered in time. If False, they are assumed to be left-aligned in time.

True
length int | None

If provided, the output will be trimmed or zero-padded to exactly this length.

None

Returns:

Type Description
Float[Array, '*channels n_samples']

Reconstructed time-domain signal.

Source code in src/korvax/transforms/fourier.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def istft(
    x: Complex[ArrayLike, "*channels n_freqs n_frames"],
    /,
    n_fft: int | None = None,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    length: int | None = None,
) -> Float[Array, "*channels n_samples"]:
    """Compute the inverse short-time Fourier transform (ISTFT).

    Args:
        x: STFT coefficients.
        n_fft: FFT size (number of samples per frame).
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If `True`, frames are assumed to be centered in time. If `False`, they
            are assumed to be left-aligned in time.
        length: If provided, the output will be trimmed or zero-padded to exactly this
            length.

    Returns:
        Reconstructed time-domain signal.
    """
    x = jnp.asarray(x)

    if n_fft is None:
        n_fft = (x.shape[-2] - 1) * 2

    if win_length is None:
        win_length = n_fft

    if hop_length is None:
        hop_length = win_length // 4

    if length:
        if center:
            padded_length = length + 2 * (n_fft // 2)
        else:
            padded_length = length
        n_frames = min(x.shape[-1], int(math.ceil(padded_length / hop_length)))
    else:
        n_frames = x.shape[-1]

    x = x[..., :n_frames]
    x = jnp.fft.irfft(x, n=n_fft, axis=-2)

    expected_length = n_fft + hop_length * (n_frames - 1)
    if length:
        expected_length = length
    elif center:
        expected_length -= n_fft

    with jax.ensure_compile_time_eval():
        ifft_window = util.get_window(
            window,
            win_length,
            fftbins=True,
            dtype=x.dtype,
        )

        ifft_window = util.pad_center(ifft_window, n_fft)

        win_dims = [1] * x.ndim
        win_dims[-2] = len(ifft_window)
        ifft_window = ifft_window.reshape(*win_dims)

        win_sumsq = (ifft_window / ifft_window.max()) ** 2
        win_sumsq = jnp.broadcast_to(win_sumsq, win_dims[:-1] + [x.shape[-1]])
        win_sumsq = util.overlap_and_add(win_sumsq, hop_length=hop_length)
        if center:
            win_sumsq = win_sumsq[..., n_fft // 2 :]
        win_sumsq = util.fix_length(win_sumsq, size=expected_length)
        win_sumsq = jnp.where(
            win_sumsq < jnp.finfo(win_sumsq.dtype).eps, 1.0, win_sumsq
        )

    x *= ifft_window

    x = util.overlap_and_add(x, hop_length=hop_length)
    if center:
        x = x[..., n_fft // 2 :]

    x = util.fix_length(x, size=expected_length)

    return x / win_sumsq

korvax.griffin_lim

griffin_lim(
    S,
    /,
    key=None,
    n_iter=32,
    n_fft=None,
    hop_length=None,
    win_length=None,
    window="hann",
    center=True,
    length=None,
    momentum=0.99,
    pad_kwargs=dict(),
)

Reconstruct a time-domain signal from a magnitude spectrogram using the Griffin-Lim algorithm.

Parameters:

Name Type Description Default
S Float[ArrayLike, '*channels n_freqs n_frames']

Magnitude spectrogram.

required
key PRNGKeyArray | None

JAX PRNG key for random phase initialization. If None, uses zero phase initialization.

None
n_iter int

Number of Griffin-Lim iterations to perform.

32
n_fft int | None

FFT size (number of samples per frame). If None, inferred from spectrogram shape.

None
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to win_length // 4.

None
win_length int | None

Length of the analysis window. If None, defaults to n_fft. Ignored if window is an array.

None
window _WindowSpec

Either a 1d array containing the window to apply to each frame, or a window specification (see get_window).

'hann'
center bool

If True, frames are assumed to be centered in time. If False, they are assumed to be left-aligned in time.

True
length int | None

If provided, the output will be trimmed or zero-padded to exactly this length.

None
momentum float

Momentum parameter for fast Griffin-Lim (typically between 0 and 1).

0.99
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Float[Array, '*channels n_samples']

Reconstructed time-domain signal.

Source code in src/korvax/transforms/fourier.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def griffin_lim(
    S: Float[ArrayLike, "*channels n_freqs n_frames"],
    /,
    key: PRNGKeyArray | None = None,
    n_iter: int = 32,
    n_fft: int | None = None,
    hop_length: int | None = None,
    win_length: int | None = None,
    window: _WindowSpec = "hann",
    center: bool = True,
    length: int | None = None,
    momentum: float = 0.99,
    pad_kwargs: dict[str, Any] = dict(),
) -> Float[Array, "*channels n_samples"]:
    """Reconstruct a time-domain signal from a magnitude spectrogram using the Griffin-Lim algorithm.

    Args:
        S: Magnitude spectrogram.
        key: JAX PRNG key for random phase initialization. If None, uses zero phase
            initialization.
        n_iter: Number of Griffin-Lim iterations to perform.
        n_fft: FFT size (number of samples per frame). If None, inferred from spectrogram
            shape.
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `win_length // 4`.
        win_length: Length of the analysis window. If None, defaults to `n_fft`.
            Ignored if `window` is an array.
        window: Either a 1d array containing the window to apply to each frame,
            or a window specification (see [get_window][korvax.util.get_window]).
        center: If True, frames are assumed to be centered in time. If False, they
            are assumed to be left-aligned in time.
        length: If provided, the output will be trimmed or zero-padded to exactly this
            length.
        momentum: Momentum parameter for fast Griffin-Lim (typically between 0 and 1).
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        Reconstructed time-domain signal.
    """
    S = jnp.asarray(S)

    if n_fft is None:
        n_fft = (S.shape[-2] - 1) * 2

    complex_dtype = jnp.result_type(S.dtype, 1j)

    if key is None:
        angles = S.astype(complex_dtype)
    else:
        angles = jax.random.uniform(
            key, S.shape, minval=0.0, maxval=2 * jnp.pi, dtype=S.dtype
        )
        angles = jnp.cos(angles) + 1j * jnp.sin(angles)
        angles *= S

    def step(carry, _):
        prev_rebuilt, angles = carry

        inverse = istft(
            angles,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=center,
        )
        rebuilt = stft(
            inverse,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=center,
            pad_kwargs=pad_kwargs,
        )

        angles = rebuilt
        angles -= (momentum / (1 + momentum)) * prev_rebuilt
        angles /= jnp.abs(angles) + util.feps(angles)
        angles *= S
        return (rebuilt, angles), None

    (_, angles), _ = jax.lax.scan(
        step, init=(jnp.zeros_like(angles), angles), length=n_iter
    )

    return istft(
        angles,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        length=length,
    )

Frequency Transforms

These functions take in frequency-domain representations and output modified frequency-domain representations. They are used in the above time-to-frequency transforms, but can also be used standalone.

korvax.cepstral_coefficients

cepstral_coefficients(
    S, /, n_cc=20, norm="ortho", mag_scale="db", lifter=0.0
)

Compute cepstral coefficients from a spectrogram via discrete cosine transform.

Parameters:

Name Type Description Default
S Float[Array, '*channels n_freqs n_frames']

Input spectrogram.

required
n_cc int

Number of cepstral coefficients to return.

20
norm Literal['backward', 'ortho'] | None

Normalization mode for DCT.

'ortho'
mag_scale Literal['linear', 'log', 'db']

Magnitude scaling to apply before DCT. Options are "linear" (no scaling), "log" (natural logarithm), or "db" (decibels).

'db'
lifter float

If greater than 0, apply liftering (cepstral filtering) with the specified coefficient.

0.0

Returns:

Type Description
Float[Array, '*channels {n_cc} n_frames']

Cepstral coefficients.

Source code in src/korvax/transforms/mel.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def cepstral_coefficients(
    S: Float[Array, "*channels n_freqs n_frames"],
    /,
    n_cc: int = 20,
    norm: Literal["backward", "ortho"] | None = "ortho",
    mag_scale: Literal["linear", "log", "db"] = "db",
    lifter: float = 0.0,
) -> Float[Array, "*channels {n_cc} n_frames"]:
    """Compute cepstral coefficients from a spectrogram via discrete cosine transform.

    Args:
        S: Input spectrogram.
        n_cc: Number of cepstral coefficients to return.
        norm: Normalization mode for DCT.
        mag_scale: Magnitude scaling to apply before DCT. Options are "linear" (no scaling),
            "log" (natural logarithm), or "db" (decibels).
        lifter: If greater than 0, apply liftering (cepstral filtering) with the specified
            coefficient.

    Returns:
        Cepstral coefficients.
    """
    if mag_scale == "log":
        S = jnp.log(S + 1e-6)
    elif mag_scale == "db":
        S = power_to_db(S, amin=1e-6)

    M = jax.scipy.fft.dct(S, axis=-2, norm=norm)[..., :n_cc, :]

    if lifter > 0.0:
        li = jnp.sin(jnp.pi * jnp.arange(1, 1 + n_cc, dtype=S.dtype) / lifter)

        shape = [1] * M.ndim
        shape[-2] = n_cc
        M *= 1 + (lifter / 2) * li.reshape(shape)
    return M

korvax.to_mel_scale

to_mel_scale(
    S, /, sr, n_fft, n_mels=128, fmin=0.0, fmax=None
)

Convert a linear-frequency spectrogram to mel scale.

Parameters:

Name Type Description Default
S Float[Array, '*channels n_freqs n_frames']

Input spectrogram.

required
sr float

Sample rate of the audio signal.

required
n_fft int

FFT size (number of samples per frame).

required
n_mels int

Number of mel bands to generate.

128
fmin float

Minimum frequency (Hz).

0.0
fmax float | None

Maximum frequency (Hz). If None, defaults to sr / 2.

None

Returns:

Type Description
Float[Array, '*channels {n_mels} n_frames']

Mel-scale spectrogram.

Source code in src/korvax/transforms/mel.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def to_mel_scale(
    S: Float[Array, "*channels n_freqs n_frames"],
    /,
    sr: float,
    n_fft: int,
    n_mels: int = 128,
    fmin: float = 0.0,
    fmax: float | None = None,
) -> Float[Array, "*channels {n_mels} n_frames"]:
    """Convert a linear-frequency spectrogram to mel scale.

    Args:
        S: Input spectrogram.
        sr: Sample rate of the audio signal.
        n_fft: FFT size (number of samples per frame).
        n_mels: Number of mel bands to generate.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz). If None, defaults to `sr / 2`.

    Returns:
        Mel-scale spectrogram.
    """
    with jax.ensure_compile_time_eval():
        mels = mel_filterbank(
            sr=sr,
            n_fft=n_fft,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
        )

    return jnp.einsum("...fn,mf->...mn", S, mels)

Perceptual Loudness Weighting

korvax.A_weighting

A_weighting(frequencies, /, min_db=-80.0)

Compute A-weighting curve for given frequencies.

Parameters:

Name Type Description Default
frequencies Float[ArrayLike, ' n_freqs']

Frequencies in Hz.

required
min_db float | None

Minimum decibel value for clipping. If None, no clipping applied.

-80.0

Returns:

Type Description
Float[Array, ' n_freqs']

A-weighting values in dB.

Source code in src/korvax/convert.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def A_weighting(
    frequencies: Float[ArrayLike, " n_freqs"], /, min_db: float | None = -80.0
) -> Float[Array, " n_freqs"]:
    """Compute A-weighting curve for given frequencies.

    Args:
        frequencies: Frequencies in Hz.
        min_db: Minimum decibel value for clipping. If None, no clipping applied.

    Returns:
        A-weighting values in dB.
    """
    f = jnp.asarray(frequencies) ** 2
    const = jnp.array([12194.217, 20.598997, 107.65265, 737.86223]) ** 2.0
    weights: jnp.ndarray = 2.0 + 20.0 * (
        jnp.log10(const[0])
        + 2 * jnp.log10(f)
        - jnp.log10(f + const[0])
        - jnp.log10(f + const[1])
        - 0.5 * jnp.log10(f + const[2])
        - 0.5 * jnp.log10(f + const[3])
    )

    if min_db is None:
        return weights
    else:
        return jnp.maximum(min_db, weights)

korvax.B_weighting

B_weighting(frequencies, /, min_db=-80.0)

Compute B-weighting curve for given frequencies.

Parameters:

Name Type Description Default
frequencies Float[ArrayLike, ' n_freqs']

Frequencies in Hz.

required
min_db float | None

Minimum decibel value for clipping. If None, no clipping applied.

-80.0

Returns:

Type Description
Float[Array, ' n_freqs']

B-weighting values in dB.

Source code in src/korvax/convert.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def B_weighting(
    frequencies: Float[ArrayLike, " n_freqs"], /, min_db: float | None = -80.0
) -> Float[Array, " n_freqs"]:
    """Compute B-weighting curve for given frequencies.

    Args:
        frequencies: Frequencies in Hz.
        min_db: Minimum decibel value for clipping. If None, no clipping applied.

    Returns:
        B-weighting values in dB.
    """
    f = jnp.asarray(frequencies) ** 2
    const = jnp.array([12194.217, 20.598997, 158.48932]) ** 2.0
    weights: jnp.ndarray = 0.17 + 20.0 * (
        jnp.log10(const[0])
        + 1.5 * jnp.log10(f)
        - jnp.log10(f + const[0])
        - jnp.log10(f + const[1])
        - 0.5 * jnp.log10(f + const[2])
    )

    if min_db is None:
        return weights
    else:
        return jnp.maximum(min_db, weights)

korvax.C_weighting

C_weighting(frequencies, /, min_db=-80.0)

Compute C-weighting curve for given frequencies.

Parameters:

Name Type Description Default
frequencies Float[ArrayLike, ' n_freqs']

Frequencies in Hz.

required
min_db float | None

Minimum decibel value for clipping. If None, no clipping applied.

-80.0

Returns:

Type Description
Float[Array, ' n_freqs']

C-weighting values in dB.

Source code in src/korvax/convert.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def C_weighting(
    frequencies: Float[ArrayLike, " n_freqs"], /, min_db: float | None = -80.0
) -> Float[Array, " n_freqs"]:
    """Compute C-weighting curve for given frequencies.

    Args:
        frequencies: Frequencies in Hz.
        min_db: Minimum decibel value for clipping. If None, no clipping applied.

    Returns:
        C-weighting values in dB.
    """
    f = jnp.asarray(frequencies) ** 2.0
    const = jnp.array([12194.217, 20.598997]) ** 2.0
    weights: jnp.ndarray = 0.062 + 20.0 * (
        jnp.log10(const[0])
        + jnp.log10(f)
        - jnp.log10(f + const[0])
        - jnp.log10(f + const[1])
    )

    if min_db is None:
        return weights
    else:
        return jnp.maximum(min_db, weights)

korvax.D_weighting

D_weighting(frequencies, /, min_db=-80.0)

Compute D-weighting curve for given frequencies.

Parameters:

Name Type Description Default
frequencies Float[ArrayLike, ' n_freqs']

Frequencies in Hz.

required
min_db float | None

Minimum decibel value for clipping. If None, no clipping applied.

-80.0

Returns:

Type Description
Float[Array, ' n_freqs']

D-weighting values in dB.

Source code in src/korvax/convert.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def D_weighting(
    frequencies: Float[ArrayLike, " n_freqs"], /, min_db: float | None = -80.0
) -> Float[Array, " n_freqs"]:
    """Compute D-weighting curve for given frequencies.

    Args:
        frequencies: Frequencies in Hz.
        min_db: Minimum decibel value for clipping. If None, no clipping applied.

    Returns:
        D-weighting values in dB.
    """
    f = jnp.asarray(frequencies) ** 2
    const = jnp.array([8.3046305e-3, 1018.7, 1039.6, 3136.5, 3424, 282.7, 1160]) ** 2.0
    weights = 20.0 * (
        0.5 * jnp.log10(f)
        - jnp.log10(const[0])
        + 0.5
        * (
            +jnp.log10((const[1] - f) ** 2 + const[2] * f)
            - jnp.log10((const[3] - f) ** 2 + const[4] * f)
            - jnp.log10(const[5] + f)
            - jnp.log10(const[6] + f)
        )
    )

    if min_db is None:
        return weights
    else:
        return jnp.maximum(min_db, weights)

Utilities

korvax.mel_filterbank

mel_filterbank(
    *,
    sr,
    n_fft,
    n_mels=128,
    fmin=0.0,
    fmax=None,
    htk=False,
    norm="slaney",
    dtype=None,
)

Create a mel-scale filterbank.

Parameters:

Name Type Description Default
sr float

Sample rate of the audio signal.

required
n_fft int

FFT size (number of samples per frame).

required
n_mels int

Number of mel bands to generate.

128
fmin float

Minimum frequency (Hz).

0.0
fmax float | None

Maximum frequency (Hz). If None, defaults to sr / 2.

None
htk bool

If True, use HTK formula for mel scale. Otherwise, use Slaney formula.

False
norm Literal['slaney'] | float | None

Normalization mode. If "slaney", use Slaney-style normalization. If a float, use L-norm normalization. If None, no normalization.

'slaney'
dtype DTypeLike | None

Data type for the filterbank. If None, defaults to default float type.

None

Returns:

Type Description
Float[Array, ' {n_mels} {n_fft}//2+1']

Mel filterbank matrix.

Source code in src/korvax/transforms/mel.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def mel_filterbank(
    *,
    sr: float,
    n_fft: int,
    n_mels: int = 128,
    fmin: float = 0.0,
    fmax: float | None = None,
    htk: bool = False,
    norm: Literal["slaney"] | float | None = "slaney",
    dtype: DTypeLike | None = None,
) -> Float[Array, " {n_mels} {n_fft}//2+1"]:
    """Create a mel-scale filterbank.

    Args:
        sr: Sample rate of the audio signal.
        n_fft: FFT size (number of samples per frame).
        n_mels: Number of mel bands to generate.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz). If None, defaults to `sr / 2`.
        htk: If True, use HTK formula for mel scale. Otherwise, use Slaney formula.
        norm: Normalization mode. If "slaney", use Slaney-style normalization.
            If a float, use L-norm normalization. If None, no normalization.
        dtype: Data type for the filterbank. If None, defaults to default float type.

    Returns:
        Mel filterbank matrix.
    """
    if fmax is None:
        fmax = sr / 2

    fft_freqs = fft_frequencies(sr=sr, n_fft=n_fft).astype(dtype)
    mel_freqs = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk).astype(dtype)
    fdiff = jnp.diff(mel_freqs)

    def _mel(i):
        lower = (-mel_freqs[i] + fft_freqs) / fdiff[i]
        upper = (mel_freqs[i + 2] - fft_freqs) / fdiff[i + 1]
        return jnp.maximum(0.0, jnp.minimum(lower, upper))

    mels = jax.vmap(_mel)(jnp.arange(n_mels))

    if norm == "slaney":
        enorm = 2.0 / (mel_freqs[2 : n_mels + 2] - mel_freqs[:n_mels])
        mels *= enorm[:, None]
    else:
        mels = util.normalize(mels, ord=norm, axis=-1)

    return mels

korvax.mel_to_hz

mel_to_hz(mels, /, htk=False)

Convert mel scale to frequencies in Hz.

Parameters:

Name Type Description Default
mels Float[ArrayLike, '*dims']

Mel-scale values.

required
htk bool

If True, use HTK formula. Otherwise, use Slaney formula.

False

Returns:

Type Description
Float[Array, '*dims']

Frequencies in Hz.

Source code in src/korvax/convert.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def mel_to_hz(
    mels: Float[ArrayLike, "*dims"], /, htk: bool = False
) -> Float[Array, "*dims"]:
    """Convert mel scale to frequencies in Hz.

    Args:
        mels: Mel-scale values.
        htk: If True, use HTK formula. Otherwise, use Slaney formula.

    Returns:
        Frequencies in Hz.
    """
    mels = jnp.asarray(mels)

    if htk:
        return 700.0 * (10 ** (mels / 2595.0) - 1.0)

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    frequencies = f_min + f_sp * mels

    # Fill in the log-scale part
    min_log_hz = 1000.0  # beginning of log region (Hz)
    min_log_mel = (min_log_hz - f_min) / f_sp  # same (Mels)
    logstep = jnp.log(6.4) / 27.0  # step size for log region

    return jnp.where(
        mels >= min_log_mel,
        min_log_hz * jnp.exp(logstep * (mels - min_log_mel)),
        frequencies,
    )

korvax.hz_to_mel

hz_to_mel(frequencies, /, htk=False)

Convert frequencies in Hz to mel scale.

Parameters:

Name Type Description Default
frequencies Float[ArrayLike, '*dims']

Frequencies in Hz.

required
htk bool

If True, use HTK formula. Otherwise, use Slaney formula.

False

Returns:

Type Description
Float[Array, '*dims']

Mel-scale values.

Source code in src/korvax/convert.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def hz_to_mel(
    frequencies: Float[ArrayLike, "*dims"], /, htk: bool = False
) -> Float[Array, "*dims"]:
    """Convert frequencies in Hz to mel scale.

    Args:
        frequencies: Frequencies in Hz.
        htk: If True, use HTK formula. Otherwise, use Slaney formula.

    Returns:
        Mel-scale values.
    """
    frequencies = jnp.asarray(frequencies)

    if htk:
        return 2595.0 * jnp.log10(1.0 + frequencies / 700.0)

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (frequencies - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0  # beginning of log region (Hz)
    min_log_mel = (min_log_hz - f_min) / f_sp  # same (Mels)
    logstep = jnp.log(6.4) / 27.0  # step size for log region

    return jnp.where(
        frequencies >= min_log_hz,
        min_log_mel + jnp.log(frequencies / min_log_hz) / logstep,
        mels,
    )

korvax.db_to_amplitude

db_to_amplitude(S_db, /, ref=1.0)

Convert a decibel-scale spectrogram to amplitude scale.

Parameters:

Name Type Description Default
S_db Float[ArrayLike, '*dims']

Input spectrogram in dB.

required
ref float

Reference value for decibel calculation.

1.0

Returns:

Type Description
Float[Array, '*dims']

Amplitude spectrogram.

Source code in src/korvax/convert.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def db_to_amplitude(
    S_db: Float[ArrayLike, "*dims"],
    /,
    ref: float = 1.0,
) -> Float[Array, "*dims"]:
    """Convert a decibel-scale spectrogram to amplitude scale.

    Args:
        S_db: Input spectrogram in dB.
        ref: Reference value for decibel calculation.

    Returns:
        Amplitude spectrogram.
    """
    return ref * (10.0 ** (jnp.asarray(S_db) / 20.0))

korvax.amplitude_to_db

amplitude_to_db(S, /, ref=1.0, amin=1e-08, top_db=80.0)

Convert an amplitude spectrogram to decibel scale.

Parameters:

Name Type Description Default
S Inexact[ArrayLike, '*dims']

Input amplitude spectrogram.

required
ref Float[ArrayLike, ''] | Callable[[Float[ArrayLike, '*']], Float[ArrayLike, '']]

Reference value for decibel calculation. Can be a scalar or callable that computes a reference from the input.

1.0
amin float

Minimum threshold for input values.

1e-08
top_db float | None

Maximum decibel range. Values below max - top_db are clipped. If None, no clipping applied.

80.0

Returns:

Type Description
Float[Array, '*dims']

Amplitude spectrogram in dB.

Source code in src/korvax/convert.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def amplitude_to_db(
    S: Inexact[ArrayLike, "*dims"],
    /,
    ref: Float[ArrayLike, ""]
    | Callable[[Float[ArrayLike, "*"]], Float[ArrayLike, ""]] = 1.0,
    amin: float = 1e-8,
    top_db: float | None = 80.0,
) -> Float[Array, "*dims"]:
    """Convert an amplitude spectrogram to decibel scale.

    Args:
        S: Input amplitude spectrogram.
        ref: Reference value for decibel calculation. Can be a scalar or callable
            that computes a reference from the input.
        amin: Minimum threshold for input values.
        top_db: Maximum decibel range. Values below `max - top_db` are clipped.
            If None, no clipping applied.

    Returns:
        Amplitude spectrogram in dB.
    """
    if jnp.issubdtype(jnp.result_type(S), jnp.complexfloating):
        mag = jnp.abs(S)
    else:
        mag = jnp.asarray(S)
    if callable(ref):
        ref_value = ref(mag)
    else:
        ref_value = jnp.abs(ref)

    power = mag**2

    return power_to_db(power, ref=ref_value**2, amin=amin**2, top_db=top_db)  # pyright: ignore[reportArgumentType]

korvax.power_to_db

power_to_db(S, /, ref=1.0, amin=1e-10, top_db=80.0)

Convert a power spectrogram to decibel scale.

Parameters:

Name Type Description Default
S Inexact[ArrayLike, '*dims']

Input power spectrogram.

required
ref Float[ArrayLike, ''] | Callable[[Float[ArrayLike, '*']], Float[ArrayLike, '']]

Reference value for decibel calculation. Can be a scalar or callable that computes a reference from the input.

1.0
amin float

Minimum threshold for input values.

1e-10
top_db float | None

Maximum decibel range. Values below max - top_db are clipped. If None, no clipping applied.

80.0

Returns:

Type Description
Float[Array, '*dims']

Power spectrogram in dB.

Source code in src/korvax/convert.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def power_to_db(
    S: Inexact[ArrayLike, "*dims"],
    /,
    ref: Float[ArrayLike, ""]
    | Callable[[Float[ArrayLike, "*"]], Float[ArrayLike, ""]] = 1.0,
    amin: float = 1e-10,
    top_db: float | None = 80.0,
) -> Float[Array, "*dims"]:
    """Convert a power spectrogram to decibel scale.

    Args:
        S: Input power spectrogram.
        ref: Reference value for decibel calculation. Can be a scalar or callable
            that computes a reference from the input.
        amin: Minimum threshold for input values.
        top_db: Maximum decibel range. Values below `max - top_db` are clipped.
            If None, no clipping applied.

    Returns:
        Power spectrogram in dB.
    """
    if jnp.issubdtype(jnp.result_type(S), jnp.complexfloating):
        power = jnp.abs(S)
    else:
        power = jnp.asarray(S)

    if callable(ref):
        ref_value = ref(power)
    else:
        ref_value = jnp.abs(ref)

    log_spec = 10.0 * jnp.log10(jnp.maximum(amin, power))
    log_spec -= 10.0 * jnp.log10(jnp.maximum(amin, ref_value))

    if top_db is not None:
        log_spec = jnp.maximum(log_spec, log_spec.max() - top_db)

    return log_spec

korvax.db_to_power

db_to_power(S_db, /, ref=1.0)

Convert a decibel-scale spectrogram to power scale.

Parameters:

Name Type Description Default
S_db Float[ArrayLike, '*dims']

Input spectrogram in dB.

required
ref Float[ArrayLike, '']

Reference value for decibel calculation.

1.0

Returns:

Type Description
Float[Array, '*dims']

Power spectrogram.

Source code in src/korvax/convert.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def db_to_power(
    S_db: Float[ArrayLike, "*dims"],
    /,
    ref: Float[ArrayLike, ""] = 1.0,
) -> Float[Array, "*dims"]:
    """Convert a decibel-scale spectrogram to power scale.

    Args:
        S_db: Input spectrogram in dB.
        ref: Reference value for decibel calculation.

    Returns:
        Power spectrogram.
    """
    return jnp.asarray(ref) * (10.0 ** (jnp.asarray(S_db) / 10.0))

korvax.fft_frequencies

fft_frequencies(*, sr=22050, n_fft=2048)

Compute the center frequencies of FFT bins.

Parameters:

Name Type Description Default
sr float

Sample rate of the audio signal.

22050
n_fft int

FFT size (number of samples per frame).

2048

Returns:

Type Description
Float[Array, ' {n_fft}//2+1']

Center frequencies of FFT bins in Hz.

Source code in src/korvax/convert.py
130
131
132
133
134
135
136
137
138
139
140
141
142
def fft_frequencies(
    *, sr: float = 22050, n_fft: int = 2048
) -> Float[Array, " {n_fft}//2+1"]:
    """Compute the center frequencies of FFT bins.

    Args:
        sr: Sample rate of the audio signal.
        n_fft: FFT size (number of samples per frame).

    Returns:
        Center frequencies of FFT bins in Hz.
    """
    return jnp.fft.rfftfreq(n=n_fft, d=1.0 / sr)

korvax.mel_frequencies

mel_frequencies(
    n_mels=128, /, fmin=0.0, fmax=11025.0, htk=False
)

Compute an array of mel-spaced frequencies.

Parameters:

Name Type Description Default
n_mels int

Number of mel bands.

128
fmin float

Minimum frequency (Hz).

0.0
fmax float

Maximum frequency (Hz).

11025.0
htk bool

If True, use HTK formula. Otherwise, use Slaney formula.

False

Returns:

Type Description
Float[Array, ' {n_mels}']

Array of frequencies in Hz.

Source code in src/korvax/convert.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def mel_frequencies(
    n_mels: int = 128, /, fmin: float = 0.0, fmax: float = 11025.0, htk: bool = False
) -> Float[Array, " {n_mels}"]:
    """Compute an array of mel-spaced frequencies.

    Args:
        n_mels: Number of mel bands.
        fmin: Minimum frequency (Hz).
        fmax: Maximum frequency (Hz).
        htk: If True, use HTK formula. Otherwise, use Slaney formula.

    Returns:
        Array of frequencies in Hz.
    """
    mels = jnp.linspace(hz_to_mel(fmin, htk=htk), hz_to_mel(fmax, htk=htk), n_mels)
    return mel_to_hz(mels, htk=htk)

korvax.cqt_frequencies

cqt_frequencies(
    n_bins, /, fmin, bins_per_octave=12, tuning=0.0
)

Compute the center frequencies of constant-Q transform bins.

Parameters:

Name Type Description Default
n_bins int

Number of frequency bins.

required
fmin float

Minimum frequency (Hz).

required
bins_per_octave int

Number of bins per octave.

12
tuning float

Tuning offset in fractions of a bin.

0.0

Returns:

Type Description
Float[Array, ' {n_bins}']

Center frequencies of CQT bins in Hz.

Source code in src/korvax/convert.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def cqt_frequencies(
    n_bins: int, /, fmin: float, bins_per_octave: int = 12, tuning: float = 0.0
) -> Float[Array, " {n_bins}"]:
    """Compute the center frequencies of constant-Q transform bins.

    Args:
        n_bins: Number of frequency bins.
        fmin: Minimum frequency (Hz).
        bins_per_octave: Number of bins per octave.
        tuning: Tuning offset in fractions of a bin.

    Returns:
        Center frequencies of CQT bins in Hz.
    """
    correction = 2.0 ** (tuning / bins_per_octave)
    frequencies = 2.0 ** (jnp.arange(0, n_bins) / bins_per_octave)

    return correction * fmin * frequencies