diff --git a/main_onnx_amd.py b/main_onnx_amd.py index 5ac6df8..1b76191 100644 --- a/main_onnx_amd.py +++ b/main_onnx_amd.py @@ -116,7 +116,7 @@ def main(): if len(im.shape) == 3: im = im[None] - outputs = ort_sess.run(None, {'images': im}) + outputs = ort_sess.run(None, {'images': np.array(im)}) im = torch.from_numpy(outputs[0]).to('cpu')