[static_runtime] Add script to auto-generate view ops (#77105)

Summary:
Add script to go through view ops in "native_functions.yaml" and auto-register them into static runtime and auto-generate op unit tests for each.

Overall there are 96 grouped view ops, among which 21 is already registered by hand; 9 (including sparse ops/training related ops etc.) are not the target of static runtime; 30 has list args or list ret; and 7 has non-basic types such as "Dimname", "MemoryFormat", etc. In summary, this script auto-generate 29 view ops for now.

Run `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops` to generate static runtime ops, and the results with this script are,

```
total grouped native ops: 1582
grouped native ops with out variant: 548
generated functions groups with out variant: 241

view grouped native ops: 96
generated functions view groups: 29

overall generated : 270
```

The generated view ops are added in D36258968

Test Plan:
Generate static runtime ops: `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops`

Unit tests: `buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Differential Revision: D36258767

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77105
Approved by: https://github.com/mikeiovine
This commit is contained in:
Hui Guo
2022-05-26 03:12:22 +00:00
committed by PyTorch MergeBot
parent f69c990ecc
commit 1803a592f4
3 changed files with 320 additions and 87 deletions

View File

@@ -1,10 +1,13 @@
from torchgen.model import NativeFunctionsGroup
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
from typing import Dict
from typing import Dict, Union
def func_name_base_str(g: NativeFunctionsGroup) -> str:
return str(g.functional.func.name.name.base)
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
if isinstance(g, NativeFunctionsGroup):
return str(g.functional.func.name.name.base)
else:
return str(g.view.root_name)
is_hand_written_ops_ = frozenset(
@@ -35,6 +38,19 @@ is_hand_written_ops_ = frozenset(
"sign",
"sub",
"tanh",
"detach",
"expand_as",
"flatten",
"narrow",
"reshape_as",
"select",
"slice",
"softmax",
"split",
"squeeze",
"transpose",
"view",
"where",
)
)
@@ -349,3 +365,8 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
arg_map["size"] = "24"
arg_map["out_int32"] = "false"
return
if op_name in ("diagonal", "linalg_diagonal"):
arg_map["offset"] = "0"
arg_map["dim0"] = "1"
arg_map["dim1"] = "2"
return

View File

@@ -1,12 +1,13 @@
from torchgen import gen
from torchgen.context import native_function_manager
from torchgen.model import DispatchKey, NativeFunctionsGroup
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
from torchgen.static_runtime import generator
import argparse
import itertools
import os
from typing import Sequence
from typing import Sequence, Union
from libfb.py.log import set_simple_logging
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and
@@ -14,15 +15,15 @@ from typing import Sequence
def group_functions_by_op_name(
grouped_native_functions: Sequence[NativeFunctionsGroup],
) -> Sequence[Sequence[NativeFunctionsGroup]]:
grouped_native_functions: Sequence[
Union[NativeFunctionsGroup, NativeFunctionsViewGroup]
]
) -> Sequence[Sequence[Union[NativeFunctionsGroup, NativeFunctionsViewGroup]]]:
if not grouped_native_functions:
return []
groups = []
current_op_name = None
current_group = None
def is_supported(g: NativeFunctionsGroup) -> bool:
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
with native_function_manager(g):
return generator.is_supported(g)
@@ -31,7 +32,10 @@ def group_functions_by_op_name(
list(group)
for k, group in (
itertools.groupby(
eligible_ops, key=lambda g: g.functional.func.name.name.base
eligible_ops,
key=lambda g: g.functional.func.name.name.base
if isinstance(g, NativeFunctionsGroup)
else g.view.root_name,
)
)
]
@@ -147,34 +151,73 @@ def main() -> None:
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
grouped_native_functions = gen.get_grouped_native_functions(native_functions)
structured_native_functions = [
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
]
supported_function_groups = group_functions_by_op_name(structured_native_functions)
gen_out_variant_dispatcher = generator.GenOutVariantDispatcher()
result = [
gen_out_variant_dispatcher(groups, backend_indices[DispatchKey.CPU])
for groups in supported_function_groups
op_generator = generator.GenOpDispatcher()
test_case_generator = generator.GenOpTestCase()
native_functions_groups = [
g
for g in gen.get_grouped_native_functions(native_functions)
if isinstance(g, NativeFunctionsGroup)
]
gen_out_variant_dispatcher_test_case = generator.GenOutVariantDispatcherTestCase()
test_result = [
gen_out_variant_dispatcher_test_case(groups)
for groups in supported_function_groups
supported_functions_groups = group_functions_by_op_name(native_functions_groups)
out_variant_op_result = [
op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
for groups in supported_functions_groups
]
out_variant_test_result = [
test_case_generator.out_variant(groups) for groups in supported_functions_groups
]
write_cpp(result, options.generated_ops_cpp_path)
native_functions_view_groups = [
g
for g in gen.get_grouped_by_view_native_functions(native_functions)
if isinstance(g, NativeFunctionsViewGroup)
]
supported_functions_view_groups = group_functions_by_op_name(
native_functions_view_groups
)
view_op_result = [
op_generator.view(groups, backend_indices[DispatchKey.CPU])
for groups in supported_functions_view_groups
]
view_test_result = [
test_case_generator.view(groups) for groups in supported_functions_view_groups
]
op_result = out_variant_op_result + ["\n\n"] + view_op_result
test_result = out_variant_test_result + ["\n\n"] + view_test_result
write_cpp(op_result, options.generated_ops_cpp_path)
write_test_cpp(test_result, options.generated_ops_test_cpp_path)
print("total grouped native ops: %d" % len(grouped_native_functions))
print("structured grouped native ops: %d" % len(structured_native_functions))
supported_grouped_functions = sum(
[len(groups) for groups in supported_function_groups]
print(
"\ntotal grouped native ops: %d"
% len(gen.get_grouped_native_functions(native_functions))
)
print("grouped native ops with out variant: %d" % len(native_functions_groups))
supported_functions_num = sum(
[len(groups) for groups in supported_functions_groups]
)
print("generated functions groups with out variant: %d" % supported_functions_num)
print("\nview grouped native ops: %d" % len(native_functions_view_groups))
supported_view_functions_num = sum(
[len(groups) for groups in supported_functions_view_groups]
)
print("generated functions view groups: %d" % supported_view_functions_num)
print(
"\noverall generated : %d"
% (supported_functions_num + supported_view_functions_num)
)
print("generated grouped native ops: %d" % supported_grouped_functions)
if __name__ == "__main__":
set_simple_logging(escape_newlines=False)
main()

View File

@@ -11,12 +11,17 @@ from torchgen.model import (
NativeFunctionsGroup,
TensorOptionsArguments,
Type,
NativeFunctionsViewGroup,
)
from torchgen.static_runtime import config
import math
import logging
import json
from typing import List, Optional, Sequence, Tuple, Union
logger: logger = logging.getLogger()
def has_alias(
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
@@ -38,6 +43,13 @@ BLOCKED_OPS = frozenset(
"hspmm",
# sparse ops
"sspaddmm",
"coalesce",
"_indices",
"indices",
"_values",
"values",
"crow_indices",
"col_indices",
# deprecated ops
"floor_divide",
"ger",
@@ -50,41 +62,72 @@ BLOCKED_OPS = frozenset(
"lu_solve",
"linalg_cholesky",
"linalg_householder_product",
"linalg_ldl_solve",
"_compute_linear_combination",
# training related ops
"_make_dual",
# cannot call directly
"_fw_primal",
# no documentation
"_index_reduce",
)
)
def is_supported(g: NativeFunctionsGroup) -> bool:
base_op_name = g.out.func.name.name.base
if base_op_name in BLOCKED_OPS:
return False
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
base_op_name = ""
func = None
if isinstance(g, NativeFunctionsViewGroup):
base_op_name = g.view.root_name
func = g.view.func
else:
base_op_name = g.out.func.name.name.base
func = g.out.func
if config.is_hand_written(g):
logger.info(f"HAND WRITTEN: {base_op_name}")
return False
if base_op_name in BLOCKED_OPS:
logger.info(f"BLOCKED: {base_op_name}")
return False
for arg in func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
return False
if isinstance(g, NativeFunctionsViewGroup):
if "at::Tensor" != cpp.returns_type(func.returns).cpp_type():
# Returns a non-Tensor value.
logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
return False
return True
# For out variant ops, we need to check the arguments of its functional func.
for arg in g.functional.func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
return False
if not g.structured:
# In case of unstructured op, we check if it has out variant implementation.
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
# parameter.
if (
not hasattr(g, "out")
or not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
or not str(g.out.func.name).endswith(".out")
or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
or not str(func.name).endswith(".out")
):
return False
if has_alias(g.out.func.arguments.non_out):
if "at::Tensor &" != cpp.returns_type(func.returns).cpp_type():
logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
return False
if has_alias(func.arguments.non_out):
# This op may create an alias of inputs.
logger.info(f"INPUTS ALIAS: {base_op_name}")
return False
if len(g.out.func.arguments.out) > 1:
# More than 1 output values.
return False
if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
# Returns a non-Tensor value.
return False
for arg in g.out.func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
return False
return True
@@ -137,6 +180,8 @@ should_use_int_tensor_ops_ = frozenset(
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"gcd",
"lcm",
"scatter",
@@ -145,12 +190,17 @@ should_use_int_tensor_ops_ = frozenset(
"_convert_indices_from_csr_to_coo",
)
)
should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
def should_use_int_tensor(op_name: str) -> bool:
return op_name in should_use_int_tensor_ops_
def should_use_complex_tensor(op_name: str) -> bool:
return op_name in should_use_complex_tensor_ops_
test_tensor_dim_ops_1_ = frozenset(
(
"addmv",
@@ -165,7 +215,7 @@ test_tensor_dim_ops_1_ = frozenset(
)
)
test_tensor_dim_ops_2_ = frozenset(
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation")
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
)
@@ -177,16 +227,31 @@ def test_tensor_dim(op_name: str) -> int:
return 3
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
test_tensor_shape_json = json.loads(test_tensor_shapes_string)
def test_tensor_shape(op_name: str) -> str:
if op_name in test_tensor_shape_json:
return test_tensor_shape_json[op_name]
else:
return ""
def test_value_expression(
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
) -> str:
num_tensors = 16 if index == 0 else 64
num_dim = test_tensor_dim(op_name)
size_per_dim = math.ceil(num_tensors / float(num_dim))
size_per_dim += size_per_dim % 2
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
tensor_size_ex = test_tensor_shape(op_name)
if tensor_size_ex == "":
num_tensors = 16 if index == 0 else 64
num_dim = test_tensor_dim(op_name)
size_per_dim = math.ceil(num_tensors / float(num_dim))
size_per_dim += size_per_dim % 2
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
if should_use_int_tensor(op_name):
tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
elif should_use_complex_tensor(op_name):
tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
else:
tensor_expression = f"at::rand({tensor_size_ex})"
@@ -212,8 +277,7 @@ def test_value_expression(
return value_expression
def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
schema = g.functional.func
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
assert not schema.is_out_fn()
schema_name = schema.name.name.base
arg_map = {}
@@ -227,8 +291,7 @@ def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
return ";\n ".join(arg_populations) + ";"
def generate_test_value_names(g: NativeFunctionsGroup, index: int) -> str:
schema = g.functional.func
def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
assert not schema.is_out_fn()
return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
@@ -245,7 +308,7 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
def generate_test_ir_arguments(
g: NativeFunctionsGroup,
schema: FunctionSchema,
) -> List[Tuple[str, Optional[str]]]:
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
t = arg.type
@@ -261,14 +324,10 @@ def generate_test_ir_arguments(
type_str = f"{type_str}?"
return ("%" + arg.name, type_str)
schema = g.functional.func
assert not schema.is_out_fn()
return [ir_argument(arg) for arg in schema.schema_order_arguments()]
def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
schema = g.functional.func
assert not schema.is_out_fn()
def generate_arg_extraction(schema: FunctionSchema) -> str:
arg_populations = []
for i, arg in enumerate(schema.schema_order_arguments()):
maybe_method = ivalue_type_conversion_method(arg.type)
@@ -306,6 +365,19 @@ def generate_non_out_variant_call(
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def generate_call_to_view_ops(
g: NativeFunctionsViewGroup, backend_index: BackendIndex
) -> str:
schema = g.view.func
kernel_name = cpp.name(schema)
kernel = backend_index.get_kernel(g.view)
if kernel:
kernel_name = kernel.kernel
arg_names = (arg.name for arg in schema.schema_order_arguments())
namespace_name = "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def generate_out_variant_call(
g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
@@ -360,8 +432,8 @@ def op_name_from_group(g: NativeFunctionsGroup) -> str:
return g.functional.func.name.name.base
class GenOutVariantDispatcher:
def __call__(
class GenOpDispatcher:
def out_variant(
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
) -> str:
if not groups:
@@ -371,7 +443,7 @@ class GenOutVariantDispatcher:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.op_generator(g, backend_index)
generated_type_variant = self.out_variant_op_generator(g, backend_index)
generated_type_variants.append(generated_type_variant)
op_name = op_name_from_group(groups[0])
body = "\n".join(generated_type_variants)
@@ -387,11 +459,39 @@ REGISTER_OPERATOR_FUNCTOR(
"""
return generated
def op_generator(self, g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
def view(
self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsViewGroup)
generated_type_variant = self.view_op_generator(g, backend_index)
generated_type_variants.append(generated_type_variant)
op_name = config.func_name_base_str(groups[0])
body = "\n".join(generated_type_variants)
generated = f"""
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::{op_name},
aten_{op_name},
[](Node* n) -> SROperator {{
{body}
LogAndDumpSchema(n);
return nullptr;
}});
"""
return generated
def out_variant_op_generator(
self, g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
functional = g.functional
schema = str(functional.func)
op_name = op_name_from_group(g)
populated_argument = generate_arg_extraction(g)
populated_argument = generate_arg_extraction(g.functional.func)
functional_variant_call = generate_non_out_variant_call(g, backend_index)
assert len(g.out.func.arguments.out) == 1
out_variable_name = str(g.out.func.arguments.out[0].name)
@@ -411,9 +511,25 @@ REGISTER_OPERATOR_FUNCTOR(
}}"""
return generated
def view_op_generator(
self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
) -> str:
schema = str(g.view.func)
op_name = config.func_name_base_str(g)
populated_argument = generate_arg_extraction(g.view.func)
functional_variant_call = generate_call_to_view_ops(g, backend_index)
generated = f"""
if (n->matches(torch::schema("aten::{schema}"))) {{
return [](ProcessedNode* p_node) {{
{populated_argument}
p_node->Output(0) = {functional_variant_call};
}};
}}"""
return generated
class GenOutVariantDispatcherTestCase:
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
class GenOpTestCase:
def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
if not groups:
return ""
generated_type_variants = []
@@ -421,19 +537,31 @@ class GenOutVariantDispatcherTestCase:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.test_case_generator(g)
generated_type_variant = self.out_variant_op_test_case_generator(g)
generated_type_variants.append(generated_type_variant)
return "\n".join(generated_type_variants)
def test_case_generator(self, g: NativeFunctionsGroup) -> str:
functional = g.functional
schema = str(functional.func)
assert schema.find("(") > 0
type_variant_op_name = schema[: schema.find("(")].replace(".", "_")
def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsViewGroup)
generated_type_variant = self.view_op_test_case_generator(g)
generated_type_variants.append(generated_type_variant)
return "\n".join(generated_type_variants)
def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
schema = g.functional.func
schema_str = str(schema)
assert schema_str.find("(") > 0
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
op_name = op_name_from_group(g)
assert type_variant_op_name.startswith(op_name)
arg_types = generate_test_ir_arguments(g)
arg_types = generate_test_ir_arguments(schema)
arg_declarations = ", ".join(
(
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
@@ -442,15 +570,15 @@ class GenOutVariantDispatcherTestCase:
)
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
assert (
len(functional.func.returns) == 1
and isinstance(functional.func.returns[0].type, BaseType)
and functional.func.returns[0].type.name is BaseTy.Tensor
len(schema.returns) == 1
and isinstance(schema.returns[0].type, BaseType)
and schema.returns[0].type.name is BaseTy.Tensor
)
test_value_definitions = generate_test_value_definitions(g, 0)
test_value_names = generate_test_value_names(g, 0)
test_value_definitions2 = generate_test_value_definitions(g, 1)
test_value_names2 = generate_test_value_names(g, 1)
check_resize = "true" if should_check_resize(functional.func) else "false"
test_value_definitions = generate_test_value_definitions(schema, 0)
test_value_names = generate_test_value_names(schema, 0)
test_value_definitions2 = generate_test_value_definitions(schema, 1)
test_value_names2 = generate_test_value_names(schema, 1)
check_resize = "true" if should_check_resize(schema) else "false"
generated = f"""
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
const std::string script = R"IR(
@@ -472,3 +600,44 @@ TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
}}
"""
return generated
def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
schema = g.view.func
schema_str = str(schema)
assert schema_str.find("(") > 0
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
op_name = g.view.root_name
assert type_variant_op_name.startswith(op_name)
arg_types = generate_test_ir_arguments(schema)
arg_declarations = ", ".join(
(
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
for arg_name, arg_type in arg_types
)
)
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
assert (
len(schema.returns) == 1
and isinstance(schema.returns[0].type, BaseType)
and schema.returns[0].type.name is BaseTy.Tensor
)
test_value_definitions = generate_test_value_definitions(schema, 0)
test_value_names = generate_test_value_names(schema, 0)
generated = f"""
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
const std::string script = R"IR(
graph({arg_declarations}):
%bias: None = prim::Constant()
%ret = aten::{op_name}({arg_names})
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
{test_value_definitions}
std::vector<IValue> args{{{test_value_names}}};
testStaticRuntime(script, args);
}}
"""
return generated