最近想将fastacvnet++部署到RK3576上,大佬能否指点一下
提个小建议,源代码里面
att_topk = torch.gather(att_weights_prob, 2, ind_k) #源代码
可以换成
#-------------------------------------------------修改部分-----------------------------------------------------#
# -------------------------------------------------v0.1-----------------------------------------------------#
att_4d = att_weights_prob.squeeze(1) # [B,D,H,W]
ind_4d = ind_k.squeeze(1) # [B,k,H,W]
att_topk = torch.gather(att_4d, 1, ind_4d)
disparity_sample_topk = ind_k.squeeze(1).float()
# ------------------------------------------------修改部分-----------------------------------------------------#
因为源代码导出onnx这里的node是5D-tensor,而大部分嵌入式平台对4D-tensor更友好
最近想将fastacvnet++部署到RK3576上,大佬能否指点一下
提个小建议,源代码里面
att_topk = torch.gather(att_weights_prob, 2, ind_k) #源代码
可以换成
#-------------------------------------------------修改部分-----------------------------------------------------#
# -------------------------------------------------v0.1-----------------------------------------------------#
att_4d = att_weights_prob.squeeze(1) # [B,D,H,W]
ind_4d = ind_k.squeeze(1) # [B,k,H,W]
att_topk = torch.gather(att_4d, 1, ind_4d)
disparity_sample_topk = ind_k.squeeze(1).float()
# ------------------------------------------------修改部分-----------------------------------------------------#
因为源代码导出onnx这里的node是5D-tensor,而大部分嵌入式平台对4D-tensor更友好