首頁(yè) > AI資訊 > 最新資訊 > 量化自定義PyTorch模型入門(mén)教程

量化自定義PyTorch模型入門(mén)教程

新火種    2023-09-26

來(lái)源:Deephub Imba

本文約2500字,建議閱讀5分鐘

本文將使用CIFAR 10和一個(gè)自定義AlexNet模型。

在以前Pytorch只有一種量化的方法,叫做“eager mode qunatization”,在量化我們自定定義模型時(shí)經(jīng)常會(huì)產(chǎn)生奇怪的錯(cuò)誤,并且很難解決。但是最近,PyTorch發(fā)布了一種稱(chēng)為“fx-graph-mode-qunatization”的方方法。在本文中我們將研究這個(gè)fx-graph-mode-qunatization”看看它能不能讓我們的量化操作更容易,更穩(wěn)定。

本文將使用CIFAR 10和一個(gè)自定義AlexNet模型,我對(duì)這個(gè)模型進(jìn)行了小的修改以提高效率,最后就是因?yàn)槟P秃蛿?shù)據(jù)集都很小,所以CPU也可以跑起來(lái)。

 import os import cv2 import time import torch import numpy as np import torchvision from PIL import Image import torch.nn as nn import matplotlib.pyplot as plt from torchvision import transforms from torchvision import datasets, models, transforms device = "cpu" print(device) transform = transforms.Compose([    transforms.Resize(224),    transforms.ToTensor(),    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])    ]) batch_size = 8 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform) testset = torchvision.datasets.CIFAR10(root='./data', train=False,                                        download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,                                          shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,                                          shuffle=False, num_workers=2) def print_model_size(mdl):    torch.save(mdl.state_dict(), "tmp.pt")    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))    os.remove('tmp.pt')

模型代碼如下,使用AlexNet是因?yàn)樗宋覀內(nèi)粘S玫降幕緦樱?/p>

 from torch.nn import init class mAlexNet(nn.Module):    def __init__(self, num_classes=2):        super().__init__()        self.input_channel = 3        self.num_output = num_classes        self.layer1 = nn.Sequential(            nn.Conv2d(in_channels=self.input_channel, out_channels= 16, kernel_size= 11, stride= 4),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2)        )        init.xavier_uniform_(self.layer1[0].weight,gain= nn.init.calculate_gain('conv2d'))        self.layer2 = nn.Sequential(            nn.Conv2d(in_channels= 16, out_channels= 20, kernel_size= 5, stride= 1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2)        )        init.xavier_uniform_(self.layer2[0].weight,gain= nn.init.calculate_gain('conv2d'))        self.layer3 = nn.Sequential(            nn.Conv2d(in_channels= 20, out_channels= 30, kernel_size= 3, stride= 1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2)        )        init.xavier_uniform_(self.layer3[0].weight,gain= nn.init.calculate_gain('conv2d'))        self.layer4 = nn.Sequential(            nn.Linear(30*3*3, out_features=48),            nn.ReLU(inplace=True)        )        init.kaiming_normal_(self.layer4[0].weight, mode='fan_in', nonlinearity='relu')        self.layer5 = nn.Sequential(            nn.Linear(in_features=48, out_features=self.num_output)        )        init.kaiming_normal_(self.layer5[0].weight, mode='fan_in', nonlinearity='relu')    def forward(self, x):        x = self.layer1(x)        x = self.layer2(x)        x = self.layer3(x)        # Squeezes or flattens the image, but keeps the batch dimension        x = x.reshape(x.size(0), -1)        x = self.layer4(x)        logits= self.layer5(x)        return logits model = mAlexNet(num_classes= 10).to(device)

現(xiàn)在讓我們用基本精度模型做一個(gè)快速的訓(xùn)練循環(huán)來(lái)獲得基線(xiàn):

 import torch.optim as optim  def train_model(model):  criterion = nn.CrossEntropyLoss()  optimizer = optim.SGD(model.parameters(), lr=0.001, momentum = 0.9)   for epoch in range(2):    running_loss =0.0         for i, data in enumerate(trainloader,0):             inputs, labels = data      inputs, labels = inputs.to(device), labels.to(device)       optimizer.zero_grad()      outputs = model(inputs)      loss = criterion(outputs, labels)      loss.backward()      optimizer.step()       # print statistics      running_loss += loss.item()      if i % 1000 == 999:        print(f'[Ep: {epoch + 1}, Step: {i + 1:5d}] loss: {running_loss / 2000:.3f}')        running_loss = 0.0     return model  model = train_model(model) PATH = './float_model.pth' torch.save(model.state_dict(), PATH)

可以看到損失是在降低的,我們這里只演示量化,所以就訓(xùn)練了2輪,對(duì)于準(zhǔn)確率我們只做對(duì)比。

我將做所有三種可能的量化:

動(dòng)態(tài)量化 Dynamic qunatization:使權(quán)重為整數(shù)(訓(xùn)練后)靜態(tài)量化 Static quantization:使權(quán)值和激活值為整數(shù)(訓(xùn)練后)量化感知訓(xùn)練 Quantization aware training:以整數(shù)精度對(duì)模型進(jìn)行訓(xùn)練

