mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[vulkan] Re-route arithmetic ops to scalar versions when second arg is zero-dim (#73108)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73108 When arithmetic ops are invoked from torchscript the scalar argument will sometimes be wrapped in a zero-dimensional tensor, which will cause the Vulkan implementation to complain as all input tensors are expected to have the same number of channels. The solution is to have the Tensor implementations of the op check if the second argument is zero-dimensional and re-route it to the Scalar implementation if that's the case. Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34354840 Pulled By: SS-JIA fbshipit-source-id: b24799bb3dd4336791a39bea9382c14243ad58e4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
423bcbff64
commit
c6dd8eb13b
@@ -322,6 +322,13 @@ Tensor add_tensor(
|
||||
const Tensor& self_arg,
|
||||
const Tensor& other_arg,
|
||||
const Scalar& alpha) {
|
||||
if (other_arg.sizes().size() == 0) {
|
||||
return arithmetic_scalar(
|
||||
self_arg,
|
||||
other_arg.item<float>(),
|
||||
c10::optional<Scalar>(alpha.to<float>()),
|
||||
VK_KERNEL(add_scalar));
|
||||
}
|
||||
return arithmetic_tensor(
|
||||
self_arg, other_arg, c10::optional<Scalar>(alpha), VK_KERNEL(add));
|
||||
}
|
||||
@@ -354,6 +361,13 @@ Tensor sub_tensor(
|
||||
const Tensor& self_arg,
|
||||
const Tensor& other_arg,
|
||||
const Scalar& alpha) {
|
||||
if (other_arg.sizes().size() == 0) {
|
||||
return arithmetic_scalar(
|
||||
self_arg,
|
||||
other_arg.item<float>(),
|
||||
c10::optional<Scalar>(-1 * alpha.to<float>()),
|
||||
VK_KERNEL(add_scalar));
|
||||
}
|
||||
return arithmetic_tensor(
|
||||
self_arg, other_arg, c10::optional<Scalar>(alpha), VK_KERNEL(sub));
|
||||
}
|
||||
@@ -374,6 +388,13 @@ Tensor& mul_scalar_(Tensor& self, const Scalar& other) {
|
||||
}
|
||||
|
||||
Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) {
|
||||
if (other_arg.sizes().size() == 0) {
|
||||
return arithmetic_scalar(
|
||||
self_arg,
|
||||
other_arg.item<float>(),
|
||||
c10::optional<Scalar>(),
|
||||
VK_KERNEL(mul_scalar));
|
||||
}
|
||||
return arithmetic_tensor(
|
||||
self_arg, other_arg, c10::optional<Scalar>(), VK_KERNEL(mul));
|
||||
}
|
||||
@@ -400,6 +421,13 @@ Tensor& div_scalar_(Tensor& self, const Scalar& other) {
|
||||
}
|
||||
|
||||
Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) {
|
||||
if (other_arg.sizes().size() == 0) {
|
||||
return arithmetic_scalar(
|
||||
self_arg,
|
||||
1.0 / other_arg.item<float>(),
|
||||
c10::optional<Scalar>(),
|
||||
VK_KERNEL(mul_scalar));
|
||||
}
|
||||
return arithmetic_tensor(
|
||||
self_arg, other_arg, c10::optional<Scalar>(), VK_KERNEL(div));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user