该章节教大家如何突破tensorrt加速默认推理尺寸(640*640)的限制,全网该类教程几乎没有,看到就是赚到!!!
问题描述:
可以想象这样一个场景:我们有一个1280*720的私人数据集,在进行正常的训练、验证、tensorrt转换、用engine文件推理,engine文件推理时,发现终端栏显示(160 * 640)的一个尺寸,这与我们的数据集尺寸不匹配。
我当时注意到这个问题,就在想如何做到推理与数据集尺寸匹配呢?
首先我查看了YOLOV8的官方文档,在模型训练和推理时可以设置imgsz
参数,我当即在运行脚本设置了imgsz=(720,1280)
注:参数格式imgsz=(h,w)
,先输入h,再输入w。
输入之后:
报错:input size torch.Size([1,3,640,640]) not equal to max model size (1,3,1280,1280)
会不会是训练和转换时没设置参数导致该问题?
于是我添加 imgsz=[1280,720]
参数,并重新训练
注:训练时的imgsz
参数格式是imgsz=[w,h]
训练脚本:
1 2 3 4 5 6 7 8 9 10
| from ultralytics import YOLO
# 加载模型 model = YOLO('yolov8n-seg.yaml').load('yolov8n-seg.pt') # 从YAML构建并转移权重
if __name__ == '__main__': # 训练模型 results = model.train(data='seg.yaml', epochs=100, imgsz=[1280,720]) #
metrics = model.val()
|
训练结束得到best.pt文件,我们进行tensorrt加速,pt转化engine文件:
pt转化engine脚本:
1 2 3 4 5 6 7
| from ultralytics import YOLO
# Load a model model = YOLO('best.pt') # load a custom trained
# Export the model model.export(format='engine',half=True,simplify=True,imgsz=(736,1280))
|
注:转换时的imgsz
参数格式是imgsz=[w,h]
最终还是报错:input size torch.Size([1,3,640,640]) not equal to max model size (1,3,1280,1280)`
我搜集了全网都没有找到相关答案(甚至看了YOLOv8作者的答案都无法解决问题)
解决方法:
耗时半个月,我终于找到了解决方法,以上步骤是完全正确的,大家可以放心学,我们只需要修改yolo的代码参数就能正常推理engine文件了。
首先,找到ultralytics/nn/autobackend.py
文件,打开,ctrl+f
搜索warmup
把(1,3,640,640)改成(1,3,720,1280)
ctrl+s
保存,接着我们在运行脚本添加imgsz=(720,1280)
参数后就可以推理啦
最后给大家分享我的运行脚本:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import cv2 from ultralytics import YOLO from cv2 import getTickCount, getTickFrequency # 加载 YOLOv8 模型 model = YOLO("best.engine",task='segment')
# 获取摄像头内容,参数 0 表示使用默认的摄像头 cap = cv2.VideoCapture(0)
while cap.isOpened(): loop_start = getTickCount() success,frame = cap.read() # 读取摄像头的一帧图像
if success : results = model.predict(frame,imgsz=(720,1280)) # 对当前帧进行目标检测并显示结果 annotated_frame = results[0].plot()
# 中间放自己的显示程序 loop_time = getTickCount() - loop_start total_time = loop_time / (getTickFrequency()) FPS = int(1 / total_time) # 在图像左上角添加FPS文本 fps_text = f"FPS: {FPS:.2f}" font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1 font_thickness = 2 text_color = (0, 255, 0) # 绿色 text_position = (10, 30) # 左上角位置
cv2.putText(annotated_frame, fps_text, text_position, font, font_scale, text_color, font_thickness) cv2.namedWindow("img",cv2.WINDOW_NORMAL) cv2.imshow('img', annotated_frame) # 通过按下 'q' 键退出循环 if cv2.waitKey(1) & 0xFF == ord('q'): break
cap.release() # 释放摄像头资源 cv2.destroyAllWindows() # 关闭OpenCV窗口
|
希望我的文章能对你有所帮助,有问题联系邮箱:[email protected]