@@ -11,47 +11,55 @@ def _prepare_input(self):
11
11
import numpy as np
12
12
return (np .random .randn (1 , 3 , 224 , 224 ).astype (np .float32 ),)
13
13
14
- def create_model (self , unbiased , dim = None , keepdim = True , two_args_case = True , return_mean = False ):
14
+ def create_model (self , unbiased , dim = None , keepdim = True , two_args_case = True , op_type = "var" ):
15
15
import torch
16
16
17
+ ops = {
18
+ "var" : torch .var ,
19
+ "var_mean" : torch .var_mean ,
20
+ "std" : torch .std ,
21
+ "std_mean" : torch .std_mean
22
+ }
23
+
24
+ op = ops [op_type ]
25
+
17
26
class aten_var (torch .nn .Module ):
18
- def __init__ (self , dim , unbiased , keepdim , return_mean ):
27
+ def __init__ (self , dim , unbiased , keepdim , op ):
19
28
super (aten_var , self ).__init__ ()
20
29
self .unbiased = unbiased
21
30
self .dim = dim
22
31
self .keepdim = keepdim
23
- self .op = torch . var if not return_mean else torch . var_mean
32
+ self .op = op
24
33
25
34
def forward (self , x ):
26
35
return self .op (x , self .dim , unbiased = self .unbiased , keepdim = self .keepdim )
27
36
28
37
class aten_var2args (torch .nn .Module ):
29
- def __init__ (self , unbiased , return_mean ):
38
+ def __init__ (self , unbiased , op ):
30
39
super (aten_var2args , self ).__init__ ()
31
40
self .unbiased = unbiased
32
- self .op = torch .var if not return_mean else torch .var_mean
33
-
41
+ self .op = op
34
42
def forward (self , x ):
35
- return torch . var (x , self .unbiased )
43
+ return self . op (x , self .unbiased )
36
44
37
45
ref_net = None
38
- op_name = "aten::var" if not return_mean else "aten::var_mean "
46
+ op_name = f "aten::{ op_type } "
39
47
if two_args_case :
40
- return aten_var2args (unbiased , return_mean ), ref_net , op_name
41
- return aten_var (dim , unbiased , keepdim , return_mean ), ref_net , op_name
48
+ return aten_var2args (unbiased , op ), ref_net , op_name
49
+ return aten_var (dim , unbiased , keepdim , op ), ref_net , op_name
42
50
43
51
@pytest .mark .nightly
44
52
@pytest .mark .precommit
45
53
@pytest .mark .parametrize ("unbiased" , [True , False ])
46
- @pytest .mark .parametrize ("return_mean " , [True , False ])
47
- def test_var2args (self , unbiased , return_mean , ie_device , precision , ir_version ):
48
- self ._test (* self .create_model (unbiased , return_mean ), ie_device , precision , ir_version )
54
+ @pytest .mark .parametrize ("op_type " , ["var" , "var_mean" , "std" , "std_mean" ])
55
+ def test_var2args (self , unbiased , op_type , ie_device , precision , ir_version ):
56
+ self ._test (* self .create_model (unbiased , op_type = op_type ), ie_device , precision , ir_version )
49
57
50
58
@pytest .mark .nightly
51
59
@pytest .mark .precommit
52
60
@pytest .mark .parametrize ("unbiased" , [False , True ])
53
61
@pytest .mark .parametrize ("dim" , [None , 0 , 1 , 2 , 3 , - 1 , - 2 , (0 , 1 ), (- 1 , - 2 ), (0 , 1 , - 1 ), (0 , 1 , 2 , 3 )])
54
62
@pytest .mark .parametrize ("keepdim" , [True , False ])
55
- @pytest .mark .parametrize ("return_mean " , [True , False ])
56
- def test_var (self , unbiased , dim , keepdim , return_mean , ie_device , precision , ir_version ):
57
- self ._test (* self .create_model (unbiased , dim , keepdim , two_args_case = False , return_mean = return_mean ), ie_device , precision , ir_version )
63
+ @pytest .mark .parametrize ("op_type " , ["var" , "var_mean" , "std" , "std_mean" ])
64
+ def test_var (self , unbiased , dim , keepdim , op_type , ie_device , precision , ir_version ):
65
+ self ._test (* self .create_model (unbiased , dim , keepdim , two_args_case = False , op_type = op_type ), ie_device , precision , ir_version )
0 commit comments