In [3]:
import torch
import numpy as np

import matplotlib.pyplot as plt
from IPython.display import Audio, Image
from ipywidgets import interact, fixed, widgets

from data import setup_loader
from features import waveform_voiced_only

In [4]:
dl_train = setup_loader(data_dir="data/train", batch_size=16, shuffle=False)
ds_train = dl_train.dataset

dl_dev = setup_loader(data_dir="data/dev", batch_size=16, shuffle=False)
ds_dev = dl_dev.dataset

In [5]:
FIGSIZE = (6, 2)

def plot_waveform(waveform, sample_rate, title="Waveform", ax=None):
    new_fig = False
    waveform = waveform.numpy()[0]
    num_frames, = waveform.shape

    time_axis = torch.arange(num_frames) / sample_rate

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
        new_fig = True

    ax.plot(time_axis, waveform, linewidth=1)
    ax.grid(True)
    ax.set_title(title)

    if new_fig is True:
        plt.show()
        plt.close()

def plot_waveform_thr(waveform, sample_rate, cond, thr, title="Waveform", ax=None):
    new_fig = False
    waveform = waveform.numpy()[0]
    num_frames, = waveform.shape

    time_axis = torch.arange(num_frames) / sample_rate

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
        new_fig = True

    above = np.ma.masked_where(cond > thr, waveform)
    below_eq = np.ma.masked_where(cond <= thr, waveform)

    ax.plot(time_axis, above, time_axis, below_eq, linewidth=1)
    ax.grid(True)
    ax.set_title(title)

    if new_fig is True:
        plt.show()
        plt.close()

def plot_waveform_energy(waveform, sample_rate, wf_energy, title="Waveform, Energy", ax=None):
    new_fig = False
    waveform = waveform.numpy()[0]
    num_frames, = waveform.shape

    frame_nums = np.arange(num_frames)
    time_axis = frame_nums / sample_rate

    ax1 = ax
    if ax1 is None:
        fig, ax1 = plt.subplots(1, 1, figsize=FIGSIZE)
        new_fig = True

    ax1.grid(True)
    ax1.plot(time_axis, waveform, linewidth=1)
    ax2 = ax1.twinx()
    ax2.plot(time_axis, wf_energy, linewidth=1, c="r")
    ax1.set_title(title)

    if new_fig is True:
        plt.show()
        plt.close()

    return ax2

def plot_waveform_specgram(waveform, sample_rate, title="Spectrogram", ax=None):
    new_fig = False
    waveform = waveform.numpy()[0]
    num_frames, = waveform.shape

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)
        new_fig = True

    ax.specgram(waveform, Fs=sample_rate)
    ax.set_title(title)

    if new_fig is True:
        plt.show()
        plt.close()

In [7]:
def visualize_vad(dataset, idx, rel_voice_thr, vad_mfcc_0, maybe_skip_s):
    sample = dataset[idx]
    label = sample["label"]
    session = sample["session"]
    img_path = sample["img_path"]
    img = sample["img"]
    snd = sample["snd"]
    snd_sr = sample["snd_sr"]
    assert snd.ndim == 2
    assert snd.size(0) == 1
    assert snd_sr == 16000

    snd_sliced, aux = waveform_voiced_only(
        snd, snd_sr, skip_start_seconds=maybe_skip_s, relative_voice_energy_threshold=rel_voice_thr, vad_source="mfcc0" if vad_mfcc_0 else "energy")
    snd_effect = aux["snd_effect"]
    snd_skip = aux["snd_skip"]
    snd_emph = aux["snd_emph"]
    energy_interp = aux["energy_interp"]
    energy_emph_interp = aux["energy_emph_interp"]
    start_was_cut = aux["start_was_cut"]

    stuff = []
    with open(sample["img_path"], "rb") as img_file:
        img = img_file.read()
    stuff.append(widgets.VBox([widgets.Label(str(label)), widgets.Image(value=img, format="png")]))
    sounds = {"1. Original": snd, "2. Trimmed": snd_skip, "3. Emphasized": snd_emph, "5. Sliced": snd_sliced}
    for label, audio in sounds.items():
        if audio.size(1) == 0:
            stuff.append(widgets.VBox([widgets.Label(label), widgets.Label("Empty sound")]))
            continue
        out_w = widgets.Output()
        with out_w:
            display(Audio(data=audio, rate=snd_sr, normalize=True, autoplay=True if label == "5. Sliced" else False))
        stuff.append(widgets.VBox([widgets.Label(label), out_w]))
    boxed_stuff = widgets.HBox(stuff)
    display(boxed_stuff)

    layout = "AB\nCD\nEE\nFF"
    fig, axs = plt.subplot_mosaic(layout, figsize=(12, 8))
    fig.set_layout_engine("constrained")
    fig.suptitle(f"Subject {label}, session {session}")
    plot_waveform(snd, snd_sr, "1. Original", axs["A"])
    plot_waveform(snd_effect, snd_sr, "2. Trimmed silence", axs["B"])
    twin_ax = plot_waveform_energy(snd_effect, snd_sr, energy_interp, "3. Short Time Energy", axs["C"])
    twin_ax.axhline(rel_voice_thr, 0, alpha=0.8, color="orange")
    plot_waveform(snd_emph, snd_sr, f"4. Emphasized (start cut-off: {start_was_cut})", axs["D"])
    plot_waveform_thr(snd_emph, snd_sr, energy_emph_interp, rel_voice_thr, "5. Tresholded", axs["E"])
    plot_waveform(snd_sliced, snd_sr, "6. Sliced", axs["F"])
    plt.show()
    plt.close()

In [8]:
idx_w = widgets.IntSlider(min=0, max=len(ds_train) - 1, step=1, value=0, continuous_update=False, description="Sample")
rel_voice_thr_w = widgets.BoundedFloatText(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.1,
    description="Rel. voice E",
)
vad_mfcc_0 = widgets.Checkbox(value=False, description="0th MFCC instead of E")
skip_w = widgets.FloatText(min=0, max=5, step=0.5, value=1, description="Maybe skip [s]")

ui = widgets.HBox([idx_w, rel_voice_thr_w, vad_mfcc_0, skip_w])
out = widgets.interactive_output(visualize_vad, {
    "dataset": fixed(ds_train), "idx": idx_w, "rel_voice_thr": rel_voice_thr_w, "vad_mfcc_0": vad_mfcc_0, "maybe_skip_s": skip_w,
})
display(ui, out)

HBox(children=(IntSlider(value=0, continuous_update=False, description='Sample', max=185), BoundedFloatText(vaâ€¦

Output()