Skip to content

Loss Functions

If you're looking for an out-of-the-box MR-STFT loss, see mrstft_loss.

Korvax provides a general interface for frame-based loss calculation. A loss is defined by three components:

  • a transform function that converts time-domain signals into time-frequency representations (e.g. STFT, VQT...)
  • a loss function that computes a distance metric between two such representations (e.g. L1/L2, Wasserstein, spectral convergence...)
  • an optional scaling function applied to each frame (e.g. Mel, A-weighting...)

This module contains documentation for the general interface, implements some common frame-wise loss functions, and a ready-to-use MR-STFT loss configuration. For transform functions implemented in Korvax, see Transforms.

General Interface

korvax.loss.time_frequency_loss

time_frequency_loss(
    x,
    y,
    /,
    transform_fn,
    loss_fn,
    scale_fn=None,
    weights=None,
)

Compute a time-frequency loss between two signals.

If loss_fn and scale_fn are sequences, they need to be the same length. The resulting losses are combined as a weighted sum, either using the provided weights or equal weighting if weights is None.

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input signal.

required
y Float[Array, '*channels n_samples']

Target signal.

required
transform_fn TransformFn

Function to compute the time-frequency representation.

required
loss_fn LossFn | Sequence[LossFn]

Loss function(s) to apply in the time-frequency domain.

required
scale_fn ScaleFn | Sequence[ScaleFn] | None

Optional scaling function(s) to apply to the time-frequency representations before computing the loss.

None
weights Sequence[float] | Float[ArrayLike, ' n_losses'] | None

Optional weights for each loss function. If None, equal weighting is used.

None

Returns:

Type Description
Float[Array, '']

The scalar loss value.

Example

Define linear STFT loss with loudness A-weighting applied as scaling:

a_weighted_stft_loss = functools.partial(
    korvax.loss.time_frequency_loss,
    transform_fn=functools.partial(
        korvax.spectrogram,
        win_length=2048,
        power=1,
    ),
    loss_fn=functools.partial(korvax.loss.elementwise_loss, metric="L1"),
    scale_fn=lambda S: korvax.amplitude_to_db(S) + korvax.A_weighting(
        korvax.fft_frequencies(sr=16000, n_fft=2048)
    )[:, None],
)
Example

Define a combination of spectral optimal transport and L1 log magnitude loss on power spectrograms:

combined_lin_sot_loss = functools.partial(
    korvax.loss.time_frequency_loss,
    transform_fn=functools.partial(
        korvax.spectrogram,
        win_length=2048,
        power=2,
    ),
    loss_fn=[
        korvax.loss.spectral_optimal_transport_loss,
        functools.partial(korvax.loss.elementwise_loss, metric="L1"),
    ],
    scale_fn=[
        lambda S: S,
        korvax.power_to_db
    ],
    weights=[1.0, 0.1],
)
Example

Define a multi-resolution STFT loss (also see mrstft_loss):

def my_mrstft_loss(x, y):
    hops = [128, 256, 512]
    wins = [512, 1024, 2048]

    loss_fn = functools.partial(
        korvax.loss.elementwise_loss,
        metric="L1",
    )

    loss = 0.0
    for hop, win in zip(hops, wins):
        transform = functools.partial(
            korvax.spectrogram,
            win_length=win,
            hop_length=hop,
            power=1,
        )
        loss += korvax.loss.time_frequency_loss(
            x,
            y,
            transform_fn=transform,
            loss_fn=loss_fn,
            scale_fn=korvax.amplitude_to_db,
        )

    return loss / len(hops)
