mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[TensorExpr] Move AOT compilation logic from aot_compiler.cpp to NNC's to_backend (#70375)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70375
Differential Revision:
D33303645
D33303645
Test Plan: Imported from OSS
Reviewed By: VitalyFedyunin, priyaramani
Pulled By: ZolotukhinM
fbshipit-source-id: 01ab9fab9bb0d63f89b06a146d3c5fb6ed7fe52d
(cherry picked from commit aac8e0ed90)
This commit is contained in:
committed by
PyTorch MergeBot
parent
64668e61b8
commit
a60e2ae037
@@ -7,14 +7,7 @@
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/backends/backend_preprocess.h>
|
||||
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
||||
@@ -61,125 +54,20 @@ std::vector<std::string> split(
|
||||
return pieces;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> parseInputShapes() {
|
||||
CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
|
||||
std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
|
||||
std::vector<std::vector<int64_t>> inputs;
|
||||
for (const auto& input_dims_item : input_dims_list) {
|
||||
auto input_dims_str = split(',', input_dims_item);
|
||||
std::vector<int64_t> input_dims;
|
||||
input_dims.reserve(input_dims_str.size());
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
inputs.push_back(input_dims);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<at::ScalarType> parseInputTypes() {
|
||||
std::vector<std::string> inputTypes = split(';', FLAGS_input_types);
|
||||
std::vector<at::ScalarType> scalarTypes;
|
||||
for (const auto& inputType : inputTypes) {
|
||||
at::ScalarType scalarType;
|
||||
if (inputType == "float") {
|
||||
scalarType = at::ScalarType::Float;
|
||||
} else if (inputType == "uint8") {
|
||||
scalarType = at::ScalarType::Byte;
|
||||
} else if (inputType == "int64") {
|
||||
scalarType = at::ScalarType::Long;
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", inputType);
|
||||
}
|
||||
scalarTypes.push_back(scalarType);
|
||||
}
|
||||
return scalarTypes;
|
||||
}
|
||||
|
||||
c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
|
||||
c10::Dict<c10::IValue, c10::IValue> compile_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
c10::Dict<c10::IValue, c10::IValue> method_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
auto inputShapes = parseInputShapes();
|
||||
auto inputTypes = parseInputTypes();
|
||||
method_spec.insert("sizes", inputShapes);
|
||||
method_spec.insert("types", inputTypes);
|
||||
method_spec.insert("sizes", FLAGS_input_dims);
|
||||
method_spec.insert("types", FLAGS_input_types);
|
||||
method_spec.insert("asmfile", FLAGS_output_llvm);
|
||||
method_spec.insert("model_name", FLAGS_model_name);
|
||||
method_spec.insert("model_version", FLAGS_model_version);
|
||||
compile_spec.insert(FLAGS_method_name, method_spec);
|
||||
return compile_spec;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> getInputSizes(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
|
||||
auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
|
||||
std::vector<std::vector<int64_t>> inputSizes;
|
||||
for (const auto& input_shape : input_shapes) {
|
||||
auto sizes = ((c10::IValue) input_shape).toIntVector();
|
||||
inputSizes.emplace_back(sizes);
|
||||
}
|
||||
return inputSizes;
|
||||
}
|
||||
|
||||
std::vector<at::ScalarType> getInputTypes(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
|
||||
auto inputTypesList = compile_spec.at(FLAGS_method_name).toGenericDict().at("types").toList();
|
||||
std::vector<at::ScalarType> inputTypes;
|
||||
for (const auto& inputType : inputTypesList) {
|
||||
auto type = ((c10::IValue) inputType).toScalarType();
|
||||
inputTypes.emplace_back(type);
|
||||
}
|
||||
return inputTypes;
|
||||
}
|
||||
|
||||
std::string getNncKernelId() {
|
||||
// TODO: calculate the version_token.
|
||||
const std::string version_token = "VERTOKEN";
|
||||
return FLAGS_model_name + ":" + FLAGS_model_version + ":" + FLAGS_method_name +
|
||||
":" + version_token;
|
||||
}
|
||||
|
||||
std::string getNncKernelFuncName(const std::string& method_name) {
|
||||
return "nnc_" + FLAGS_model_name + "_" + FLAGS_model_version + "_" + method_name;
|
||||
}
|
||||
|
||||
void writeOutputLlvmAssembly(const std::string& asm_code) {
|
||||
std::string output_llvm_file_name = FLAGS_output_llvm;
|
||||
if (output_llvm_file_name.empty()) {
|
||||
output_llvm_file_name =
|
||||
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
|
||||
}
|
||||
|
||||
std::ofstream output(output_llvm_file_name);
|
||||
output << asm_code;
|
||||
std::cout << "The compiled llvm assembly code was saved to " << output_llvm_file_name
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
c10::IValue preprocess(
|
||||
const torch::jit::Module& mod,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
|
||||
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
||||
|
||||
auto method = mod.get_method(FLAGS_method_name);
|
||||
auto graph = toGraphFunction(method.function()).graph()->copy();
|
||||
auto sizes = getInputSizes(compile_spec);
|
||||
auto types = getInputTypes(compile_spec);
|
||||
auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name);
|
||||
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(
|
||||
FLAGS_method_name, graph, sizes, types, kernel_func_name);
|
||||
writeOutputLlvmAssembly(compiled.second);
|
||||
|
||||
auto func = std::move(compiled.first);
|
||||
func->set_nnc_kernel_id(getNncKernelId());
|
||||
|
||||
torch::jit::mobile::nnc::CompilationUnit cu;
|
||||
cu.register_function(std::move(func));
|
||||
return cu.serialize();
|
||||
}
|
||||
|
||||
static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
@@ -205,7 +93,9 @@ int main(int argc, char** argv) {
|
||||
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(),
|
||||
CAFFE_ENFORCE(
|
||||
split(';', FLAGS_input_dims).size() ==
|
||||
split(';', FLAGS_input_types).size(),
|
||||
"Number of input_dims and input_types should be the same");
|
||||
|
||||
std::string output_model_name = FLAGS_output_model;
|
||||
@@ -217,27 +107,6 @@ int main(int argc, char** argv) {
|
||||
auto m = torch::jit::load(FLAGS_model);
|
||||
m.eval();
|
||||
auto frozen_m = torch::jit::freeze_module(m.clone());
|
||||
auto graph = frozen_m.get_method(FLAGS_method_name).graph();
|
||||
auto inputShapes = parseInputShapes();
|
||||
auto inputTypes = parseInputTypes();
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(inputShapes.size());
|
||||
for (int i = 0; i < inputShapes.size(); ++i) {
|
||||
example_inputs.emplace_back(at::rand(inputShapes[i]).to(at::dtype(inputTypes[i])));
|
||||
}
|
||||
|
||||
torch::jit::RemoveTensorMutation(graph);
|
||||
torch::jit::EliminateDeadCode(graph->block());
|
||||
graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
|
||||
|
||||
torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
|
||||
torch::jit::OptimizeFrozenGraph(graph, true);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
|
||||
auto compile_spec = createCompileSpec();
|
||||
auto any_dict_ty =
|
||||
|
||||
Reference in New Issue
Block a user