我們先從動(dòng)態(tài)量化開(kāi)始:

 import torch from torch.ao.quantization import (  get_default_qconfig_mapping,  get_default_qat_qconfig_mapping,  QConfigMapping, ) import torch.ao.quantization.quantize_fx as quantize_fx import copy  # Load float model model_fp = mAlexNet(num_classes= 10).to(device) model_fp.load_state_dict(torch.load("./float_model.pth", map_location=device))  # Copy model to qunatize model_to_quantize = copy.deepcopy(model_fp).to(device) model_to_quantize.eval() qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)  # a tuple of one or more example inputs are needed to trace the model example_inputs = next(iter(trainloader))[0]  # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping,                  example_inputs) # no calibration needed when we only have dynamic/weight_only quantization # quantize model_quantized_dynamic = quantize_fx.convert_fx(model_prepared)

正如你所看到的,只需要通過(guò)模型傳遞一個(gè)示例輸入來(lái)校準(zhǔn)量化層,所以代碼十分簡(jiǎn)單,看看我們的模型對(duì)比:

 print_model_size(model) print_model_size(model_quantized_dynamic)

可以看到的,減少了0.03 MB或者說(shuō)模型變?yōu)榱嗽瓉?lái)的75%,我們可以通過(guò)靜態(tài)模式量化使其更小:

 model_to_quantize = copy.deepcopy(model_fp) qconfig_mapping = get_default_qconfig_mapping("qnnpack") model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) # calibrate with torch.no_grad():    for i in range(20):        batch = next(iter(trainloader))[0]        output = model_prepared(batch.to(device))

靜態(tài)量化與動(dòng)態(tài)量化是非常相似的,我們只需要傳遞更多批次的數(shù)據(jù)來(lái)更好地校準(zhǔn)模型。

讓我們看看這些步驟是如何影響模型的:

可以看到其實(shí)程序?yàn)槲覀冏隽撕芏嗍虑?,所以我們才可以?zhuān)注于功能而不是具體的實(shí)現(xiàn),通過(guò)以上的準(zhǔn)備,我們可以進(jìn)行最后的量化了:

 # quantize model_quantized_static = quantize_fx.convert_fx(model_prepared)

量化后的model_quantized_static看起來(lái)像這樣:

現(xiàn)在可以更清楚地看到,將Conv2d和Relu層融合并替換為相應(yīng)的量子化對(duì)應(yīng)層,并對(duì)其進(jìn)行校準(zhǔn)??梢詫⑦@些模型與最初的模型進(jìn)行比較:

 print_model_size(model) print_model_size(model_quantized_dynamic) print_model_size(model_quantized_static)

量子化后的模型比原來(lái)的模型小3倍,這對(duì)于大模型來(lái)說(shuō)非常重要

現(xiàn)在讓我們看看如何在量化的情況下訓(xùn)練模型,量化感知的訓(xùn)練就需要在訓(xùn)練的時(shí)候加入量化的操作,代碼如下:

 model_to_quantize = mAlexNet(num_classes= 10).to(device) qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack") model_to_quantize.train() # prepare model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)  # training loop model_trained_prepared = train_model(model_prepared)  # quantize model_quantized_trained = quantize_fx.convert_fx(model_trained_prepared)

讓我們比較一下到目前為止所有模型的大小。

 print("Regular floating point model: " ) print_model_size( model_fp) print("Weights only qunatization: ") print_model_size( model_quantized_dynamic) print("Weights/Activations only qunatization: ") print_model_size(model_quantized_static) print("Qunatization aware trained: ") print_model_size(model_quantized_trained)

量化感知的訓(xùn)練對(duì)模型的大小沒(méi)有任何影響,但它能提高準(zhǔn)確率嗎?

 def get_accuracy(model):  correct = 0  total = 0  with torch.no_grad():      for data in testloader:          images, labels = data          images, labels = images, labels          outputs = model(images)          _, predicted = torch.max(outputs.data, 1)          total += labels.size(0)          correct += (predicted == labels).sum().item()       return 100 * correct / total  fp_model_acc = get_accuracy(model) dy_model_acc = get_accuracy(model_quantized_dynamic) static_model_acc = get_accuracy(model_quantized_static) q_trained_model_acc = get_accuracy(model_quantized_trained)   print("Acc on fp_model:" ,fp_model_acc) print("Acc weigths only quantization:", dy_model_acc) print("Acc weigths/activations quantization" ,static_model_acc) print("Acc on qunatization awere trained model:" ,q_trained_model_acc)

為了更方便的比較,我們可視化一下:

可以看到基礎(chǔ)模型與量化模型具有相似的準(zhǔn)確性,但模型尺寸大大減小,這在我們希望將其部署到服務(wù)器或低功耗設(shè)備上時(shí)至關(guān)重要。

最后一些資料:

/uploads/pic/20230926/cm4l5fmcmgr.html data-track="84">https://pytorch.org/docs/stable/quantization.html

本文代碼:

/uploads/pic/20230926/hxd0c4yyfk0

相關(guān)推薦
免責(zé)聲明
本文所包含的觀點(diǎn)僅代表作者個(gè)人看法,不代表新火種的觀點(diǎn)。在新火種上獲取的所有信息均不應(yīng)被視為投資建議。新火種對(duì)本文可能提及或鏈接的任何項(xiàng)目不表示認(rèn)可。 交易和投資涉及高風(fēng)險(xiǎn),讀者在采取與本文內(nèi)容相關(guān)的任何行動(dòng)之前,請(qǐng)務(wù)必進(jìn)行充分的盡職調(diào)查。最終的決策應(yīng)該基于您自己的獨(dú)立判斷。新火種不對(duì)因依賴(lài)本文觀點(diǎn)而產(chǎn)生的任何金錢(qián)損失負(fù)任何責(zé)任。

熱門(mén)文章