Skip to content

Pitch Estimation

korvax.pitch.yin

yin(
    x,
    /,
    fmin,
    fmax,
    sr,
    frame_length=2048,
    hop_length=None,
    trough_threshold=0.1,
    center=True,
    pad_kwargs=dict(),
)

Estimate fundamental frequency using the YIN algorithm.

Parameters:

Name Type Description Default
x Float[ArrayLike, '*channels n_samples']

Input signal.

required
fmin float

Minimum frequency (Hz) to search.

required
fmax float

Maximum frequency (Hz) to search.

required
sr float

Sample rate of the audio signal.

required
frame_length int

Length of each analysis frame in samples.

2048
hop_length int | None

Hop (step) length between adjacent frames. If None, defaults to frame_length // 4.

None
trough_threshold float

Absolute threshold for peak selection. Troughs below this value are considered valid pitch candidates.

0.1
center bool

If True, pad the input so that frames are centered on their timestamps.

True
pad_kwargs dict[str, Any]

Additional keyword arguments forwarded to pad_center.

dict()

Returns:

Type Description
Float[Array, '*channels n_frames']

Estimated fundamental frequency for each frame in Hz.

Source code in src/korvax/pitch/_yin.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def yin(
    x: Float[ArrayLike, "*channels n_samples"],
    /,
    fmin: float,
    fmax: float,
    sr: float,
    frame_length: int = 2048,
    hop_length: int | None = None,
    trough_threshold: float = 0.1,
    center: bool = True,
    pad_kwargs: dict[str, Any] = dict(),
) -> Float[Array, "*channels n_frames"]:
    """Estimate fundamental frequency using the YIN algorithm.

    Args:
        x: Input signal.
        fmin: Minimum frequency (Hz) to search.
        fmax: Maximum frequency (Hz) to search.
        sr: Sample rate of the audio signal.
        frame_length: Length of each analysis frame in samples.
        hop_length: Hop (step) length between adjacent frames. If None, defaults to
            `frame_length // 4`.
        trough_threshold: Absolute threshold for peak selection. Troughs below this
            value are considered valid pitch candidates.
        center: If True, pad the input so that frames are centered on their timestamps.
        pad_kwargs: Additional keyword arguments forwarded to [pad_center][korvax.util.pad_center].

    Returns:
        Estimated fundamental frequency for each frame in Hz.
    """
    x = jnp.asarray(x)

    # Set the default hop if it is not already specified.
    if hop_length is None:
        hop_length = frame_length // 4

    # Pad the time series so that frames are centered
    if center:
        x = util.pad_center(
            x,
            size=x.shape[-1] + frame_length,
            pad_kwargs=pad_kwargs,
        )

    # Frame audio.
    frames = util.frame(x, frame_length=frame_length, hop_length=hop_length)

    # Calculate minimum and maximum periods
    min_period = int(math.floor(sr / fmax))
    max_period = min(int(math.ceil(sr / fmin)), frame_length - 1)

    # Calculate cumulative mean normalized difference function.
    yin_frames = _cumulative_mean_normalized_difference(frames, min_period, max_period)

    parabolic_shifts = util.parabolic_peak_shifts(yin_frames, axis=-2)

    # Find local minima.
    is_trough = util.localmin(yin_frames, axis=-2)
    is_trough = is_trough.at[..., 0, :].set(
        yin_frames[..., 0, :] < yin_frames[..., 1, :]
    )

    # Find minima below peak threshold.
    is_threshold_trough = jnp.logical_and(is_trough, yin_frames < trough_threshold)

    # Absolute threshold.
    # "The solution we propose is to set an absolute threshold and choose the
    # smallest value of tau that gives a minimum of d' deeper than
    # this threshold. If none is found, the global minimum is chosen instead."

    global_min = jnp.argmin(yin_frames, axis=-2, keepdims=True)
    yin_period = jnp.argmax(is_threshold_trough, axis=-2, keepdims=True)

    no_trough_below_threshold = jnp.all(~is_threshold_trough, axis=-2, keepdims=True)
    yin_period = jnp.where(no_trough_below_threshold, global_min, yin_period)

    # Refine peak by parabolic interpolation.

    yin_period = (
        min_period
        + yin_period
        + jnp.take_along_axis(parabolic_shifts, yin_period, axis=-2)
    )[..., 0, :]

    # Convert period to fundamental frequency.
    f0: jnp.ndarray = sr / yin_period
    return f0