一、算法迁移
1. 方案
采用FPGA-ZYNQ平台结合 PYNQ(Python productivity for Zynq - Home)技术
2. 流程
- 获取deepforest-NEON.pt 模型
- 使用pytorch将pt模型转换为onnx模型
- vitis-ai工具将模型转换为dpu模型
- 开发版安装pynq,dpu相关依赖
- 测试模型速度
优化模型大小
3. 整体架构
4. 工具链
- FPGA Vitis AI
- Pytorch
PYNQ
5. 模型测试
二、进展
Onnx模型转换
1.1 重构deepforest源码模型import torch import typing from PIL import Image import matplotlib.pyplot as plt import numpy as np import torchvision from torchvision.models.detection.retinanet import RetinaNet from torchvision.models.detection.retinanet import AnchorGenerator from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights from torchvision.models.detection import retinanet_resnet50_fpn def load_backbone(): backbone = torchvision.models.detection.retinanet_resnet50_fpn( weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1) return backbone def create_model(num_classes, nms_thresh, score_thresh, backbone=None): if not backbone: resnet = load_backbone() backbone = resnet.backbone model = RetinaNet(backbone=backbone, num_classes=num_classes) model.nms_thresh = nms_thresh model.score_thresh = score_thresh return model def predict_image(image: typing.Optional[np.ndarray] = None,path: typing.Optional[str] = None,return_plot: bool = False,thickness: int = 1,color: typing.Optional[tuple] = (0, 165, 255)): if path: image = np.array(Image.open(path).convert("RGB")).astype("float32") # sanity checks on input images if not type(image) == np.ndarray: raise TypeError("Input image is of type {}, expected numpy, if reading " "from PIL, wrap in " "np.array(image).astype(float32)".format(type(image))) model.eval() model.score_thresh = 0.3 if image.dtype != "float32": warnings.warn(f"Image type is {image.dtype}, transforming to float32. " f"This assumes that the range of pixel values is 0-255, as " f"opposed to 0-1.To suppress this warning, transform image " f"(image.astype('float32')") image = image.astype("float32") print(1) image = torch.tensor(image, device='cpu').permute(2, 0, 1) image = image / 255 print(2) with torch.no_grad(): prediction = model(image.unsqueeze(0)) print(image.unsqueeze(0).shape) print(3) # return None for no predictions if len(prediction[0]["boxes"]) == 0: return None print(4) df = visualize.format_boxes(prediction[0]) df = predict.across_class_nms(df, iou_threshold=0.05) if return_plot: # Bring to gpu image = image.cpu() # Cv2 likes no batch dim, BGR image and channels last, 0-255 image = np.array(image.squeeze(0)) image = np.rollaxis(image, 0, 3) image = image[:, :, ::-1] * 255 image = image.astype("uint8") image = visualize.plot_predictions(image, df, color=color, thickness=thickness) return image else: if path: df["image_path"] = os.path.basename(path) df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x]) return df 1.2 类型转换 num_classes = 1 label_dict = {"Tree": 0} score_thresh = 0.3 nms_thresh= 0.05 model = create_model(num_classes, nms_thresh, score_thresh) model.load_state_dict(torch.load('NEON.pt')) model.eval() dummy_input = torch.randn(1, 3, 400, 400) torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True) 1.3 模型验证 onnx_path = "alexnet.onnx" model = onnx.load(onnx_path) onnx.checker.check_model(onnx_path) session = ort.InferenceSession(onnx_path) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name print(input_name) print(output_name) path = 'old_project/train_data_folder/2019_YELL_2_528000_4978000_image_crop2_10.png' image = np.array(Image.open(path).convert("RGB")).astype("float32") if image.dtype != "float32": image = image.astype("float32") image = torch.tensor(image, device='cpu').permute(2, 0, 1) image = image / 255 image = image.unsqueeze(0) print(image.unsqueeze(0).shape) output_data = session.run([output_name], {input_name: image.numpy()}) print(output_data) images 3345 torch.Size([1, 1, 3, 400, 400]) [array([[313.0144 , 0. , 390.55798 , 60.181404 ], [ 21.69411 , 243.14648 , 68.56987 , 291.38635 ], [ 29.7658 , 308.8676 , 74.60773 , 349.61908 ], [ 15.569448 , 352.18268 , 58.870987 , 393.23517 ], [ 86.635864 , 237.48146 , 130.74779 , 281.5208 ], [138.22173 , 288.21066 , 182.91176 , 335.15622 ], [ 89.589966 , 153.33046 , 170.03311 , 233.51076 ], [306.25778 , 73.54499 , 344.38382 , 109.25332 ], [ 86.35097 , 353.92947 , 133.08755 , 398.1063 ], [308.7377 , 340.64532 , 347.70377 , 379.64368 ], [185.83337 , 32.636623 , 252.12595 , 101.247086 ], [141.25052 , 95.265274 , 183.53696 , 132.68791 ], [166.71738 , 134.74255 , 198.74393 , 169.84103 ], [ 0.42692566, 153.42303 , 9.319502 , 180.90723 ], [123.06708 , 274.01016 , 147.02313 , 298.4445 ], [199.37772 , 373.3064 , 259.72678 , 399.9558 ], [263.13043 , 305.88364 , 310.36835 , 361.31857 ], [ 0. , 339.92944 , 23.363434 , 366.08337 ], [306.46985 , 304.89267 , 346.48383 , 341.3264 ]], dtype=float32)]