Source code in src/korvax/loss.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def time_frequency_loss(
    x: Float[Array, "*channels n_samples"],
    y: Float[Array, "*channels n_samples"],
    /,
    transform_fn: TransformFn,
    loss_fn: LossFn | Sequence[LossFn],
    scale_fn: ScaleFn | Sequence[ScaleFn] | None = None,
    weights: Sequence[float] | Float[ArrayLike, " n_losses"] | None = None,
) -> Float[Array, ""]:
    """Compute a time-frequency loss between two signals.

    If loss_fn and scale_fn are sequences, they need to be the same length.
    The resulting losses are combined as a weighted sum, either
    using the provided `weights` or equal weighting if `weights` is `None`.

    Args:
        x: Input signal.
        y: Target signal.
        transform_fn: Function to compute the time-frequency representation.
        loss_fn: Loss function(s) to apply in the time-frequency domain.
        scale_fn: Optional scaling function(s) to apply to the time-frequency representations before computing the loss.
        weights: Optional weights for each loss function. If `None`, equal weighting is used.

    Returns:
        The scalar loss value.

    Example:
        Define linear STFT loss with loudness A-weighting applied as scaling:

        ```python
        a_weighted_stft_loss = functools.partial(
            korvax.loss.time_frequency_loss,
            transform_fn=functools.partial(
                korvax.spectrogram,
                win_length=2048,
                power=1,
            ),
            loss_fn=functools.partial(korvax.loss.elementwise_loss, metric="L1"),
            scale_fn=lambda S: korvax.amplitude_to_db(S) + korvax.A_weighting(
                korvax.fft_frequencies(sr=16000, n_fft=2048)
            )[:, None],
        )
        ```

    Example:
        Define a combination of spectral optimal transport and L1 log magnitude loss on power spectrograms:

        ```python
        combined_lin_sot_loss = functools.partial(
            korvax.loss.time_frequency_loss,
            transform_fn=functools.partial(
                korvax.spectrogram,
                win_length=2048,
                power=2,
            ),
            loss_fn=[
                korvax.loss.spectral_optimal_transport_loss,
                functools.partial(korvax.loss.elementwise_loss, metric="L1"),
            ],
            scale_fn=[
                lambda S: S,
                korvax.power_to_db
            ],
            weights=[1.0, 0.1],
        )
        ```

    Example:
        Define a multi-resolution STFT loss (also see [`mrstft_loss`][..mrstft_loss]):

        ```python
        def my_mrstft_loss(x, y):
            hops = [128, 256, 512]
            wins = [512, 1024, 2048]

            loss_fn = functools.partial(
                korvax.loss.elementwise_loss,
                metric="L1",
            )

            loss = 0.0
            for hop, win in zip(hops, wins):
                transform = functools.partial(
                    korvax.spectrogram,
                    win_length=win,
                    hop_length=hop,
                    power=1,
                )
                loss += korvax.loss.time_frequency_loss(
                    x,
                    y,
                    transform_fn=transform,
                    loss_fn=loss_fn,
                    scale_fn=korvax.amplitude_to_db,
                )

            return loss / len(hops)
        ```
    """

    if not isinstance(loss_fn, Sequence):
        loss_fn = [loss_fn]

    if scale_fn is None:
        scale_fn = [lambda S: S] * len(loss_fn)

    if not isinstance(scale_fn, Sequence):
        scale_fn = [scale_fn]

    assert (n_losses := len(scale_fn)) == len(loss_fn)

    if weights is None:
        weights = jnp.ones(len(loss_fn), dtype=x.dtype) / n_losses
    else:
        weights = jnp.array(weights, dtype=x.dtype)

    assert len(weights) == n_losses

    loss_total = jnp.array(0.0, dtype=x.dtype)

    x = transform_fn(x)
    y = transform_fn(y)

    for i in range(n_losses):
        loss = loss_fn[i](scale_fn[i](x), scale_fn[i](y))
        loss_total += weights[i] * loss

    return loss_total

korvax.loss.TransformFn module-attribute

TransformFn = Callable[
    [Float[Array, "*channels n_samples"]],
    Inexact[Array, "*channels n_bins n_frames"],
]

korvax.loss.LossFn module-attribute

LossFn = Callable[
    [
        Inexact[Array, "*channels n_bins n_frames"],
        Inexact[Array, "*channels n_bins n_frames"],
    ],
    Float[Array, ""],
]

korvax.loss.ScaleFn module-attribute

ScaleFn = Callable[
    [Inexact[Array, "*channels n_bins_in n_frames"]],
    Inexact[Array, "*channels n_bins_out n_frames"],
]

Loss Functions

korvax.loss.elementwise_loss

elementwise_loss(S_x, S_y, /, metric='L1')

Compute elementwise L1 or L2 loss between two arrays.

Parameters:

