Deep Learning Paradigms in Music Information Retrieval
Explore the neural revolution in MIR: from convolutional networks processing spectrograms to attention mechanisms understanding musical structure. Learn how modern architectures are transforming music analysis.
The Deep Learning Revolution in MIR
Deep learning has fundamentally transformed Music Information Retrieval, achieving state-of-the-art results across virtually every task. From raw audio waveforms to symbolic representations, neural networks have learned to extract features and patterns that eluded traditional approaches for decades.
- • End-to-end learning: Learn features directly from data
- • Hierarchical representations: Capture patterns at multiple scales
- • Transfer learning: Leverage pre-trained models
- • Scalability: Performance improves with more data
Convolutional Neural Networks for Spectrograms
CNNs treat spectrograms as images, exploiting local patterns in time-frequency representations. This approach has been remarkably successful for tasks like genre classification, instrument recognition, and onset detection.
Key architectural considerations for music spectrograms:
Frequency-Aware Convolutions
Different filter shapes for frequency vs. time dimensions:
Typically to capture harmonic structure
Multi-Scale Processing
Parallel convolutions at different scales capture various musical elements
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import torchaudio
5import numpy as np
6
7class MusicCNN(nn.Module):
8 def __init__(self, n_classes=10, sample_rate=22050):
9 super(MusicCNN, self).__init__()
10 self.sample_rate = sample_rate
11
12 # Spectrogram computation
13 self.melspec = torchaudio.transforms.MelSpectrogram(
14 sample_rate=sample_rate,
15 n_fft=2048,
16 hop_length=512,
17 n_mels=128,
18 f_min=0,
19 f_max=sample_rate//2
20 )
21
22 # Convolutional blocks with different receptive fields
23 # Block 1: Capture local patterns
24 self.conv1 = nn.Sequential(
25 nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1),
26 nn.BatchNorm2d(64),
27 nn.ReLU(),
28 nn.MaxPool2d(kernel_size=(2, 2))
29 )
30
31 # Block 2: Capture harmonic patterns (tall filters)
32 self.conv2_harmonic = nn.Sequential(
33 nn.Conv2d(64, 128, kernel_size=(12, 1), padding=(6, 0)),
34 nn.BatchNorm2d(128),
35 nn.ReLU()
36 )
37
38 # Block 2: Capture temporal patterns (wide filters)
39 self.conv2_temporal = nn.Sequential(
40 nn.Conv2d(64, 128, kernel_size=(1, 7), padding=(0, 3)),
41 nn.BatchNorm2d(128),
42 nn.ReLU()
43 )
44
45 # Block 3: Higher-level features
46 self.conv3 = nn.Sequential(
47 nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1),
48 nn.BatchNorm2d(256),
49 nn.ReLU(),
50 nn.MaxPool2d(kernel_size=(2, 2))
51 )
52
53 # Global pooling and classification
54 self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
55 self.dropout = nn.Dropout(0.5)
56 self.fc = nn.Linear(256, n_classes)
57
58 # Attention mechanism for temporal importance
59 self.attention = nn.Sequential(
60 nn.Linear(256, 128),
61 nn.Tanh(),
62 nn.Linear(128, 1),
63 nn.Softmax(dim=1)
64 )
65
66 def forward(self, x):
67 # x shape: (batch, samples)
68
69 # Compute mel-spectrogram
70 x = self.melspec(x) # (batch, mel_bins, time)
71 x = torch.log(x + 1e-9) # Log scaling
72
73 # Add channel dimension
74 x = x.unsqueeze(1) # (batch, 1, mel_bins, time)
75
76 # Convolutional feature extraction
77 x = self.conv1(x)
78
79 # Multi-scale processing
80 x_harmonic = self.conv2_harmonic(x)
81 x_temporal = self.conv2_temporal(x)
82 x = torch.cat([x_harmonic, x_temporal], dim=1)
83
84 x = self.conv3(x)
85
86 # Global average pooling
87 x = self.global_pool(x)
88 x = x.view(x.size(0), -1)
89
90 # Classification
91 x = self.dropout(x)
92 x = self.fc(x)
93
94 return x
95
96 def extract_features(self, x):
97 """Extract intermediate features for visualization"""
98 features = {}
99
100 x = self.melspec(x)
101 features['spectrogram'] = x.clone()
102
103 x = torch.log(x + 1e-9)
104 x = x.unsqueeze(1)
105
106 x = self.conv1(x)
107 features['conv1'] = x.clone()
108
109 x_harmonic = self.conv2_harmonic(x)
110 x_temporal = self.conv2_temporal(x)
111 features['harmonic'] = x_harmonic.clone()
112 features['temporal'] = x_temporal.clone()
113
114 x = torch.cat([x_harmonic, x_temporal], dim=1)
115 x = self.conv3(x)
116 features['conv3'] = x.clone()
117
118 return features
119
120# Advanced: Harmonic CNN with learnable filterbanks
121class HarmonicCNN(nn.Module):
122 def __init__(self, n_harmonics=6, n_classes=10):
123 super(HarmonicCNN, self).__init__()
124
125 # Learnable harmonic filters
126 self.harmonic_filters = nn.Parameter(
127 torch.randn(n_harmonics, 1, 128, 1)
128 )
129
130 # Process each harmonic
131 self.harmonic_conv = nn.ModuleList([
132 nn.Sequential(
133 nn.Conv2d(1, 32, kernel_size=(3, 3)),
134 nn.BatchNorm2d(32),
135 nn.ReLU(),
136 nn.MaxPool2d(2)
137 ) for _ in range(n_harmonics)
138 ])
139
140 # Combine harmonics
141 self.combine = nn.Sequential(
142 nn.Conv2d(32 * n_harmonics, 128, kernel_size=(1, 1)),
143 nn.BatchNorm2d(128),
144 nn.ReLU()
145 )
146
147 self.classifier = nn.Sequential(
148 nn.AdaptiveAvgPool2d((1, 1)),
149 nn.Flatten(),
150 nn.Linear(128, n_classes)
151 )
152
153 def forward(self, x):
154 # Apply harmonic filters
155 harmonic_specs = []
156 for i, h_filter in enumerate(self.harmonic_filters):
157 # Shift spectrogram by harmonic ratio
158 h_spec = F.conv2d(x, h_filter.unsqueeze(0))
159 h_features = self.harmonic_conv[i](h_spec)
160 harmonic_specs.append(h_features)
161
162 # Combine all harmonics
163 x = torch.cat(harmonic_specs, dim=1)
164 x = self.combine(x)
165 x = self.classifier(x)
166
167 return x
Recurrent Networks for Sequential Modeling
RNNs, LSTMs, and GRUs excel at modeling temporal dependencies in music, from note sequences to long-term structure.
LSTM cells maintain long-term memory through gating mechanisms:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class MusicTranscriptionRNN(nn.Module):
6 def __init__(self, input_dim=88, hidden_dim=256, n_layers=3):
7 super(MusicTranscriptionRNN, self).__init__()
8
9 # Feature extraction
10 self.feature_conv = nn.Sequential(
11 nn.Conv1d(input_dim, 128, kernel_size=3, padding=1),
12 nn.BatchNorm1d(128),
13 nn.ReLU(),
14 nn.Conv1d(128, 256, kernel_size=3, padding=1),
15 nn.BatchNorm1d(256),
16 nn.ReLU()
17 )
18
19 # Bidirectional LSTM
20 self.lstm = nn.LSTM(
21 input_size=256,
22 hidden_size=hidden_dim,
23 num_layers=n_layers,
24 batch_first=True,
25 dropout=0.3,
26 bidirectional=True
27 )
28
29 # Self-attention over time
30 self.attention = nn.MultiheadAttention(
31 embed_dim=hidden_dim * 2,
32 num_heads=8,
33 dropout=0.1
34 )
35
36 # Output layers for multi-task learning
37 self.pitch_detector = nn.Linear(hidden_dim * 2, 88) # Piano keys
38 self.onset_detector = nn.Linear(hidden_dim * 2, 1) # Onset probability
39 self.velocity_estimator = nn.Linear(hidden_dim * 2, 1) # Velocity
40
41 def forward(self, x):
42 # x shape: (batch, time, features)
43
44 # Transpose for conv1d
45 x = x.transpose(1, 2) # (batch, features, time)
46 x = self.feature_conv(x)
47 x = x.transpose(1, 2) # (batch, time, features)
48
49 # LSTM processing
50 lstm_out, (h_n, c_n) = self.lstm(x)
51
52 # Self-attention
53 attn_out, attn_weights = self.attention(
54 lstm_out.transpose(0, 1),
55 lstm_out.transpose(0, 1),
56 lstm_out.transpose(0, 1)
57 )
58 attn_out = attn_out.transpose(0, 1)
59
60 # Residual connection
61 out = lstm_out + attn_out
62
63 # Multi-task outputs
64 pitches = torch.sigmoid(self.pitch_detector(out))
65 onsets = torch.sigmoid(self.onset_detector(out))
66 velocities = torch.sigmoid(self.velocity_estimator(out)) * 127
67
68 return {
69 'pitches': pitches,
70 'onsets': onsets,
71 'velocities': velocities,
72 'attention': attn_weights
73 }
74
75# GRU-based model for real-time processing
76class RealTimeMusicRNN(nn.Module):
77 def __init__(self, input_dim=128, hidden_dim=128):
78 super(RealTimeMusicRNN, self).__init__()
79
80 # Lightweight GRU for low latency
81 self.gru = nn.GRU(
82 input_size=input_dim,
83 hidden_size=hidden_dim,
84 num_layers=2,
85 batch_first=True
86 )
87
88 # Causal convolution for online processing
89 self.causal_conv = nn.Conv1d(
90 hidden_dim, hidden_dim,
91 kernel_size=3,
92 padding=2, # Extra padding for causality
93 dilation=1
94 )
95
96 self.output = nn.Linear(hidden_dim, 12) # Chroma output
97
98 def forward(self, x, hidden=None):
99 # Process streaming input
100 out, hidden = self.gru(x, hidden)
101
102 # Causal convolution (mask future)
103 out_conv = out.transpose(1, 2)
104 out_conv = self.causal_conv(out_conv)
105 out_conv = out_conv[:, :, :-2] # Remove future padding
106 out_conv = out_conv.transpose(1, 2)
107
108 # Combine
109 out = out + out_conv
110 chroma = torch.sigmoid(self.output(out))
111
112 return chroma, hidden
Transformer Architecture in MIR
Transformers have revolutionized sequence modeling with self-attention mechanisms that capture long-range dependencies without recurrence.
The core of transformers - computing attention weights:
Where Q, K, V are query, key, and value matrices derived from input
Positional Encoding for Music:
1import torch
2import torch.nn as nn
3import math
4
5class MusicTransformer(nn.Module):
6 def __init__(self, d_model=512, n_heads=8, n_layers=6,
7 max_seq_len=2048, vocab_size=388):
8 super(MusicTransformer, self).__init__()
9
10 self.d_model = d_model
11
12 # Token embeddings (pitch, duration, velocity, etc.)
13 self.token_embedding = nn.Embedding(vocab_size, d_model)
14
15 # Positional encoding
16 self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
17
18 # Relative positional encoding for music
19 self.relative_pos = RelativePositionalEncoding(d_model, max_seq_len)
20
21 # Transformer layers
22 encoder_layer = nn.TransformerEncoderLayer(
23 d_model=d_model,
24 nhead=n_heads,
25 dim_feedforward=d_model * 4,
26 dropout=0.1,
27 activation='gelu'
28 )
29
30 self.transformer = nn.TransformerEncoder(
31 encoder_layer,
32 num_layers=n_layers
33 )
34
35 # Output heads for different musical attributes
36 self.pitch_head = nn.Linear(d_model, 128) # MIDI pitches
37 self.duration_head = nn.Linear(d_model, 32) # Quantized durations
38 self.dynamics_head = nn.Linear(d_model, 8) # Dynamic levels
39
40 # Style embedding for conditional generation
41 self.style_embedding = nn.Embedding(10, d_model) # 10 styles
42
43 def forward(self, x, style=None, mask=None):
44 # x shape: (batch, seq_len)
45
46 # Embed tokens
47 x = self.token_embedding(x) * math.sqrt(self.d_model)
48
49 # Add positional encoding
50 x = self.pos_encoding(x)
51
52 # Add style conditioning if provided
53 if style is not None:
54 style_emb = self.style_embedding(style).unsqueeze(1)
55 x = x + style_emb
56
57 # Transformer encoding
58 x = x.transpose(0, 1) # (seq_len, batch, d_model)
59 x = self.transformer(x, mask=mask)
60 x = x.transpose(0, 1) # (batch, seq_len, d_model)
61
62 # Multi-task outputs
63 pitches = self.pitch_head(x)
64 durations = self.duration_head(x)
65 dynamics = self.dynamics_head(x)
66
67 return {
68 'pitches': pitches,
69 'durations': durations,
70 'dynamics': dynamics
71 }
72
73class PositionalEncoding(nn.Module):
74 def __init__(self, d_model, max_len=5000):
75 super(PositionalEncoding, self).__init__()
76
77 pe = torch.zeros(max_len, d_model)
78 position = torch.arange(0, max_len).unsqueeze(1).float()
79
80 div_term = torch.exp(torch.arange(0, d_model, 2).float() *
81 -(math.log(10000.0) / d_model))
82
83 pe[:, 0::2] = torch.sin(position * div_term)
84 pe[:, 1::2] = torch.cos(position * div_term)
85
86 self.register_buffer('pe', pe.unsqueeze(0))
87
88 def forward(self, x):
89 return x + self.pe[:, :x.size(1)]
90
91class RelativePositionalEncoding(nn.Module):
92 """Music-aware relative position encoding"""
93 def __init__(self, d_model, max_len):
94 super(RelativePositionalEncoding, self).__init__()
95
96 # Learnable relative position embeddings
97 self.rel_pos_emb = nn.Parameter(
98 torch.randn(2 * max_len - 1, d_model)
99 )
100
101 def forward(self, q, k):
102 # Compute relative positions
103 seq_len = q.size(1)
104 pos = torch.arange(seq_len).unsqueeze(1) - torch.arange(seq_len).unsqueeze(0)
105 pos = pos + seq_len - 1 # Shift to positive indices
106
107 # Get embeddings
108 rel_emb = self.rel_pos_emb[pos]
109
110 return rel_emb
111
112# Music BERT for representation learning
113class MusicBERT(nn.Module):
114 def __init__(self, d_model=768, n_heads=12, n_layers=12):
115 super(MusicBERT, self).__init__()
116
117 self.transformer = MusicTransformer(
118 d_model=d_model,
119 n_heads=n_heads,
120 n_layers=n_layers
121 )
122
123 # Masked language modeling head
124 self.mlm_head = nn.Sequential(
125 nn.Linear(d_model, d_model),
126 nn.GELU(),
127 nn.LayerNorm(d_model),
128 nn.Linear(d_model, 388) # Vocabulary size
129 )
130
131 # Next sentence prediction for musical phrases
132 self.nsp_head = nn.Linear(d_model, 2)
133
134 def forward(self, x, mask_positions=None):
135 # Get transformer representations
136 outputs = self.transformer(x)
137 hidden_states = outputs['hidden_states']
138
139 # MLM predictions
140 if mask_positions is not None:
141 masked_hidden = hidden_states[mask_positions]
142 mlm_logits = self.mlm_head(masked_hidden)
143 else:
144 mlm_logits = self.mlm_head(hidden_states)
145
146 # NSP prediction using [CLS] token
147 cls_hidden = hidden_states[:, 0]
148 nsp_logits = self.nsp_head(cls_hidden)
149
150 return {
151 'mlm_logits': mlm_logits,
152 'nsp_logits': nsp_logits,
153 'hidden_states': hidden_states
154 }
Variational Autoencoders for Music
VAEs learn compressed representations of music that can be manipulated for generation and style transfer.
VAEs optimize a lower bound on the data likelihood:
Reconstruction term + KL regularization
1class MusicVAE(nn.Module):
2 def __init__(self, input_dim=88, latent_dim=256, hidden_dim=512):
3 super(MusicVAE, self).__init__()
4
5 # Encoder
6 self.encoder = nn.Sequential(
7 nn.Linear(input_dim, hidden_dim),
8 nn.ReLU(),
9 nn.Linear(hidden_dim, hidden_dim),
10 nn.ReLU()
11 )
12
13 # Latent space
14 self.fc_mu = nn.Linear(hidden_dim, latent_dim)
15 self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
16
17 # Decoder
18 self.decoder = nn.Sequential(
19 nn.Linear(latent_dim, hidden_dim),
20 nn.ReLU(),
21 nn.Linear(hidden_dim, hidden_dim),
22 nn.ReLU(),
23 nn.Linear(hidden_dim, input_dim)
24 )
25
26 # Hierarchical structure for long sequences
27 self.conductor = nn.LSTM(latent_dim, latent_dim // 2, 2)
28
29 def encode(self, x):
30 h = self.encoder(x)
31 mu = self.fc_mu(h)
32 logvar = self.fc_logvar(h)
33 return mu, logvar
34
35 def reparameterize(self, mu, logvar):
36 std = torch.exp(0.5 * logvar)
37 eps = torch.randn_like(std)
38 return mu + eps * std
39
40 def decode(self, z):
41 return torch.sigmoid(self.decoder(z))
42
43 def forward(self, x):
44 mu, logvar = self.encode(x)
45 z = self.reparameterize(mu, logvar)
46 recon = self.decode(z)
47 return recon, mu, logvar
48
49 def loss_function(self, recon, x, mu, logvar, beta=1.0):
50 # Reconstruction loss
51 recon_loss = F.binary_cross_entropy(recon, x, reduction='sum')
52
53 # KL divergence
54 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
55
56 # β-VAE for disentanglement
57 return recon_loss + beta * kl_div
58
59 def interpolate(self, x1, x2, steps=10):
60 """Interpolate between two music pieces"""
61 mu1, _ = self.encode(x1)
62 mu2, _ = self.encode(x2)
63
64 interpolations = []
65 for alpha in np.linspace(0, 1, steps):
66 z = (1 - alpha) * mu1 + alpha * mu2
67 recon = self.decode(z)
68 interpolations.append(recon)
69
70 return torch.stack(interpolations)
Graph Neural Networks for Musical Structure
GNNs model relationships between musical elements, from note-to-note connections to large-scale structural patterns.
Music as a graph with nodes and edges:
- • Nodes: Notes, chords, or measures
- • Edges: Temporal, harmonic, or voice-leading relationships
- • Features: Pitch, duration, dynamics
Message passing updates node representations:
1import torch_geometric
2from torch_geometric.nn import GCNConv, global_mean_pool
3
4class ChordProgressionGNN(nn.Module):
5 def __init__(self, node_features=12, hidden_dim=128, n_classes=7):
6 super(ChordProgressionGNN, self).__init__()
7
8 # Graph convolution layers
9 self.conv1 = GCNConv(node_features, hidden_dim)
10 self.conv2 = GCNConv(hidden_dim, hidden_dim)
11 self.conv3 = GCNConv(hidden_dim, hidden_dim)
12
13 # Edge type embedding (for different relationships)
14 self.edge_embedding = nn.Embedding(5, hidden_dim)
15
16 # Attention for node importance
17 self.attention = nn.Sequential(
18 nn.Linear(hidden_dim, 64),
19 nn.Tanh(),
20 nn.Linear(64, 1)
21 )
22
23 # Classification head
24 self.classifier = nn.Linear(hidden_dim, n_classes)
25
26 def build_chord_graph(self, chord_sequence):
27 """Convert chord sequence to graph"""
28 nodes = []
29 edges = []
30
31 for i, chord in enumerate(chord_sequence):
32 # Node features: pitch class profile
33 nodes.append(chord.pitch_classes)
34
35 # Temporal edges
36 if i > 0:
37 edges.append([i-1, i, 0]) # Previous chord
38
39 # Harmonic edges (circle of fifths)
40 for j, other in enumerate(chord_sequence):
41 if i != j:
42 interval = (chord.root - other.root) % 12
43 if interval == 7: # Perfect fifth
44 edges.append([i, j, 1])
45 elif interval == 5: # Perfect fourth
46 edges.append([i, j, 2])
47
48 return torch.tensor(nodes), torch.tensor(edges)
49
50 def forward(self, x, edge_index, batch=None):
51 # Graph convolutions
52 x = F.relu(self.conv1(x, edge_index))
53 x = F.dropout(x, p=0.5, training=self.training)
54
55 x = F.relu(self.conv2(x, edge_index))
56 x = F.dropout(x, p=0.5, training=self.training)
57
58 x = self.conv3(x, edge_index)
59
60 # Attention weights
61 attn_weights = self.attention(x)
62 attn_weights = F.softmax(attn_weights, dim=0)
63
64 # Weighted pooling
65 if batch is not None:
66 x = global_mean_pool(x * attn_weights, batch)
67 else:
68 x = (x * attn_weights).mean(dim=0, keepdim=True)
69
70 # Classification
71 return self.classifier(x)
Contrastive Learning for Music Representations
Self-supervised learning techniques that learn representations by contrasting positive and negative examples.
Contrastive loss for representation learning:
Learn by maximizing agreement between augmented views
Best Practices and Tips
Data Augmentation
Pitch shifting, time stretching, and mixing are crucial for robust models. Consider music-specific augmentations like key transposition.
Multi-Task Learning
Training on multiple related tasks (pitch, onset, dynamics) improves overall performance through shared representations.
Transfer Learning
Pre-trained models from speech or general audio often provide excellent initialization for music tasks.
Key Takeaways
- CNNs excel at spectral patterns: Treat spectrograms as images but respect frequency-time asymmetry.
- RNNs model temporal dependencies: LSTMs and GRUs capture musical sequences and long-term structure.
- Transformers revolutionize sequence modeling: Self-attention captures global dependencies without recurrence.
- Architecture matters: Music-specific design choices (harmonic filters, relative position) improve performance.