diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index 893e936f8da..b78ca94dc51 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -53,12 +53,11 @@ bool SymInt::has_hint() const { #define DEFINE_BINARY(API, OP, METHOD, RET) \ RET SymInt::API(const SymInt& sci) const { \ if (auto ma = maybe_as_int()) { \ - if (auto mb = sci.maybe_as_int()) { \ - return RET(OP(*ma, *mb)); \ - } else { \ - auto b = sci.toSymNode(); \ - return RET(b->wrap_int(*ma)->METHOD(b)); \ - } \ + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( \ + !sci.maybe_as_int(), \ + "should have hit fast path in the header in this case."); \ + auto b = sci.toSymNode(); \ + return RET(b->wrap_int(*ma)->METHOD(b)); \ } else { \ if (auto mb = sci.maybe_as_int()) { \ auto a = toSymNodeImplUnowned(); \ @@ -69,19 +68,19 @@ bool SymInt::has_hint() const { } \ } -DEFINE_BINARY(operator+, std::plus<>(), add, SymInt) -DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt) -DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt) -DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt) -DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt) -DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool) -DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool) -DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool) -DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool) -DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool) -DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool) -DEFINE_BINARY(min, std::min, sym_min, SymInt) -DEFINE_BINARY(max, std::max, sym_max, SymInt) +DEFINE_BINARY(operator_add_slow_path, std::plus<>(), add, SymInt) +DEFINE_BINARY(operator_sub_slow_path, std::minus<>(), sub, SymInt) +DEFINE_BINARY(operator_mul_slow_path, std::multiplies<>(), mul, SymInt) +DEFINE_BINARY(operator_div_slow_path, std::divides<>(), floordiv, SymInt) +DEFINE_BINARY(operator_mod_slow_path, std::modulus<>(), mod, SymInt) +DEFINE_BINARY(sym_eq_slow_path, std::equal_to<>(), eq, SymBool) +DEFINE_BINARY(sym_ne_slow_path, std::not_equal_to<>(), ne, SymBool) +DEFINE_BINARY(sym_lt_slow_path, std::less<>(), lt, SymBool) +DEFINE_BINARY(sym_le_slow_path, std::less_equal<>(), le, SymBool) +DEFINE_BINARY(sym_gt_slow_path, std::greater<>(), gt, SymBool) +DEFINE_BINARY(sym_ge_slow_path, std::greater_equal<>(), ge, SymBool) +DEFINE_BINARY(min_slow_path, std::min, sym_min, SymInt) +DEFINE_BINARY(max_slow_path, std::max, sym_max, SymInt) SymInt::operator SymFloat() const { if (auto ma = maybe_as_int()) { @@ -161,15 +160,15 @@ SymInt operator-(const SymInt& s) { } } -void SymInt::operator*=(const SymInt& sci) { +void SymInt::operator_imul_slow_path(const SymInt& sci) { *this = *this * sci; } -void SymInt::operator/=(const SymInt& sci) { +void SymInt::operator_idiv_slow_path(const SymInt& sci) { *this = *this / sci; } -void SymInt::operator+=(const SymInt& sci) { +void SymInt::operator_iadd_slow_path(const SymInt& sci) { *this = *this + sci; } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index d28bbe7a9b2..9b1c776cbe2 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -177,23 +178,136 @@ class C10_API SymInt { #endif } - SymInt operator+(const SymInt& sci) const; - SymInt operator-(const SymInt& sci) const; - SymInt operator*(const SymInt& sci) const; - SymInt operator/(const SymInt& sci) const; - SymInt operator%(const SymInt& sci) const; - void operator*=(const SymInt& sci); - void operator+=(const SymInt& sci); - void operator/=(const SymInt& sci); + SymInt operator+(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma + *mb); + } + } + return operator_add_slow_path(sci); + } + + SymInt operator-(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma - *mb); + } + } + return operator_sub_slow_path(sci); + } + + SymInt operator*(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma * *mb); + } + } + return operator_mul_slow_path(sci); + } + + SymInt operator/(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma / *mb); + } + } + return operator_div_slow_path(sci); + } + + SymInt operator%(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma % *mb); + } + } + return operator_mod_slow_path(sci); + } + + void operator*=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma * *mb); + return; + } + } + operator_imul_slow_path(sci); + } + + void operator+=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma + *mb); + return; + } + } + operator_iadd_slow_path(sci); + } + + void operator/=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma / *mb); + return; + } + } + operator_idiv_slow_path(sci); + } SymInt clone() const; - SymBool sym_eq(const SymInt&) const; - SymBool sym_ne(const SymInt&) const; - SymBool sym_lt(const SymInt&) const; - SymBool sym_le(const SymInt&) const; - SymBool sym_gt(const SymInt&) const; - SymBool sym_ge(const SymInt&) const; + SymBool sym_eq(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma == *mb); + } + } + return sym_eq_slow_path(sci); + } + + SymBool sym_ne(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma != *mb); + } + } + return sym_ne_slow_path(sci); + } + + SymBool sym_lt(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma < *mb); + } + } + return sym_lt_slow_path(sci); + } + + SymBool sym_le(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma <= *mb); + } + } + return sym_le_slow_path(sci); + } + + SymBool sym_gt(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma > *mb); + } + } + return sym_gt_slow_path(sci); + } + + SymBool sym_ge(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma >= *mb); + } + } + return sym_ge_slow_path(sci); + } bool operator==(const SymInt& o) const { return sym_eq(o).guard_bool(__FILE__, __LINE__); @@ -214,8 +328,23 @@ class C10_API SymInt { return sym_ge(o).guard_bool(__FILE__, __LINE__); } - SymInt min(const SymInt& sci) const; - SymInt max(const SymInt& sci) const; + SymInt min(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(std::min(*ma, *mb)); + } + } + return min_slow_path(sci); + } + + SymInt max(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(std::max(*ma, *mb)); + } + } + return max_slow_path(sci); + } // If both are symbolic, this checks if // they share the same node. @@ -260,6 +389,23 @@ class C10_API SymInt { private: void promote_to_negative(); + SymInt operator_add_slow_path(const SymInt& sci) const; + SymInt operator_sub_slow_path(const SymInt& sci) const; + SymInt operator_mul_slow_path(const SymInt& sci) const; + SymInt operator_div_slow_path(const SymInt& sci) const; + SymInt operator_mod_slow_path(const SymInt& sci) const; + void operator_imul_slow_path(const SymInt& sci); + void operator_iadd_slow_path(const SymInt& sci); + void operator_idiv_slow_path(const SymInt& sci); + SymBool sym_eq_slow_path(const SymInt& sci) const; + SymBool sym_ne_slow_path(const SymInt& sci) const; + SymBool sym_lt_slow_path(const SymInt& sci) const; + SymBool sym_le_slow_path(const SymInt& sci) const; + SymBool sym_gt_slow_path(const SymInt& sci) const; + SymBool sym_ge_slow_path(const SymInt& sci) const; + + SymInt min_slow_path(const SymInt& sci) const; + SymInt max_slow_path(const SymInt& sci) const; std::optional maybe_as_int_slow_path() const; diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index 7cefa1e4a77..e408543f536 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -35,4 +36,169 @@ TEST(SymIntTest, Overflows) { } #endif +namespace { + +// We need a SymNodeImpl that 1) has working arithmetic with +// predictable results and 2) causes SymInt::maybe_as_int to return +// nullopt so that we can hit all 4 cases (zero/one/both arguments +// have null maybe_as_int) in the operator implementations. +class ConstantIntPretendingToBeSymbolicSymNodeImpl + : public ConstantSymNodeImpl { + public: + using ConstantSymNodeImpl::ConstantSymNodeImpl; + std::optional constant_int() override { + return std::nullopt; + } + std::optional maybe_as_int() override { + return std::nullopt; + } + // Needs to be implemented for arithmetic to actually + // work. NestedIntSymNodeImpl does this, for example. + c10::SymNode wrap_int(int64_t num) override { + return SymNode( + c10::make_intrusive(num)); + } + + c10::SymNode wrap_bool(bool b) override { + return SymNode(c10::make_intrusive>(b)); + } + + SymNode add(const SymNode& other) override { + return wrap_int(int_() + other->int_()); + } + + SymNode sub(const SymNode& other) override { + return wrap_int(int_() - other->int_()); + } + + SymNode mul(const SymNode& other) override { + return wrap_int(int_() * other->int_()); + } + + SymNode floordiv(const SymNode& other) override { + return wrap_int(int_() / other->int_()); + } + + SymNode sym_min(const SymNode& other) override { + return wrap_int(std::min(int_(), other->int_())); + } + + SymNode sym_max(const SymNode& other) override { + return wrap_int(std::max(int_(), other->int_())); + } + + SymNode mod(const SymNode& other) override { + return wrap_int(int_() % other->int_()); + } + + SymNode eq(const SymNode& other) override { + return wrap_bool(int_() == other->int_()); + } + + SymNode ne(const SymNode& other) override { + return wrap_bool(int_() != other->int_()); + } + + SymNode lt(const SymNode& other) override { + return wrap_bool(int_() < other->int_()); + } + + SymNode le(const SymNode& other) override { + return wrap_bool(int_() <= other->int_()); + } + + SymNode gt(const SymNode& other) override { + return wrap_bool(int_() > other->int_()); + } + + SymNode ge(const SymNode& other) override { + return wrap_bool(int_() >= other->int_()); + } +}; + +SymInt create_symbolic_symint(int64_t value) { + return SymInt( + SymNode(c10::make_intrusive( + value))); +} + +auto unwrap(const SymInt& x) { + return x.guard_int(__FILE__, __LINE__); +} + +auto unwrap(bool b) { + return b; +} + +template