Pytorch如何评估模型的复杂度

2024-03-15 11:40:50
/
0 点赞
/
74 阅读
2024-03-15

前言

FLOPS(Floating-Point Operations per Second)

每秒所执行的浮点运算次数,是计算设备的计算速度指标,主要是衡量硬件性能

GFLOPs(Giga-FLOPs)

每秒执行的十亿次浮点运算,主要是衡量硬件性能

FLOPs(Floating-Point Operations)

某个任务或算法中执行的总浮点运算次数,用于衡量计算复杂度或算法的计算量。
主要是衡量模型或算法复杂度

评估模型的复杂度

import time
import torch
from thop import profile, clever_format
from net.net import net
 
width = 3840
height = 2160

"""
NVIDIA RTX 3090
Number of parameters: 341.767K
Size of model: 1.30 MB
Computational complexity: 2.828T FLOPs
device: cuda - fps: 1304.787
"""

def compute_FLOPs_and_model_size(model):
    input = torch.randn(1, 3, width, height).cuda() 
    macs, params = profile(model, inputs=(input,), verbose=False)
    return macs, params
 
@torch.no_grad()
 
def compute_fps(model, shape, epoch=100, device=None):
    total_time = 0.0
 
    if not device:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
 
    for i in range(epoch):
        data = torch.randn(shape).cuda()
 
        start = time.time()
        outputs = model(data)
        end = time.time()
 
        total_time += (end - start)
 
    return epoch/total_time
 
 
def test_model_flops():
    model = net()     #这里使用你的模型
    model.cuda()
 
    FLOPs, params = compute_FLOPs_and_model_size(model)

    model_size = params * 4.0 / 1024 / 1024
    params_M = params/pow(10, 6)
    flops, params = clever_format([FLOPs, params], "%.3f")
 
    print('Number of parameters: {}'.format(params))
    print('Size of model: {:.2f} MB'.format(model_size))
    print('Computational complexity: {} FLOPs'.format(flops))
 
def test_fps():
    model = net()       #这里使用你的模型
    model.cuda()
 
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    fps = compute_fps(model, (1, 3, width, height), device=device)
    print('device: {} - fps: {:.3f}'.format(device.type, fps))
 
 
if __name__ == '__main__':
    test_model_flops()
    test_fps()

参考

版权属于:

那棵树看起来生气了

本文链接:

(转载时请注明本文出处及文章链接)