emb|从视频到音频:使用VIT进行音频分类( 二 )


       return len(self.samples)
   
   def __getitem__(self idx):
       fp target = self.samples[idx

       img = Image.open(fp)
       if self.transform:
           img = self.transform(img)
       return img target

train_dataset = AudioDataset(root transform=transforms.Compose([
   transforms.Resize((480 480))
   transforms.ToTensor()
   transforms.Normalize((0.5 0.5 0.5) (0.5 0.5 0.5))

))
ViT模型我们将利用ViT来作为我们的模型:Vision Transformer在论文中首次介绍了一幅图像等于16x16个单词 , 并成功地展示了这种方式不依赖任何的cnn , 直接应用于图像Patches序列的纯Transformer可以很好地执行图像分类任务 。

将图像分割成Patches , 并将这些Patches的线性嵌入序列作为Transformer的输入 。 Patches的处理方式与NLP应用程序中的标记(单词)是相同的 。
由于缺乏CNN固有的归纳偏差(如局部性) , Transformer在训练数据量不足时不能很好地泛化 。 但是当在大型数据集上训练时 , 它确实在多个图像识别基准上达到或击败了最先进的水平 。
实现的结构如下所示:
class ViT(nn.Sequential):
   def __init__(self    
               in_channels: int = 3
               patch_size: int = 16
               emb_size: int = 768
               img_size: int = 356
               depth: int = 12
               n_classes: int = 1000
               **kwargs):
       super().__init__(
           PatchEmbedding(in_channels patch_size emb_size img_size)
           TransformerEncoder(depth emb_size=emb_size **kwargs)
           ClassificationHead(emb_size n_classes)
       )
训练训练循环也是传统的训练过程:
vit = ViT(
   n_classes = len(train_dataset.classes)
)

vit.to(device)

# train
train_loader = DataLoader(train_dataset batch_size=32 shuffle=True)
optimizer = optim.Adam(vit.parameters() lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer 'max' factor=0.3 patience=3 verbose=True)
criterion = nn.CrossEntropyLoss()
num_epochs = 30

for epoch in range(num_epochs):
   print('Epoch {/{'.format(epoch num_epochs - 1))
   print('-' * 10)

   vit.train()

   running_loss = 0.0
   running_corrects = 0

   for inputs labels in tqdm.tqdm(train_loader):
       inputs = inputs.to(device)
       labels = labels.to(device)

       optimizer.zero_grad()

       with torch.set_grad_enabled(True):
           outputs = vit(inputs)
           loss = criterion(outputs labels)

           _ preds = torch.max(outputs 1)
           loss.backward()
           optimizer.step()

       running_loss += loss.item() * inputs.size(0)
       running_corrects += torch.sum(preds == labels.data)

   epoch_loss = running_loss / len(train_dataset)
   epoch_acc = running_corrects.double() / len(train_dataset)
   scheduler.step(epoch_acc)