FPGA树冠识别算法迁移

GA666666 2025-07-23 AM 15℃ 0条

一、算法迁移

1. 方案

采用FPGA-ZYNQ平台结合 PYNQ(Python productivity for Zynq - Home)技术

2. 流程

  1. 获取deepforest-NEON.pt 模型
  2. 使用pytorch将pt模型转换为onnx模型
  3. vitis-ai工具将模型转换为dpu模型
  4. 开发版安装pynq,dpu相关依赖
  5. 测试模型速度
  6. 优化模型大小

    3. 整体架构

4. 工具链

  1. FPGA Vitis AI
  2. Pytorch
  3. PYNQ

    5. 模型测试

二、进展

  1. Onnx模型转换
    1.1 重构deepforest源码模型

    import torch
    import typing
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision
    from torchvision.models.detection.retinanet import RetinaNet
    from torchvision.models.detection.retinanet import AnchorGenerator
    from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights
    from torchvision.models.detection import retinanet_resnet50_fpn
    
    def load_backbone():
     backbone = torchvision.models.detection.retinanet_resnet50_fpn(
         weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1)
     return backbone
    
    
    def create_model(num_classes, nms_thresh, score_thresh, backbone=None):
     if not backbone:
         resnet = load_backbone()
         backbone = resnet.backbone
     model = RetinaNet(backbone=backbone, num_classes=num_classes)
     model.nms_thresh = nms_thresh
     model.score_thresh = score_thresh
     return model
    
    def predict_image(image: typing.Optional[np.ndarray] = None,path: typing.Optional[str] = None,return_plot: bool = False,thickness: int = 1,color: typing.Optional[tuple] = (0, 165, 255)):
         if path:
             image = np.array(Image.open(path).convert("RGB")).astype("float32")
    
         # sanity checks on input images
         if not type(image) == np.ndarray:
             raise TypeError("Input image is of type {}, expected numpy, if reading "
                             "from PIL, wrap in "
                             "np.array(image).astype(float32)".format(type(image)))
    
         model.eval()
         model.score_thresh = 0.3
    
         if image.dtype != "float32":
             warnings.warn(f"Image type is {image.dtype}, transforming to float32. "
                           f"This assumes that the range of pixel values is 0-255, as "
                           f"opposed to 0-1.To suppress this warning, transform image "
                           f"(image.astype('float32')")
             image = image.astype("float32")
         print(1)
         image = torch.tensor(image, device='cpu').permute(2, 0, 1)
         image = image / 255
         print(2)
    
         with torch.no_grad():
             prediction = model(image.unsqueeze(0))
             print(image.unsqueeze(0).shape)
         print(3)
         # return None for no predictions
         if len(prediction[0]["boxes"]) == 0:
             return None
         print(4)
         df = visualize.format_boxes(prediction[0])
         df = predict.across_class_nms(df, iou_threshold=0.05)
    
         if return_plot:
             # Bring to gpu
             image = image.cpu()
    
             # Cv2 likes no batch dim, BGR image and channels last, 0-255
             image = np.array(image.squeeze(0))
             image = np.rollaxis(image, 0, 3)
             image = image[:, :, ::-1] * 255
             image = image.astype("uint8")
             image = visualize.plot_predictions(image,
                                                df,
                                                color=color,
                                                thickness=thickness)
    
             return image
         else:
             if path:
                 df["image_path"] = os.path.basename(path)
    
             df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])
    
         return df
    1.2 类型转换
    num_classes = 1
    label_dict = {"Tree": 0}
    score_thresh = 0.3
    nms_thresh= 0.05
    
    model = create_model(num_classes, nms_thresh, score_thresh)
    model.load_state_dict(torch.load('NEON.pt'))
    model.eval()
    
    dummy_input = torch.randn(1, 3, 400, 400)
    
    torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True)
    1.3 模型验证
    onnx_path = "alexnet.onnx"
    model = onnx.load(onnx_path)
    
    onnx.checker.check_model(onnx_path)
    
    session = ort.InferenceSession(onnx_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    print(input_name)
    print(output_name)
    path = 'old_project/train_data_folder/2019_YELL_2_528000_4978000_image_crop2_10.png'
    image = np.array(Image.open(path).convert("RGB")).astype("float32")
    
    if image.dtype != "float32":
     image = image.astype("float32")
    image = torch.tensor(image, device='cpu').permute(2, 0, 1)
    image = image / 255
    image = image.unsqueeze(0)
    print(image.unsqueeze(0).shape)
    
    output_data = session.run([output_name], {input_name: image.numpy()})
    print(output_data)
    
    images
    3345
    torch.Size([1, 1, 3, 400, 400])
    [array([[313.0144    ,   0.        , 390.55798   ,  60.181404  ],
        [ 21.69411   , 243.14648   ,  68.56987   , 291.38635   ],
        [ 29.7658    , 308.8676    ,  74.60773   , 349.61908   ],
        [ 15.569448  , 352.18268   ,  58.870987  , 393.23517   ],
        [ 86.635864  , 237.48146   , 130.74779   , 281.5208    ],
        [138.22173   , 288.21066   , 182.91176   , 335.15622   ],
        [ 89.589966  , 153.33046   , 170.03311   , 233.51076   ],
        [306.25778   ,  73.54499   , 344.38382   , 109.25332   ],
        [ 86.35097   , 353.92947   , 133.08755   , 398.1063    ],
        [308.7377    , 340.64532   , 347.70377   , 379.64368   ],
        [185.83337   ,  32.636623  , 252.12595   , 101.247086  ],
        [141.25052   ,  95.265274  , 183.53696   , 132.68791   ],
        [166.71738   , 134.74255   , 198.74393   , 169.84103   ],
        [  0.42692566, 153.42303   ,   9.319502  , 180.90723   ],
        [123.06708   , 274.01016   , 147.02313   , 298.4445    ],
        [199.37772   , 373.3064    , 259.72678   , 399.9558    ],
        [263.13043   , 305.88364   , 310.36835   , 361.31857   ],
        [  0.        , 339.92944   ,  23.363434  , 366.08337   ],
        [306.46985   , 304.89267   , 346.48383   , 341.3264    ]],
       dtype=float32)]
    
标签: none

非特殊说明,本博所有文章均为博主原创。

评论啦~