diff --git a/buckbuild.bzl b/buckbuild.bzl index 380d330600a..305a0cf3c89 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -960,7 +960,6 @@ def define_buck_targets( "Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]", "VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]", "variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]", - "ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]", # Don't build python bindings on mobile. #"python_functions.h", }, @@ -1467,7 +1466,6 @@ def define_buck_targets( "torch/csrc/jit/mobile/train/random.cpp", "torch/csrc/jit/mobile/train/sequential.cpp", ":gen_aten_libtorch[autograd/generated/Functions.cpp]", - ":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]", ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], diff --git a/build.bzl b/build.bzl index 5ab9f92acec..6490a7f3839 100644 --- a/build.bzl +++ b/build.bzl @@ -261,7 +261,6 @@ _GENERATED_AUTOGRAD_PYTHON_HEADERS = [ _GENERATED_AUTOGRAD_CPP_HEADERS = [ "torch/csrc/autograd/generated/Functions.h", "torch/csrc/autograd/generated/VariableType.h", - "torch/csrc/autograd/generated/ViewFuncs.h", "torch/csrc/autograd/generated/variable_factories.h", ] @@ -304,7 +303,6 @@ GENERATED_AUTOGRAD_CPP = [ "torch/csrc/autograd/generated/VariableType_2.cpp", "torch/csrc/autograd/generated/VariableType_3.cpp", "torch/csrc/autograd/generated/VariableType_4.cpp", - "torch/csrc/autograd/generated/ViewFuncs.cpp", "torch/csrc/autograd/generated/TraceType_0.cpp", "torch/csrc/autograd/generated/TraceType_1.cpp", "torch/csrc/autograd/generated/TraceType_2.cpp", diff --git a/build_variables.bzl b/build_variables.bzl index 6240f7aefa6..efa452cee32 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -24,7 +24,6 @@ def libtorch_generated_sources(gencode_pattern): "torch/csrc/autograd/generated/VariableType_2.cpp", "torch/csrc/autograd/generated/VariableType_3.cpp", "torch/csrc/autograd/generated/VariableType_4.cpp", - "torch/csrc/autograd/generated/ViewFuncs.cpp", "torch/csrc/autograd/generated/TraceType_0.cpp", "torch/csrc/autograd/generated/TraceType_1.cpp", "torch/csrc/autograd/generated/TraceType_2.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 17582104933..949e1d0e803 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -351,7 +351,6 @@ configure_file("${TORCH_SRC_DIR}/csrc/api/include/torch/version.h.in" set(GENERATED_CXX_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.cpp" ) if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER) @@ -381,7 +380,6 @@ endif() set(GENERATED_H_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h" "${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h" - "${TORCH_SRC_DIR}/csrc/autograd/generated/ViewFuncs.h" ) if(NOT INTERN_DISABLE_AUTOGRAD) diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index 6d42026ba6c..6bfcfc6f231 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -131,8 +131,6 @@ def get_generate_code_bin_outs(): "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"], "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"], "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"], - "autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"], - "autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"], } if is_arvr_mode(): diff --git a/test/test_autograd.py b/test/test_autograd.py index ec570fab4c6..bce6b3538c6 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8792,7 +8792,6 @@ get_out().sum().backward() _assert_match_metadata(new_inp, inp) new_out = out._view_func(new_inp) _assert_match_metadata(new_out, out) - self.assertEqual(new_out, out) # reverse view_func new_out = out.detach() @@ -8831,7 +8830,7 @@ get_out().sum().backward() _test_fn( lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), torch.randn(2, 3, 4)) _test_fn( - lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), torch.randn(2, 3, 4)) + lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, -1), torch.randn(2, 3, 4)) # chains with missing view_func()s use as_strided() to cover the gaps def chain_with_only_parent_view_func(x): @@ -8839,7 +8838,7 @@ get_out().sum().backward() x = x.split_with_sizes([1, 3], -1)[0] with torch.autograd._force_original_view_tracking(False): - x = x.chunk(2, 0) + x = x.chunk(2, -1) return x @@ -8850,50 +8849,12 @@ get_out().sum().backward() x = x.split_with_sizes([1, 3], -1)[0] with torch.autograd._force_original_view_tracking(True): - x = x.chunk(2, 0) + x = x.chunk(2, -1) return x _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4)) - def test_view_func_replay_with_modified_state(self): - with torch.autograd._force_original_view_tracking(True): - base = torch.randn(3, 4, 5) - view = base.select(1, 2) - - def symint_visitor_fn(x): - # modify saved index - return x + 1 - - # ensure modifying state changes view replay - new_base = torch.randn_like(base) - new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn) - self.assertEqual(new_view, new_base.select(1, 3)) - - # ensure saved state reverts back afterwards - self.assertEqual(view._view_func(new_base), new_base.select(1, 2)) - - # check modifying tensor state. currently, slice_inverse() is the only - # view that saves a tensor - base = torch.randn(3, 4, 5) - sliced = base[:, 2:3, :].detach() - view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1) - - replacement_shape = (1, 2, 3) - - def tensor_visitor_fn(x): - # return tensor with a smaller shape than the saved one - return torch.randn(*replacement_shape) - - # ensure modifying state changes view replay - new_sliced = torch.ones_like(base)[:, 2:3, :].detach() - new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn) - self.assertEqual(new_view.shape, replacement_shape) - self.assertEqual(new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1))) - - # ensure saved state reverts back afterwards - self.assertEqual(view._view_func(sliced), base) - def test_setup_context_when_forward_has_default_args(self): class PowFunction(Function): @staticmethod diff --git a/test/test_ops.py b/test/test_ops.py index 22023059ea6..74a1f86b01e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1564,7 +1564,6 @@ class TestCompositeCompliance(TestCase): _assert_match_metadata(new_inp, inp) new_out = out._view_func_unsafe(new_inp) _assert_match_metadata(new_out, out) - self.assertEqual(new_out, out) # reverse view_func new_out = out.detach() diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 5b410274ff7..7ad0882e44f 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -124,8 +124,6 @@ def define_tools_targets( "autograd/templates/TraceType.cpp", "autograd/templates/VariableType.cpp", "autograd/templates/VariableType.h", - "autograd/templates/ViewFuncs.cpp", - "autograd/templates/ViewFuncs.h", "autograd/templates/annotated_fn_args.py.in", "autograd/templates/python_enum_tag.cpp", "autograd/templates/python_fft_functions.cpp", diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 0d4aa91d3fa..c4d1df00a95 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -43,7 +43,6 @@ from .gen_inplace_or_view_type import gen_inplace_or_view_type from .gen_trace_type import gen_trace_type from .gen_variable_factories import gen_variable_factories from .gen_variable_type import gen_variable_type -from .gen_view_funcs import gen_view_funcs from .load_derivatives import load_derivatives @@ -96,9 +95,6 @@ def gen_autograd( # Generate variable_factories.h gen_variable_factories(out, native_functions_path, tags_path, template_path) - # Generate ViewFuncs.h/cpp - gen_view_funcs(out, fns_with_diff_infos, template_path) - def gen_autograd_python( native_functions_path: str, diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index df9841312fa..6e713579445 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -4,7 +4,7 @@ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimick this codegen, so we should keep the two in sync. -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple from torchgen.api import cpp from torchgen.api.autograd import ( @@ -172,7 +172,7 @@ for (auto ${view_idx} : c10::irange(${var}.size())) { SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( """\ -std::unique_ptr func(nullptr); +std::function func=nullptr; std::function rev_func=nullptr; if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() || @@ -184,9 +184,11 @@ if (${is_view_with_metadata_change} || """ ) -REPLAY_VIEW_FUNC = CodeTemplate( +REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( """\ -func = std::make_unique<${view_func_name}>(${view_func_args}); +func = [=](const at::Tensor& ${input_base}) { + return ${replay_view_call}${view_indexing}; +}; """ ) @@ -344,13 +346,24 @@ def get_view_info(f: NativeFunction) -> Optional[str]: return view_info -def emit_view_func( +# For view replay calls, we generate an ordinary Dispatcher::call() instead, because: +# - We want to replay the entire call into the op, including any previously-set dispatch keys (including autograd!). +# - The view replay call also is not part of the hot path. +def emit_view_call( + f: NativeFunction, input_base: str, unpacked_args: Sequence[str] +) -> str: + # View replay functions use the standard Dispatcher::call API. + return CALL_DISPATCH.substitute( + unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=unpacked_args + ) + + +def emit_view_lambda( f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None ) -> str: """Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. """ - # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. input_base = "input_base" replay_view_func = "" updated_args: List[str] = [] @@ -399,14 +412,11 @@ def emit_view_func( else: updated_args.append(arg) - from .gen_view_funcs import view_func_name - - view_func_args = [b.name for b in bindings if b.name != "self"] - if view_idx is not None: - view_func_args.append(f"{view_idx}") - replay_view_func += REPLAY_VIEW_FUNC.substitute( - view_func_name=view_func_name(f, include_namespace=True), - view_func_args=view_func_args, + replay_view_call = emit_view_call(f, input_base, updated_args) + replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( + input_base=input_base, + replay_view_call=replay_view_call, + view_indexing=("" if view_idx is None else f"[{view_idx}]"), ) input_view = "input_view" @@ -483,26 +493,26 @@ def emit_view_body( if is_tensor_list_type(return_info.type): creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") view_idx = "view_idx" - view_func = emit_view_func( + view_lambda = emit_view_lambda( f, extract_bindings(f), view_idx=view_idx ).strip() as_view_call = ( f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " - "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " + "/* view_func */ func, /* rev_view_func */ rev_func, " f"/* creation_meta */ {creation_meta});" ) call += MULTI_OUTPUT_VIEW_ITERATION.substitute( - var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" + var=var, view_idx=view_idx, body=f"{view_lambda}\n{as_view_call}" ) rhs_value = f"std::move({var})" else: - call += emit_view_func(f, extract_bindings(f), view_idx=None) + call += emit_view_lambda(f, extract_bindings(f), view_idx=None) creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") rhs_value = ( f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " "/* is_fw_differentiable */ true, " - f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" + f"/* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" ) else: # This could be supported but we don't need it at the moment, so keeping things simple. diff --git a/tools/autograd/gen_view_funcs.py b/tools/autograd/gen_view_funcs.py deleted file mode 100644 index c9f7561dca1..00000000000 --- a/tools/autograd/gen_view_funcs.py +++ /dev/null @@ -1,334 +0,0 @@ -# Generates ViewFuncs.h/cpp -# -# NOTE: If any changes are being made to the ViewFunc codegen please also check -# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp -# The fallback is expected to mimic this codegen, so we should keep the two in sync. - -from typing import List, Tuple - -import torchgen.api.dispatcher as dispatcher -from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo -from torchgen.api.translate import translate -from torchgen.api.types import ( - BaseCType, - Binding, - NamedCType, - SymIntT, - tensorT, - VectorCType, -) -from torchgen.code_template import CodeTemplate -from torchgen.model import Argument, NativeFunction, OptionalType -from torchgen.utils import FileManager - -from .gen_inplace_or_view_type import ( - CALL_DISPATCH, - extract_bindings, - get_view_info, - modifies_arguments, - use_derived, -) - -FUNCTION_DECLARATION = CodeTemplate( - """\ -#define ${uppercase_op}_AVAILABLE -struct ${op} : public ${superclass} { - ${op}(${constructor_args}) ${initializer_list} - {}; - virtual ~${op}() override {}; - virtual std::vector get_symints() const override; - virtual size_t num_symints() const override; - virtual std::vector get_tensors() const override; - virtual size_t num_tensors() const override; - virtual at::Tensor operator()(const at::Tensor&) const override; - virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const override; - -protected: - virtual void set_symints(std::vector) override; - virtual void set_tensors(std::vector) override; - -private: - ${state} -}; - -""" -) - -FUNCTION_DEFINITION = CodeTemplate( - """\ -std::vector ${op}::get_symints() const { - ${get_symints} -} - -size_t ${op}::num_symints() const { - return static_cast(${num_symints}); -} - -void ${op}::set_symints(std::vector ${symints_vec}) { - TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); - ${set_symints} -} - -std::vector ${op}::get_tensors() const { - ${get_tensors} -} - -size_t ${op}::num_tensors() const { - return static_cast(${num_tensors}); -} - -void ${op}::set_tensors(std::vector ${tensors_vec}) { - TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); - ${set_tensors} -} - -at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { - return ${op_call}; -} - -std::unique_ptr ${op}::clone_and_set( - std::optional> ${symints_vec}, - std::optional> ${tensors_vec}) const { - auto output = std::make_unique<${op}>(${clone_args}); - if (${symints_vec}.has_value()) { - output->set_symints(std::move(*(${symints_vec}))); - } - if (${tensors_vec}.has_value()) { - output->set_tensors(std::move(*(${tensors_vec}))); - } - return output; -} - -""" -) - - -# e.g. as_strided -> AsStridedViewFunc for camel case or -# as_strided_view_func otherwise -def view_func_name( - f: NativeFunction, include_namespace: bool = False, camel_case: bool = True -) -> str: - name = f.func.name.unambiguous_name() - view_func_name = f"{name.replace('.', '_')}_view_func" - if camel_case: - is_private = view_func_name.startswith("_") - view_func_name = "".join( - [p.title() for p in view_func_name.replace(".", "_").split("_")] - ) - if is_private: - # put the leading underscore back in - view_func_name = f"_{view_func_name}" - namespace = "torch::autograd::generated::" if include_namespace else "" - return f"{namespace}{view_func_name}" - - -def is_symint_or_tensor(arg: Argument) -> bool: - return arg.type.is_tensor_like() or arg.type.is_symint_like() - - -def remove_const_ref(binding: Binding) -> Binding: - return Binding( - name=binding.name, - nctype=binding.nctype.remove_const_ref(), - argument=binding.argument, - default=binding.default, - ) - - -def returns_multi_tensor(fn: NativeFunction) -> bool: - returns = fn.func.returns - assert len(returns) == 1 - returns_list_like = returns[0].type.is_list_like() is not None - returns_tensor_like = returns[0].type.is_tensor_like() - return returns_list_like and returns_tensor_like - - -# Generates strings with logic for getting / setting state of a particular type. -# -# Args: -# bindings (list): List of state bindings of interest (may be empty) -# state_vec_type (NamedCType): Type of vector to either return or copy from -# -# Returns: -# tuple: (list of getter logic strings, list of setter logic strings, string -# with num items expression) -def generate_state_getter_setter( - bindings: List[Binding], - state_vec_type: NamedCType, -) -> Tuple[List[str], List[str], str]: - getter_logic = [] - setter_logic = [] - - state_vec = state_vec_type.name - getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") - if len(bindings) > 0: - setter_logic.append("auto i = 0;") - - num_exprs = [] - for i, b in enumerate(bindings): - assert isinstance(b.argument, Argument) - if b.argument.type.is_list_like(): - # Handle list-likes. - num_expr = f"{b.name}.size()" - num_exprs.append(num_expr) - getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" - setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" - elif isinstance(b.argument.type, OptionalType): - # Handle optionals. - num_expr = f"({b.name}.has_value() ? 1 : 0)" - num_exprs.append(num_expr) - conditional = f"if({b.name}.has_value())" - getter = ( - f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" - ) - setter = f"{conditional} {b.name} = {state_vec}[i];" - else: - num_expr = "1" - num_exprs.append(num_expr) - getter = f"{state_vec}.push_back({b.name});" - setter = f"{b.name} = {state_vec}[i];" - - getter_logic.append(getter) - setter_logic.append(setter) - if i < len(bindings) - 1: - setter_logic.append(f"i += {num_expr};") - - # Reserve / assert based on the total number of items expression. - num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) - if len(bindings) > 0: - getter_logic.insert(1, f"{state_vec}.reserve({num_items});") - - getter_logic.append(f"return {state_vec};") - - return getter_logic, setter_logic, num_items - - -def process_function(fn: NativeFunction, template: CodeTemplate) -> str: - bindings = extract_bindings(fn) - non_self_bindings = [b for b in bindings if b.name != "self"] - - non_self_args = fn.func.arguments.flat_all[1:] - non_self_value_bindings = [ - dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args - ] - - # Generate constructor / clone args for the generated struct. - constructor_args = [b.defn() for b in non_self_bindings] - clone_args = [b.name for b in non_self_bindings] - - # Generate state variable declarations for the generated struct. - state_variables = [ - f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings - ] - - # Generate initializer list expressions for the generated struct. - # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as - # vectors. - init_exprs = translate( - non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True - ) - initializers = [] - for b, init_expr in zip(non_self_bindings, init_exprs): - name = b.nctype.name - assert isinstance(name, str) - initializers.append(f"{name}({init_expr.expr})") - - # Generate call to underlying view op - call_input_name = "input_base" - op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] - op_call = CALL_DISPATCH.substitute( - unambiguous_name=fn.func.name.unambiguous_name(), - unpacked_args=op_call_args, - ) - - # Multi-output views additionally require a view_idx for disambiguation. - if returns_multi_tensor(fn): - view_idx_name = "view_idx" - view_idx_typename = "int64_t" - view_idx_decl = f"{view_idx_typename} {view_idx_name}" - constructor_args.append(view_idx_decl) - clone_args.append(view_idx_name) - state_variables.append(f"{view_idx_decl};") - initializers.append(f"{view_idx_name}({view_idx_name})") - op_call += f"[{view_idx_name}]" - - # Generate initializer list for the generated struct. - initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" - - # Generate getter / setter logic for any symints. - symint_bindings = [ - b - for b in non_self_bindings - if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() - ] - symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) - get_symints, set_symints, num_symints = generate_state_getter_setter( - symint_bindings, symints_vec_type - ) - - # Generate getter / setter logic for any tensors. - tensor_bindings = [ - b - for b in non_self_bindings - if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() - ] - tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) - get_tensors, set_tensors, num_tensors = generate_state_getter_setter( - tensor_bindings, tensors_vec_type - ) - - return template.substitute( - op=view_func_name(fn), - uppercase_op=view_func_name(fn, camel_case=False).upper(), - superclass="torch::autograd::ViewFunc", - initializer_list=initializer_list, - state=state_variables, - constructor_args=constructor_args, - clone_args=clone_args, - symints_vec=symints_vec_type.name, - get_symints=get_symints, - set_symints=set_symints, - num_symints=num_symints, - tensors_vec=tensors_vec_type.name, - get_tensors=get_tensors, - set_tensors=set_tensors, - num_tensors=num_tensors, - call_input_name=call_input_name, - op_call=op_call, - ) - - -def gen_view_funcs( - out: str, - fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], - template_path: str, -) -> None: - # don't need the info parts, just the function - fns = [fn.func for fn in fns_with_infos if use_derived(fn)] - # only want out-of-place views - view_fns = [ - fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) - ] - - declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] - definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] - ops_headers = [f"#include " for fn in view_fns] - - file_basename = "ViewFuncs" - fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) - for suffix in [".h", ".cpp"]: - fname = file_basename + suffix - fm.write_with_template( - fname, - fname, - lambda: { - "generated_comment": "@" - + f"generated from {fm.template_dir_for_comments()}/" - + fname, - "view_func_declarations": declarations, - "view_func_definitions": definitions, - "ops_headers": ops_headers, - }, - ) diff --git a/tools/autograd/templates/ADInplaceOrViewType.cpp b/tools/autograd/templates/ADInplaceOrViewType.cpp index e8276697eee..7a19047dd5c 100644 --- a/tools/autograd/templates/ADInplaceOrViewType.cpp +++ b/tools/autograd/templates/ADInplaceOrViewType.cpp @@ -1,6 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include "torch/csrc/autograd/VariableTypeUtils.h" -#include "torch/csrc/autograd/generated/ViewFuncs.h" #include #include diff --git a/tools/autograd/templates/ViewFuncs.cpp b/tools/autograd/templates/ViewFuncs.cpp deleted file mode 100644 index 0f104023ba2..00000000000 --- a/tools/autograd/templates/ViewFuncs.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include "torch/csrc/autograd/generated/ViewFuncs.h" - -// ${generated_comment} - -using at::Tensor; -using at::Scalar; -using at::IntArrayRef; -using at::TensorList; - -namespace torch::autograd { - -std::vector ChainedViewFunc::get_symints() const { - auto symints = first->get_symints(); - auto second_symints = second->get_symints(); - symints.reserve(symints.size() + second_symints.size()); - symints.insert( - symints.end(), - std::make_move_iterator(second_symints.begin()), - std::make_move_iterator(second_symints.end())); - return symints; -} - -std::vector ChainedViewFunc::get_tensors() const { - auto tensors = first->get_tensors(); - auto second_tensors = second->get_tensors(); - tensors.reserve(tensors.size() + second_tensors.size()); - tensors.insert( - tensors.end(), - std::make_move_iterator(second_tensors.begin()), - std::make_move_iterator(second_tensors.end())); - return tensors; -} - -at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const { - return (*second)((*first)(input_base)); -} - -std::unique_ptr ChainedViewFunc::clone_and_set( - std::optional> symints, - std::optional> tensors) const { - std::optional> first_symints; - std::optional> second_symints; - if (symints.has_value()) { - TORCH_INTERNAL_ASSERT(symints->size() == num_symints()); - first_symints = std::vector( - symints->begin(), symints->begin() + first->num_symints()); - second_symints = std::vector( - symints->begin() + first->num_symints(), symints->end()); - } - - std::optional> first_tensors; - std::optional> second_tensors; - if (tensors.has_value()) { - TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors()); - first_tensors = std::vector( - tensors->begin(), tensors->begin() + first->num_tensors()); - second_tensors = std::vector( - tensors->begin() + first->num_tensors(), tensors->end()); - } - - return std::make_unique( - first->clone_and_set(first_symints, first_tensors), - second->clone_and_set(second_symints, second_tensors)); -} - -namespace generated { - -${view_func_definitions} - -} // namespace torch::autograd -} // namespace generated diff --git a/tools/autograd/templates/ViewFuncs.h b/tools/autograd/templates/ViewFuncs.h deleted file mode 100644 index f9acce387a9..00000000000 --- a/tools/autograd/templates/ViewFuncs.h +++ /dev/null @@ -1,106 +0,0 @@ -#pragma once - -// ${generated_comment} - -#include - -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -$ops_headers -#endif - -namespace torch { namespace autograd { - -/// Base class for view functions, providing reapplication of a view on a new base. -/// Each view op should get a codegenerated subclass of this class containing -/// any state needed to reconstruct the view. The class also provides convenience -/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification, -/// where we want to use symbolic values or fake tensors instead. -struct TORCH_API ViewFunc { - virtual ~ViewFunc() {} - /// Returns any SymInts in the saved state. - virtual std::vector get_symints() const { return {}; } - /// Returns the number of SymInts in the saved state. - virtual size_t num_symints() const { return 0; } - /// Returns any tensors in the saved state. - virtual std::vector get_tensors() const { return {}; } - /// Returns the number of tensors in the saved state. - virtual size_t num_tensors() const { return 0; } - /// Reapplies the view on the given base using the saved state. - virtual at::Tensor operator()(const at::Tensor&) const = 0; - /// Returns a clone of this ViewFunc, optionally with the specified saved state. - virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const = 0; - -protected: - /// Sets the values of any SymInts in the saved state. The input vector size must - /// match the number of SymInts in the saved state (i.e. the size of the list - /// returned by get_symints()). - virtual void set_symints(std::vector) {} - /// Sets the values of any Tensors in the saved state. The input vector size must - /// match the number of Tensors in the saved state (i.e. the size of the list - /// returned by get_tensors()). - virtual void set_tensors(std::vector) {} -}; - -/// ViewFunc that represents a chain of two ViewFuncs. -struct ChainedViewFunc : public ViewFunc { - ChainedViewFunc( - std::unique_ptr first, - std::unique_ptr second) - : first(std::move(first)), - second(std::move(second)) {} - virtual ~ChainedViewFunc() override {}; - virtual std::vector get_symints() const override; - virtual size_t num_symints() const override { - return first->num_symints() + second->num_symints(); - } - virtual std::vector get_tensors() const override; - virtual size_t num_tensors() const override { - return first->num_tensors() + second->num_tensors(); - } - virtual at::Tensor operator()(const at::Tensor&) const override; - virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const override; - -private: - std::unique_ptr first; - std::unique_ptr second; -}; - -/// ViewFunc that errors with a specified error message when called. -struct ErroringViewFunc : public ViewFunc { - ErroringViewFunc(const std::string& error_msg) : error_msg(error_msg) {} - virtual ~ErroringViewFunc() override {}; - virtual at::Tensor operator()(const at::Tensor&) const override { - TORCH_CHECK(false, error_msg); - } - virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const override { - return std::make_unique(error_msg); - } - -private: - std::string error_msg; -}; - -namespace generated { - -using at::Scalar; -using at::Tensor; -using at::IntArrayRef; -using at::ArrayRef; -using at::Type; -using at::ScalarType; -using c10::optional; -using c10::fmap; - -${view_func_declarations} - -}}} // namespace torch::autograd::generated diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 38a63640c11..83e963d64f6 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -475,10 +474,13 @@ static Tensor _fw_primal( at::AutoDispatchBelowADInplaceOrView guard; return at::alias(self); })(); - std::unique_ptr func(nullptr); + std::function func = nullptr; std::function rev_func = nullptr; if (!self.unsafeGetTensorImpl()->support_as_strided()) { - func = std::make_unique(self.sym_sizes()); + auto size_vec = self.sizes().vec(); + func = [=](const at::Tensor& input_base) { + return input_base.view(size_vec); + }; rev_func = [=](const at::Tensor& input_view) { TORCH_INTERNAL_ASSERT( false, @@ -508,10 +510,13 @@ static Tensor _make_dual( at::AutoDispatchBelowADInplaceOrView guard; return at::alias(primal); })(); - std::unique_ptr func(nullptr); + std::function func = nullptr; std::function rev_func = nullptr; if (!primal.unsafeGetTensorImpl()->support_as_strided()) { - func = std::make_unique(primal.sym_sizes()); + auto size_vec = primal.sizes().vec(); + func = [=](const at::Tensor& input_base) { + return input_base.view(size_vec); + }; rev_func = [=](const at::Tensor& input_view) { TORCH_INTERNAL_ASSERT( false, diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 9794ca9a4ad..8f87317d847 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -160,7 +160,7 @@ inline at::Tensor as_view( const at::Tensor& tensor, bool is_bw_differentiable, bool is_fw_differentiable, - std::unique_ptr view_func = nullptr, + std::function view_func = nullptr, std::function rev_view_func = nullptr, CreationMeta creation_meta = CreationMeta::DEFAULT, bool allow_tensor_metadata_change = true) { @@ -208,13 +208,11 @@ inline at::Tensor as_view( c10::optional new_fw_info; if (is_bw_differentiable) { - auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr; if (diff_view_meta && diff_view_meta->has_bw_view()) { const auto& base_bw_info = diff_view_meta->get_backward_view(); - new_bw_info = base_bw_info.chain( - base, tensor, std::move(bw_view_func), rev_view_func); + new_bw_info = base_bw_info.chain(base, tensor, view_func, rev_view_func); } else { - new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func); + new_bw_info = ViewInfo(base, view_func, rev_view_func); } } else { TORCH_CHECK( diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index f6de7599c6b..bea070a973a 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -213,7 +213,7 @@ void AutogradMeta::set_fw_grad( // - Copy the given new_grad into this view // - Use this view as the new new_grad if (this_view_meta->has_fw_view()) { - auto& view_info = this_view_meta->get_forward_view(); + auto view_info = this_view_meta->get_forward_view(); auto& base = view_info.base_; if (!base._fw_grad(level).defined()) { diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 3659a826b91..b334e6f097f 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -530,14 +530,18 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( (*stack)[stack->size() - num_returns + aliased_output_idx]; // See NOTE [ View + Inplace detection ] for more details about this logic - // We always need this view_func because otherwise if we do in-place - // on this view, we would implicitly use AsStridedBackward instead - // of the NotImplemented node. For the cross-dtype/non-strided - // cases, we would create something like this anyway - auto error_msg = - ("Mutating the view " + op_name + - "which does not have a derivative implemented is forbidden."); - auto erroring_view_func = std::make_unique(error_msg); + const auto erroring_view_func = [op_name = op_name](const at::Tensor&) { + // We always need this view_func because otherwise if we do in-place + // on this view, we would implicitly use AsStridedBackward instead + // of the NotImplemented node. For the cross-dtype/non-strided + // cases, we would create something like this anyway + TORCH_CHECK( + false, + "Mutating the view ", + op_name, + " which does not have a derivative implemented is forbidden."); + return at::Tensor(); + }; const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) { TORCH_CHECK( @@ -556,7 +560,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( /* tensor=*/sub_output, /* is_bw_differentiable=*/true, /* is_fw_differentiable=*/true, - /* view_func=*/std::move(erroring_view_func), + /* view_func=*/erroring_view_func, /* rev_view_func=*/erroring_rev_view_func, /* creation_meta=*/ InferenceMode::is_enabled() @@ -573,7 +577,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( /* tensor=*/std::move(aliased_output_iv).toTensor(), /* is_bw_differentiable=*/true, /* is_fw_differentiable=*/true, - /* view_func=*/std::move(erroring_view_func), + /* view_func=*/erroring_view_func, /* rev_view_func=*/erroring_rev_view_func, /* creation_meta=*/ InferenceMode::is_enabled() diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index a47ac8f77ca..2f5015f9a2c 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -55,7 +55,7 @@ variable_list CopyBackwards::apply_with_saved( CopySlices::CopySlices( const Variable& base_var, at::TensorGeometry view_, - std::unique_ptr view_fn_, + std::function view_fn_, std::shared_ptr fn_) : Node(), base(base_var), @@ -98,7 +98,7 @@ inline variable_list CopySlices::apply_impl( at::Tensor grad_slice; if (view_fn) { - grad_slice = (*view_fn)(result); + grad_slice = view_fn(result); } else { auto offset = view.sym_storage_offset() - base.sym_storage_offset(); grad_slice = diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index 6e99ed6ae2a..29f9259170f 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -79,7 +79,7 @@ struct TORCH_API CopyBackwards : public Node { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // We need to perform grad_view = fn(grad_view), but out-of-place. -// view_fn_ is an optional function saved in DifferentiableViewMeta +// view_fn_ is an optional lambda function saved in DifferentiableViewMeta // from forward pass, so that we can recover we when as_strided is not // supported. It preserves the invariants: // view = view_fn_(base) @@ -160,7 +160,7 @@ struct TORCH_API CopySlices : public Node { CopySlices( const Variable& base_var, at::TensorGeometry view_, - std::unique_ptr view_fn_, + std::function view_fn_, std::shared_ptr fn_); // common code between apply/apply_with_saved @@ -178,7 +178,7 @@ struct TORCH_API CopySlices : public Node { // view and view_fn are redundant and view_fn will be used if available. // See Note [View + Inplace update for base tensor] for details. at::TensorGeometry view; - std::unique_ptr view_fn; + std::function view_fn; std::shared_ptr fn; }; diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 27712b91fc5..4140af9284e 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -524,36 +524,16 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } -// Maps the given python callable over a vector of items, returning a vector -// of the same type of items. -template -static std::vector map_py_func( - const py::function& func, - const std::vector& items) { - std::vector new_items; - new_items.reserve(items.size()); - for (auto& item : items) { - new_items.push_back(py::cast(func(item))); - } - return new_items; -} - static PyObject* view_func_impl( - PyObject* _self, - PyObject* args, - PyObject* kwargs, + PyObject* self_, + PyObject* arg, bool check_has_same_meta) { HANDLE_TH_ERRORS - const auto& self = THPVariable_Unpack(_self); - - static PythonArgParser parser({ - "_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)", - }); - ParsedArgs<3> parsed_args{}; - auto r = parser.parse(_self, args, kwargs, parsed_args); - auto new_base = r.tensor(0); - PyObject* symint_visitor_fn = r.pyobject(1); - PyObject* tensor_visitor_fn = r.pyobject(2); + const auto& self = THPVariable_Unpack(self_); + TORCH_CHECK( + THPVariable_Check(arg), + "_view_func expect a single argument that is a Tensor"); + const auto& new_base = THPVariable_Unpack(arg); // Ensure that self is indeed a backward differentiable view // If not, we return an undefined Tensor (None) and let the user handle it. @@ -566,29 +546,7 @@ static PyObject* view_func_impl( torch::autograd::utils::has_same_meta(new_base, view_info.base_)) { // Do the actual view replay if (view_info.has_view_fn()) { - auto& view_func = view_info.view_fn(); - - // Determine new SymInt / tensor state as needed. - c10::optional> new_symints = c10::nullopt; - if (symint_visitor_fn != Py_None) { - new_symints = map_py_func( - py::cast(symint_visitor_fn), - view_func.get_symints()); - } - - c10::optional> new_tensors = c10::nullopt; - if (tensor_visitor_fn != Py_None) { - new_tensors = map_py_func( - py::cast(tensor_visitor_fn), - view_func.get_tensors()); - } - - // call view func - if (new_symints.has_value() || new_tensors.has_value()) { - out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base); - } else { - out = view_func(new_base); - } + out = view_info.view_fn()(new_base); } else { out = new_base.as_strided( self.sizes(), self.strides(), self.storage_offset()); @@ -599,18 +557,12 @@ static PyObject* view_func_impl( END_HANDLE_TH_ERRORS } -static PyObject* THPVariable_view_func( - PyObject* self_, - PyObject* args, - PyObject* kwargs) { - return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true); +static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { + return view_func_impl(self_, arg, /*check_has_same_meta=*/true); } -static PyObject* THPVariable_view_func_unsafe( - PyObject* self_, - PyObject* args, - PyObject* kwargs) { - return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false); +static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) { + return view_func_impl(self_, arg, /*check_has_same_meta=*/false); } static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) { @@ -1715,14 +1667,8 @@ static PyMethodDef extra_methods[] = { METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, - {"_view_func", - castPyCFunctionWithKeywords(THPVariable_view_func), - METH_VARARGS | METH_KEYWORDS, - nullptr}, - {"_view_func_unsafe", - castPyCFunctionWithKeywords(THPVariable_view_func_unsafe), - METH_VARARGS | METH_KEYWORDS, - nullptr}, + {"_view_func", THPVariable_view_func, METH_O, nullptr}, + {"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr}, {"_rev_view_func_unsafe", THPVariable_rev_view_func_unsafe, METH_O, diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 5eab310766c..821eea07c4b 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include @@ -26,19 +25,6 @@ namespace torch { namespace autograd { -// Returns a ViewFunc with a corresponding view that matches the shape, -// stride, and storage offset of the given tensor. -// NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc -// may not be available. -static std::unique_ptr create_view_func_matching(const Variable& t) { -#ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE - return std::make_unique( - t.sym_sizes(), t.sym_strides(), t.sym_storage_offset()); -#else - return std::make_unique("as_strided() not available"); -#endif -} - DifferentiableViewMeta::DifferentiableViewMeta( at::TensorImpl* self_impl, c10::optional backward_info, @@ -72,7 +58,7 @@ DifferentiableViewMeta::DifferentiableViewMeta( ViewInfo ViewInfo::chain( const Variable& base, const Variable& tensor, - std::unique_ptr view_func, + std::function view_func, std::function rev_view_func) const { // Set `view_func` using the root base as input. // `view_func` is used to recover views in backward when either as_strided is @@ -83,8 +69,12 @@ ViewInfo ViewInfo::chain( if (view_func) { // both current_view and it's parent have a view_func if (view_fn_) { - view_func = std::make_unique( - view_fn_->clone_and_set(), std::move(view_func)); + // Copy parent view function to gain ownership + auto prev_fn = view_fn_; + view_func = [=](const at::Tensor& root_base) { + auto temp = prev_fn(root_base); + return view_func(temp); + }; // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_fn = rev_view_fn_; @@ -95,9 +85,13 @@ ViewInfo ViewInfo::chain( } else { // current_view has a view_func and but it's parent doesn't have one if (base.unsafeGetTensorImpl()->support_as_strided()) { - auto match_base_view_func = create_view_func_matching(base); - view_func = std::make_unique( - std::move(match_base_view_func), std::move(view_func)); + auto size = base.sym_sizes().vec(); + auto stride = base.sym_strides().vec(); + auto storage_offset = base.sym_storage_offset(); + view_func = [=](const at::Tensor& root_base) { + auto temp = root_base.as_strided_symint(size, stride, storage_offset); + return view_func(temp); + }; // assume view_fn_ / rev_view_fn_ always exist together or neither are // set @@ -117,7 +111,12 @@ ViewInfo ViewInfo::chain( auto error_msg = ("Attempted to chain views when the parent view has no view_func() and " "does not support as_strided(). This is not supported."); - view_func = std::make_unique(error_msg); + + view_func = [=](const at::Tensor& root_base) { + TORCH_CHECK(false, error_msg); + return root_base; + }; + rev_view_func = [=](const at::Tensor& root_view) { TORCH_CHECK(false, error_msg); return root_view; @@ -126,9 +125,15 @@ ViewInfo ViewInfo::chain( } } else if (view_fn_) { // if current_view doesn't have a view_func but it's parent has one - auto match_tensor_view_func = create_view_func_matching(tensor); - view_func = std::make_unique( - view_fn_->clone_and_set(), std::move(match_tensor_view_func)); + // Copy parent view function to gain ownership + auto prev_view_fn = view_fn_; + auto size = tensor.sym_sizes().vec(); + auto stride = tensor.sym_strides().vec(); + auto storage_offset = tensor.sym_storage_offset(); + view_func = [=](const at::Tensor& root_base) { + auto temp = prev_view_fn(root_base); + return temp.as_strided_symint(size, stride, storage_offset); + }; // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_view_fn = rev_view_fn_; @@ -227,12 +232,12 @@ void rebase_history(const Variable& self, Edge gradient_edge) { TORCH_CHECK( gradient_edge.function->num_inputs() == 1, "Functions which modify views in-place must return a single Variable"); - const auto& view_info = diff_view_meta->get_backward_view(); + auto view_info = diff_view_meta->get_backward_view(); diff_view_meta->output_nr_ = gradient_edge.input_nr; auto copy_slices = std::make_shared( view_info.base_, at::TensorGeometry(self), - view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr, + view_info.view_fn_, std::move(gradient_edge.function)); if (self.requires_grad()) { // If self did not previously require grad, there are no hooks to move @@ -651,7 +656,7 @@ const std::shared_ptr& VariableHooks::grad_fn( if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] std::lock_guard lock(diff_view_meta->mutex_); - auto& view_info = diff_view_meta->get_backward_view(); + auto view_info = diff_view_meta->get_backward_view(); if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { return diff_view_meta->grad_fn_; } @@ -691,7 +696,7 @@ const std::shared_ptr& VariableHooks::grad_fn( // in VariableType_x.cpp // that would provide a way to recreate the grad_fn chain. if (view_info.has_view_fn()) { - auto& view_fn = view_info.view_fn(); + auto view_fn = view_info.view_fn(); Tensor diff_view; { // We can reach this path with grad_mode disabled, e.g. engine diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 1c98c3317af..2ff5b149aeb 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -7,7 +7,6 @@ #include #include #include -#include #include #include @@ -332,8 +331,7 @@ struct TORCH_API ViewInfo { /// By default we use as_strided to recover views which is more efficient. /// view_fn is only saved when as_strided is not supported. /// If view_fn has value, we use it to recover views in backward. - std::unique_ptr view_fn_; - + std::function view_fn_; /// Analogue of view_fn but in reverse: given a view -> produce the base by /// applying the inverse view. std::function rev_view_fn_; @@ -344,10 +342,10 @@ struct TORCH_API ViewInfo { return view_fn_ != nullptr; } - const ViewFunc& view_fn() const { + std::function view_fn() const { TORCH_CHECK( has_view_fn(), "Can only access the view function if it exists."); - return *view_fn_; + return view_fn_; } std::function rev_view_fn() const { @@ -368,12 +366,12 @@ struct TORCH_API ViewInfo { ViewInfo chain( const Variable& base, const Variable& tensor, - std::unique_ptr view_func = nullptr, + std::function view_func = nullptr, std::function rev_view_func = nullptr) const; ViewInfo( Variable base, - std::unique_ptr view_fn, + std::function view_fn, std::function rev_view_fn) : base_(std::move(base)), view_fn_(std::move(view_fn)),