Name Type Description Default
S_x Float[Array, '*dims']

Input array.

required
S_y Float[Array, '*dims']

Target array.

required
metric Literal['L1', 'L2']

Distance metric to use. Either "L1" (mean absolute error) or "L2" (mean squared error).

'L1'

Returns:

Type Description
Float[Array, '']

Scalar loss value.

Source code in src/korvax/loss.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def elementwise_loss(
    S_x: Float[Array, "*dims"],
    S_y: Float[Array, "*dims"],
    /,
    metric: Literal["L1", "L2"] = "L1",
) -> Float[Array, ""]:
    """Compute elementwise L1 or L2 loss between two arrays.

    Args:
        S_x: Input array.
        S_y: Target array.
        metric: Distance metric to use. Either "L1" (mean absolute error) or "L2"
            (mean squared error).

    Returns:
        Scalar loss value.
    """
    if metric == "L1":
        loss = jnp.abs(S_x - S_y)
    elif metric == "L2":
        loss = (S_x - S_y) ** 2

    return jnp.mean(loss)

korvax.loss.spectral_convergence_loss

spectral_convergence_loss(S_x, S_y)

Compute spectral convergence loss between two spectrograms.

Parameters:

Name Type Description Default
S_x Float[Array, '*channels n_freq n_frames']

Input spectrogram.

required
S_y Float[Array, '*channels n_freq n_frames']

Target spectrogram.

required

Returns:

Type Description
Float[Array, '']

Scalar loss value.

Source code in src/korvax/loss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def spectral_convergence_loss(
    S_x: Float[Array, "*channels n_freq n_frames"],
    S_y: Float[Array, "*channels n_freq n_frames"],
) -> Float[Array, ""]:
    """Compute spectral convergence loss between two spectrograms.

    Args:
        S_x: Input spectrogram.
        S_y: Target spectrogram.

    Returns:
        Scalar loss value.
    """
    numerator = jnp.linalg.norm(S_y - S_x, ord="fro", axis=(-2, -1))
    denominator = jnp.linalg.norm(S_y, ord="fro", axis=(-2, -1))
    loss = numerator / (denominator + util.feps(denominator))
    return jnp.mean(loss)

korvax.loss.spectral_optimal_transport_loss

spectral_optimal_transport_loss(
    S_x,
    S_y,
    /,
    positions=None,
    p=2,
    normalize=True,
    balanced=True,
    quantile_lowpass=False,
)

Compute the frame-wise 1D Wasserstein distance, known as spectral optimal transport [1, 2].

The implementation and API are based on sot-loss.

Parameters:

Name Type Description Default
S_x Float[Array, '*channels n_bins n_frames']

Input spectrogram.

required
S_y Float[Array, '*channels n_bins n_frames']

Target spectrogram.

required
positions Float[Array, ' n_bins'] | None

Positions of frequency bins. If None, uses uniform spacing in [0, 1).

None
p int

Order of the Wasserstein distance (typically 1 or 2).

2
normalize bool

Whether to normalize spectrograms to sum to 1.

True
balanced bool

If True, S_x and S_y are normalized independently. If False, S_y is scaled to have the same total mass as S_x.

True
quantile_lowpass bool

If True, zeroes out bins in S_y for quantiles above 1.0. Useful when balanced is False.

False

Returns:

Type Description
Float[Array, '']

Scalar loss value.

References

[1] E. Cazelles, A. Robert, F. Tobar, "The Wasserstein-Fourier Distance for Stationary Time Series," IEEE Transactions on Signal Processing, vol. 69, pp. 709-721, 2020.

[2] B. Torres, G. Peeters, G. Richard, "Unsupervised Harmonic Parameter Estimation Using Differentiable DSP and Spectral Optimal Transport,", in Proc. ICASSP 2024, pp. 1176-1180, 2024.

