mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[static_runtime] Add script to auto-generate view ops (#77105)
Summary: Add script to go through view ops in "native_functions.yaml" and auto-register them into static runtime and auto-generate op unit tests for each. Overall there are 96 grouped view ops, among which 21 is already registered by hand; 9 (including sparse ops/training related ops etc.) are not the target of static runtime; 30 has list args or list ret; and 7 has non-basic types such as "Dimname", "MemoryFormat", etc. In summary, this script auto-generate 29 view ops for now. Run `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops` to generate static runtime ops, and the results with this script are, ``` total grouped native ops: 1582 grouped native ops with out variant: 548 generated functions groups with out variant: 241 view grouped native ops: 96 generated functions view groups: 29 overall generated : 270 ``` The generated view ops are added in D36258968 Test Plan: Generate static runtime ops: `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops` Unit tests: `buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest` Differential Revision: D36258767 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77105 Approved by: https://github.com/mikeiovine
This commit is contained in:
committed by
PyTorch MergeBot
parent
f69c990ecc
commit
1803a592f4
@@ -1,10 +1,13 @@
|
||||
from torchgen.model import NativeFunctionsGroup
|
||||
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
def func_name_base_str(g: NativeFunctionsGroup) -> str:
|
||||
return str(g.functional.func.name.name.base)
|
||||
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
return str(g.functional.func.name.name.base)
|
||||
else:
|
||||
return str(g.view.root_name)
|
||||
|
||||
|
||||
is_hand_written_ops_ = frozenset(
|
||||
@@ -35,6 +38,19 @@ is_hand_written_ops_ = frozenset(
|
||||
"sign",
|
||||
"sub",
|
||||
"tanh",
|
||||
"detach",
|
||||
"expand_as",
|
||||
"flatten",
|
||||
"narrow",
|
||||
"reshape_as",
|
||||
"select",
|
||||
"slice",
|
||||
"softmax",
|
||||
"split",
|
||||
"squeeze",
|
||||
"transpose",
|
||||
"view",
|
||||
"where",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -349,3 +365,8 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
|
||||
arg_map["size"] = "24"
|
||||
arg_map["out_int32"] = "false"
|
||||
return
|
||||
if op_name in ("diagonal", "linalg_diagonal"):
|
||||
arg_map["offset"] = "0"
|
||||
arg_map["dim0"] = "1"
|
||||
arg_map["dim1"] = "2"
|
||||
return
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from torchgen import gen
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
from torchgen.static_runtime import generator
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Union
|
||||
from libfb.py.log import set_simple_logging
|
||||
|
||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||
@@ -14,15 +15,15 @@ from typing import Sequence
|
||||
|
||||
|
||||
def group_functions_by_op_name(
|
||||
grouped_native_functions: Sequence[NativeFunctionsGroup],
|
||||
) -> Sequence[Sequence[NativeFunctionsGroup]]:
|
||||
grouped_native_functions: Sequence[
|
||||
Union[NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
]
|
||||
) -> Sequence[Sequence[Union[NativeFunctionsGroup, NativeFunctionsViewGroup]]]:
|
||||
if not grouped_native_functions:
|
||||
return []
|
||||
groups = []
|
||||
current_op_name = None
|
||||
current_group = None
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup) -> bool:
|
||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
||||
with native_function_manager(g):
|
||||
return generator.is_supported(g)
|
||||
|
||||
@@ -31,7 +32,10 @@ def group_functions_by_op_name(
|
||||
list(group)
|
||||
for k, group in (
|
||||
itertools.groupby(
|
||||
eligible_ops, key=lambda g: g.functional.func.name.name.base
|
||||
eligible_ops,
|
||||
key=lambda g: g.functional.func.name.name.base
|
||||
if isinstance(g, NativeFunctionsGroup)
|
||||
else g.view.root_name,
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -147,34 +151,73 @@ def main() -> None:
|
||||
parsed_yaml.native_functions,
|
||||
parsed_yaml.backend_indices,
|
||||
)
|
||||
grouped_native_functions = gen.get_grouped_native_functions(native_functions)
|
||||
structured_native_functions = [
|
||||
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
|
||||
]
|
||||
supported_function_groups = group_functions_by_op_name(structured_native_functions)
|
||||
|
||||
gen_out_variant_dispatcher = generator.GenOutVariantDispatcher()
|
||||
result = [
|
||||
gen_out_variant_dispatcher(groups, backend_indices[DispatchKey.CPU])
|
||||
for groups in supported_function_groups
|
||||
op_generator = generator.GenOpDispatcher()
|
||||
test_case_generator = generator.GenOpTestCase()
|
||||
|
||||
native_functions_groups = [
|
||||
g
|
||||
for g in gen.get_grouped_native_functions(native_functions)
|
||||
if isinstance(g, NativeFunctionsGroup)
|
||||
]
|
||||
|
||||
gen_out_variant_dispatcher_test_case = generator.GenOutVariantDispatcherTestCase()
|
||||
test_result = [
|
||||
gen_out_variant_dispatcher_test_case(groups)
|
||||
for groups in supported_function_groups
|
||||
supported_functions_groups = group_functions_by_op_name(native_functions_groups)
|
||||
|
||||
out_variant_op_result = [
|
||||
op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
|
||||
for groups in supported_functions_groups
|
||||
]
|
||||
out_variant_test_result = [
|
||||
test_case_generator.out_variant(groups) for groups in supported_functions_groups
|
||||
]
|
||||
|
||||
write_cpp(result, options.generated_ops_cpp_path)
|
||||
native_functions_view_groups = [
|
||||
g
|
||||
for g in gen.get_grouped_by_view_native_functions(native_functions)
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
]
|
||||
|
||||
supported_functions_view_groups = group_functions_by_op_name(
|
||||
native_functions_view_groups
|
||||
)
|
||||
|
||||
view_op_result = [
|
||||
op_generator.view(groups, backend_indices[DispatchKey.CPU])
|
||||
for groups in supported_functions_view_groups
|
||||
]
|
||||
view_test_result = [
|
||||
test_case_generator.view(groups) for groups in supported_functions_view_groups
|
||||
]
|
||||
|
||||
op_result = out_variant_op_result + ["\n\n"] + view_op_result
|
||||
test_result = out_variant_test_result + ["\n\n"] + view_test_result
|
||||
|
||||
write_cpp(op_result, options.generated_ops_cpp_path)
|
||||
write_test_cpp(test_result, options.generated_ops_test_cpp_path)
|
||||
|
||||
print("total grouped native ops: %d" % len(grouped_native_functions))
|
||||
print("structured grouped native ops: %d" % len(structured_native_functions))
|
||||
supported_grouped_functions = sum(
|
||||
[len(groups) for groups in supported_function_groups]
|
||||
print(
|
||||
"\ntotal grouped native ops: %d"
|
||||
% len(gen.get_grouped_native_functions(native_functions))
|
||||
)
|
||||
|
||||
print("grouped native ops with out variant: %d" % len(native_functions_groups))
|
||||
supported_functions_num = sum(
|
||||
[len(groups) for groups in supported_functions_groups]
|
||||
)
|
||||
print("generated functions groups with out variant: %d" % supported_functions_num)
|
||||
|
||||
print("\nview grouped native ops: %d" % len(native_functions_view_groups))
|
||||
supported_view_functions_num = sum(
|
||||
[len(groups) for groups in supported_functions_view_groups]
|
||||
)
|
||||
print("generated functions view groups: %d" % supported_view_functions_num)
|
||||
|
||||
print(
|
||||
"\noverall generated : %d"
|
||||
% (supported_functions_num + supported_view_functions_num)
|
||||
)
|
||||
print("generated grouped native ops: %d" % supported_grouped_functions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_simple_logging(escape_newlines=False)
|
||||
main()
|
||||
|
||||
@@ -11,12 +11,17 @@ from torchgen.model import (
|
||||
NativeFunctionsGroup,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
NativeFunctionsViewGroup,
|
||||
)
|
||||
from torchgen.static_runtime import config
|
||||
|
||||
import math
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
logger: logger = logging.getLogger()
|
||||
|
||||
|
||||
def has_alias(
|
||||
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
|
||||
@@ -38,6 +43,13 @@ BLOCKED_OPS = frozenset(
|
||||
"hspmm",
|
||||
# sparse ops
|
||||
"sspaddmm",
|
||||
"coalesce",
|
||||
"_indices",
|
||||
"indices",
|
||||
"_values",
|
||||
"values",
|
||||
"crow_indices",
|
||||
"col_indices",
|
||||
# deprecated ops
|
||||
"floor_divide",
|
||||
"ger",
|
||||
@@ -50,41 +62,72 @@ BLOCKED_OPS = frozenset(
|
||||
"lu_solve",
|
||||
"linalg_cholesky",
|
||||
"linalg_householder_product",
|
||||
"linalg_ldl_solve",
|
||||
"_compute_linear_combination",
|
||||
# training related ops
|
||||
"_make_dual",
|
||||
# cannot call directly
|
||||
"_fw_primal",
|
||||
# no documentation
|
||||
"_index_reduce",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup) -> bool:
|
||||
base_op_name = g.out.func.name.name.base
|
||||
if base_op_name in BLOCKED_OPS:
|
||||
return False
|
||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
||||
base_op_name = ""
|
||||
func = None
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
base_op_name = g.view.root_name
|
||||
func = g.view.func
|
||||
else:
|
||||
base_op_name = g.out.func.name.name.base
|
||||
func = g.out.func
|
||||
if config.is_hand_written(g):
|
||||
logger.info(f"HAND WRITTEN: {base_op_name}")
|
||||
return False
|
||||
if base_op_name in BLOCKED_OPS:
|
||||
logger.info(f"BLOCKED: {base_op_name}")
|
||||
return False
|
||||
for arg in func.schema_order_arguments():
|
||||
maybe_method = ivalue_type_conversion_method(arg.type)
|
||||
if not maybe_method:
|
||||
# Type converting is unsupported yet.
|
||||
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
|
||||
return False
|
||||
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
if "at::Tensor" != cpp.returns_type(func.returns).cpp_type():
|
||||
# Returns a non-Tensor value.
|
||||
logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# For out variant ops, we need to check the arguments of its functional func.
|
||||
for arg in g.functional.func.schema_order_arguments():
|
||||
maybe_method = ivalue_type_conversion_method(arg.type)
|
||||
if not maybe_method:
|
||||
# Type converting is unsupported yet.
|
||||
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
|
||||
return False
|
||||
|
||||
if not g.structured:
|
||||
# In case of unstructured op, we check if it has out variant implementation.
|
||||
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
|
||||
# parameter.
|
||||
if (
|
||||
not hasattr(g, "out")
|
||||
or not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
|
||||
or not str(g.out.func.name).endswith(".out")
|
||||
or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
|
||||
or not str(func.name).endswith(".out")
|
||||
):
|
||||
return False
|
||||
if has_alias(g.out.func.arguments.non_out):
|
||||
if "at::Tensor &" != cpp.returns_type(func.returns).cpp_type():
|
||||
logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
|
||||
return False
|
||||
if has_alias(func.arguments.non_out):
|
||||
# This op may create an alias of inputs.
|
||||
logger.info(f"INPUTS ALIAS: {base_op_name}")
|
||||
return False
|
||||
if len(g.out.func.arguments.out) > 1:
|
||||
# More than 1 output values.
|
||||
return False
|
||||
if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
|
||||
# Returns a non-Tensor value.
|
||||
return False
|
||||
for arg in g.out.func.schema_order_arguments():
|
||||
maybe_method = ivalue_type_conversion_method(arg.type)
|
||||
if not maybe_method:
|
||||
# Type converting is unsupported yet.
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -137,6 +180,8 @@ should_use_int_tensor_ops_ = frozenset(
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"bitwise_left_shift",
|
||||
"bitwise_right_shift",
|
||||
"gcd",
|
||||
"lcm",
|
||||
"scatter",
|
||||
@@ -145,12 +190,17 @@ should_use_int_tensor_ops_ = frozenset(
|
||||
"_convert_indices_from_csr_to_coo",
|
||||
)
|
||||
)
|
||||
should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
|
||||
|
||||
|
||||
def should_use_int_tensor(op_name: str) -> bool:
|
||||
return op_name in should_use_int_tensor_ops_
|
||||
|
||||
|
||||
def should_use_complex_tensor(op_name: str) -> bool:
|
||||
return op_name in should_use_complex_tensor_ops_
|
||||
|
||||
|
||||
test_tensor_dim_ops_1_ = frozenset(
|
||||
(
|
||||
"addmv",
|
||||
@@ -165,7 +215,7 @@ test_tensor_dim_ops_1_ = frozenset(
|
||||
)
|
||||
)
|
||||
test_tensor_dim_ops_2_ = frozenset(
|
||||
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation")
|
||||
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
|
||||
)
|
||||
|
||||
|
||||
@@ -177,16 +227,31 @@ def test_tensor_dim(op_name: str) -> int:
|
||||
return 3
|
||||
|
||||
|
||||
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
|
||||
test_tensor_shape_json = json.loads(test_tensor_shapes_string)
|
||||
|
||||
|
||||
def test_tensor_shape(op_name: str) -> str:
|
||||
if op_name in test_tensor_shape_json:
|
||||
return test_tensor_shape_json[op_name]
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def test_value_expression(
|
||||
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
|
||||
) -> str:
|
||||
num_tensors = 16 if index == 0 else 64
|
||||
num_dim = test_tensor_dim(op_name)
|
||||
size_per_dim = math.ceil(num_tensors / float(num_dim))
|
||||
size_per_dim += size_per_dim % 2
|
||||
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
|
||||
tensor_size_ex = test_tensor_shape(op_name)
|
||||
if tensor_size_ex == "":
|
||||
num_tensors = 16 if index == 0 else 64
|
||||
num_dim = test_tensor_dim(op_name)
|
||||
size_per_dim = math.ceil(num_tensors / float(num_dim))
|
||||
size_per_dim += size_per_dim % 2
|
||||
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
|
||||
if should_use_int_tensor(op_name):
|
||||
tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
|
||||
elif should_use_complex_tensor(op_name):
|
||||
tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
|
||||
else:
|
||||
tensor_expression = f"at::rand({tensor_size_ex})"
|
||||
|
||||
@@ -212,8 +277,7 @@ def test_value_expression(
|
||||
return value_expression
|
||||
|
||||
|
||||
def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
|
||||
schema = g.functional.func
|
||||
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
|
||||
assert not schema.is_out_fn()
|
||||
schema_name = schema.name.name.base
|
||||
arg_map = {}
|
||||
@@ -227,8 +291,7 @@ def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
|
||||
return ";\n ".join(arg_populations) + ";"
|
||||
|
||||
|
||||
def generate_test_value_names(g: NativeFunctionsGroup, index: int) -> str:
|
||||
schema = g.functional.func
|
||||
def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
|
||||
assert not schema.is_out_fn()
|
||||
return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
|
||||
|
||||
@@ -245,7 +308,7 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
|
||||
|
||||
|
||||
def generate_test_ir_arguments(
|
||||
g: NativeFunctionsGroup,
|
||||
schema: FunctionSchema,
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
|
||||
t = arg.type
|
||||
@@ -261,14 +324,10 @@ def generate_test_ir_arguments(
|
||||
type_str = f"{type_str}?"
|
||||
return ("%" + arg.name, type_str)
|
||||
|
||||
schema = g.functional.func
|
||||
assert not schema.is_out_fn()
|
||||
return [ir_argument(arg) for arg in schema.schema_order_arguments()]
|
||||
|
||||
|
||||
def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
|
||||
schema = g.functional.func
|
||||
assert not schema.is_out_fn()
|
||||
def generate_arg_extraction(schema: FunctionSchema) -> str:
|
||||
arg_populations = []
|
||||
for i, arg in enumerate(schema.schema_order_arguments()):
|
||||
maybe_method = ivalue_type_conversion_method(arg.type)
|
||||
@@ -306,6 +365,19 @@ def generate_non_out_variant_call(
|
||||
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
||||
|
||||
|
||||
def generate_call_to_view_ops(
|
||||
g: NativeFunctionsViewGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
schema = g.view.func
|
||||
kernel_name = cpp.name(schema)
|
||||
kernel = backend_index.get_kernel(g.view)
|
||||
if kernel:
|
||||
kernel_name = kernel.kernel
|
||||
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
||||
namespace_name = "native"
|
||||
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
||||
|
||||
|
||||
def generate_out_variant_call(
|
||||
g: NativeFunctionsGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
@@ -360,8 +432,8 @@ def op_name_from_group(g: NativeFunctionsGroup) -> str:
|
||||
return g.functional.func.name.name.base
|
||||
|
||||
|
||||
class GenOutVariantDispatcher:
|
||||
def __call__(
|
||||
class GenOpDispatcher:
|
||||
def out_variant(
|
||||
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
|
||||
) -> str:
|
||||
if not groups:
|
||||
@@ -371,7 +443,7 @@ class GenOutVariantDispatcher:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsGroup)
|
||||
generated_type_variant = self.op_generator(g, backend_index)
|
||||
generated_type_variant = self.out_variant_op_generator(g, backend_index)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
op_name = op_name_from_group(groups[0])
|
||||
body = "\n".join(generated_type_variants)
|
||||
@@ -387,11 +459,39 @@ REGISTER_OPERATOR_FUNCTOR(
|
||||
"""
|
||||
return generated
|
||||
|
||||
def op_generator(self, g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
||||
def view(
|
||||
self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
|
||||
) -> str:
|
||||
if not groups:
|
||||
return ""
|
||||
generated_type_variants = []
|
||||
for g in groups:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsViewGroup)
|
||||
generated_type_variant = self.view_op_generator(g, backend_index)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
op_name = config.func_name_base_str(groups[0])
|
||||
body = "\n".join(generated_type_variants)
|
||||
generated = f"""
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::{op_name},
|
||||
aten_{op_name},
|
||||
[](Node* n) -> SROperator {{
|
||||
{body}
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}});
|
||||
"""
|
||||
return generated
|
||||
|
||||
def out_variant_op_generator(
|
||||
self, g: NativeFunctionsGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
functional = g.functional
|
||||
schema = str(functional.func)
|
||||
op_name = op_name_from_group(g)
|
||||
populated_argument = generate_arg_extraction(g)
|
||||
populated_argument = generate_arg_extraction(g.functional.func)
|
||||
functional_variant_call = generate_non_out_variant_call(g, backend_index)
|
||||
assert len(g.out.func.arguments.out) == 1
|
||||
out_variable_name = str(g.out.func.arguments.out[0].name)
|
||||
@@ -411,9 +511,25 @@ REGISTER_OPERATOR_FUNCTOR(
|
||||
}}"""
|
||||
return generated
|
||||
|
||||
def view_op_generator(
|
||||
self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
schema = str(g.view.func)
|
||||
op_name = config.func_name_base_str(g)
|
||||
populated_argument = generate_arg_extraction(g.view.func)
|
||||
functional_variant_call = generate_call_to_view_ops(g, backend_index)
|
||||
generated = f"""
|
||||
if (n->matches(torch::schema("aten::{schema}"))) {{
|
||||
return [](ProcessedNode* p_node) {{
|
||||
{populated_argument}
|
||||
p_node->Output(0) = {functional_variant_call};
|
||||
}};
|
||||
}}"""
|
||||
return generated
|
||||
|
||||
class GenOutVariantDispatcherTestCase:
|
||||
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
||||
|
||||
class GenOpTestCase:
|
||||
def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
||||
if not groups:
|
||||
return ""
|
||||
generated_type_variants = []
|
||||
@@ -421,19 +537,31 @@ class GenOutVariantDispatcherTestCase:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsGroup)
|
||||
generated_type_variant = self.test_case_generator(g)
|
||||
generated_type_variant = self.out_variant_op_test_case_generator(g)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
return "\n".join(generated_type_variants)
|
||||
|
||||
def test_case_generator(self, g: NativeFunctionsGroup) -> str:
|
||||
functional = g.functional
|
||||
schema = str(functional.func)
|
||||
assert schema.find("(") > 0
|
||||
type_variant_op_name = schema[: schema.find("(")].replace(".", "_")
|
||||
def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
|
||||
if not groups:
|
||||
return ""
|
||||
generated_type_variants = []
|
||||
for g in groups:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsViewGroup)
|
||||
generated_type_variant = self.view_op_test_case_generator(g)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
return "\n".join(generated_type_variants)
|
||||
|
||||
def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
|
||||
schema = g.functional.func
|
||||
schema_str = str(schema)
|
||||
assert schema_str.find("(") > 0
|
||||
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
|
||||
op_name = op_name_from_group(g)
|
||||
assert type_variant_op_name.startswith(op_name)
|
||||
|
||||
arg_types = generate_test_ir_arguments(g)
|
||||
arg_types = generate_test_ir_arguments(schema)
|
||||
arg_declarations = ", ".join(
|
||||
(
|
||||
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
|
||||
@@ -442,15 +570,15 @@ class GenOutVariantDispatcherTestCase:
|
||||
)
|
||||
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
|
||||
assert (
|
||||
len(functional.func.returns) == 1
|
||||
and isinstance(functional.func.returns[0].type, BaseType)
|
||||
and functional.func.returns[0].type.name is BaseTy.Tensor
|
||||
len(schema.returns) == 1
|
||||
and isinstance(schema.returns[0].type, BaseType)
|
||||
and schema.returns[0].type.name is BaseTy.Tensor
|
||||
)
|
||||
test_value_definitions = generate_test_value_definitions(g, 0)
|
||||
test_value_names = generate_test_value_names(g, 0)
|
||||
test_value_definitions2 = generate_test_value_definitions(g, 1)
|
||||
test_value_names2 = generate_test_value_names(g, 1)
|
||||
check_resize = "true" if should_check_resize(functional.func) else "false"
|
||||
test_value_definitions = generate_test_value_definitions(schema, 0)
|
||||
test_value_names = generate_test_value_names(schema, 0)
|
||||
test_value_definitions2 = generate_test_value_definitions(schema, 1)
|
||||
test_value_names2 = generate_test_value_names(schema, 1)
|
||||
check_resize = "true" if should_check_resize(schema) else "false"
|
||||
generated = f"""
|
||||
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
||||
const std::string script = R"IR(
|
||||
@@ -472,3 +600,44 @@ TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
||||
}}
|
||||
"""
|
||||
return generated
|
||||
|
||||
def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
|
||||
schema = g.view.func
|
||||
schema_str = str(schema)
|
||||
assert schema_str.find("(") > 0
|
||||
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
|
||||
op_name = g.view.root_name
|
||||
assert type_variant_op_name.startswith(op_name)
|
||||
|
||||
arg_types = generate_test_ir_arguments(schema)
|
||||
arg_declarations = ", ".join(
|
||||
(
|
||||
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
|
||||
for arg_name, arg_type in arg_types
|
||||
)
|
||||
)
|
||||
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
|
||||
assert (
|
||||
len(schema.returns) == 1
|
||||
and isinstance(schema.returns[0].type, BaseType)
|
||||
and schema.returns[0].type.name is BaseTy.Tensor
|
||||
)
|
||||
test_value_definitions = generate_test_value_definitions(schema, 0)
|
||||
test_value_names = generate_test_value_names(schema, 0)
|
||||
generated = f"""
|
||||
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
||||
const std::string script = R"IR(
|
||||
graph({arg_declarations}):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::{op_name}({arg_names})
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
{test_value_definitions}
|
||||
std::vector<IValue> args{{{test_value_names}}};
|
||||
testStaticRuntime(script, args);
|
||||
}}
|
||||
"""
|
||||
|
||||
return generated
|
||||
|
||||
Reference in New Issue
Block a user