Token Generation with Axial Transformers

AudioLM-like model

Token decoding with AudioLM-like strategy. We decode the token indices (frequency) column-by-column, in a raster-scan order across time and depth. Starting from (time=0, depth=0), we predict the following depth's indices, i.e., (time=0, depth=1) and (time=0, depth=2). Then, we decode the indices of time 1: (time=1, depth=0), (time=1, depth=1), and (time=1, depth=2). We use a single Time-Frequency Transformer (TF-Transformer) to achieve this.


AudioLM-like model

Implementation of the AudioLM-like strategy. A TF-Transformer takes a start-of-sequence token as input and decodes the first column from (frequency=0, time=0, depth=0) to (frequency=F-1, time=0, depth=0). For simplicity, the number of frequency bins F is set to 2, while it is 32 in practice. Then, the column-by-column decoding is performed, as illustrated in the above figure. Before passing to TF-Transformer, we convert the decoded indices (denoted as Q) into discretized vectors by codebook lookup and apply a single linear layer. These preprocessed features are denoted as x. Then, add time-frequency positional encodings, which is a concatenation of time-axis and frequency-axis encoding vectors. We finish the decoding if the transformer outputs end-of-sequence tokens. To condition the generation process, we append the reference signal encodings z to the input.


AudioLM-like model

Token decoding with the RQ-Transformer-like strategy. We use the same ordering used in the AudioLM-like method, but we use a separate Depth Transformer to decode across the depth axis.


AudioLM-like model

Implementation of RQ-Transformer-like strategy. Instead of processing every depth, the TF-Transformer computes a context vector for every time frame. Then, each context vector is passed to the Depth Transformer, which performs autoregressive decoding across the depth axis.


AudioLM-like model

Token decoding with VALL-E-like strategy. Instead of the time-depth raster scan order, we first predict the (coarse) depth-0 tokens with a TF-Trasnformer. Then, we decode the entire token map of each following depth with another TF-Transformer.


AudioLM-like model

Implementation of the VALL-E-like strategy. We decode the depth-0 tokens autoregressively using the first TF-Transformer like we did in the AudioLM-like model. Then, the decoded tokens are preprocessed and passed to the second TF-Transformer. The second transformer also accepts the conditioning latent z and another depth encoding to differentiate different depths. Unlike the first transformer, the second one does not have any causal mask; each token can globally access any other one. Since we use residual quantization of 3 codebooks, we need to pass the second transformer twice to decode the entire indices.


AudioLM-like model

Transformer architecture. The TF-Transformer (left) stacks 4 identical blocks, where each block consists of two transformer encoder layers, one for the frequency axis and the other one for the sequence (time or flattened time/depth) axis. In other words, these two layers treat time axis and frequency axis as a batch axis, respectively. The TF-Transformer uses causal masks for the time-axis layers when we use it for the RQ-Transformer-like strategy or depth-0 decoding of VALL-E-like method. The Depth Transformer (right) is simply a stack of two causal encoder layers. In the figure, C, F, S, T, and D denote channel, frequency, sequence, time, and depth, respectively.


Conditional Generation with Reference Encoder

AudioLM-like model

Reference encoder architecture. We borrow the MFTAA-Net, which is known for its state-of-the-art speech enhancement performance. Specifically, we use only its encoder part with the following modifications. (i) We replace the frequency-axis stride of the downsampling convolutions with the time-axis stride, and (ii) use the filterbanks for the frequency downsampling instead. (iii) Then, we add attentive pooling to summarize the features across the time axis. The AxialSelfAttention from the MFTAABlock consists of frequency-axis and time-axis self-attention. The MultiDilatedConv2d is a 4-group two-dimensional convolutional layer, which has a kernel size of (3, 7), a stride of (1, 2), and a dilation factor of [1, 2, 4, 8] (assigned for each group).