[Static Runtime] [RFC] Codegen support for ops with unstructured kernels (#76203)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76203

Request for comments:

This change adds extra code generator support to generate out variant wrappers for operators with unstructured kernels.

The current version generates 105 new out variant wrappers in addition to the existing 136 auto-generated out variants wrappers.

This change shows that a simple tweak can increase the generated op coverage to 16% (241/1559) among all native ops described in native_functions.yaml no. matter if they are structured or not.

Command to generate out variant wrappers.
```
buck run //caffe2/torch/fb/jit:gen_static_runtime_ops
```
- AFTER this change
```
total grouped native ops: 1559
structured grouped native ops: 545
generated grouped native ops: 241
```

- BEFORE this change
```
total grouped native ops: 1503
structured grouped native ops: 540
generated grouped native ops: 136
```

To enable CI tests and make it easier to review, the generated ops are added in a separate diff: D35945633

More details:
We added a block list to remove the generation of around 10 operations that are deprecated or for which the unit test would fail. All generated ops are well *compiled* but the compiled unittest may not pass due to the lack of hand-picked test input values for certain ops. Among the 42 ops whose unittest does not pass, 1 (op "index_select") is repeated from the existing ops; 32 ops are fixed; and 9 ops are removed and blocked from generation because either it is not being commonly used in internal models such as "cholesky", "linalg_householder_product", sparse kernel "sspaddmm", or it causes some errors in static runtime such as "conj_physical" leads to an error in memory planner, and "binary_cross_entropy".

Test Plan:
OP generation:
```buck run //caffe2/torch/fb/jit:gen_static_runtime_ops```

Test generated ops:
```buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest```

Reviewed By: tenpercent

Differential Revision: D34913736

fbshipit-source-id: a6f408321653c3589ae1c76826177fc403d59c44
(cherry picked from commit 6f4501730478dbaeeea7f3ad4f9d29bf6787e7c1)
This commit is contained in:
Hui Guo
2022-05-04 12:22:43 -07:00
committed by PyTorch MergeBot
parent fc64dbdc01
commit ca0f267022
3 changed files with 270 additions and 27 deletions

View File

@@ -9,6 +9,7 @@ def func_name_base_str(g: NativeFunctionsGroup) -> str:
is_hand_written_ops_ = frozenset(
(
"abs",
"add",
"addmm",
"all",
@@ -16,12 +17,18 @@ is_hand_written_ops_ = frozenset(
"argmin",
"bmm",
"clamp",
"clamp_min",
"cumsum",
"div",
"fmod",
"index_select",
"leaky_relu",
"linear",
"log",
"matmul",
"mul",
"narrow_copy",
"nonzero",
"pow",
"remainder",
"sigmoid",
@@ -39,6 +46,140 @@ def is_hand_written(g: NativeFunctionsGroup) -> bool:
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
assert index == 0 or index == 1
if op_name == "addr":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["vec1"] = "at::rand({6})"
arg_map["vec2"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["vec1"] = "at::rand({22})"
arg_map["vec2"] = "at::rand({22})"
return
if op_name == "mv":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["vec"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["vec"] = "at::rand({22})"
return
if op_name == "addbmm":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
else:
arg_map["self"] = "at::rand({22, 22})"
return
if op_name == "cross":
if index == 0:
arg_map["self"] = "at::rand({3, 3, 3})"
arg_map["other"] = "at::rand({3, 3, 3})"
else:
arg_map["self"] = "at::rand({22, 3, 22})"
arg_map["other"] = "at::rand({22, 3, 22})"
return
if op_name == "take":
if index == 0:
arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
else:
arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
return
if op_name == "take_along_dim":
if index == 0:
arg_map["indices"] = "at::argsort(self0, 1)"
else:
arg_map["indices"] = "at::argsort(self1, 1)"
return
if op_name == "masked_select":
if index == 0:
arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
else:
arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
return
if op_name == "orgqr":
if index == 0:
arg_map["input2"] = "at::rand({6, 6})"
else:
arg_map["input2"] = "at::rand({22, 22})"
return
if op_name == "ormqr":
if index == 0:
arg_map["input2"] = "at::rand({6, 6})"
else:
arg_map["input2"] = "at::rand({22, 22})"
return
if op_name == "quantile":
if index == 0:
arg_map["q"] = "at::rand({6})"
arg_map["interpolation"] = '"linear"'
else:
arg_map["q"] = "at::rand({22})"
arg_map["interpolation"] = '"linear"'
return
if op_name == "nanquantile":
if index == 0:
arg_map["q"] = "at::rand({6})"
arg_map["interpolation"] = '"linear"'
else:
arg_map["q"] = "at::rand({22})"
arg_map["interpolation"] = '"linear"'
return
if op_name == "multi_margin_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name == "multilabel_margin_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
return
if op_name == "nll_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name == "nll_loss2d":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6, 6})"
arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22, 22, 22})"
arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name in (
"fft_fft",
"fft_ifft",
"fft_rfft",
"fft_irfft",
"fft_hfft",
"fft_ihfft",
):
arg_map["norm"] = '"forward"'
return
if op_name == "linalg_tensorinv":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6, 6})"
arg_map["ind"] = "2"
else:
arg_map["self"] = "at::rand({22, 22, 22, 22})"
arg_map["ind"] = "2"
return
if op_name == "addmv":
if index == 0:
arg_map["self"] = "at::rand({2})"
@@ -171,6 +312,13 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
if "reduce" in arg_map:
arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
return
if op_name == "scatter_reduce":
arg_map["reduce"] = '"mean"'
if index == 0:
arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
else:
arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
return
if op_name == "special_zeta":
if index == 0:
arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"

View File

@@ -1,7 +1,7 @@
from torchgen import gen
from torchgen.context import native_function_manager
from torchgen.model import NativeFunctionsGroup
from torchgen.static_runtime import gen_structured
from torchgen.model import DispatchKey, NativeFunctionsGroup
from torchgen.static_runtime import generator
import argparse
import itertools
@@ -11,6 +11,8 @@ from typing import Sequence
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and
# `mean.dim` are grouped together by this function.
def group_functions_by_op_name(
grouped_native_functions: Sequence[NativeFunctionsGroup],
) -> Sequence[Sequence[NativeFunctionsGroup]]:
@@ -22,7 +24,7 @@ def group_functions_by_op_name(
def is_supported(g: NativeFunctionsGroup) -> bool:
with native_function_manager(g):
return gen_structured.is_supported(g)
return generator.is_supported(g)
eligible_ops = (g for g in grouped_native_functions if is_supported(g))
groups = [
@@ -33,6 +35,7 @@ def group_functions_by_op_name(
)
)
]
return groups
@@ -122,19 +125,19 @@ def main() -> None:
"-s",
"--source-path",
help="path to source directory for ATen",
default="aten/src/ATen",
default="caffe2/aten/src/ATen",
)
parser.add_argument(
"-p",
"--generated-ops-cpp-path",
help="path to directory to generate op dispatcher .cpp file",
default="torch/csrc/jit/runtime/static/generated_ops.cpp",
default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
)
parser.add_argument(
"-t",
"--generated-ops-test-cpp-path",
help="path to directory to generate op dispatcher .cpp file",
default="benchmarks/static_runtime/test_generated_ops.cc",
default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
)
options = parser.parse_args()
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
@@ -150,14 +153,13 @@ def main() -> None:
]
supported_function_groups = group_functions_by_op_name(structured_native_functions)
gen_out_variant_dispatcher = gen_structured.GenOutVariantDispatcher()
gen_out_variant_dispatcher = generator.GenOutVariantDispatcher()
result = [
gen_out_variant_dispatcher(groups) for groups in supported_function_groups
gen_out_variant_dispatcher(groups, backend_indices[DispatchKey.CPU])
for groups in supported_function_groups
]
gen_out_variant_dispatcher_test_case = (
gen_structured.GenOutVariantDispatcherTestCase()
)
gen_out_variant_dispatcher_test_case = generator.GenOutVariantDispatcherTestCase()
test_result = [
gen_out_variant_dispatcher_test_case(groups)
for groups in supported_function_groups

View File

@@ -2,6 +2,7 @@ import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
FunctionSchema,
OptionalType,
@@ -30,11 +31,46 @@ def has_alias(
return False
BLOCKED_OPS = frozenset(
(
# non cpu ops
"sparse_sampled_addmm",
"hspmm",
# sparse ops
"sspaddmm",
# deprecated ops
"floor_divide",
"ger",
# buggy ops
"conj_physical", # P495807361
"binary_cross_entropy", # P496394764
"arccosh",
# uncommon ops
"cholesky",
"lu_solve",
"linalg_cholesky",
"linalg_householder_product",
"_compute_linear_combination",
)
)
def is_supported(g: NativeFunctionsGroup) -> bool:
if not g.structured:
base_op_name = g.out.func.name.name.base
if base_op_name in BLOCKED_OPS:
return False
if config.is_hand_written(g):
return False
if not g.structured:
# In case of unstructured op, we check if it has out variant implementation.
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
# parameter.
if (
not hasattr(g, "out")
or not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
or not str(g.out.func.name).endswith(".out")
):
return False
if has_alias(g.out.func.arguments.non_out):
# This op may create an alias of inputs.
return False
@@ -80,7 +116,9 @@ def ivalue_type_conversion_method(
if isinstance(arg_type, BaseType):
base_ty_object = arg_type.name
elif isinstance(arg_type, OptionalType):
assert isinstance(arg_type.elem, BaseType)
if not isinstance(arg_type.elem, BaseType):
# ListType is currently unsupported.
return None
base_ty_object = arg_type.elem.name
else:
return None
@@ -120,9 +158,15 @@ test_tensor_dim_ops_1_ = frozenset(
"_convert_indices_from_coo_to_csr",
"_convert_indices_from_csr_to_coo",
"nll_loss_backward",
"dot",
"vdot",
"outer",
"ger",
)
)
test_tensor_dim_ops_2_ = frozenset(("addmm", "mm"))
test_tensor_dim_ops_2_ = frozenset(
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation")
)
def test_tensor_dim(op_name: str) -> int:
@@ -237,32 +281,79 @@ def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
return ";\n ".join(arg_populations) + ";"
def generate_non_out_variant_call(g: NativeFunctionsGroup) -> str:
def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
kernel = backend_index.get_kernel(g.functional)
if g.structured or kernel is None:
return cpp.name(g.functional.func)
return kernel.kernel
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
kernel = backend_index.get_kernel(g.out)
if g.structured or kernel is None:
return cpp.name(g.out.func)
return kernel.kernel
def generate_non_out_variant_call(
g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
schema = g.functional.func
assert not schema.is_out_fn()
kernel_name = get_kernel_name(g, backend_index)
arg_names = (arg.name for arg in schema.schema_order_arguments())
return f'at::cpu::{cpp.name(schema)}({",".join(arg_names)})'
namespace_name = "cpu" if g.structured else "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def generate_out_variant_call(g: NativeFunctionsGroup) -> str:
def generate_out_variant_call(
g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
schema = g.out.func
assert schema.is_out_fn()
arg_names = [out_arg.name for out_arg in schema.arguments.out]
arg_names = []
kernel_name = get_out_kernel_name(g, backend_index)
if g.structured:
# structured op starts with the output tensor argument.
arg_names = [out_arg.name for out_arg in schema.arguments.out]
else:
arg_names = []
for arg in schema.arguments.non_out:
if isinstance(arg, SelfArgument):
arg_names.append(arg.argument.name)
else:
assert isinstance(arg, Argument)
arg_names.append(arg.name)
if not g.structured:
assert len(schema.arguments.out) == 1
arg_names.append(schema.arguments.out[0].name)
cpp_func_name = cpp.name(schema)
cpp_arg_names = ",".join(arg_names)
return f"at::cpu::{cpp_func_name}({cpp_arg_names})"
namespace_name = "cpu" if g.structured else "native"
return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
no_memory_resize_ops = frozenset(
(
"isin.Scalar_Tensor",
"index_add",
"dot",
"vdot",
"nuclear_norm",
"histc",
"l1_loss",
"multi_margin_loss",
"multilabel_margin_loss",
"nll_loss",
"nll_loss2d",
)
)
def should_check_resize(schema: FunctionSchema) -> bool:
schema_str = str(schema)
type_variant_op_name = schema_str[: schema_str.find("(")]
return type_variant_op_name not in ("isin.Scalar_Tensor", "index_add")
return type_variant_op_name not in no_memory_resize_ops
def op_name_from_group(g: NativeFunctionsGroup) -> str:
@@ -270,7 +361,9 @@ def op_name_from_group(g: NativeFunctionsGroup) -> str:
class GenOutVariantDispatcher:
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
def __call__(
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
) -> str:
if not groups:
return ""
generated_type_variants = []
@@ -278,7 +371,7 @@ class GenOutVariantDispatcher:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.gen_structured(g)
generated_type_variant = self.op_generator(g, backend_index)
generated_type_variants.append(generated_type_variant)
op_name = op_name_from_group(groups[0])
body = "\n".join(generated_type_variants)
@@ -294,15 +387,15 @@ REGISTER_OPERATOR_FUNCTOR(
"""
return generated
def gen_structured(self, g: NativeFunctionsGroup) -> str:
def op_generator(self, g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
functional = g.functional
schema = str(functional.func)
op_name = op_name_from_group(g)
populated_argument = generate_arg_extraction(g)
functional_variant_call = generate_non_out_variant_call(g)
functional_variant_call = generate_non_out_variant_call(g, backend_index)
assert len(g.out.func.arguments.out) == 1
out_variable_name = str(g.out.func.arguments.out[0].name)
out_variant_call = generate_out_variant_call(g)
out_variant_call = generate_out_variant_call(g, backend_index)
generated = f"""
if (n->matches(torch::schema("aten::{schema}"))) {{
return [](ProcessedNode* p_node) {{
@@ -328,11 +421,11 @@ class GenOutVariantDispatcherTestCase:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.gen_structured_test_case(g)
generated_type_variant = self.test_case_generator(g)
generated_type_variants.append(generated_type_variant)
return "\n".join(generated_type_variants)
def gen_structured_test_case(self, g: NativeFunctionsGroup) -> str:
def test_case_generator(self, g: NativeFunctionsGroup) -> str:
functional = g.functional
schema = str(functional.func)
assert schema.find("(") > 0