Source code in src/korvax/loss.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def spectral_optimal_transport_loss(
    S_x: Float[Array, "*channels n_bins n_frames"],
    S_y: Float[Array, "*channels n_bins n_frames"],
    /,
    positions: Float[Array, " n_bins"] | None = None,
    p: int = 2,
    normalize: bool = True,
    balanced: bool = True,
    quantile_lowpass: bool = False,
) -> Float[Array, ""]:
    """Compute the frame-wise 1D Wasserstein distance, known as spectral optimal transport [1, 2].

    The implementation and API are based on [sot-loss](https://github.com/bernardo-torres/spectral-optimal-transport/).

    Args:
        S_x: Input spectrogram.
        S_y: Target spectrogram.
        positions: Positions of frequency bins. If None, uses uniform spacing in [0, 1).
        p: Order of the Wasserstein distance (typically 1 or 2).
        normalize: Whether to normalize spectrograms to sum to 1.
        balanced: If True, `S_x` and `S_y` are normalized independently. If False, `S_y` is scaled to have the same total mass as `S_x`.
        quantile_lowpass: If True, zeroes out bins in `S_y` for quantiles above 1.0. Useful when `balanced` is False.

    Returns:
        Scalar loss value.

    References:
        [1] E. Cazelles, A. Robert, F. Tobar, "The Wasserstein-Fourier Distance for Stationary Time Series," IEEE Transactions on Signal Processing, vol. 69, pp. 709-721, 2020.

        [2] B. Torres, G. Peeters, G. Richard, "Unsupervised Harmonic Parameter Estimation Using Differentiable DSP and Spectral Optimal Transport,", in Proc. ICASSP 2024, pp. 1176-1180, 2024.
    """
    n_bins = S_x.shape[-2]

    if positions is None:
        positions = jnp.linspace(0, 1, num=n_bins, endpoint=False, dtype=S_x.dtype)

    S_x = S_x.swapaxes(-1, -2).reshape(-1, n_bins)
    S_y = S_y.swapaxes(-1, -2).reshape(-1, n_bins)

    total_mass_x = jnp.sum(S_x, axis=-1, keepdims=True) + util.feps(S_x)
    total_mass_y = jnp.sum(S_y, axis=-1, keepdims=True) + util.feps(S_y)

    if normalize:
        S_x = S_x / total_mass_x
        if balanced:
            S_y = S_y / total_mass_y
        else:
            S_y = S_y / total_mass_x
    elif balanced:
        S_y = S_y * (total_mass_x / total_mass_y)

    return jax.vmap(
        partial(
            _wasserstein_1d,
            positions=positions,
            p=p,
            limit_quantile_range=quantile_lowpass,
        )
    )(S_x, S_y).mean()

Ready-to-use Configurations

For convenience, common configurations of the above loss functions are provided.

korvax.loss.mrstft_loss

mrstft_loss(
    x,
    y,
    /,
    hop_lengths=(32, 64, 128, 256, 512, 1024),
    win_lengths=(64, 128, 256, 512, 1024, 2048),
    fft_sizes=None,
    window="hann",
    w_lin=1.0,
    w_log=1.0,
    lin_dist="L1",
    log_dist="L1",
    log_fac=1.0,
    log_eps=1e-07,
    power=1,
)

Multi-resolution STFT loss (also known as multi-scale spectral loss).

  • Linear magnitudes are computed as abs(STFT)**power.
  • Log magnitudes are computed as log(log_fac * abs(STFT)**power + log_eps).

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input signal.

required
y Float[Array, '*channels n_samples']

Target signal.

required
hop_lengths Sequence[int]

Sequence of hop lengths for STFTs.

(32, 64, 128, 256, 512, 1024)
win_lengths Sequence[int]

Sequence of window lengths for STFTs.

(64, 128, 256, 512, 1024, 2048)
fft_sizes Sequence[int] | None

Sequence of FFT sizes for STFTs. If None, uses win_lengths.

None
window str | float | tuple

Window function specification.

'hann'
w_lin float

Weight for linear magnitude loss.

1.0
w_log float

Weight for log magnitude loss.

1.0
lin_dist Literal['L1', 'L2']

Distance metric for linear magnitude loss.

'L1'
log_dist Literal['L1', 'L2']

Distance metric for log magnitude loss.

'L1'
log_fac float

Scaling factor for log magnitude.

1.0
log_eps float

Additive constant for magnitude before taking log.

1e-07
power float | int

Exponent for the magnitude spectrogram.

1

Returns:

Type Description
Float[Array, '']

The scalar loss value.

