Files
pytorch/aten/src/ATen/native/TriangularOps.cpp
zhudada 6658a04c7c [Code Clean] Better error handling in aten/src/ATen/native/* and aten/src/ATen/mkl/Exceptions.h (#165290)
Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK
Including:

- aten/src/ATen/native/*
- aten/src/ATen/mkl/Exceptions.h

fix partialy #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165290
Approved by: https://github.com/fffrog, https://github.com/albanD
2025-11-28 14:02:08 +00:00

200 lines
6.2 KiB
C++

#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorMeta.h>
#include <ATen/native/TriangularOpsUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <c10/util/irange.h>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/trace_backward_native.h>
#include <ATen/ops/tril_native.h>
#include <ATen/ops/triu_native.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::meta {
TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) {
TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions")
set_output_raw_strided(0, self.sizes(), {}, self.options());
}
TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) {
TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions")
set_output_raw_strided(0, self.sizes(), {}, self.options());
}
} // namespace at::meta
namespace at::native {
namespace {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
void apply_triu_tril_single(
scalar_t* result,
const scalar_t* self,
bool inplace,
int64_t k,
int64_t n,
int64_t m,
int64_t res_row_stride,
int64_t res_col_stride,
int64_t self_row_stride,
int64_t self_col_stride,
bool upper) {
constexpr int64_t zero = 0;
k = std::clamp(k, -n, m); // Clamp k to [-n, m] to prevent i + k arithmetic overflow, especially if k approaches INT64_MAX/INT64_MIN.
if (upper) {
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = 0; j < std::min(m, i + k); j++) {
result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = std::max(zero, i + k); j < m; j++) {
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
}
}
}
});
} else {
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
}
}
}
});
}
}
template <typename scalar_t>
void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
auto n = self.size(-2);
auto m = self.size(-1);
auto self_data = self.const_data_ptr<scalar_t>();
auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
auto batchsize = batchCountTrilTriu(result);
auto self_row_stride = self.stride(-2);
auto self_col_stride = self.stride(-1);
auto result_data = result.data_ptr<scalar_t>();
int64_t result_stride = 0, result_row_stride = 0, result_col_stride = 0;
if (result_data != self_data) {
result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
result_row_stride = result.stride(-2);
result_col_stride = result.stride(-1);
} else {
result_stride = self_stride;
result_row_stride = self_row_stride;
result_col_stride = self_col_stride;
}
parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
const scalar_t* self_batch = &self_data[b * self_stride];
scalar_t* result_batch = &result_data[b * result_stride];
apply_triu_tril_single<scalar_t>(
result_batch,
self_batch,
inplace,
k,
n,
m,
result_row_stride,
result_col_stride,
self_row_stride,
self_col_stride,
upper);
}
});
}
struct UpperTriangle {
static constexpr const char* op_name = "triu";
static constexpr bool upper = true;
};
struct LowerTriangle {
static constexpr const char *op_name = "tril";
static constexpr bool upper = false;
};
template <typename Triangle>
void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
if (self.numel() == 0) {
return;
}
bool inplace_op = self.is_same(result);
bool inplace_update = false;
Tensor self_c;
std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op);
Tensor result_c;
if (inplace_op && !inplace_update) {
result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
result_c = result;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
ScalarType::ComplexHalf,
ScalarType::BFloat16,
ScalarType::Half,
ScalarType::Bool,
self.scalar_type(),
Triangle::op_name,
[&]{
apply_triu_tril<scalar_t>(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper);
});
if (inplace_op && !inplace_update) {
result.copy_(result_c);
}
}
} // namespace
TORCH_IMPL_FUNC(tril_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
compute_triu_tril<LowerTriangle>(self, k, result);
}
TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
compute_triu_tril<UpperTriangle>(self, k, result);
}
Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
TORCH_CHECK(sizes.size() == 2, "expected matrix input");
auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options());
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
// for composite compliance, use out-of-place variant of
// `index_fill` if grad tensor is a Tensor Subclass.
if (isTensorSubclassLike(grad)) {
grad_input = grad_input.index_fill(0, indices, grad);
} else {
grad_input.index_fill_(0, indices, grad);
}
return grad_input.view_symint(sizes);
}
} // namespace at::native