diff --git a/Caffe/layer_param.py b/Caffe/layer_param.py index fd966d7..b415687 100755 --- a/Caffe/layer_param.py +++ b/Caffe/layer_param.py @@ -93,7 +93,7 @@ def norm_param(self, eps): self.param.norm_param.CopyFrom(l2norm_param) - def permute_param(self, order1, order2, order3, order4): + def permute_param(self, order): """ add a conv_param layer if you spec the layer type "Convolution" Args: @@ -105,7 +105,7 @@ def permute_param(self, order1, order2, order3, order4): Returns: """ permute_param = pb.PermuteParameter() - permute_param.order.extend([order1, order2, order3, order4]) + permute_param.order.extend(*order) self.param.permute_param.CopyFrom(permute_param) @@ -180,4 +180,4 @@ def copy_from(self,layer_param): pass def set_enum(param,key,value): - setattr(param,key,param.Value(value)) \ No newline at end of file + setattr(param,key,param.Value(value)) diff --git a/pytorch_to_caffe.py b/pytorch_to_caffe.py index 3ab7154..4e3de94 100755 --- a/pytorch_to_caffe.py +++ b/pytorch_to_caffe.py @@ -607,12 +607,8 @@ def _permute(input, *args): log.add_blobs([x], name='permute_blob') layer = caffe_net.Layer_param(name=name, type='Permute', bottom=[log.blobs(input)], top=[log.blobs(x)]) - order1 = args[0] - order2 = args[1] - order3 = args[2] - order4 = args[3] - layer.permute_param(order1, order2, order3, order4) + layer.permute_param(*args) log.cnet.add_layer(layer) return x