Source code in src/korvax/loss.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def mrstft_loss(
    x: Float[Array, "*channels n_samples"],
    y: Float[Array, "*channels n_samples"],
    /,
    hop_lengths: Sequence[int] = (32, 64, 128, 256, 512, 1024),
    win_lengths: Sequence[int] = (64, 128, 256, 512, 1024, 2048),
    fft_sizes: Sequence[int] | None = None,
    window: str | float | tuple = "hann",
    w_lin: float = 1.0,
    w_log: float = 1.0,
    lin_dist: Literal["L1", "L2"] = "L1",
    log_dist: Literal["L1", "L2"] = "L1",
    log_fac: float = 1.0,
    log_eps: float = 1e-7,
    power: float | int = 1,
) -> Float[Array, ""]:
    """Multi-resolution STFT loss (also known as multi-scale spectral loss).

    * Linear magnitudes are computed as `abs(STFT)**power`.
    * Log magnitudes are computed as `log(log_fac * abs(STFT)**power + log_eps)`.

    Args:
        x: Input signal.
        y: Target signal.
        hop_lengths: Sequence of hop lengths for STFTs.
        win_lengths: Sequence of window lengths for STFTs.
        fft_sizes: Sequence of FFT sizes for STFTs. If None, uses `win_lengths`.
        window: Window function specification.
        w_lin: Weight for linear magnitude loss.
        w_log: Weight for log magnitude loss.
        lin_dist: Distance metric for linear magnitude loss.
        log_dist: Distance metric for log magnitude loss.
        log_fac: Scaling factor for log magnitude.
        log_eps: Additive constant for magnitude before taking log.
        power: Exponent for the magnitude spectrogram.

    Returns:
        The scalar loss value.
    """
    if fft_sizes is None:
        fft_sizes = win_lengths

    assert (n_res := len(win_lengths)) == len(hop_lengths) == len(fft_sizes)

    def log_scale(S):
        return jnp.log(log_fac * S + log_eps)

    scale_fns = [lambda S: S, log_scale]

    loss_fns = [
        partial(elementwise_loss, metric=lin_dist),
        partial(elementwise_loss, metric=log_dist),
    ]

    weights = [w_lin, w_log]

    loss_total = jnp.array(0.0, dtype=x.dtype)

    for i in range(n_res):
        transform_fn = partial(
            spectrogram,
            win_length=win_lengths[i],
            hop_length=hop_lengths[i],
            n_fft=fft_sizes[i],
            window=window,
            power=power,
        )

        loss_total += time_frequency_loss(
            x,
            y,
            transform_fn=transform_fn,
            scale_fn=scale_fns,
            loss_fn=loss_fns,
            weights=weights,
        )

    return loss_total / n_res

korvax.loss.smooth_mrstft_loss

smooth_mrstft_loss(x, y)

Implements the "smooth" multi-resolution STFT loss configuration specified in [1].

Parameters:

Name Type Description Default
x Float[Array, '*channels n_samples']

Input signal.

required
y Float[Array, '*channels n_samples']

Target signal.

required

Returns:

Type Description
Float[Array, '']

The scalar loss value.

References

[1] S. Schwär and M. Müller, "Multi-Scale Spectral Loss Revisited," IEEE Signal Processing Letters, vol. 30, pp. 1712-1716, 2023.

Source code in src/korvax/loss.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def smooth_mrstft_loss(
    x: Float[Array, "*channels n_samples"],
    y: Float[Array, "*channels n_samples"],
) -> Float[Array, ""]:
    """Implements the "smooth" multi-resolution STFT loss configuration specified in [1].

    Args:
        x: Input signal.
        y: Target signal.

    Returns:
        The scalar loss value.

    References:
        [1] S. Schwär and M. Müller, "Multi-Scale Spectral Loss Revisited," IEEE Signal Processing Letters, vol. 30, pp. 1712-1716, 2023.
    """
    return mrstft_loss(
        x,
        y,
        hop_lengths=(32, 63, 128, 254, 510, 1026),
        win_lengths=(67, 127, 257, 509, 1021, 2053),
        window="flattop",
        w_lin=0.0,
        w_log=1.0,
        log_dist="L2",
        log_fac=1.0,
        log_eps=1.0,
    )