Skip to content

Commit 2ec7491

Browse files
committed
[PT FE]: support aten::view_as and aten::std
1 parent 1501e29 commit 2ec7491

File tree

4 files changed

+51
-26
lines changed

4 files changed

+51
-26
lines changed

src/frontends/pytorch/src/op/var_mean.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "openvino/op/reduce_mean.hpp"
1111
#include "openvino/op/reduce_prod.hpp"
1212
#include "openvino/op/shape_of.hpp"
13+
#include "openvino/op/sqrt.hpp"
1314
#include "openvino/op/subtract.hpp"
1415
#include "utils.hpp"
1516

@@ -80,6 +81,18 @@ OutputVector translate_var(const NodeContext& context) {
8081
return {res[0]};
8182
}
8283

84+
OutputVector translate_std(const NodeContext& context) {
85+
auto res = translate_var_mean(context);
86+
auto var = res[0];
87+
return {context.mark_node(std::make_shared<v0::Sqrt>(var))};
88+
}
89+
90+
OutputVector translate_std_mean(const NodeContext& context) {
91+
auto res = translate_var_mean(context);
92+
auto var = res[0];
93+
return {context.mark_node(std::make_shared<v0::Sqrt>(var)), res[1]};
94+
}
95+
8396
} // namespace op
8497
} // namespace pytorch
8598
} // namespace frontend

