mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[Static Runtime] [RFC] Codegen support for ops with unstructured kernels (#76203)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76203 Request for comments: This change adds extra code generator support to generate out variant wrappers for operators with unstructured kernels. The current version generates 105 new out variant wrappers in addition to the existing 136 auto-generated out variants wrappers. This change shows that a simple tweak can increase the generated op coverage to 16% (241/1559) among all native ops described in native_functions.yaml no. matter if they are structured or not. Command to generate out variant wrappers. ``` buck run //caffe2/torch/fb/jit:gen_static_runtime_ops ``` - AFTER this change ``` total grouped native ops: 1559 structured grouped native ops: 545 generated grouped native ops: 241 ``` - BEFORE this change ``` total grouped native ops: 1503 structured grouped native ops: 540 generated grouped native ops: 136 ``` To enable CI tests and make it easier to review, the generated ops are added in a separate diff: D35945633 More details: We added a block list to remove the generation of around 10 operations that are deprecated or for which the unit test would fail. All generated ops are well *compiled* but the compiled unittest may not pass due to the lack of hand-picked test input values for certain ops. Among the 42 ops whose unittest does not pass, 1 (op "index_select") is repeated from the existing ops; 32 ops are fixed; and 9 ops are removed and blocked from generation because either it is not being commonly used in internal models such as "cholesky", "linalg_householder_product", sparse kernel "sspaddmm", or it causes some errors in static runtime such as "conj_physical" leads to an error in memory planner, and "binary_cross_entropy". Test Plan: OP generation: ```buck run //caffe2/torch/fb/jit:gen_static_runtime_ops``` Test generated ops: ```buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest``` Reviewed By: tenpercent Differential Revision: D34913736 fbshipit-source-id: a6f408321653c3589ae1c76826177fc403d59c44 (cherry picked from commit 6f4501730478dbaeeea7f3ad4f9d29bf6787e7c1)
This commit is contained in:
committed by
PyTorch MergeBot
parent
fc64dbdc01
commit
ca0f267022
@@ -9,6 +9,7 @@ def func_name_base_str(g: NativeFunctionsGroup) -> str:
|
||||
|
||||
is_hand_written_ops_ = frozenset(
|
||||
(
|
||||
"abs",
|
||||
"add",
|
||||
"addmm",
|
||||
"all",
|
||||
@@ -16,12 +17,18 @@ is_hand_written_ops_ = frozenset(
|
||||
"argmin",
|
||||
"bmm",
|
||||
"clamp",
|
||||
"clamp_min",
|
||||
"cumsum",
|
||||
"div",
|
||||
"fmod",
|
||||
"index_select",
|
||||
"leaky_relu",
|
||||
"linear",
|
||||
"log",
|
||||
"matmul",
|
||||
"mul",
|
||||
"narrow_copy",
|
||||
"nonzero",
|
||||
"pow",
|
||||
"remainder",
|
||||
"sigmoid",
|
||||
@@ -39,6 +46,140 @@ def is_hand_written(g: NativeFunctionsGroup) -> bool:
|
||||
|
||||
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
|
||||
assert index == 0 or index == 1
|
||||
if op_name == "addr":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
arg_map["vec1"] = "at::rand({6})"
|
||||
arg_map["vec2"] = "at::rand({6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
arg_map["vec1"] = "at::rand({22})"
|
||||
arg_map["vec2"] = "at::rand({22})"
|
||||
return
|
||||
if op_name == "mv":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
arg_map["vec"] = "at::rand({6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
arg_map["vec"] = "at::rand({22})"
|
||||
return
|
||||
if op_name == "addbmm":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
return
|
||||
if op_name == "cross":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({3, 3, 3})"
|
||||
arg_map["other"] = "at::rand({3, 3, 3})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 3, 22})"
|
||||
arg_map["other"] = "at::rand({22, 3, 22})"
|
||||
return
|
||||
if op_name == "take":
|
||||
if index == 0:
|
||||
arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
|
||||
else:
|
||||
arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
|
||||
return
|
||||
if op_name == "take_along_dim":
|
||||
if index == 0:
|
||||
arg_map["indices"] = "at::argsort(self0, 1)"
|
||||
else:
|
||||
arg_map["indices"] = "at::argsort(self1, 1)"
|
||||
return
|
||||
if op_name == "masked_select":
|
||||
if index == 0:
|
||||
arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
|
||||
else:
|
||||
arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
|
||||
return
|
||||
if op_name == "orgqr":
|
||||
if index == 0:
|
||||
arg_map["input2"] = "at::rand({6, 6})"
|
||||
else:
|
||||
arg_map["input2"] = "at::rand({22, 22})"
|
||||
return
|
||||
if op_name == "ormqr":
|
||||
if index == 0:
|
||||
arg_map["input2"] = "at::rand({6, 6})"
|
||||
else:
|
||||
arg_map["input2"] = "at::rand({22, 22})"
|
||||
return
|
||||
if op_name == "quantile":
|
||||
if index == 0:
|
||||
arg_map["q"] = "at::rand({6})"
|
||||
arg_map["interpolation"] = '"linear"'
|
||||
else:
|
||||
arg_map["q"] = "at::rand({22})"
|
||||
arg_map["interpolation"] = '"linear"'
|
||||
return
|
||||
if op_name == "nanquantile":
|
||||
if index == 0:
|
||||
arg_map["q"] = "at::rand({6})"
|
||||
arg_map["interpolation"] = '"linear"'
|
||||
else:
|
||||
arg_map["q"] = "at::rand({22})"
|
||||
arg_map["interpolation"] = '"linear"'
|
||||
return
|
||||
if op_name == "multi_margin_loss":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({22})"
|
||||
return
|
||||
if op_name == "multilabel_margin_loss":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
|
||||
return
|
||||
if op_name == "nll_loss":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6})"
|
||||
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22})"
|
||||
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({22})"
|
||||
return
|
||||
if op_name == "nll_loss2d":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6, 6, 6})"
|
||||
arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({6})"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22, 22, 22})"
|
||||
arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
|
||||
arg_map["weight"] = "at::rand({22})"
|
||||
return
|
||||
if op_name in (
|
||||
"fft_fft",
|
||||
"fft_ifft",
|
||||
"fft_rfft",
|
||||
"fft_irfft",
|
||||
"fft_hfft",
|
||||
"fft_ihfft",
|
||||
):
|
||||
arg_map["norm"] = '"forward"'
|
||||
return
|
||||
if op_name == "linalg_tensorinv":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({6, 6, 6, 6})"
|
||||
arg_map["ind"] = "2"
|
||||
else:
|
||||
arg_map["self"] = "at::rand({22, 22, 22, 22})"
|
||||
arg_map["ind"] = "2"
|
||||
return
|
||||
if op_name == "addmv":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({2})"
|
||||
@@ -171,6 +312,13 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
|
||||
if "reduce" in arg_map:
|
||||
arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
|
||||
return
|
||||
if op_name == "scatter_reduce":
|
||||
arg_map["reduce"] = '"mean"'
|
||||
if index == 0:
|
||||
arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
|
||||
else:
|
||||
arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
|
||||
return
|
||||
if op_name == "special_zeta":
|
||||
if index == 0:
|
||||
arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from torchgen import gen
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import NativeFunctionsGroup
|
||||
from torchgen.static_runtime import gen_structured
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup
|
||||
from torchgen.static_runtime import generator
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
@@ -11,6 +11,8 @@ from typing import Sequence
|
||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||
# `mean.dim` are grouped together by this function.
|
||||
|
||||
|
||||
def group_functions_by_op_name(
|
||||
grouped_native_functions: Sequence[NativeFunctionsGroup],
|
||||
) -> Sequence[Sequence[NativeFunctionsGroup]]:
|
||||
@@ -22,7 +24,7 @@ def group_functions_by_op_name(
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup) -> bool:
|
||||
with native_function_manager(g):
|
||||
return gen_structured.is_supported(g)
|
||||
return generator.is_supported(g)
|
||||
|
||||
eligible_ops = (g for g in grouped_native_functions if is_supported(g))
|
||||
groups = [
|
||||
@@ -33,6 +35,7 @@ def group_functions_by_op_name(
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
@@ -122,19 +125,19 @@ def main() -> None:
|
||||
"-s",
|
||||
"--source-path",
|
||||
help="path to source directory for ATen",
|
||||
default="aten/src/ATen",
|
||||
default="caffe2/aten/src/ATen",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--generated-ops-cpp-path",
|
||||
help="path to directory to generate op dispatcher .cpp file",
|
||||
default="torch/csrc/jit/runtime/static/generated_ops.cpp",
|
||||
default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--generated-ops-test-cpp-path",
|
||||
help="path to directory to generate op dispatcher .cpp file",
|
||||
default="benchmarks/static_runtime/test_generated_ops.cc",
|
||||
default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
|
||||
)
|
||||
options = parser.parse_args()
|
||||
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
|
||||
@@ -150,14 +153,13 @@ def main() -> None:
|
||||
]
|
||||
supported_function_groups = group_functions_by_op_name(structured_native_functions)
|
||||
|
||||
gen_out_variant_dispatcher = gen_structured.GenOutVariantDispatcher()
|
||||
gen_out_variant_dispatcher = generator.GenOutVariantDispatcher()
|
||||
result = [
|
||||
gen_out_variant_dispatcher(groups) for groups in supported_function_groups
|
||||
gen_out_variant_dispatcher(groups, backend_indices[DispatchKey.CPU])
|
||||
for groups in supported_function_groups
|
||||
]
|
||||
|
||||
gen_out_variant_dispatcher_test_case = (
|
||||
gen_structured.GenOutVariantDispatcherTestCase()
|
||||
)
|
||||
gen_out_variant_dispatcher_test_case = generator.GenOutVariantDispatcherTestCase()
|
||||
test_result = [
|
||||
gen_out_variant_dispatcher_test_case(groups)
|
||||
for groups in supported_function_groups
|
||||
|
||||
@@ -2,6 +2,7 @@ import torchgen.api.cpp as cpp
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
FunctionSchema,
|
||||
OptionalType,
|
||||
@@ -30,11 +31,46 @@ def has_alias(
|
||||
return False
|
||||
|
||||
|
||||
BLOCKED_OPS = frozenset(
|
||||
(
|
||||
# non cpu ops
|
||||
"sparse_sampled_addmm",
|
||||
"hspmm",
|
||||
# sparse ops
|
||||
"sspaddmm",
|
||||
# deprecated ops
|
||||
"floor_divide",
|
||||
"ger",
|
||||
# buggy ops
|
||||
"conj_physical", # P495807361
|
||||
"binary_cross_entropy", # P496394764
|
||||
"arccosh",
|
||||
# uncommon ops
|
||||
"cholesky",
|
||||
"lu_solve",
|
||||
"linalg_cholesky",
|
||||
"linalg_householder_product",
|
||||
"_compute_linear_combination",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup) -> bool:
|
||||
if not g.structured:
|
||||
base_op_name = g.out.func.name.name.base
|
||||
if base_op_name in BLOCKED_OPS:
|
||||
return False
|
||||
if config.is_hand_written(g):
|
||||
return False
|
||||
if not g.structured:
|
||||
# In case of unstructured op, we check if it has out variant implementation.
|
||||
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
|
||||
# parameter.
|
||||
if (
|
||||
not hasattr(g, "out")
|
||||
or not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
|
||||
or not str(g.out.func.name).endswith(".out")
|
||||
):
|
||||
return False
|
||||
if has_alias(g.out.func.arguments.non_out):
|
||||
# This op may create an alias of inputs.
|
||||
return False
|
||||
@@ -80,7 +116,9 @@ def ivalue_type_conversion_method(
|
||||
if isinstance(arg_type, BaseType):
|
||||
base_ty_object = arg_type.name
|
||||
elif isinstance(arg_type, OptionalType):
|
||||
assert isinstance(arg_type.elem, BaseType)
|
||||
if not isinstance(arg_type.elem, BaseType):
|
||||
# ListType is currently unsupported.
|
||||
return None
|
||||
base_ty_object = arg_type.elem.name
|
||||
else:
|
||||
return None
|
||||
@@ -120,9 +158,15 @@ test_tensor_dim_ops_1_ = frozenset(
|
||||
"_convert_indices_from_coo_to_csr",
|
||||
"_convert_indices_from_csr_to_coo",
|
||||
"nll_loss_backward",
|
||||
"dot",
|
||||
"vdot",
|
||||
"outer",
|
||||
"ger",
|
||||
)
|
||||
)
|
||||
test_tensor_dim_ops_2_ = frozenset(("addmm", "mm"))
|
||||
test_tensor_dim_ops_2_ = frozenset(
|
||||
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation")
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_dim(op_name: str) -> int:
|
||||
@@ -237,32 +281,79 @@ def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
|
||||
return ";\n ".join(arg_populations) + ";"
|
||||
|
||||
|
||||
def generate_non_out_variant_call(g: NativeFunctionsGroup) -> str:
|
||||
def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
||||
kernel = backend_index.get_kernel(g.functional)
|
||||
if g.structured or kernel is None:
|
||||
return cpp.name(g.functional.func)
|
||||
return kernel.kernel
|
||||
|
||||
|
||||
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
||||
kernel = backend_index.get_kernel(g.out)
|
||||
if g.structured or kernel is None:
|
||||
return cpp.name(g.out.func)
|
||||
return kernel.kernel
|
||||
|
||||
|
||||
def generate_non_out_variant_call(
|
||||
g: NativeFunctionsGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
schema = g.functional.func
|
||||
assert not schema.is_out_fn()
|
||||
kernel_name = get_kernel_name(g, backend_index)
|
||||
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
||||
return f'at::cpu::{cpp.name(schema)}({",".join(arg_names)})'
|
||||
namespace_name = "cpu" if g.structured else "native"
|
||||
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
||||
|
||||
|
||||
def generate_out_variant_call(g: NativeFunctionsGroup) -> str:
|
||||
def generate_out_variant_call(
|
||||
g: NativeFunctionsGroup, backend_index: BackendIndex
|
||||
) -> str:
|
||||
schema = g.out.func
|
||||
assert schema.is_out_fn()
|
||||
arg_names = [out_arg.name for out_arg in schema.arguments.out]
|
||||
arg_names = []
|
||||
kernel_name = get_out_kernel_name(g, backend_index)
|
||||
if g.structured:
|
||||
# structured op starts with the output tensor argument.
|
||||
arg_names = [out_arg.name for out_arg in schema.arguments.out]
|
||||
else:
|
||||
arg_names = []
|
||||
for arg in schema.arguments.non_out:
|
||||
if isinstance(arg, SelfArgument):
|
||||
arg_names.append(arg.argument.name)
|
||||
else:
|
||||
assert isinstance(arg, Argument)
|
||||
arg_names.append(arg.name)
|
||||
if not g.structured:
|
||||
assert len(schema.arguments.out) == 1
|
||||
arg_names.append(schema.arguments.out[0].name)
|
||||
cpp_func_name = cpp.name(schema)
|
||||
cpp_arg_names = ",".join(arg_names)
|
||||
return f"at::cpu::{cpp_func_name}({cpp_arg_names})"
|
||||
namespace_name = "cpu" if g.structured else "native"
|
||||
return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
|
||||
|
||||
|
||||
no_memory_resize_ops = frozenset(
|
||||
(
|
||||
"isin.Scalar_Tensor",
|
||||
"index_add",
|
||||
"dot",
|
||||
"vdot",
|
||||
"nuclear_norm",
|
||||
"histc",
|
||||
"l1_loss",
|
||||
"multi_margin_loss",
|
||||
"multilabel_margin_loss",
|
||||
"nll_loss",
|
||||
"nll_loss2d",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def should_check_resize(schema: FunctionSchema) -> bool:
|
||||
schema_str = str(schema)
|
||||
type_variant_op_name = schema_str[: schema_str.find("(")]
|
||||
return type_variant_op_name not in ("isin.Scalar_Tensor", "index_add")
|
||||
return type_variant_op_name not in no_memory_resize_ops
|
||||
|
||||
|
||||
def op_name_from_group(g: NativeFunctionsGroup) -> str:
|
||||
@@ -270,7 +361,9 @@ def op_name_from_group(g: NativeFunctionsGroup) -> str:
|
||||
|
||||
|
||||
class GenOutVariantDispatcher:
|
||||
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
||||
def __call__(
|
||||
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
|
||||
) -> str:
|
||||
if not groups:
|
||||
return ""
|
||||
generated_type_variants = []
|
||||
@@ -278,7 +371,7 @@ class GenOutVariantDispatcher:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsGroup)
|
||||
generated_type_variant = self.gen_structured(g)
|
||||
generated_type_variant = self.op_generator(g, backend_index)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
op_name = op_name_from_group(groups[0])
|
||||
body = "\n".join(generated_type_variants)
|
||||
@@ -294,15 +387,15 @@ REGISTER_OPERATOR_FUNCTOR(
|
||||
"""
|
||||
return generated
|
||||
|
||||
def gen_structured(self, g: NativeFunctionsGroup) -> str:
|
||||
def op_generator(self, g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
||||
functional = g.functional
|
||||
schema = str(functional.func)
|
||||
op_name = op_name_from_group(g)
|
||||
populated_argument = generate_arg_extraction(g)
|
||||
functional_variant_call = generate_non_out_variant_call(g)
|
||||
functional_variant_call = generate_non_out_variant_call(g, backend_index)
|
||||
assert len(g.out.func.arguments.out) == 1
|
||||
out_variable_name = str(g.out.func.arguments.out[0].name)
|
||||
out_variant_call = generate_out_variant_call(g)
|
||||
out_variant_call = generate_out_variant_call(g, backend_index)
|
||||
generated = f"""
|
||||
if (n->matches(torch::schema("aten::{schema}"))) {{
|
||||
return [](ProcessedNode* p_node) {{
|
||||
@@ -328,11 +421,11 @@ class GenOutVariantDispatcherTestCase:
|
||||
with native_function_manager(g):
|
||||
assert is_supported(g)
|
||||
assert isinstance(g, NativeFunctionsGroup)
|
||||
generated_type_variant = self.gen_structured_test_case(g)
|
||||
generated_type_variant = self.test_case_generator(g)
|
||||
generated_type_variants.append(generated_type_variant)
|
||||
return "\n".join(generated_type_variants)
|
||||
|
||||
def gen_structured_test_case(self, g: NativeFunctionsGroup) -> str:
|
||||
def test_case_generator(self, g: NativeFunctionsGroup) -> str:
|
||||
functional = g.functional
|
||||
schema = str(functional.func)
|
||||
assert schema.find("(") > 0
|
||||
Reference in New Issue
Block a user