diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 0ffe7cb37de..e6aa08601a0 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -635,6 +635,64 @@ class TestPatternMatcher(TestCase): self.assertEqual(res1, res2) + @skipIfRocm + def test_addmm_activation(self): + def fn_addmm_relu(input, mat1, mat2): + return torch.nn.functional.relu(torch.addmm(input, mat1, mat2)) + + def fn_addmm_gelu(input, mat1, mat2): + return torch.nn.functional.gelu( + torch.addmm(input, mat1, mat2), approximate="tanh" + ) + + def fn_add_mm_relu(input, mat1, mat2): + return torch.nn.functional.relu(torch.add(input, torch.mm(mat1, mat2))) + + def fn_add_mm_gelu(input, mat1, mat2): + return torch.nn.functional.gelu( + torch.add(input, torch.mm(mat1, mat2)), approximate="tanh" + ) + + args = [ + torch.randn(20, device=GPU_TYPE), + torch.randn(10, 15, device=GPU_TYPE), + torch.randn(15, 20, device=GPU_TYPE), + ] + + for fn, atol in ( + (fn_addmm_relu, 1e-8), + (fn_add_mm_relu, 1e-8), + (fn_addmm_gelu, 1e-3), + (fn_add_mm_gelu, 1e-3), + ): + expected = fn(*args) + actual, (code,) = run_and_get_code(torch.compile(fn), *args) + torch.testing.assert_close(actual, expected, atol=atol, rtol=0) + self.assertTrue("_addmm_activation" in code) + + for fn in (fn_addmm_relu, fn_addmm_gelu): + actual, (code,) = run_and_get_code( + torch.compile(fn, options={"max_autotune_gemm": True}), *args + ) + self.assertFalse("_addmm_activation" in code) + + args_not_replaced = [ + # addmm + activation with a rank-2 input + # is not fusable, hence not replaced + torch.randn(10, 20, device=GPU_TYPE), # input + torch.randn(10, 15, device=GPU_TYPE), # mat1 + torch.randn(15, 20, device=GPU_TYPE), # mat2 + ] + + for fn in (fn_addmm_relu, fn_addmm_gelu): + actual, (code,) = run_and_get_code( + torch.compile( + fn, + ), + *args_not_replaced, + ) + self.assertFalse("_addmm_activation" in code) + @inductor_config.patch( { "max_autotune_gemm_backends": "ATEN", diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index db273b06c8e..d72079b83a0 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -33,6 +33,7 @@ from ..pattern_matcher import ( CallFunctionVarArgs, filter_nodes, fwd_only, + gen_register_replacement, get_arg_value, get_mutation_region_id, Ignored, @@ -660,6 +661,97 @@ def lazy_init(): extra_check=prepare_softmax_extra_check, ) + register_addmm_activation_fusion() + + +@functools.cache +def register_addmm_activation_fusion(): + shapes = [(5,), (3, 4), (4, 5)] + args_fp32 = [torch.empty(shape) for shape in shapes] + args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes] + + for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]: + name = f"{pattern.__name__}_fp32" + gen_register_replacement( + name, + pattern, + addmm_relu_replacement, + args_fp32, + trace_fn=fwd_only, + pass_dicts=pass_patterns[2], + extra_check=is_valid_addmm_activation_fusion, + ) + + for args, dtype_suffix in [(args_fp32, "fp32"), (args_bf16, "bf16")]: + for pattern in [addmm_gelu_pattern, addmm_gelu_pattern_2]: + name = f"{pattern.__name__}_{dtype_suffix}" + gen_register_replacement( + name, + pattern, + addmm_gelu_replacement, + args, + trace_fn=fwd_only, + pass_dicts=pass_patterns[2], + extra_check=is_valid_addmm_activation_fusion, + ) + + +def is_valid_addmm_activation_fusion(match): + if config.max_autotune_gemm: + return False + inp = match.kwargs["input"].meta["val"] + mat1 = match.kwargs["mat1"].meta["val"] + mat2 = match.kwargs["mat2"].meta["val"] + + # match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp + if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()): + return False + + if not (mat1.dim() == 2 and mat2.dim() == 2): + return False + + if inp.size(0) != mat2.size(1): + return False + + if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype: + return False + + output = match.output_node() + # do not fuse if there are pointwise ops after + return not all(is_pointwise_use(use) for use in output.users) + + +def addmm_gelu_pattern(input, mat1, mat2): + output = aten.mm(mat1, mat2) + output = aten.add(output, input) + return aten.gelu(output, approximate="tanh") + + +def addmm_gelu_pattern_2(input, mat1, mat2): + output = aten.mm(mat1, mat2) + output = aten.add(input, output) + return aten.gelu(output, approximate="tanh") + + +def addmm_gelu_replacement(input, mat1, mat2): + return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=True) + + +def addmm_relu_pattern(input, mat1, mat2): + output = aten.mm(mat1, mat2) + output = aten.add(input, output) + return aten.relu(output) + + +def addmm_relu_pattern_2(input, mat1, mat2): + output = aten.mm(mat1, mat2) + output = aten.add(output, input) + return aten.relu(output) + + +def addmm_relu_replacement(input, mat1, mat2): + return aten._addmm_activation(input, mat1, mat2, beta=1, alpha=1, use_gelu=False) + def reorder_for_locality(graph: torch.fx.Graph): if torch.distributed.is_available(): @@ -1461,7 +1553,7 @@ def should_prefer_unfused_addmm(match): @register_graph_pattern( CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), - pass_dict=pass_patterns[2], + pass_dict=pass_patterns[1], extra_check=should_prefer_unfused_addmm, ) def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py new file mode 100644 index 00000000000..99f691e6fdd --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py @@ -0,0 +1,59 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input'), _users=4) +mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) +add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) +tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) +add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) +addmm_gelu_pattern_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0) + + +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) +add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) +tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) +add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2) +addmm_gelu_pattern_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py new file mode 100644 index 00000000000..288177ed37a --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern_2.py @@ -0,0 +1,59 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default, _users=4) +mul_Tensor = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, add_Tensor, add_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, add_Tensor) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) +add_Tensor_1 = CallFunction(aten.add.Tensor, add_Tensor, mul_Tensor_3) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) +tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) +add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) +addmm_gelu_pattern_2_fp32 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2, _users=0) + + +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=4) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default, convert_element_type_default) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, convert_element_type_default) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) +add_Tensor_1 = CallFunction(aten.add.Tensor, convert_element_type_default, mul_Tensor_3) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor_1, Ignored()) +tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) +add_Tensor_2 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_2) +addmm_gelu_pattern_2_bf16 = CallFunction(prims.convert_element_type.default, mul_Tensor_5, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py new file mode 100644 index 00000000000..9deef11cf32 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py @@ -0,0 +1,36 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, KeywordArg('input'), mm_default) +addmm_relu_pattern_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py new file mode 100644 index 00000000000..4a3c4731051 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern_2.py @@ -0,0 +1,36 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +add_Tensor = CallFunction(aten.add.Tensor, mm_default, KeywordArg('input')) +addmm_relu_pattern_2_fp32 = CallFunction(aten.relu.default, add_Tensor, _users=0) diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py index 0861c882e3f..b4bdf022202 100644 --- a/torchgen/fuse/gen_patterns.py +++ b/torchgen/fuse/gen_patterns.py @@ -2,7 +2,7 @@ import os from torch._inductor import pattern_matcher -from torch._inductor.fx_passes import joint_graph +from torch._inductor.fx_passes import joint_graph, post_grad if __name__ == "__main__": @@ -17,3 +17,4 @@ if __name__ == "__main__": # to serialize the patterns as it goes. os.environ["PYTORCH_GEN_PATTERNS"] = "1" joint_graph.lazy_init() + post_grad.lazy_init()