src/frontends/pytorch/src/op_table.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ OP_CONVERTER(translate_softmax);
149149
OP_CONVERTER(translate_sort);
150150
OP_CONVERTER(translate_square);
151151
OP_CONVERTER(translate_squeeze);
152+
OP_CONVERTER(translate_std);
153+
OP_CONVERTER(translate_std_mean);
152154
OP_CONVERTER(translate_sub);
153155
OP_CONVERTER(translate_sum);
154156
OP_CONVERTER(translate_t);
@@ -407,6 +409,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
407409
{"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
408410
{"aten::square", op::translate_square},
409411
{"aten::squeeze", op::quantizable_op<op::translate_squeeze>},
412+
{"aten::std", op::translate_std},
413+
{"aten::std_mean", op::translate_std_mean},
410414
{"aten::sub", op::translate_sub},
411415
{"aten::sub_", op::inplace_op<op::translate_sub>},
412416
{"aten::sum", op::translate_sum},
@@ -442,6 +446,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
442446
{"aten::var", op::translate_var},
443447
{"aten::var_mean", op::translate_var_mean},
444448
{"aten::view", op::quantizable_op<op::translate_reshape>},
449+
{"aten::view_as", op::translate_reshape_as},
445450
{"aten::where", op::translate_where},
446451
{"aten::zero_", op::inplace_op<op::translate_zeros_like>},
447452
{"aten::zeros", op::translate_zeros},

tests/layer_tests/pytorch_tests/test_reshape_as.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,25 @@
88
from pytorch_layer_test_class import PytorchLayerTest
99

1010

11-
@pytest.mark.parametrize('input_tesnors', ((np.ones((3, 6)), np.ones((2, 9))),
12-
(np.ones((2, 2, 3)), np.ones((6, 2))),
13-
(np.ones((6, 2)), np.ones((2, 2, 3)))))
1411
class TestReshapeAs(PytorchLayerTest):
1512

16-
def _prepare_input(self):
17-
return self.input_tesnors
13+
def _prepare_input(self, shape1, shape2):
14+
return (np.ones(shape1, dtype=np.float32), np.ones(shape2, dtype=np.float32))
1815

19-
def create_model(self):
16+
def create_model(self, op):
2017
class aten_reshape_as(torch.nn.Module):
2118

2219
def forward(self, input_tensor, shape_tensor):
2320
return input_tensor.reshape_as(shape_tensor)
2421

2522
ref_net = None
2623

27-
return aten_reshape_as(), ref_net, "aten::reshape_as"
24+
return aten_reshape_as(), ref_net, "aten::{op}"
2825

2926
@pytest.mark.nightly
3027
@pytest.mark.precommit
31-
def test_reshape_as(self, ie_device, precision, ir_version, input_tesnors):
32-
self.input_tesnors = input_tesnors
33-
self._test(*self.create_model(), ie_device, precision, ir_version)
28+
@pytest.mark.parametrize("op", ["reshape_as", "view_as"])
29+
@pytest.mark.parametrize('input_tesnor_shapes',( ((3, 6), (2, 9)), ((2, 2, 3), (6, 2)), ((6, 2), (2, 2, 3))))
30+
def test_reshape_as(self, op, input_tensor_shapes, ie_device, precision, ir_version):
31+
self._test(*self.create_model(op), ie_device, precision, ir_version,
32+
kwargs_to_prepare_input={"shape1": input_tensor_shapes[0], "shape2": input_tensor_shapes[1]})

tests/layer_tests/pytorch_tests/test_var_mean.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,55 @@ def _prepare_input(self):
1111
import numpy as np
1212
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
1313

14-
def create_model(self, unbiased, dim=None, keepdim=True, two_args_case=True, return_mean=False):
14+
def create_model(self, unbiased, dim=None, keepdim=True, two_args_case=True, op_type="var"):
1515
import torch
1616

17+
ops = {
18+
"var": torch.var,
19+
"var_mean": torch.var_mean,
20+
"std": torch.std,
21+
"std_mean": torch.std_mean
22+
}
23+
24+
op = ops[op_type]
25+
1726
class aten_var(torch.nn.Module):
18-
def __init__(self, dim, unbiased, keepdim, return_mean):
27+
def __init__(self, dim, unbiased, keepdim, op):
1928
super(aten_var, self).__init__()
2029
self.unbiased = unbiased
2130
self.dim = dim
2231
self.keepdim = keepdim
23-
self.op = torch.var if not return_mean else torch.var_mean
32+
self.op = op
2433

2534
def forward(self, x):
2635
return self.op(x, self.dim, unbiased=self.unbiased, keepdim=self.keepdim)
2736

2837
class aten_var2args(torch.nn.Module):
29-
def __init__(self, unbiased, return_mean):
38+
def __init__(self, unbiased, op):
3039
super(aten_var2args, self).__init__()
3140
self.unbiased = unbiased
32-
self.op = torch.var if not return_mean else torch.var_mean
33-
41+
self.op = op
3442
def forward(self, x):
35-
return torch.var(x, self.unbiased)
43+
return self.op(x, self.unbiased)
3644

3745
ref_net = None
38-
op_name = "aten::var" if not return_mean else "aten::var_mean"
46+
op_name = f"aten::{op_type}"
3947
if two_args_case:
40-
return aten_var2args(unbiased, return_mean), ref_net, op_name
41-
return aten_var(dim, unbiased, keepdim, return_mean), ref_net, op_name
48+
return aten_var2args(unbiased, op), ref_net, op_name
49+
return aten_var(dim, unbiased, keepdim, op), ref_net, op_name
4250

4351
@pytest.mark.nightly
4452
@pytest.mark.precommit
4553
@pytest.mark.parametrize("unbiased", [True, False])
46-
@pytest.mark.parametrize("return_mean", [True, False])
47-
def test_var2args(self, unbiased, return_mean, ie_device, precision, ir_version):
48-
self._test(*self.create_model(unbiased, return_mean), ie_device, precision, ir_version)
54+
@pytest.mark.parametrize("op_type", ["var", "var_mean", "std", "std_mean"])
55+
def test_var2args(self, unbiased, op_type, ie_device, precision, ir_version):
56+
self._test(*self.create_model(unbiased, op_type=op_type), ie_device, precision, ir_version)
4957

5058
@pytest.mark.nightly
5159
@pytest.mark.precommit
5260
@pytest.mark.parametrize("unbiased", [False, True])
5361
@pytest.mark.parametrize("dim", [None, 0, 1, 2, 3, -1, -2, (0, 1), (-1, -2), (0, 1, -1), (0, 1, 2, 3)])
5462
@pytest.mark.parametrize("keepdim", [True, False])
55-
@pytest.mark.parametrize("return_mean", [True, False])
56-
def test_var(self, unbiased, dim, keepdim, return_mean, ie_device, precision, ir_version):
57-
self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, return_mean=return_mean), ie_device, precision, ir_version)
63+
@pytest.mark.parametrize("op_type", ["var", "var_mean", "std", "std_mean"])
64+
def test_var(self, unbiased, dim, keepdim, op_type, ie_device, precision, ir_version):
65+
self._test(*self.create_model(unbiased, dim, keepdim, two_args_case=False, op_type=op_type), ie_device, precision, ir_version)

0 commit comments

Comments
 (0)