2020-02-27 12:18:24 -08:00
|
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
2019-04-27 16:10:04 -07:00
|
|
|
#include <ATen/core/qualified_name.h>
|
2019-01-24 11:05:07 -08:00
|
|
|
#include <c10/util/Exception.h>
|
2019-10-02 11:27:58 -07:00
|
|
|
#include <c10/util/StringUtil.h>
|
2020-03-26 11:15:49 -07:00
|
|
|
#include <torch/csrc/jit/api/module.h>
|
|
|
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
2020-06-24 12:39:42 -07:00
|
|
|
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
2020-02-27 12:18:24 -08:00
|
|
|
#include <torch/csrc/jit/ir/attributes.h>
|
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
|
#include <torch/csrc/jit/ir/ir_views.h>
|
Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-08 19:32:01 -08:00
|
|
|
#include <torch/csrc/jit/resource_guard.h>
|
2018-11-27 11:46:17 -08:00
|
|
|
|
2020-06-24 12:39:42 -07:00
|
|
|
#include <algorithm>
|
|
|
|
|
|
2019-04-27 16:10:04 -07:00
|
|
|
using c10::QualifiedName;
|
|
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
2018-11-27 17:08:09 -08:00
|
|
|
static bool isValidIdentifierChar(char c, size_t pos) {
|
2018-12-26 06:52:25 -08:00
|
|
|
return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
|
2018-11-27 17:08:09 -08:00
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
static bool isValidIdentifier(const std::string& name) {
|
2018-11-27 17:08:09 -08:00
|
|
|
if (name.size() == 0)
|
|
|
|
|
return false;
|
|
|
|
|
for (size_t i = 0; i < name.size(); ++i) {
|
|
|
|
|
if (!isValidIdentifierChar(name[i], i))
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
// some names are valid identifiers but off limits because
|
|
|
|
|
// they are keywords or namespaces used in the output
|
|
|
|
|
const static std::unordered_set<std::string> reserved_names = {
|
2018-11-15 15:28:56 -08:00
|
|
|
// identifiers in the environment while parsing
|
2018-12-21 15:59:29 -08:00
|
|
|
"_", // avoid the confusing unnamed _
|
2020-06-18 16:55:42 -07:00
|
|
|
"as",
|
2018-11-15 15:28:56 -08:00
|
|
|
"aten",
|
2018-12-21 15:59:29 -08:00
|
|
|
"attribute",
|
2018-11-15 15:28:56 -08:00
|
|
|
"CONSTANTS",
|
|
|
|
|
"fork",
|
2018-11-27 17:08:09 -08:00
|
|
|
"getattr",
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
"inf",
|
|
|
|
|
"nan",
|
2018-12-21 15:59:29 -08:00
|
|
|
"ops",
|
2019-04-26 19:14:10 -07:00
|
|
|
"__torch__",
|
2018-11-15 15:28:56 -08:00
|
|
|
// the python keywords
|
|
|
|
|
"and",
|
|
|
|
|
"as",
|
|
|
|
|
"assert",
|
2018-12-21 15:59:29 -08:00
|
|
|
"async",
|
|
|
|
|
"await",
|
2018-11-15 15:28:56 -08:00
|
|
|
"break",
|
|
|
|
|
"class",
|
|
|
|
|
"continue",
|
|
|
|
|
"def",
|
|
|
|
|
"del",
|
|
|
|
|
"elif",
|
|
|
|
|
"else",
|
|
|
|
|
"except",
|
2018-12-21 15:59:29 -08:00
|
|
|
"False",
|
2018-11-15 15:28:56 -08:00
|
|
|
"finally",
|
|
|
|
|
"for",
|
|
|
|
|
"from",
|
|
|
|
|
"global",
|
|
|
|
|
"if",
|
|
|
|
|
"import",
|
|
|
|
|
"in",
|
|
|
|
|
"is",
|
|
|
|
|
"lambda",
|
2018-12-21 15:59:29 -08:00
|
|
|
"None",
|
2018-11-15 15:28:56 -08:00
|
|
|
"nonlocal",
|
|
|
|
|
"not",
|
|
|
|
|
"or",
|
|
|
|
|
"pass",
|
|
|
|
|
"raise",
|
|
|
|
|
"return",
|
2018-12-21 15:59:29 -08:00
|
|
|
"True",
|
2018-11-15 15:28:56 -08:00
|
|
|
"try",
|
2020-06-18 16:55:42 -07:00
|
|
|
"with",
|
2018-11-15 15:28:56 -08:00
|
|
|
"while",
|
|
|
|
|
"with",
|
|
|
|
|
"yield",
|
2019-10-15 15:58:05 -07:00
|
|
|
"uninitialized",
|
|
|
|
|
"unchecked_cast",
|
2018-12-26 06:52:25 -08:00
|
|
|
};
|
2018-11-15 15:28:56 -08:00
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
struct PythonPrintImpl {
|
2019-07-01 21:11:12 -07:00
|
|
|
using SourceRangeStack = std::vector<SourceRange>;
|
2019-07-26 17:43:55 -07:00
|
|
|
SourceRangeStack source_range_stack_ = {SourceRange()};
|
2019-07-01 21:11:12 -07:00
|
|
|
|
|
|
|
|
struct WithSourceRange {
|
2019-07-01 21:11:12 -07:00
|
|
|
explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
|
2019-07-01 21:11:12 -07:00
|
|
|
TORCH_INTERNAL_ASSERT(stack);
|
2019-07-01 21:11:12 -07:00
|
|
|
if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
|
|
|
|
|
stack->push_back(std::move(gen_source.value()));
|
|
|
|
|
} else {
|
2019-07-29 22:01:36 -07:00
|
|
|
stack->push_back(n->sourceRange());
|
2019-07-01 21:11:12 -07:00
|
|
|
}
|
2019-07-01 21:11:12 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~WithSourceRange() {
|
|
|
|
|
stack->pop_back();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SourceRangeStack* stack;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class TaggedStringStream {
|
|
|
|
|
public:
|
|
|
|
|
TaggedStringStream(const SourceRangeStack* srs) : srs_(srs) {}
|
|
|
|
|
|
|
|
|
|
TaggedStringStream& operator<<(const std::string& s) {
|
2019-07-01 21:11:12 -07:00
|
|
|
// This prevents having redundant entries at the same offset,
|
|
|
|
|
// which can happen for example in printValueList when begin
|
|
|
|
|
// and end are the empty string.
|
|
|
|
|
if (s.size() == 0) {
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-01 21:11:12 -07:00
|
|
|
if (!ranges_.size() || ranges_.back().range != srs_->back()) {
|
|
|
|
|
ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
|
|
|
|
|
}
|
|
|
|
|
oss_ << s;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TaggedStringStream& operator<<(const TaggedStringStream& rhs) {
|
|
|
|
|
for (const auto& range : rhs.ranges_) {
|
|
|
|
|
if (!ranges_.size() || ranges_.back().range != range.range) {
|
|
|
|
|
ranges_.emplace_back((size_t)oss_.tellp() + range.bytes, range.range);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
oss_ << rhs.oss_.str();
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This overload is here to prevent people from shooting themselves in the
|
|
|
|
|
// foot. I would be highly surprised if someone actually wanted to write out
|
|
|
|
|
// the address of a TaggedStringStream in the pretty print.
|
|
|
|
|
TaggedStringStream& operator<<(
|
|
|
|
|
const std::shared_ptr<TaggedStringStream>& rhs) {
|
|
|
|
|
(*this) << *rhs;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
TaggedStringStream& operator<<(const T& t) {
|
|
|
|
|
if (!ranges_.size() || ranges_.back().range != srs_->back()) {
|
|
|
|
|
ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
|
|
|
|
|
}
|
|
|
|
|
oss_ << t;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string str() const {
|
|
|
|
|
return oss_.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<TaggedRange>& ranges() const {
|
|
|
|
|
return ranges_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::ostringstream oss_;
|
|
|
|
|
std::vector<TaggedRange> ranges_;
|
|
|
|
|
const SourceRangeStack* srs_;
|
|
|
|
|
};
|
|
|
|
|
|
2019-03-15 12:00:50 -07:00
|
|
|
// Helper to avoid duplicating class types
|
2019-08-14 11:21:42 -07:00
|
|
|
void registerDependency(const c10::NamedTypePtr& type) {
|
2019-08-19 18:41:08 -07:00
|
|
|
// Need to do actual equality comparison, not a pointer equality. This is
|
|
|
|
|
// because for some types (e.g. FunctionType), we may have multiple
|
|
|
|
|
// TypePtr's that represent the same underlying thing.
|
|
|
|
|
auto it = std::find_if(
|
|
|
|
|
deps_table_.cbegin(),
|
|
|
|
|
deps_table_.cend(),
|
|
|
|
|
[&](const c10::NamedTypePtr& dep) { return *dep == *type; });
|
|
|
|
|
|
|
|
|
|
if (it == deps_table_.cend()) {
|
2019-08-14 11:21:42 -07:00
|
|
|
deps_table_.push_back(type);
|
2019-03-15 12:00:50 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
// scanValue, scanNode, scanBlock:
|
|
|
|
|
// decide if it is safe to omit the output of a temporary variable,
|
|
|
|
|
// and inline the expression into its use
|
|
|
|
|
// we only do this if
|
|
|
|
|
// (1) it is a constant, or
|
|
|
|
|
// (2) the temporary is unnamed, is single output, is used once,
|
2018-12-26 06:52:25 -08:00
|
|
|
// and would appear in the same order when the expression tree is
|
|
|
|
|
// reparsed.
|
2018-11-12 10:15:44 -08:00
|
|
|
// The last case can be checked
|
2020-01-28 04:44:18 -08:00
|
|
|
// because when we emit a expresion tree in the parser,
|
2018-12-26 06:52:25 -08:00
|
|
|
// we do a left-to-right postorder traversal of the expression tree (emit
|
|
|
|
|
// children, then emit op). The reverse of this is a right-to-left preorder
|
|
|
|
|
// traversal of the tree. By doing a right-to-left preorder traversal of the
|
|
|
|
|
// inputs of a node, while also scanning the list of emitted nodes backward,
|
|
|
|
|
// we can see if they line up with what would happen when parsed the node as
|
|
|
|
|
// an expression. While they line up we collapse them into an inline
|
|
|
|
|
// expression.
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
// The inductive step is that the right-most input should be produced by the
|
|
|
|
|
// node immediatly before the current node if it is in tree order.
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
bool canInline(Value* v) {
|
|
|
|
|
Node* n = v->node();
|
2018-12-26 06:52:25 -08:00
|
|
|
// there must be only 1 values, otherwise we need an assignment to handle
|
|
|
|
|
// the multiple outout values
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
if (n->outputs().size() != 1)
|
|
|
|
|
return false;
|
|
|
|
|
// if it is used more than once, then we need a variable
|
|
|
|
|
if (v->uses().size() != 1)
|
|
|
|
|
return false;
|
|
|
|
|
auto use = v->uses().at(0);
|
|
|
|
|
// if it has a name set, then it was written as a variable so preserve that
|
|
|
|
|
// unless it is being fed directly to the end of the block.
|
|
|
|
|
// in which case it is not as useful to give it a name just to return it
|
2019-06-21 20:51:17 -07:00
|
|
|
if (v->hasDebugName() && use.user->kind() != prim::Return)
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
return false;
|
|
|
|
|
// don't try to inline control blocks
|
|
|
|
|
if (n->blocks().size() != 0)
|
|
|
|
|
return false;
|
|
|
|
|
// if it is a loop-carried input, we need a variable
|
2018-12-26 06:52:25 -08:00
|
|
|
// otherwise the condition or trip count may be emitted in the wrong order
|
|
|
|
|
// w.r.t. to it
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
if (use.user->kind() == prim::Loop && use.offset >= 2)
|
|
|
|
|
return false;
|
2019-04-25 15:43:53 -07:00
|
|
|
|
|
|
|
|
// subgraph may use this more than once, so disable inlining
|
[JIT] Register rpc.rpc_async(..) as a JIT operator (#33329)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33329
# Use case
```
torch.jit.script
def send_rpc_async(dst_worker_name, user_callable_qual_name, tensor):
# type: (str, str, Tensor) -> None
rpc._rpc_async_torchscript(
dst_worker_name, user_callable_qual_name, args=(tensor,)
)
```
# Problem
```
torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
File "/data/users/shihaoxu/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/rpc/rpc_spawn#binary,link-tree/torch/distributed/rpc/api.py", line 722
args = args if args else ()
kwargs = kwargs if kwargs else {}
fut = _invoke_rpc_torchscript(to, qualified_name, *args, **kwargs)
~~~~~~ <--- HERE
return fut
```
# Solution
Register `rpc.rpc_async(..)` as a JIT operator to handle variable-length argument list.
# Plan
This PR is the required changes to make `rpc.rpc_async(..)` a JIT prim operator, which can dynamically handle different number of arguments.
- Register "prim::rpc_async" as a `Symbol` in "interned_string.h"
- Add a if branch in "python_sugared_value.cpp" `toSugarValue(py::object, ..)` entry utility function to set up how JIT frontend convert `torch.distributed.rpc.rpc_async(..)` Python function (Python object) into a `SpecialFormValue` (IR SugaredValue).
- Add a switch case for "prim::rpc_aynsc" Symbol in "ir_emitter.cpp" and `emitApplySpecialForm(..)` to set up how JIT compiler provides inputs to the "prim::rpc_aynsc" Operator.
- Register "prim::rpc_async" as a `jit::Operator` and provide implementation in "register_distributed_ops.cpp".
Notice, since the distributed module is an optional part when building PyTorch. The code to be added in this PR should be wrapped within preprocessing maco.
```
#ifdef USE_DISTRIBUTED
new code here
#endif
```
Test Plan:
Items that need to be confirmed in the test cases
https://fb.quip.com/DCvdA9ZLjeO0
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
\
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_call_python_function_remotely_from_script_not_supported
```
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_spawn
```
```
buck test mode/dev-nosan //caffe2/caffe2/python/operator_test:layer_norm_op_test-2.7 -- test_layer_norm_op_jit
```
Differential Revision: D5738300
fbshipit-source-id: a4604fe762e00be062dc8232ca9790df31fb2074
2020-03-03 19:52:22 -08:00
|
|
|
if (use.user->kind() == prim::fork || use.user->kind() == prim::rpc_async)
|
2019-04-25 15:43:53 -07:00
|
|
|
return false;
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
// isinstance appearing in an if expression
|
|
|
|
|
// causes type refinement to occur, but we have
|
|
|
|
|
// already handled the refinement and inserted cast
|
|
|
|
|
// expressions. By not inlining it into the if condition,
|
|
|
|
|
// we prevent it from happening again.
|
|
|
|
|
if (v->node()->kind() == prim::isinstance) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
// block_point is the current node in the reverse linear scan of the emitted
|
|
|
|
|
// nodes v is the current value in the tree traversal that may match with
|
|
|
|
|
// block_point's output.
|
2018-11-27 11:46:17 -08:00
|
|
|
Node* scanValue(Node* block_point, Value* v) {
|
|
|
|
|
Node* n = v->node();
|
2019-03-01 15:00:01 -08:00
|
|
|
AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
if (n == block_point &&
|
|
|
|
|
canInline(v)) { // the node must be at the expected point of the typical
|
|
|
|
|
// tree traversal
|
2018-11-12 10:15:44 -08:00
|
|
|
// recursively see if we can inline the inputs to this input
|
|
|
|
|
block_point = scanNode(block_point);
|
|
|
|
|
output_inline_.insert(n);
|
2019-03-01 15:00:01 -08:00
|
|
|
} else if (n->kind() == prim::Constant) {
|
2018-11-12 10:15:44 -08:00
|
|
|
// constant nodes can always be inlined, we will de-dup them on parsing
|
|
|
|
|
// and put them at the top of the function regardless
|
|
|
|
|
output_inline_.insert(n);
|
|
|
|
|
}
|
|
|
|
|
return block_point;
|
|
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
Node* previousNonConstant(Node* n) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
do {
|
|
|
|
|
n = n->prev();
|
2019-03-01 15:00:01 -08:00
|
|
|
} while (n->kind() == prim::Constant);
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
return n;
|
|
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
Node* scanNode(Node* n) {
|
2018-11-12 10:15:44 -08:00
|
|
|
// don't bother to scan nodes we have already determined to be inline
|
2018-12-26 06:52:25 -08:00
|
|
|
if (output_inline_.count(n)) {
|
2018-11-12 10:15:44 -08:00
|
|
|
return n;
|
|
|
|
|
}
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto b : n->blocks()) {
|
2018-11-12 10:15:44 -08:00
|
|
|
scanBlock(b);
|
|
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
Node* block_point = previousNonConstant(n);
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
|
|
|
|
|
++it) {
|
2018-11-12 10:15:44 -08:00
|
|
|
block_point = scanValue(block_point, *it);
|
|
|
|
|
}
|
|
|
|
|
return block_point;
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
void scanBlock(Block* b) {
|
2018-11-12 10:15:44 -08:00
|
|
|
scanNode(b->return_node());
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto node : b->nodes().reverse()) {
|
2018-11-12 10:15:44 -08:00
|
|
|
scanNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
size_t getOrAddTensorConstant(at::Tensor t) {
|
|
|
|
|
// XXX - N^2 warning. This code does the exact same thing as
|
|
|
|
|
// ConstantPool, which is also N^2 in the size of the constants,
|
|
|
|
|
// because it doesn't hash any information about the tensors.
|
|
|
|
|
// We will probably need to optimize this at some point using hashing.
|
2018-12-26 06:52:25 -08:00
|
|
|
for (size_t i = 0; i < tensor_table_.size(); ++i) {
|
2020-03-26 11:15:49 -07:00
|
|
|
if (t.options().type_equal(tensor_table_[i].options()) &&
|
|
|
|
|
t.equal(tensor_table_[i])) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
return i;
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-11-29 17:51:45 -08:00
|
|
|
tensor_table_.emplace_back(std::move(t));
|
|
|
|
|
return tensor_table_.size() - 1;
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
}
|
|
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
std::unordered_set<Node*> seen_constants;
|
|
|
|
|
void buildConstantList(Node* n, std::vector<Node*>& constants) {
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto input : n->inputs()) {
|
2019-03-01 15:00:01 -08:00
|
|
|
if (input->node()->kind() == prim::Constant &&
|
2018-12-26 06:52:25 -08:00
|
|
|
seen_constants.count(input->node()) == 0) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
constants.push_back(input->node());
|
|
|
|
|
seen_constants.insert(input->node());
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto b : n->blocks()) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
buildConstantList(b, constants);
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
void buildConstantList(Block* b, std::vector<Node*>& constants) {
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto n : b->nodes())
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
buildConstantList(n, constants);
|
|
|
|
|
buildConstantList(b->return_node(), constants);
|
|
|
|
|
}
|
2019-02-05 12:16:56 -08:00
|
|
|
|
2019-06-21 20:51:17 -07:00
|
|
|
// get a new name unique across calls to debugName() and
|
2018-11-12 10:15:44 -08:00
|
|
|
// anything we have used.
|
2019-02-05 12:16:56 -08:00
|
|
|
std::unordered_map<std::string, size_t> next_id;
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
std::string genNameImpl(
|
|
|
|
|
const std::string& candidate,
|
|
|
|
|
std::unordered_set<std::string>& used) {
|
2018-11-12 10:15:44 -08:00
|
|
|
std::string name = candidate;
|
2018-12-26 06:52:25 -08:00
|
|
|
while (used.count(name) || reserved_names.count(name)) {
|
2019-10-24 15:49:58 -07:00
|
|
|
name = candidate + c10::to_string(next_id[name]++);
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2018-11-15 15:28:56 -08:00
|
|
|
used.insert(name);
|
2018-11-12 10:15:44 -08:00
|
|
|
return name;
|
|
|
|
|
}
|
2018-11-15 15:28:56 -08:00
|
|
|
std::string genName(const std::string& candidate) {
|
|
|
|
|
return genNameImpl(candidate, used_names_);
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
// unique names might not be valid identifiers,
|
|
|
|
|
// force them to be by rewriting them
|
|
|
|
|
static std::string makeValidIdentifier(const std::string& candidate) {
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
if (candidate.size() == 0 || isdigit(candidate[0]))
|
|
|
|
|
ss << "_";
|
2018-12-26 06:52:25 -08:00
|
|
|
for (char c : candidate) {
|
2018-11-12 10:15:44 -08:00
|
|
|
if (isupper(c) || islower(c) || isdigit(c) || c == '_')
|
|
|
|
|
ss << c;
|
|
|
|
|
else
|
|
|
|
|
ss << '_';
|
|
|
|
|
}
|
|
|
|
|
return ss.str();
|
|
|
|
|
}
|
|
|
|
|
// if we have to assign 'v' a name, what should it be?
|
2019-06-21 20:51:17 -07:00
|
|
|
// use the debugName if it was set, otherwise generate a name.
|
2018-11-27 11:46:17 -08:00
|
|
|
std::string genUniqueNameFor(Value* v) {
|
2018-11-12 10:15:44 -08:00
|
|
|
return genName(
|
2019-06-21 20:51:17 -07:00
|
|
|
v->hasDebugName() ? makeValidIdentifier(v->debugNameBase()) : "_");
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// map from Value to how it should be printed at each use
|
2019-07-01 21:11:12 -07:00
|
|
|
std::unordered_map<Value*, std::shared_ptr<TaggedStringStream>> expr_table_;
|
|
|
|
|
std::unordered_map<Value*, std::string> ident_refs_;
|
|
|
|
|
|
|
|
|
|
// NB: we MUST pass around the shared pointers to these streams by value.
|
|
|
|
|
// There is an interaction in splitLongInlines where the string value for
|
|
|
|
|
// both the RHS and the LHS of an expression are live at the same time,
|
|
|
|
|
// however the value for the RHS is overwritten in the table.
|
|
|
|
|
std::shared_ptr<TaggedStringStream> useOf(Value* v) const {
|
|
|
|
|
// Ident refs take precedent over expression refs, since presence in
|
|
|
|
|
// the ident ref table indicates we have already emitted a statement
|
|
|
|
|
// assigning the given value.
|
|
|
|
|
if (ident_refs_.count(v)) {
|
|
|
|
|
auto rv = std::make_shared<TaggedStringStream>(&source_range_stack_);
|
|
|
|
|
(*rv) << ident_refs_.at(v);
|
|
|
|
|
return rv;
|
|
|
|
|
}
|
|
|
|
|
if (expr_table_.count(v)) {
|
|
|
|
|
return expr_table_.at(v);
|
|
|
|
|
}
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
|
|
|
false,
|
|
|
|
|
"Value was not present in either expressions"
|
|
|
|
|
" table or ident refs table");
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
void assignValue(Value* v, const std::string& s) {
|
2019-07-01 21:11:12 -07:00
|
|
|
ident_refs_[v] = s;
|
|
|
|
|
}
|
|
|
|
|
void assignValue(Value* v, std::shared_ptr<TaggedStringStream> s) {
|
|
|
|
|
expr_table_[v] = std::move(s);
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
void assignValue(Value* v, Value* w) {
|
2018-11-12 10:15:44 -08:00
|
|
|
assignValue(v, useOf(w));
|
|
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
|
2018-12-26 06:52:25 -08:00
|
|
|
for (auto v : values) {
|
2018-11-12 10:15:44 -08:00
|
|
|
assignValue(v, genUniqueNameFor(v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t level = 0;
|
|
|
|
|
// indent to the current indent level
|
2019-07-01 21:11:12 -07:00
|
|
|
TaggedStringStream& indent() {
|
2018-11-12 10:15:44 -08:00
|
|
|
for (size_t i = 0; i < level; ++i) {
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << " ";
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2019-04-19 12:48:39 -07:00
|
|
|
return body_;
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResourceGuard WithIndented() {
|
|
|
|
|
level++;
|
2018-12-26 06:52:25 -08:00
|
|
|
return ResourceGuard([this] { level--; });
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T0, class T1, class F>
|
2018-12-26 06:52:25 -08:00
|
|
|
void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
|
|
|
|
|
const {
|
2018-11-12 10:15:44 -08:00
|
|
|
auto it_a = list_a.begin();
|
|
|
|
|
auto it_b = list_b.begin();
|
|
|
|
|
|
|
|
|
|
if (list_a.size() != list_b.size()) {
|
2019-01-31 14:06:44 -08:00
|
|
|
AT_ERROR("Python printer expected 2 lists of same size");
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (; it_a != list_a.end(); ++it_a, ++it_b) {
|
|
|
|
|
action(*it_a, *it_b);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
void printValueList(
|
2019-07-01 21:11:12 -07:00
|
|
|
TaggedStringStream& stmt,
|
2018-12-26 06:52:25 -08:00
|
|
|
at::ArrayRef<Value*> list,
|
|
|
|
|
const char* begin = "",
|
|
|
|
|
const char* end = "") {
|
2018-11-12 10:15:44 -08:00
|
|
|
stmt << begin;
|
|
|
|
|
auto delimiter = "";
|
2018-11-27 11:46:17 -08:00
|
|
|
for (auto* value : list) {
|
2018-11-12 10:15:44 -08:00
|
|
|
stmt << delimiter;
|
|
|
|
|
stmt << useOf(value);
|
|
|
|
|
delimiter = ", ";
|
|
|
|
|
}
|
|
|
|
|
stmt << end;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-10 14:38:12 -07:00
|
|
|
void printValueIndex(TaggedStringStream& stmt, at::ArrayRef<Value*> inputs) {
|
|
|
|
|
const std::string val_name = useOf(inputs[0])->str();
|
|
|
|
|
if (isValidIdentifier(val_name)) {
|
|
|
|
|
stmt << val_name;
|
|
|
|
|
} else {
|
|
|
|
|
stmt << "(" << val_name << ")";
|
|
|
|
|
}
|
|
|
|
|
stmt << "[";
|
|
|
|
|
stmt << useOf(inputs[1]);
|
|
|
|
|
stmt << "]";
|
|
|
|
|
}
|
|
|
|
|
|
2019-01-31 14:06:44 -08:00
|
|
|
void printDict(
|
2019-07-01 21:11:12 -07:00
|
|
|
TaggedStringStream& stmt,
|
2019-01-31 14:06:44 -08:00
|
|
|
at::ArrayRef<Value*> key_value_pairs,
|
|
|
|
|
const char* begin = "{",
|
|
|
|
|
const char* end = "}") {
|
|
|
|
|
stmt << begin;
|
|
|
|
|
auto delimiter = "";
|
|
|
|
|
for (size_t i = 0; i < key_value_pairs.size(); i += 2) {
|
|
|
|
|
stmt << delimiter;
|
|
|
|
|
auto key = key_value_pairs[i];
|
|
|
|
|
auto value = key_value_pairs[i + 1];
|
|
|
|
|
|
|
|
|
|
stmt << useOf(key) << ": " << useOf(value);
|
|
|
|
|
|
|
|
|
|
delimiter = ", ";
|
|
|
|
|
}
|
|
|
|
|
stmt << end;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
|
2019-08-26 13:45:53 -07:00
|
|
|
if (lhs.size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
indent();
|
|
|
|
|
printValueList(body_, lhs);
|
|
|
|
|
body_ << " = ";
|
|
|
|
|
printValueList(body_, rhs);
|
|
|
|
|
body_ << "\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool requiresAnnotation(Value* lhs, Value* rhs) {
|
|
|
|
|
return *lhs->type() != *rhs->type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void printAnnotatedAssignment(
|
|
|
|
|
at::ArrayRef<Value*> lhs,
|
|
|
|
|
at::ArrayRef<Value*> rhs) {
|
|
|
|
|
for (size_t i = 0; i < lhs.size(); ++i) {
|
2018-11-12 10:15:44 -08:00
|
|
|
indent();
|
2019-08-26 13:45:53 -07:00
|
|
|
body_ << useOf(lhs[i]);
|
|
|
|
|
if (requiresAnnotation(lhs[i], rhs[i])) {
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
|
2019-08-26 13:45:53 -07:00
|
|
|
}
|
|
|
|
|
body_ << " = " << useOf(rhs[i]) << "\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
void printIf(IfView stmt) {
|
|
|
|
|
assignValuesToTheirUniqueNames(stmt.outputs());
|
|
|
|
|
indent() << "if " << useOf(stmt.cond()) << ":\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
{
|
|
|
|
|
auto guard = WithIndented();
|
|
|
|
|
// Print node contents
|
2018-11-27 11:46:17 -08:00
|
|
|
printBlock(stmt.thenBlock(), stmt.outputs().size() > 0);
|
|
|
|
|
printAssignment(stmt.outputs(), stmt.thenOutputs());
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
indent() << "else:\n";
|
|
|
|
|
{
|
|
|
|
|
auto guard = WithIndented();
|
2018-11-27 11:46:17 -08:00
|
|
|
printBlock(stmt.elseBlock(), stmt.outputs().size() > 0);
|
|
|
|
|
printAssignment(stmt.outputs(), stmt.elseOutputs());
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
void printLoop(LoopView stmt) {
|
2018-11-12 10:15:44 -08:00
|
|
|
// Loop carried dependencies are handled by assigning their initial
|
|
|
|
|
// values to the node->outputs() before the loop,
|
|
|
|
|
// and assign node->outputs() to the new values at the end of each trip.
|
|
|
|
|
|
2019-07-12 14:58:49 -07:00
|
|
|
auto loop_type = stmt.loopType();
|
|
|
|
|
if (loop_type == LoopView::ModifiedLoop) {
|
2020-03-11 23:29:34 -07:00
|
|
|
throw ErrorReport(stmt.node()->sourceRange())
|
2019-07-12 14:58:49 -07:00
|
|
|
<< "loop cannot be printed as python "
|
|
|
|
|
<< "because it has gone through an optimization "
|
2019-07-16 16:41:09 -07:00
|
|
|
<< "that combined while and for loops. File a bug";
|
2019-07-12 14:58:49 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool emit_as_for_loop = loop_type == LoopView::For;
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
assignValuesToTheirUniqueNames(stmt.carriedOutputs());
|
2018-11-12 10:15:44 -08:00
|
|
|
// Add aliases for loop-carried dependencies
|
|
|
|
|
zipWith(
|
2018-11-27 11:46:17 -08:00
|
|
|
stmt.bodyCarriedInputs(), // Start at 1 to ignore trip count
|
|
|
|
|
stmt.carriedOutputs(),
|
|
|
|
|
[&](Value* block_input, Value* node_output) {
|
2018-11-12 10:15:44 -08:00
|
|
|
assignValue(block_input, node_output);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Print initial assignments of loop node outputs = loop node inputs
|
2019-08-26 13:45:53 -07:00
|
|
|
printAnnotatedAssignment(stmt.carriedOutputs(), stmt.carriedInputs());
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2018-11-27 11:46:17 -08:00
|
|
|
assignValuesToTheirUniqueNames(stmt.currentTripCount());
|
2018-11-12 10:15:44 -08:00
|
|
|
// Loop header
|
|
|
|
|
if (emit_as_for_loop) {
|
|
|
|
|
indent();
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << "for " << useOf(stmt.currentTripCount()) << " in range("
|
|
|
|
|
<< useOf(stmt.maxTripCount()) << "):\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
} else {
|
|
|
|
|
// note: trip_count_in_block is unused because this is a while loop,
|
|
|
|
|
// so we reuse the Value* as a stand-in for the loop condition
|
2018-11-27 11:46:17 -08:00
|
|
|
printAssignment(stmt.currentTripCount(), stmt.inputCond());
|
2018-11-12 10:15:44 -08:00
|
|
|
indent();
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << "while " << useOf(stmt.currentTripCount()) << ":\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
// Loop body
|
|
|
|
|
{
|
|
|
|
|
ResourceGuard indent = WithIndented();
|
|
|
|
|
// Update block outputs to block inputs for next loop iteration
|
|
|
|
|
// skip the assignment to the new condition in for loops because
|
|
|
|
|
// the condition is always True
|
|
|
|
|
size_t offset = emit_as_for_loop ? 1 : 0;
|
2018-11-27 11:46:17 -08:00
|
|
|
auto body_block = stmt.bodyBlock();
|
2018-12-26 06:52:25 -08:00
|
|
|
ArrayRef<Value*> loop_carried_block_inputs =
|
|
|
|
|
body_block->inputs().slice(offset);
|
2018-11-15 15:28:56 -08:00
|
|
|
printBlock(body_block, loop_carried_block_inputs.size() > 0);
|
2018-12-26 06:52:25 -08:00
|
|
|
printAssignment(
|
|
|
|
|
loop_carried_block_inputs, body_block->outputs().slice(offset));
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-02-28 13:06:10 -08:00
|
|
|
bool isLongLine(const std::string& str) {
|
|
|
|
|
return str.size() + level * 2 >= 40;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool isLongInline(Node* node) {
|
2019-07-01 21:11:12 -07:00
|
|
|
return output_inline_.count(node) &&
|
|
|
|
|
isLongLine(useOf(node->output())->str());
|
2019-02-28 13:06:10 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool isNonConstantInline(Value* input) {
|
2019-03-01 15:00:01 -08:00
|
|
|
return input->node()->kind() != prim::Constant &&
|
2019-02-28 13:06:10 -08:00
|
|
|
output_inline_.count(input->node());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// [reordering of inlines]
|
|
|
|
|
// We inline anything that is semantically legal to inline, but sometimes
|
|
|
|
|
// we find that these lines get too long. In that case we break the lines
|
|
|
|
|
/// and it is important that we un-inline all the inputs preceeding the long
|
|
|
|
|
/// input:
|
|
|
|
|
// r = foo(x.add_(b), some_long + expression)
|
|
|
|
|
// wrong!
|
|
|
|
|
// _0 = some_long + expression
|
|
|
|
|
// r = foo(x.add_(b), _0) # wrong! _0 runs before mutating add_
|
|
|
|
|
// legal!
|
|
|
|
|
// _0 = x.add_(b)
|
|
|
|
|
// _1 = some_long + expression
|
|
|
|
|
// r = foo(_0, _1)
|
2020-04-24 15:12:12 -07:00
|
|
|
|
|
|
|
|
void splitLongInlines(Value* v) {
|
|
|
|
|
std::vector<Value*> to_split_reversed;
|
|
|
|
|
Use u = v->uses().at(0);
|
|
|
|
|
scanLongInlines(u.user, u.offset, to_split_reversed);
|
|
|
|
|
for (auto it = to_split_reversed.rbegin(), end = to_split_reversed.rend();
|
|
|
|
|
it != end;
|
|
|
|
|
++it) {
|
|
|
|
|
printOutputDefinition((*it)->node(), *useOf(*it));
|
2019-02-28 13:06:10 -08:00
|
|
|
}
|
2020-04-24 15:12:12 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void scanLongInlines(
|
|
|
|
|
Node* user,
|
|
|
|
|
int64_t offset,
|
|
|
|
|
std::vector<Value*>& to_split_reversed) {
|
|
|
|
|
auto it = visited_split_inline_uses_.find(user);
|
|
|
|
|
bool present = it != visited_split_inline_uses_.end();
|
|
|
|
|
for (int64_t i = offset; i >= (present ? it->second + 1 : 0); --i) {
|
|
|
|
|
Value* prev_arg = user->input(i);
|
|
|
|
|
if (isNonConstantInline(prev_arg)) {
|
|
|
|
|
to_split_reversed.push_back(prev_arg);
|
2020-04-21 16:19:44 -07:00
|
|
|
}
|
2020-04-21 15:44:14 -07:00
|
|
|
}
|
2020-04-24 15:12:12 -07:00
|
|
|
visited_split_inline_uses_[user] = offset;
|
|
|
|
|
if (!present && output_inline_.count(user)) {
|
|
|
|
|
Use u = user->output()->uses().at(0);
|
|
|
|
|
scanLongInlines(u.user, int64_t(u.offset) - 1, to_split_reversed);
|
|
|
|
|
// -1 because the actual use is still being
|
|
|
|
|
// emitted so it cannot be split
|
|
|
|
|
}
|
2019-02-28 13:06:10 -08:00
|
|
|
}
|
|
|
|
|
|
2019-07-01 21:11:12 -07:00
|
|
|
template <typename T>
|
|
|
|
|
void printOutputDefinition(Node* node, const T& expr) {
|
2019-02-28 13:06:10 -08:00
|
|
|
assignValuesToTheirUniqueNames(node->outputs());
|
|
|
|
|
indent();
|
|
|
|
|
// Print outputs
|
|
|
|
|
if (node->outputs().size() > 0) {
|
2019-04-19 12:48:39 -07:00
|
|
|
printValueList(body_, node->outputs());
|
|
|
|
|
body_ << " = ";
|
2019-02-28 13:06:10 -08:00
|
|
|
}
|
2019-07-01 21:11:12 -07:00
|
|
|
body_ << expr << "\n";
|
2019-02-28 13:06:10 -08:00
|
|
|
}
|
|
|
|
|
|
2019-03-15 12:00:50 -07:00
|
|
|
// Recursively check contained types for any class dependencies
|
|
|
|
|
void registerClassDependencies(const TypePtr& type) {
|
|
|
|
|
if (const auto classType = type->cast<ClassType>()) {
|
2019-08-14 11:21:42 -07:00
|
|
|
registerDependency(classType);
|
2019-06-19 10:38:46 -07:00
|
|
|
} else if (const auto tupleType = type->cast<TupleType>()) {
|
2019-08-14 11:21:42 -07:00
|
|
|
if (tupleType->name()) {
|
2019-08-14 11:21:42 -07:00
|
|
|
registerDependency(tupleType);
|
2019-06-19 10:38:46 -07:00
|
|
|
}
|
2019-08-27 22:52:48 -07:00
|
|
|
} else if (const auto interfaceType = type->cast<InterfaceType>()) {
|
|
|
|
|
registerDependency(interfaceType);
|
2019-03-15 12:00:50 -07:00
|
|
|
}
|
|
|
|
|
for (const auto& containedType : type->containedTypes()) {
|
|
|
|
|
registerClassDependencies(containedType);
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-10-01 16:37:34 -07:00
|
|
|
void scanTypeDependencies(Node* node) {
|
2019-03-15 12:00:50 -07:00
|
|
|
// Check for class dependencies. If this node inputs or outputs a class
|
|
|
|
|
// type, we need to add it to our table of dependencies.
|
|
|
|
|
for (const auto input : node->inputs()) {
|
|
|
|
|
registerClassDependencies(input->type());
|
|
|
|
|
}
|
|
|
|
|
for (const auto output : node->outputs()) {
|
|
|
|
|
registerClassDependencies(output->type());
|
|
|
|
|
}
|
2019-10-01 16:37:34 -07:00
|
|
|
for (const auto& name : node->attributeNames()) {
|
|
|
|
|
switch (node->kindOf(name)) {
|
|
|
|
|
case AttributeKind::ty:
|
|
|
|
|
registerClassDependencies(node->ty(name));
|
|
|
|
|
break;
|
|
|
|
|
case AttributeKind::tys:
|
|
|
|
|
for (const TypePtr& t : node->tys(name)) {
|
|
|
|
|
registerClassDependencies(t);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
// noop
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-03-15 12:00:50 -07:00
|
|
|
|
2020-06-24 12:39:42 -07:00
|
|
|
void checkVersion(const Node* const node) {
|
|
|
|
|
min_version_ =
|
|
|
|
|
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
|
|
|
|
}
|
|
|
|
|
|
2019-10-01 16:37:34 -07:00
|
|
|
void printNode(Node* node, bool print_const) {
|
|
|
|
|
WithSourceRange guard(&source_range_stack_, node);
|
|
|
|
|
scanTypeDependencies(node);
|
2020-06-24 12:39:42 -07:00
|
|
|
checkVersion(node);
|
2019-03-01 15:00:01 -08:00
|
|
|
if (!print_const && node->kind() == prim::Constant)
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
return;
|
2018-11-12 10:15:44 -08:00
|
|
|
switch (node->kind()) {
|
|
|
|
|
case prim::Return:
|
2018-12-18 10:27:26 -08:00
|
|
|
if (enforce_importable_ && node->inputs().size() != 1) {
|
2020-03-11 23:29:34 -07:00
|
|
|
throw ErrorReport(node->sourceRange())
|
2019-03-15 12:00:50 -07:00
|
|
|
<< "Exportable methods must have a single return value. "
|
2019-07-16 16:41:09 -07:00
|
|
|
<< "Normal use of ScriptMethods should enforce this";
|
2018-12-18 10:27:26 -08:00
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
if (node->inputs().size() > 0) {
|
|
|
|
|
indent();
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << "return ";
|
|
|
|
|
printValueList(body_, node->inputs());
|
|
|
|
|
body_ << "\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case prim::Loop:
|
2018-11-27 11:46:17 -08:00
|
|
|
printLoop(LoopView(node));
|
2018-11-12 10:15:44 -08:00
|
|
|
break;
|
|
|
|
|
case prim::If:
|
2018-11-27 11:46:17 -08:00
|
|
|
printIf(IfView(node));
|
2018-11-12 10:15:44 -08:00
|
|
|
break;
|
|
|
|
|
case prim::TupleUnpack:
|
|
|
|
|
case prim::ListUnpack:
|
|
|
|
|
assignValuesToTheirUniqueNames(node->outputs());
|
|
|
|
|
indent();
|
|
|
|
|
// TupleUnpack(unpacked) turns into an assignment op that forces
|
|
|
|
|
// the unpack to be inserted when parsed back in:
|
|
|
|
|
// a, b, = unpacked
|
|
|
|
|
// a, = unpacked # trailing comma forces an unpack to happen
|
|
|
|
|
if (node->outputs().size() > 0) {
|
2019-04-19 12:48:39 -07:00
|
|
|
printValueList(body_, node->outputs(), "", ", = ");
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << useOf(node->input()) << "\n";
|
2018-11-12 10:15:44 -08:00
|
|
|
break;
|
2019-03-15 12:00:50 -07:00
|
|
|
case prim::SetAttr: {
|
|
|
|
|
const auto obj = node->inputs().at(0);
|
|
|
|
|
const auto newVal = node->inputs().at(1);
|
|
|
|
|
const auto type = obj->type()->expect<ClassType>();
|
|
|
|
|
const auto& attrname = node->s(attr::name);
|
|
|
|
|
indent();
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
|
|
|
|
|
<< "\n";
|
2019-03-15 12:00:50 -07:00
|
|
|
} break;
|
2019-04-25 15:43:53 -07:00
|
|
|
case prim::fork: {
|
|
|
|
|
// the subgraph gets emitted as another function
|
|
|
|
|
auto name = genName("__forked_function");
|
|
|
|
|
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "def " << name << "():\n";
|
2019-04-26 19:14:10 -07:00
|
|
|
for (size_t i = 0; i < node->inputs().size(); ++i) {
|
2019-04-25 15:43:53 -07:00
|
|
|
assignValue(graph->inputs().at(i), node->inputs().at(i));
|
|
|
|
|
}
|
|
|
|
|
printBody(graph->block());
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << "fork(" << name << ")";
|
2019-04-26 19:14:10 -07:00
|
|
|
printOutputDefinition(node, ss.str());
|
2019-04-25 15:43:53 -07:00
|
|
|
} break;
|
2020-06-18 16:55:42 -07:00
|
|
|
case prim::Enter: {
|
|
|
|
|
const auto in = node->inputs().at(0);
|
|
|
|
|
const auto out = node->outputs().at(0);
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "with " << useOf(in);
|
|
|
|
|
if (out->uses().size() > 0) {
|
|
|
|
|
assignValue(out, genUniqueNameFor(out));
|
|
|
|
|
body_ << " as " << useOf(out);
|
|
|
|
|
}
|
|
|
|
|
body_ << ":\n";
|
|
|
|
|
level++;
|
|
|
|
|
} break;
|
|
|
|
|
case prim::Exit: {
|
|
|
|
|
// If the previous node is a prim::Enter, the with block the generated
|
|
|
|
|
// this Enter/Exit pair must have been empty.
|
|
|
|
|
if (node->prev()->kind() == prim::Enter) {
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "pass\n";
|
|
|
|
|
}
|
|
|
|
|
level--;
|
|
|
|
|
} break;
|
2019-04-25 15:43:53 -07:00
|
|
|
case prim::Function: {
|
|
|
|
|
if (enforce_importable_) {
|
2020-03-11 23:29:34 -07:00
|
|
|
throw ErrorReport(node->sourceRange())
|
2019-04-25 15:43:53 -07:00
|
|
|
<< "closures are not exportable";
|
|
|
|
|
}
|
|
|
|
|
assignValuesToTheirUniqueNames(node->outputs());
|
2019-07-01 21:11:12 -07:00
|
|
|
auto name = useOf(node->output())->str();
|
2019-04-25 15:43:53 -07:00
|
|
|
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "def " << name << "(";
|
|
|
|
|
assignValuesToTheirUniqueNames(graph->inputs());
|
2019-04-26 19:14:10 -07:00
|
|
|
for (size_t i = 0; i < graph->inputs().size(); ++i) {
|
2019-04-25 15:43:53 -07:00
|
|
|
Value* v = graph->inputs().at(i);
|
|
|
|
|
if (i > 0) {
|
|
|
|
|
body_ << ", ";
|
|
|
|
|
}
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
|
2019-04-25 15:43:53 -07:00
|
|
|
}
|
|
|
|
|
body_ << "):\n";
|
|
|
|
|
printBody(graph->block());
|
|
|
|
|
} break;
|
2018-11-12 10:15:44 -08:00
|
|
|
default:
|
2019-07-01 21:11:12 -07:00
|
|
|
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
|
|
|
|
|
printRHS(*ss, node);
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2019-02-28 13:06:10 -08:00
|
|
|
// we prevent long constants from inlining here.
|
|
|
|
|
// it is not safe to do the same thing for non-constants here
|
|
|
|
|
// because of [reordering of inlines]
|
|
|
|
|
if (output_inline_.count(node) == 0 ||
|
2019-07-01 21:11:12 -07:00
|
|
|
(node->kind() == prim::Constant && isLongLine(ss->str()))) {
|
|
|
|
|
printOutputDefinition(node, *ss);
|
2019-02-28 13:06:10 -08:00
|
|
|
} else {
|
|
|
|
|
// this node is safe to inline, so assign the output value
|
|
|
|
|
// to that expression directly
|
2019-07-01 21:11:12 -07:00
|
|
|
assignValue(node->output(), ss);
|
2020-04-24 15:12:12 -07:00
|
|
|
if (isLongLine(ss->str())) {
|
|
|
|
|
splitLongInlines(node->output());
|
|
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-01 21:11:12 -07:00
|
|
|
void printConstant(TaggedStringStream& stmt, const IValue& v) {
|
2020-01-15 17:33:25 -08:00
|
|
|
const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
|
|
|
|
|
if (v.isTensor()) {
|
|
|
|
|
ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
|
|
|
|
|
return true;
|
2018-11-27 11:46:17 -08:00
|
|
|
}
|
2020-01-31 15:22:10 -08:00
|
|
|
if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
|
|
|
|
|
// print the namedtuple constructor and let rest of tuple printing
|
|
|
|
|
// continue
|
2020-06-10 11:59:01 -07:00
|
|
|
ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
|
2020-01-31 15:22:10 -08:00
|
|
|
}
|
2020-01-15 17:33:25 -08:00
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
v.repr(ss, customFormatter);
|
2019-07-01 21:11:12 -07:00
|
|
|
stmt << ss.str();
|
2018-11-27 11:46:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-09-04 13:50:48 -07:00
|
|
|
void printOpName(TaggedStringStream& stmt, Symbol kind) {
|
2019-09-30 10:26:23 -07:00
|
|
|
// Special overriding ops set that requires serializing differently to
|
|
|
|
|
// preserve the original code semantics.
|
|
|
|
|
// This will be more properly handled when we have namespace semantics
|
|
|
|
|
// for serializing the ops, and it right now hard coded these ops to
|
|
|
|
|
// ensure consistency and not breaking BC in the future.
|
|
|
|
|
const static std::unordered_map<Symbol, std::string> override_symbols = {
|
|
|
|
|
{aten::backward, "torch.autograd.backward"},
|
|
|
|
|
{aten::grad, "torch.autograd.grad"},
|
|
|
|
|
};
|
|
|
|
|
if (override_symbols.find(kind) != override_symbols.end()) {
|
|
|
|
|
stmt << override_symbols.at(kind);
|
|
|
|
|
} else if (kind.is_aten()) {
|
2019-09-04 13:50:48 -07:00
|
|
|
// special case aten -> torch because we want to rename
|
|
|
|
|
// the aten namespace, but this change will take more time
|
|
|
|
|
// doing it here ensures we do not have fix up archives later
|
|
|
|
|
stmt << "torch." << kind.toUnqualString();
|
|
|
|
|
} else {
|
|
|
|
|
stmt << "ops." << kind.ns().toUnqualString() << "."
|
|
|
|
|
<< kind.toUnqualString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
|
2019-07-01 21:11:12 -07:00
|
|
|
void printRHS(TaggedStringStream& stmt, Node* node) {
|
2018-12-26 06:52:25 -08:00
|
|
|
switch (node->kind()) {
|
2019-04-26 01:26:49 -07:00
|
|
|
case prim::PythonOp: {
|
2018-11-12 10:15:44 -08:00
|
|
|
auto value = static_cast<const PythonOp*>(node);
|
2019-09-09 20:22:54 -07:00
|
|
|
if (enforce_importable_) {
|
2020-03-11 23:29:34 -07:00
|
|
|
throw ErrorReport(node->sourceRange())
|
2019-05-31 14:20:57 -07:00
|
|
|
<< "Could not export Python function call '" << value->name()
|
|
|
|
|
<< "'. Remove calls to Python functions before export. "
|
2019-03-29 19:06:06 -07:00
|
|
|
<< "Did you forget add @script or @script_method annotation? "
|
2019-05-31 14:20:57 -07:00
|
|
|
<< "If this is a nn.ModuleList, add it to __constants__";
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2019-09-09 20:22:54 -07:00
|
|
|
std::stringstream scalars_stream;
|
|
|
|
|
stmt << "^" << value->name();
|
|
|
|
|
value->writeScalars(scalars_stream);
|
|
|
|
|
stmt << scalars_stream.str();
|
2018-11-12 10:15:44 -08:00
|
|
|
printValueList(stmt, node->inputs(), "(", ")");
|
|
|
|
|
} break;
|
2019-06-10 14:43:19 -07:00
|
|
|
case prim::Uninitialized: {
|
2020-04-01 00:00:30 -07:00
|
|
|
stmt << "uninitialized("
|
2020-06-10 11:59:01 -07:00
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ")";
|
2019-06-10 14:43:19 -07:00
|
|
|
} break;
|
2019-03-01 15:00:01 -08:00
|
|
|
case prim::Constant: {
|
2019-08-19 18:41:08 -07:00
|
|
|
if (node->outputs().size() == 1 &&
|
|
|
|
|
node->output()->type()->kind() == TypeKind::FunctionType) {
|
|
|
|
|
auto fn = node->output()->type()->expect<FunctionType>();
|
|
|
|
|
registerDependency(fn);
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << fn->annotation_str(type_printer_);
|
2019-08-19 18:41:08 -07:00
|
|
|
} else if (!node->mustBeNone()) {
|
2019-02-19 11:34:46 -08:00
|
|
|
IValue v = toIValue(node->output()).value();
|
|
|
|
|
printConstant(stmt, v);
|
2018-11-15 15:28:56 -08:00
|
|
|
} else {
|
2019-09-04 13:50:48 -07:00
|
|
|
stmt << "None";
|
2018-11-15 15:28:56 -08:00
|
|
|
}
|
|
|
|
|
} break;
|
2020-02-12 14:45:44 -08:00
|
|
|
case aten::ScalarImplicit:
|
|
|
|
|
case aten::FloatImplicit:
|
|
|
|
|
case aten::IntImplicit: {
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << "annotate("
|
|
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
|
|
|
|
<< useOf(node->input()) << ")";
|
2018-11-12 10:15:44 -08:00
|
|
|
} break;
|
2019-07-03 22:14:14 -07:00
|
|
|
case aten::Int: {
|
2018-11-12 10:15:44 -08:00
|
|
|
printValueList(stmt, node->inputs(), "int(", ")");
|
|
|
|
|
} break;
|
2019-07-03 22:14:14 -07:00
|
|
|
case aten::Float: {
|
2018-11-12 10:15:44 -08:00
|
|
|
printValueList(stmt, node->inputs(), "float(", ")");
|
|
|
|
|
} break;
|
2019-07-03 22:14:14 -07:00
|
|
|
case aten::Bool: {
|
2018-11-12 10:15:44 -08:00
|
|
|
printValueList(stmt, node->inputs(), "bool(", ")");
|
|
|
|
|
} break;
|
2019-07-03 22:14:14 -07:00
|
|
|
case aten::str: {
|
2019-05-07 18:43:04 -07:00
|
|
|
printValueList(stmt, node->inputs(), "str(", ")");
|
|
|
|
|
} break;
|
2019-07-10 14:38:12 -07:00
|
|
|
case aten::__getitem__: {
|
|
|
|
|
printValueIndex(stmt, node->inputs());
|
|
|
|
|
} break;
|
2018-11-12 10:15:44 -08:00
|
|
|
case prim::Print: {
|
2018-12-26 06:52:25 -08:00
|
|
|
printValueList(stmt, node->inputs(), "print(", ")");
|
2018-11-12 10:15:44 -08:00
|
|
|
} break;
|
2019-07-26 17:34:22 -07:00
|
|
|
case aten::sorted: {
|
|
|
|
|
printValueList(stmt, node->inputs(), "sorted(", ")");
|
|
|
|
|
} break;
|
2018-11-12 10:15:44 -08:00
|
|
|
case prim::TupleConstruct: {
|
2019-08-14 11:21:42 -07:00
|
|
|
if (auto qualname =
|
|
|
|
|
node->output()->type()->expect<TupleType>()->name()) {
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << node->output()->type()->annotation_str(type_printer_);
|
2019-06-19 10:38:46 -07:00
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
printValueList(
|
|
|
|
|
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
|
|
|
|
|
} break;
|
|
|
|
|
case prim::TupleIndex: {
|
2019-05-06 13:25:57 -07:00
|
|
|
stmt << "(" << useOf(node->inputs().at(0)) << ")["
|
|
|
|
|
<< useOf(node->inputs().at(1)) << "]";
|
2018-11-12 10:15:44 -08:00
|
|
|
} break;
|
|
|
|
|
case prim::TupleSlice: {
|
|
|
|
|
stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
|
|
|
|
|
<< node->i(attr::end) << "]";
|
|
|
|
|
} break;
|
|
|
|
|
case prim::ListConstruct: {
|
2019-08-21 11:18:18 -07:00
|
|
|
ListTypePtr list_type = node->output()->type()->expect<ListType>();
|
|
|
|
|
TypePtr elem_type = list_type->getElementType();
|
2020-01-24 03:17:47 -08:00
|
|
|
// Empty lists must be annotated with their type so the compiler knows
|
|
|
|
|
// what type is supposed to be inside them
|
|
|
|
|
if (node->inputs().size() == 0) {
|
2020-04-01 00:00:30 -07:00
|
|
|
stmt << "annotate("
|
2020-06-10 11:59:01 -07:00
|
|
|
<< node->output()->type()->annotation_str(type_printer_)
|
|
|
|
|
<< ", [])";
|
2020-01-24 03:17:47 -08:00
|
|
|
// If we can't infer the type based on what's inside, explicitly
|
|
|
|
|
// annotate it to disambiguate.
|
|
|
|
|
// This happens for List[Tensor] vs. List[Optional[Tensor]]
|
|
|
|
|
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
|
2020-04-01 00:00:30 -07:00
|
|
|
stmt << "annotate("
|
2020-06-10 11:59:01 -07:00
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
2020-01-24 03:17:47 -08:00
|
|
|
printValueList(stmt, node->inputs(), "[", "]");
|
|
|
|
|
stmt << ")";
|
|
|
|
|
// Otherwise just print a list
|
2018-11-13 16:33:51 -08:00
|
|
|
} else {
|
|
|
|
|
printValueList(stmt, node->inputs(), "[", "]");
|
|
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
} break;
|
2019-01-31 14:06:44 -08:00
|
|
|
case prim::DictConstruct: {
|
|
|
|
|
auto dict_type = node->output()->type()->expect<DictType>();
|
2020-01-24 03:17:47 -08:00
|
|
|
// There are cases where we must annotate the dict with an explicit type
|
|
|
|
|
// to help the compiler out:
|
|
|
|
|
// - the dict is empty
|
|
|
|
|
// - the dict has potentially ambiguous element types
|
|
|
|
|
// (e.g. Tensor vs. Optional[Tensor])
|
2020-03-26 11:15:49 -07:00
|
|
|
if (node->inputs().size() == 0 ||
|
2020-01-24 03:17:47 -08:00
|
|
|
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
|
|
|
|
|
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
|
2020-04-01 00:00:30 -07:00
|
|
|
stmt << "annotate("
|
2020-06-10 11:59:01 -07:00
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
2020-01-24 03:17:47 -08:00
|
|
|
printDict(stmt, node->inputs());
|
|
|
|
|
stmt << ")";
|
|
|
|
|
// Otherwise just print a dict
|
2019-01-31 14:06:44 -08:00
|
|
|
} else {
|
|
|
|
|
printDict(stmt, node->inputs());
|
|
|
|
|
}
|
|
|
|
|
} break;
|
2019-03-15 12:00:50 -07:00
|
|
|
case prim::CreateObject: {
|
|
|
|
|
const auto classType = node->output()->type()->expect<ClassType>();
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << classType->annotation_str(type_printer_) << ".__new__("
|
|
|
|
|
<< classType->annotation_str(type_printer_) << ")";
|
2019-03-15 12:00:50 -07:00
|
|
|
} break;
|
|
|
|
|
case prim::GetAttr: {
|
|
|
|
|
const auto obj = node->inputs().at(0);
|
|
|
|
|
const auto classType = obj->type()->expect<ClassType>();
|
|
|
|
|
const auto& field = node->s(attr::name);
|
2019-04-25 15:43:53 -07:00
|
|
|
if (isValidIdentifier(field)) {
|
|
|
|
|
stmt << useOf(obj) << "." << field;
|
|
|
|
|
} else {
|
|
|
|
|
stmt << "getattr(" << useOf(obj) << ", ";
|
2019-07-01 21:11:12 -07:00
|
|
|
std::stringstream field_stream;
|
2019-10-02 11:27:58 -07:00
|
|
|
c10::printQuotedString(field_stream, field);
|
2019-07-01 21:11:12 -07:00
|
|
|
stmt << field_stream.str() << ")";
|
2019-04-25 15:43:53 -07:00
|
|
|
}
|
2019-03-15 12:00:50 -07:00
|
|
|
} break;
|
2019-08-19 18:41:08 -07:00
|
|
|
case prim::CallFunction: {
|
|
|
|
|
stmt << useOf(node->inputs().at(0)) << "(";
|
|
|
|
|
for (size_t i = 1; i < node->inputs().size(); i++) {
|
|
|
|
|
stmt << useOf(node->inputs()[i]) << ", ";
|
|
|
|
|
}
|
|
|
|
|
stmt << ")";
|
|
|
|
|
} break;
|
|
|
|
|
case prim::CallMethod: {
|
|
|
|
|
const auto& self = node->inputs().at(0);
|
|
|
|
|
const auto& methodName = node->s(attr::name);
|
|
|
|
|
stmt << "(" << useOf(self) << ")"
|
|
|
|
|
<< "." << methodName << "(";
|
|
|
|
|
for (size_t i = 1; i < node->inputs().size(); i++) {
|
|
|
|
|
stmt << useOf(node->inputs()[i]) << ", ";
|
|
|
|
|
}
|
|
|
|
|
stmt << ")";
|
2019-08-27 22:52:48 -07:00
|
|
|
|
|
|
|
|
if (auto selfClass = self->type()->cast<ClassType>()) {
|
|
|
|
|
registerDependency(selfClass);
|
2020-05-06 15:20:31 -07:00
|
|
|
const Function& method = selfClass->getMethod(node->s(attr::name));
|
2019-08-27 22:52:48 -07:00
|
|
|
TORCH_INTERNAL_ASSERT(
|
2020-05-06 15:20:31 -07:00
|
|
|
method.qualname() ==
|
2019-08-27 22:52:48 -07:00
|
|
|
QualifiedName(selfClass->name()->qualifiedName(), methodName));
|
|
|
|
|
} else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
|
|
|
|
|
registerDependency(selfInterface);
|
|
|
|
|
} else {
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
|
|
|
false, "method call to unhandled type in serialization");
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-19 18:41:08 -07:00
|
|
|
} break;
|
2019-09-04 13:50:48 -07:00
|
|
|
case aten::_unwrap_optional: {
|
|
|
|
|
printOpName(stmt, node->kind());
|
|
|
|
|
stmt << "(";
|
|
|
|
|
// we cannot recover the type of unwrap_optional(None),
|
|
|
|
|
// using normal schema matching, so we route around this by rewriting
|
|
|
|
|
// the call to unwrap_optional(annotated(Optional[T], None))
|
|
|
|
|
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
|
|
|
|
|
node->input()->mustBeNone()) {
|
|
|
|
|
auto input_type = OptionalType::create(node->output()->type());
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << "annotate(" << input_type->annotation_str(type_printer_)
|
|
|
|
|
<< ", " << useOf(node->input()) << ")";
|
2018-11-29 17:51:45 -08:00
|
|
|
} else {
|
2019-09-04 13:50:48 -07:00
|
|
|
stmt << useOf(node->input());
|
2018-11-29 17:51:45 -08:00
|
|
|
}
|
2019-09-04 13:50:48 -07:00
|
|
|
stmt << ")";
|
|
|
|
|
} break;
|
2019-10-15 15:58:05 -07:00
|
|
|
// unchecked_unwrap_optional is no longer generated by the compiler,
|
|
|
|
|
// but may end up here if it was first loaded from a old model and
|
|
|
|
|
// re-saved. On re-save we upgrade it to an unchecked_cast, which is an
|
|
|
|
|
// equivalent op
|
|
|
|
|
case prim::unchecked_unwrap_optional:
|
|
|
|
|
case prim::unchecked_cast: {
|
2020-04-01 00:00:30 -07:00
|
|
|
stmt << "unchecked_cast("
|
2020-06-10 11:59:01 -07:00
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
2020-04-01 00:00:30 -07:00
|
|
|
<< useOf(node->input()) << ")";
|
2019-10-15 15:58:05 -07:00
|
|
|
} break;
|
2019-10-01 16:37:34 -07:00
|
|
|
case prim::isinstance: {
|
|
|
|
|
stmt << "isinstance(" << useOf(node->input()) << ", ";
|
|
|
|
|
const auto& types = node->tys(attr::types);
|
2020-02-18 15:02:36 -08:00
|
|
|
if (types.size() == 1) {
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << types.at(0)->annotation_str(type_printer_);
|
2019-10-01 16:37:34 -07:00
|
|
|
} else {
|
|
|
|
|
// check multiple things, e.g. (str, list, int)
|
|
|
|
|
stmt << "(";
|
|
|
|
|
bool first = true;
|
|
|
|
|
for (const TypePtr& typ : types) {
|
|
|
|
|
if (!first) {
|
|
|
|
|
stmt << ", ";
|
|
|
|
|
}
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << typ->annotation_str(type_printer_);
|
2019-10-01 16:37:34 -07:00
|
|
|
first = false;
|
|
|
|
|
}
|
|
|
|
|
stmt << ")";
|
|
|
|
|
}
|
|
|
|
|
stmt << ")";
|
|
|
|
|
} break;
|
2020-02-27 21:43:17 -08:00
|
|
|
case prim::tolist: {
|
2020-06-10 11:59:01 -07:00
|
|
|
stmt << "annotate("
|
|
|
|
|
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
2020-02-27 21:43:17 -08:00
|
|
|
stmt << useOf(node->input(0)) << ".tolist()"
|
|
|
|
|
<< ")";
|
|
|
|
|
} break;
|
2019-09-04 13:50:48 -07:00
|
|
|
default: {
|
|
|
|
|
printOpName(stmt, node->kind());
|
2018-11-13 16:33:51 -08:00
|
|
|
const FunctionSchema& schema = node->schema();
|
2019-09-04 13:50:48 -07:00
|
|
|
stmt << "(";
|
2018-11-29 17:51:45 -08:00
|
|
|
for (size_t i = 0; i < node->inputs().size(); ++i) {
|
2018-12-26 06:52:25 -08:00
|
|
|
if (i > 0) {
|
|
|
|
|
stmt << ", ";
|
|
|
|
|
}
|
|
|
|
|
auto v = useOf(node->inputs().at(i));
|
|
|
|
|
// print the kwarg name if it is a kwarg only argument.
|
|
|
|
|
if (i < schema.arguments().size()) {
|
|
|
|
|
auto arg = schema.arguments().at(i);
|
|
|
|
|
if (arg.kwarg_only()) {
|
|
|
|
|
stmt << arg.name() << "=";
|
2018-11-13 16:33:51 -08:00
|
|
|
}
|
2018-12-26 06:52:25 -08:00
|
|
|
} else {
|
|
|
|
|
// vararg functions like format can have extra arguments
|
2019-01-24 11:05:07 -08:00
|
|
|
AT_ASSERT(schema.is_vararg());
|
2018-12-26 06:52:25 -08:00
|
|
|
}
|
2019-07-01 21:11:12 -07:00
|
|
|
stmt << *v;
|
2018-11-13 16:33:51 -08:00
|
|
|
}
|
|
|
|
|
stmt << ")";
|
2018-11-12 10:15:44 -08:00
|
|
|
} break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-01 21:11:12 -07:00
|
|
|
TaggedStringStream& printBlock(Block* root, bool block_has_other_statements) {
|
2018-12-26 06:52:25 -08:00
|
|
|
// pythons weird 'pass' syntax creates a bunch of places where we have to
|
|
|
|
|
// check if this block would be empty. But not everything in a block is a
|
|
|
|
|
// node. Sometimes if, loop, and return statements will follow this block
|
2018-11-15 15:28:56 -08:00
|
|
|
// and block_has_other_statements == true.
|
|
|
|
|
if (!block_has_other_statements &&
|
|
|
|
|
root->nodes().begin() == root->nodes().end()) {
|
|
|
|
|
indent();
|
2019-04-19 12:48:39 -07:00
|
|
|
body_ << "pass\n";
|
2018-11-15 15:28:56 -08:00
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
for (auto* node : root->nodes()) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
printNode(node, /*print_const=*/false);
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
2019-04-19 12:48:39 -07:00
|
|
|
return body_;
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
|
2019-12-09 15:09:59 -08:00
|
|
|
template <typename dtype>
|
|
|
|
|
IValue createBroadList(dtype value, const int64_t& N) {
|
|
|
|
|
c10::List<dtype> repeated;
|
|
|
|
|
repeated.reserve(N);
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
repeated.push_back(value);
|
|
|
|
|
}
|
|
|
|
|
return repeated;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 06:52:25 -08:00
|
|
|
void printDefaultValue(
|
2019-12-09 15:09:59 -08:00
|
|
|
const Argument& arg,
|
2019-07-01 21:11:12 -07:00
|
|
|
TaggedStringStream& stmt,
|
2018-12-26 06:52:25 -08:00
|
|
|
const IValue& value) {
|
2019-12-09 15:09:59 -08:00
|
|
|
stmt << "=";
|
|
|
|
|
// handle broadcasting lists
|
|
|
|
|
if (arg.type()->kind() == ListType::Kind &&
|
2018-12-18 10:27:26 -08:00
|
|
|
(value.isInt() || value.isDouble() || value.isBool())) {
|
2019-12-09 15:09:59 -08:00
|
|
|
TORCH_INTERNAL_ASSERT(arg.N(), "expected broadcastinglist");
|
|
|
|
|
if (value.isInt()) {
|
|
|
|
|
printConstant(stmt, createBroadList<int64_t>(value.toInt(), *arg.N()));
|
|
|
|
|
} else if (value.isBool()) {
|
|
|
|
|
printConstant(stmt, createBroadList<bool>(value.toBool(), *arg.N()));
|
|
|
|
|
} else if (value.isDouble()) {
|
|
|
|
|
printConstant(
|
|
|
|
|
stmt, createBroadList<double>(value.toDouble(), *arg.N()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
printConstant(stmt, value);
|
2018-12-18 10:27:26 -08:00
|
|
|
}
|
2018-11-27 11:46:17 -08:00
|
|
|
}
|
2019-12-09 15:09:59 -08:00
|
|
|
|
2019-04-25 15:43:53 -07:00
|
|
|
void printBody(Block* body) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
// we always print constants at the top of the function, in the order
|
|
|
|
|
// in which they are used.
|
2018-11-27 11:46:17 -08:00
|
|
|
std::vector<Node*> constants;
|
2019-04-25 15:43:53 -07:00
|
|
|
buildConstantList(body, constants);
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
// current graph is used to de-dup names within a single graph
|
2019-04-25 15:43:53 -07:00
|
|
|
scanBlock(body);
|
2018-11-12 10:15:44 -08:00
|
|
|
{
|
|
|
|
|
auto guard = WithIndented();
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
// Print initial constant table (most are just inlined into their use,
|
|
|
|
|
// but some like long strings do get emitted)
|
2018-11-27 11:46:17 -08:00
|
|
|
for (Node* n : constants) {
|
Address jittering issues in python_print (#14064)
Summary:
export - print a method with python_print
import - import a method with import_method
We want to ensure:
export(g) == export(import(export(g)))
That is after after exporting/importing once, the graph will stay exactly
the same. This is less strict that g == import(export(g)) which would
require us to maintain a lot more information about the structure of the
IR and about the names of debug symbols.
This PR addresses this with the following fixes:
* print out double-precision numbers with high enough precision such
that they always parse in the same way
* when creating loop-carried dependencies, sort them
by variable name, ensuring a consistent order
* parse nan correctly
* DCE: remove unused outputs of if statements, and loop-carried dependencies
in loops that are dead both after the loop and inside the body of the
loop.
* Do not set uniqueName for variables whose names are _[0-9]+, these
are probably rare in user code, and we need a way to communicate
that we do not care about a variable name when re-parsing the graph.
Otherwise temporary variable names will jitter around.
* Expand the definition of a constant in printing code to None,
and family.
* Allow re-treeing to work as long as the only thing in its way is a
constant node. These do not have side effects but are sometimes
inserted in a different order when tracing compared to how we print them.
* Print all constant nodes out first in the order in which they are used_val
(or, if they are inlined, ensure they get assigned CONSTANT.cX number
in a consistent order). Cleanup tuples (this is done in the compiler,
but not in the tracer, leading to some tuple indexing jitter if not
done).
* use strtod_l, not std::stod which can throw exceptions
Other:
* Add REL_WITH_DEB_INFO to setup.py. It already existed for the
cmake files. Threading it into setup.py allows us to turn on
debug symbols with optimization everywhere.
* enable round trip testing for all generated graphs. This only adds
~6 seconds to total build time but tests printing for every graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14064
Differential Revision: D13094637
Pulled By: zdevito
fbshipit-source-id: 0a1c6912194d965f15d6b0c6cf838ccc551f161d
2018-11-21 06:36:26 -08:00
|
|
|
printNode(n, /*print_const=*/true);
|
|
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
// Print body
|
2019-04-26 19:14:10 -07:00
|
|
|
printBlock(body, body->return_node()->inputs().size() > 0);
|
2019-04-25 15:43:53 -07:00
|
|
|
printNode(body->return_node(), /*print_const=*/false);
|
2018-11-12 10:15:44 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-26 19:14:10 -07:00
|
|
|
public:
|
2019-08-27 22:52:48 -07:00
|
|
|
void printFunction(
|
|
|
|
|
const Function& func,
|
|
|
|
|
bool print_first_argument_type = true) {
|
2020-03-07 09:59:11 -08:00
|
|
|
TORCH_INTERNAL_ASSERT(func.isGraphFunction());
|
2019-04-25 15:43:53 -07:00
|
|
|
const FunctionSchema& schema = func.getSchema();
|
|
|
|
|
Graph& graph = *func.graph();
|
|
|
|
|
used_names_.clear(); // each graph can reuse local names
|
|
|
|
|
|
2019-07-01 21:11:12 -07:00
|
|
|
WithSourceRange guard(&source_range_stack_, graph.param_node());
|
2019-07-01 21:11:12 -07:00
|
|
|
|
2019-04-25 15:43:53 -07:00
|
|
|
indent();
|
|
|
|
|
body_ << "def " << func.name() << "(";
|
|
|
|
|
auto param_it = graph.inputs().begin();
|
2019-04-26 19:14:10 -07:00
|
|
|
for (const Argument& arg : schema.arguments()) {
|
2019-04-25 15:43:53 -07:00
|
|
|
std::string arg_name = genName(arg.name());
|
|
|
|
|
if (param_it == graph.inputs().begin()) {
|
|
|
|
|
// the first argument may omit its type when it is implied by context
|
2019-10-15 15:58:05 -07:00
|
|
|
// the flag print_first_argument_type determines when to do this
|
2019-04-25 15:43:53 -07:00
|
|
|
body_ << arg_name;
|
2019-08-27 22:52:48 -07:00
|
|
|
if (print_first_argument_type) {
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << ": " << arg.type()->annotation_str(type_printer_);
|
2019-04-25 15:43:53 -07:00
|
|
|
}
|
|
|
|
|
} else {
|
2020-04-01 00:00:30 -07:00
|
|
|
body_ << ",\n " << arg_name << ": "
|
2020-06-10 11:59:01 -07:00
|
|
|
<< arg.type()->annotation_str(type_printer_);
|
2019-04-25 15:43:53 -07:00
|
|
|
}
|
|
|
|
|
if (arg.default_value()) {
|
2019-12-09 15:09:59 -08:00
|
|
|
printDefaultValue(arg, body_, *arg.default_value());
|
2019-04-25 15:43:53 -07:00
|
|
|
}
|
|
|
|
|
assignValue(*param_it++, arg_name);
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << ") -> "
|
|
|
|
|
<< schema.returns().at(0).type()->annotation_str(type_printer_)
|
2020-04-01 00:00:30 -07:00
|
|
|
<< ":\n";
|
2019-04-25 15:43:53 -07:00
|
|
|
printBody(graph.block());
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-27 22:52:48 -07:00
|
|
|
void printMethod(const Function& func) {
|
|
|
|
|
printFunction(func, /*print_first_argument_type=*/false);
|
|
|
|
|
}
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
PythonPrintImpl(
|
2018-11-29 17:51:45 -08:00
|
|
|
std::vector<at::Tensor>& tensor_table,
|
2019-08-14 11:21:42 -07:00
|
|
|
std::vector<c10::NamedTypePtr>& deps_table,
|
2020-04-01 00:00:30 -07:00
|
|
|
c10::TypePrinter type_printer,
|
2019-10-15 15:58:05 -07:00
|
|
|
bool enforce_importable)
|
2019-10-16 22:45:50 -07:00
|
|
|
: body_(&source_range_stack_),
|
2019-07-01 21:11:12 -07:00
|
|
|
tensor_table_(tensor_table),
|
2019-08-14 11:21:42 -07:00
|
|
|
deps_table_(deps_table),
|
2020-04-01 00:00:30 -07:00
|
|
|
type_printer_(type_printer),
|
2019-10-16 22:45:50 -07:00
|
|
|
enforce_importable_(enforce_importable) {}
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
void printClass(const ClassTypePtr& classType) {
|
2020-03-07 09:59:11 -08:00
|
|
|
// If any of the methods are not Graph funtions, this indicates that
|
|
|
|
|
// this class is a custom-bound C++ class. Skip serialization
|
|
|
|
|
// of this class, we will depend on the ClassType being defined
|
|
|
|
|
// in the target process.
|
|
|
|
|
for (auto& method : classType->methods()) {
|
|
|
|
|
if (!method->isGraphFunction()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
bool is_module = classType->is_module();
|
|
|
|
|
body_ << "class " << classType->name()->name();
|
|
|
|
|
if (is_module) {
|
|
|
|
|
body_ << "(Module)";
|
2019-08-11 15:43:28 -07:00
|
|
|
}
|
|
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
body_ << ":\n";
|
|
|
|
|
{
|
|
|
|
|
const auto guard = WithIndented();
|
|
|
|
|
size_t numAttrs = classType->numAttributes();
|
|
|
|
|
// For modules, we need to print special information about the module's
|
|
|
|
|
// attributes and parameters.
|
|
|
|
|
if (is_module) {
|
|
|
|
|
std::vector<std::string> params;
|
2020-05-18 23:21:27 -07:00
|
|
|
std::vector<std::string> buffers;
|
2020-01-28 10:58:28 -08:00
|
|
|
// Populate the __parameters__ field. This tells the importer which
|
|
|
|
|
// attributes are parameters.
|
|
|
|
|
for (size_t i = 0; i < numAttrs; i++) {
|
|
|
|
|
if (classType->is_parameter(i)) {
|
|
|
|
|
params.push_back(classType->getAttributeName(i));
|
|
|
|
|
}
|
2020-05-18 23:21:27 -07:00
|
|
|
if (classType->is_buffer(i)) {
|
|
|
|
|
buffers.push_back(classType->getAttributeName(i));
|
|
|
|
|
}
|
2020-01-28 10:58:28 -08:00
|
|
|
}
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "__parameters__ = [";
|
|
|
|
|
for (const auto& param : params) {
|
|
|
|
|
body_ << "\"" << param << "\", ";
|
|
|
|
|
}
|
|
|
|
|
body_ << "]\n";
|
2020-05-18 23:21:27 -07:00
|
|
|
#ifndef FBCODE_CAFFE2
|
|
|
|
|
// Note: Forward compat gated. TODO: @voznesenskym to remove when ready.
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "__buffers__ = [";
|
|
|
|
|
for (const auto& buffer : buffers) {
|
|
|
|
|
body_ << "\"" << buffer << "\", ";
|
|
|
|
|
}
|
|
|
|
|
body_ << "]\n";
|
|
|
|
|
#endif
|
2020-01-28 10:58:28 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < numAttrs; i++) {
|
|
|
|
|
const auto& name = classType->getAttributeName(i);
|
|
|
|
|
const auto& type = classType->getAttribute(i);
|
|
|
|
|
registerClassDependencies(type);
|
2019-08-11 15:43:28 -07:00
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
indent();
|
2019-08-11 15:43:28 -07:00
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
// Handling for when the attribute name is not a valid Python
|
|
|
|
|
// identifier. This happens for, e.g. ModuleList.
|
|
|
|
|
if (!isValidIdentifier(name)) {
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
// Initialize the annotations dict if necessary.
|
|
|
|
|
body_ << "__annotations__ = []\n";
|
|
|
|
|
indent();
|
|
|
|
|
}
|
|
|
|
|
// Print out a direct manipulation of the annotations dict, like:
|
|
|
|
|
// __annotations__["0"] = SomeType
|
|
|
|
|
body_ << "__annotations__["
|
2020-06-10 11:59:01 -07:00
|
|
|
<< "\"" << name
|
|
|
|
|
<< "\"] = " << type->annotation_str(type_printer_) << "\n";
|
2020-01-28 10:58:28 -08:00
|
|
|
} else {
|
|
|
|
|
// Otherwise: just emit a python 3 attribute annotation, like:
|
|
|
|
|
// foo : SomeType
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
|
2019-08-11 15:43:28 -07:00
|
|
|
}
|
|
|
|
|
}
|
2019-12-04 14:15:30 -08:00
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
size_t numConstants = classType->numConstants();
|
|
|
|
|
for (size_t i = 0; i < numConstants; i++) {
|
|
|
|
|
const auto& name = classType->getConstantName(i);
|
|
|
|
|
IValue v = classType->getConstant(i);
|
2019-12-04 14:15:30 -08:00
|
|
|
|
2020-01-28 10:58:28 -08:00
|
|
|
indent();
|
|
|
|
|
body_ << name << " : "
|
2020-06-10 11:59:01 -07:00
|
|
|
<< "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
|
2020-01-28 10:58:28 -08:00
|
|
|
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
|
|
|
|
|
printConstant(*ss, v);
|
|
|
|
|
body_ << ss->str() << "\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO fields
|
|
|
|
|
for (auto& method : classType->methods()) {
|
|
|
|
|
printFunction(*method);
|
|
|
|
|
}
|
2019-12-04 14:15:30 -08:00
|
|
|
}
|
2019-03-15 12:00:50 -07:00
|
|
|
}
|
|
|
|
|
|
2019-08-27 22:52:48 -07:00
|
|
|
void printNamedType(const c10::NamedTypePtr& type) {
|
|
|
|
|
if (auto functionType = type->cast<FunctionType>()) {
|
|
|
|
|
printFunction(*functionType->function());
|
|
|
|
|
} else if (auto classType = type->cast<ClassType>()) {
|
2020-01-28 10:58:28 -08:00
|
|
|
printClass(classType);
|
2019-06-19 10:38:46 -07:00
|
|
|
} else if (auto tupleType = type->cast<TupleType>()) {
|
|
|
|
|
TORCH_INTERNAL_ASSERT(tupleType->schema());
|
2019-08-14 11:21:42 -07:00
|
|
|
body_ << "class " << tupleType->name()->name();
|
2019-06-19 10:38:46 -07:00
|
|
|
body_ << "(NamedTuple):\n";
|
|
|
|
|
{
|
|
|
|
|
const auto guard = WithIndented();
|
|
|
|
|
for (const auto& attr : tupleType->schema()->arguments()) {
|
|
|
|
|
TORCH_INTERNAL_ASSERT(attr.type());
|
|
|
|
|
indent();
|
2020-04-01 00:00:30 -07:00
|
|
|
body_ << attr.name() << " : "
|
2020-06-10 11:59:01 -07:00
|
|
|
<< attr.type()->annotation_str(type_printer_) << "\n";
|
2019-06-19 10:38:46 -07:00
|
|
|
}
|
|
|
|
|
}
|
2019-08-27 22:52:48 -07:00
|
|
|
} else if (auto interfaceType = type->cast<InterfaceType>()) {
|
|
|
|
|
body_ << "class " << interfaceType->name()->name();
|
2019-11-02 16:37:32 -07:00
|
|
|
if (interfaceType->is_module()) {
|
|
|
|
|
body_ << "(ModuleInterface):\n";
|
|
|
|
|
} else {
|
|
|
|
|
body_ << "(Interface):\n";
|
|
|
|
|
}
|
2019-08-27 22:52:48 -07:00
|
|
|
{
|
|
|
|
|
auto guard = WithIndented();
|
|
|
|
|
for (const FunctionSchema& method : interfaceType->methods()) {
|
|
|
|
|
indent();
|
|
|
|
|
body_ << "def " << method.name() << "(self";
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
|
|
|
method.arguments().size() > 0 &&
|
|
|
|
|
method.arguments().at(0).name() == "self");
|
|
|
|
|
for (const Argument& arg :
|
|
|
|
|
at::ArrayRef<Argument>(method.arguments()).slice(1)) {
|
|
|
|
|
auto type = arg.type();
|
|
|
|
|
registerClassDependencies(type);
|
2020-04-01 00:00:30 -07:00
|
|
|
body_ << ", " << arg.name() << ": "
|
2020-06-10 11:59:01 -07:00
|
|
|
<< type->annotation_str(type_printer_);
|
2019-08-27 22:52:48 -07:00
|
|
|
}
|
|
|
|
|
auto return_type = method.returns().at(0).type();
|
|
|
|
|
registerClassDependencies(return_type);
|
2020-06-10 11:59:01 -07:00
|
|
|
body_ << ") -> " << return_type->annotation_str(type_printer_)
|
|
|
|
|
<< ":\n";
|
2019-08-27 22:52:48 -07:00
|
|
|
indent();
|
|
|
|
|
body_ << " pass\n";
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-06-19 10:38:46 -07:00
|
|
|
} else {
|
2019-08-27 22:52:48 -07:00
|
|
|
TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType");
|
2018-11-27 17:08:09 -08:00
|
|
|
}
|
2019-04-19 12:48:39 -07:00
|
|
|
}
|
2019-08-11 15:43:28 -07:00
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
~PythonPrintImpl() {}
|
|
|
|
|
|
2019-10-16 22:45:50 -07:00
|
|
|
TaggedStringStream body_;
|
2019-10-15 15:58:05 -07:00
|
|
|
// When printing this node, is it safe to write it inline (i.e. without
|
|
|
|
|
// assigning a temporary variable
|
|
|
|
|
std::unordered_set<Node*> output_inline_;
|
|
|
|
|
|
2020-04-24 15:12:12 -07:00
|
|
|
// see [reordering of inlines]
|
|
|
|
|
// used to track parts of an inline statement we already scanned
|
|
|
|
|
// for splitting long lines, so that we do not revisit them causing n^2
|
|
|
|
|
// behavior. stores the maximum offset into inputs that has already been
|
|
|
|
|
// scanned for the node.
|
|
|
|
|
std::unordered_map<Node*, int64_t> visited_split_inline_uses_;
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
// what valid identifiers are in use for the current function
|
|
|
|
|
std::unordered_set<std::string> used_names_;
|
|
|
|
|
|
|
|
|
|
// constants are written to this table, and given then named CONSTANTS.cN
|
|
|
|
|
// where N is the index into this table.
|
|
|
|
|
std::vector<at::Tensor>& tensor_table_;
|
|
|
|
|
|
|
|
|
|
// Any NamedTypes (classes, functions, NamedTuples) used are written to this
|
|
|
|
|
// table.
|
|
|
|
|
std::vector<c10::NamedTypePtr>& deps_table_;
|
|
|
|
|
|
2020-04-01 00:00:30 -07:00
|
|
|
// A function that, given a named type, returns us the correct string to print
|
|
|
|
|
// for it.
|
|
|
|
|
c10::TypePrinter type_printer_;
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
// when we print this, should we error if the resulting output would
|
|
|
|
|
// not be able to be reparsed?
|
|
|
|
|
bool enforce_importable_;
|
2020-06-24 12:39:42 -07:00
|
|
|
|
|
|
|
|
// The least version that supports all printed ops
|
|
|
|
|
uint64_t min_version_ = 0;
|
2018-11-12 10:15:44 -08:00
|
|
|
};
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
PythonPrint::PythonPrint(
|
2019-04-25 15:43:52 -07:00
|
|
|
std::vector<at::Tensor>& tensor_table,
|
2019-08-14 11:21:42 -07:00
|
|
|
std::vector<c10::NamedTypePtr>& deps_table,
|
2020-04-01 00:00:30 -07:00
|
|
|
c10::TypePrinter type_printer,
|
2019-10-15 15:58:05 -07:00
|
|
|
bool enforce_importable)
|
2019-10-16 22:45:50 -07:00
|
|
|
: pImpl(std::make_shared<PythonPrintImpl>(
|
2019-10-15 15:58:05 -07:00
|
|
|
tensor_table,
|
|
|
|
|
deps_table,
|
2020-04-01 00:00:30 -07:00
|
|
|
type_printer,
|
2019-10-15 15:58:05 -07:00
|
|
|
enforce_importable)) {}
|
|
|
|
|
|
|
|
|
|
void PythonPrint::printNamedType(const c10::NamedTypePtr& type) {
|
|
|
|
|
pImpl->printNamedType(type);
|
2019-04-25 15:43:52 -07:00
|
|
|
}
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
void PythonPrint::printFunction(const Function& func) {
|
|
|
|
|
pImpl->printFunction(func);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PythonPrint::printMethod(const Function& func) {
|
|
|
|
|
pImpl->printMethod(func);
|
2018-11-27 17:08:09 -08:00
|
|
|
}
|
2018-11-12 10:15:44 -08:00
|
|
|
|
2019-10-16 22:45:50 -07:00
|
|
|
std::string PythonPrint::str() const {
|
|
|
|
|
return pImpl->body_.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const SourceRangeRecords& PythonPrint::ranges() const {
|
|
|
|
|
return pImpl->body_.ranges();
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-24 12:39:42 -07:00
|
|
|
uint64_t PythonPrint::minVersion() const {
|
|
|
|
|
return pImpl->min_version_;
|
|
|
|
|
}
|
|
|
|
|
|
2019-10-15 15:58:05 -07:00
|
|
|
PythonPrint::~PythonPrint() = default;
|
|
|
|
|
|
2018-11-12 10:15:44 -08:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|