Skip to content

Resample

korvax.resample

resample(
    x,
    /,
    orig_sr,
    target_sr,
    lowpass_filter_width=6,
    rolloff=0.99,
    resampling_method="sinc_interp_hann",
    beta=None,
    scale=False,
)

Resample a waveform using sinc interpolation.

This function is a JAX port of torchaudio.resample. When jitted, it is just as fast as soxr (HQ) on CPU, but the JAX implementation is also fully differentiable and works on GPU/TPU.

Note: Unlike the rest of korvax, this function requires sampling rates to be specified as integers.

Parameters:

Name Type Description Default
x Float[Array, '*channels old_n_samples']

The input signal of dimension (..., time).

required
orig_sr int

The original frequency of the signal.

required
target_sr int

The desired frequency.

required
lowpass_filter_width int

Controls the sharpness of the filter. A larger value gives a sharper filter but is less efficient. Defaults to 6.

6
rolloff float

The roll-off frequency of the filter as a fraction of the Nyquist frequency. Lower values reduce aliasing but also attenuate high frequencies. Defaults to 0.99.

0.99
resampling_method str

The windowing function to use. Options: ["sinc_interp_hann", "sinc_interp_kaiser"]. Defaults to "sinc_interp_hann".

'sinc_interp_hann'
beta float | None

The shape parameter for the Kaiser window. Only used if resampling_method is "sinc_interp_kaiser". Defaults to 14.7696....

None

Returns:

Type Description
Float[Array, '*channels new_n_samples']

The waveform at the new frequency. The new shape is (..., int(ceil(target_sr * old_n_samples / orig_sr))).

Source code in src/korvax/_resample.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def resample(
    x: Float[Array, "*channels old_n_samples"],
    /,
    orig_sr: int,
    target_sr: int,
    lowpass_filter_width: int = 6,
    rolloff: float = 0.99,
    resampling_method: str = "sinc_interp_hann",
    beta: float | None = None,
    scale: bool = False,
) -> Float[Array, "*channels new_n_samples"]:
    """Resample a waveform using sinc interpolation.

    This function is a JAX port of [`torchaudio.resample`](https://docs.pytorch.org/audio/main/generated/torchaudio.functional.resample.html).
    When jitted, it is just as fast as [`soxr`](https://github.com/dofuuz/python-soxr) (HQ) on CPU,
    but the JAX implementation is also fully differentiable and works on GPU/TPU.

    Note: Unlike the rest of korvax, this function requires sampling rates to be specified as integers.

    Args:
        x: The input signal of dimension `(..., time)`.
        orig_sr: The original frequency of the signal.
        target_sr: The desired frequency.
        lowpass_filter_width: Controls the sharpness of the filter.
            A larger value gives a sharper filter but is less efficient. Defaults to ``6``.
        rolloff: The roll-off frequency of the filter as a
            fraction of the Nyquist frequency. Lower values reduce aliasing but
            also attenuate high frequencies. Defaults to ``0.99``.
        resampling_method: The windowing function to use.
            Options: [``"sinc_interp_hann"``, ``"sinc_interp_kaiser"``].
            Defaults to ``"sinc_interp_hann"``.
        beta: The shape parameter for the Kaiser window.
            Only used if `resampling_method` is ``"sinc_interp_kaiser"``.
            Defaults to ``14.7696...``.

    Returns:
        The waveform at the new frequency. The new shape is `(..., int(ceil(target_sr * old_n_samples / orig_sr)))`.
    """
    if orig_sr <= 0 or target_sr <= 0:
        raise ValueError("Original and new frequencies must be positive.")

    if orig_sr == target_sr:
        return x

    with jax.ensure_compile_time_eval():
        gcd = math.gcd(orig_sr, target_sr)
        kernel, width = _get_sinc_resample_kernel(
            orig_sr,
            target_sr,
            gcd,
            lowpass_filter_width,
            rolloff,
            resampling_method,
            beta,
            dtype=x.dtype,
        )

    x = _apply_sinc_resample_kernel(x, orig_sr, target_sr, gcd, kernel, width)

    if scale:
        x = x / jnp.sqrt((target_sr / orig_sr))

    return x