From 8cc9ec2f6b270e84bcdda9b898ec8d74caf3e671 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Mon, 29 Nov 2021 21:38:29 -0800 Subject: [PATCH] Add option to get input dtype from user (#68751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68751 Add option to get input dtype from user for AOT compilation Test Plan: BI model compiles and runs fine ``` (pytorch) ~/fbsource/fbcode/caffe2/fb/nnc └─ $ buck run //caffe2/binaries:aot_model_compiler -- --model=bi.pt --model_name=pytorch_dev_bytedoc --model_version=v1 '--input_dims=1,115;1' --input_types='int64;int64' Building... 8.3 sec (99%) 7673/7674 jobs, 0/7674 updated WARNING: Logging before InitGoogleLogging() is written to STDERR W1116 14:32:44.632536 1332111 TensorImpl.h:1418] Warning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (function operator()) E1116 14:32:44.673710 1332111 huge_pages_allocator.cc:287] Not using huge pages because not linked with jemalloc The compiled llvm assembly code was saved to bi.compiled.ll The compiled model was saved to bi.compiled.pt ``` > Error thrown when input dims and input types sizes don't match ``` (pytorch) ~/fbsource/fbcode/caffe2/fb/nnc └─ $ buck run //caffe2/binaries:aot_model_compiler -- --model=bi.pt --model_name=pytorch_dev_bytedoc --model_version=v1 '--input_dims=1,115;1' --input_types='int64;int64;int64' . . terminate called after throwing an instance of 'c10::Error' what(): [enforce fail at aot_model_compiler.cc:208] split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(). Number of input_dims and input_types should be the same . . . ``` Reviewed By: ljk53 Differential Revision: D32477001 fbshipit-source-id: 8977b0b59cf78b3a2fec0c8428f83a16ad8685c5 --- binaries/aot_model_compiler.cc | 68 ++++++++++++++++++---- torch/csrc/jit/mobile/nnc/aot_compiler.cpp | 29 +++++---- torch/csrc/jit/mobile/nnc/aot_compiler.h | 1 + 3 files changed, 76 insertions(+), 22 deletions(-) diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc index 939e83c7b3f..2ff895c235b 100644 --- a/binaries/aot_model_compiler.cc +++ b/binaries/aot_model_compiler.cc @@ -1,6 +1,8 @@ #include #include +#include +#include #include #include #include @@ -25,9 +27,16 @@ C10_DEFINE_string(model_version, "", "The version of the model."); C10_DEFINE_string( input_dims, "", - "For input float TensorCPUs, specify the dimension using comma " - "separated numbers. If multiple inputs needed, use semicolon " - "to separate the dimension of different tensors."); + "The dimensions of input TensorCPUs using comma separated numbers." + "If multiple inputs needed, use semicolon to separate " + "the dimension of different tensors."); +C10_DEFINE_string( + input_types, + "float", + "The dtype of input TensorCPUs." + "If multiple inputs needed, use semicolon to separate " + "the dtype of different tensors." + "Supported dtypes: float, int64, uint8"); C10_DEFINE_string(method_name, "forward", "The name of the method."); C10_DEFINE_string( output_llvm, @@ -68,18 +77,39 @@ std::vector> parseInputShapes() { return inputs; } +std::vector parseInputTypes() { + std::vector inputTypes = split(';', FLAGS_input_types); + std::vector scalarTypes; + for (const auto& inputType : inputTypes) { + at::ScalarType scalarType; + if (inputType == "float") { + scalarType = at::ScalarType::Float; + } else if (inputType == "uint8") { + scalarType = at::ScalarType::Byte; + } else if (inputType == "int64") { + scalarType = at::ScalarType::Long; + } else { + CAFFE_THROW("Unsupported input type: ", inputType); + } + scalarTypes.push_back(scalarType); + } + return scalarTypes; +} + c10::Dict createCompileSpec() { c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); - auto input_shapes = parseInputShapes(); - method_spec.insert("sizes", input_shapes); + auto inputShapes = parseInputShapes(); + auto inputTypes = parseInputTypes(); + method_spec.insert("sizes", inputShapes); + method_spec.insert("types", inputTypes); compile_spec.insert(FLAGS_method_name, method_spec); return compile_spec; } -std::vector> getInputSizes ( +std::vector> getInputSizes( const c10::Dict& compile_spec) { auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList(); std::vector> inputSizes; @@ -90,6 +120,17 @@ std::vector> getInputSizes ( return inputSizes; } +std::vector getInputTypes( + const c10::Dict& compile_spec) { + auto inputTypesList = compile_spec.at(FLAGS_method_name).toGenericDict().at("types").toList(); + std::vector inputTypes; + for (const auto& inputType : inputTypesList) { + auto type = ((c10::IValue) inputType).toScalarType(); + inputTypes.emplace_back(type); + } + return inputTypes; +} + std::string getNncKernelId() { // TODO: calculate the version_token. const std::string version_token = "VERTOKEN"; @@ -122,10 +163,11 @@ c10::IValue preprocess( auto method = mod.get_method(FLAGS_method_name); auto graph = toGraphFunction(method.function()).graph()->copy(); auto sizes = getInputSizes(compile_spec); + auto types = getInputTypes(compile_spec); auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name); auto compiled = torch::jit::mobile::nnc::aotCompile( - FLAGS_method_name, graph, sizes, kernel_func_name); + FLAGS_method_name, graph, sizes, types, kernel_func_name); writeOutputLlvmAssembly(compiled.second); auto func = std::move(compiled.first); @@ -148,6 +190,7 @@ int main(int argc, char** argv) { " --model_name=" " --model_version=" " --input_dims=" + " --input_types=" " [--method_name=]" " [--output_llvm=]" " [--output_model=]"); @@ -162,6 +205,8 @@ int main(int argc, char** argv) { CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage()); + CAFFE_ENFORCE(split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(), + "Number of input_dims and input_types should be the same"); std::string output_model_name = FLAGS_output_model; if (output_model_name.empty()) { @@ -173,11 +218,12 @@ int main(int argc, char** argv) { m.eval(); auto frozen_m = torch::jit::freeze_module(m.clone()); auto graph = frozen_m.get_method(FLAGS_method_name).graph(); - auto input_shapes = parseInputShapes(); + auto inputShapes = parseInputShapes(); + auto inputTypes = parseInputTypes(); std::vector> example_inputs; - example_inputs.reserve(input_shapes.size()); - for (const auto& input_shape : input_shapes) { - example_inputs.emplace_back(at::rand(input_shape)); + example_inputs.reserve(inputShapes.size()); + for (int i = 0; i < inputShapes.size(); ++i) { + example_inputs.emplace_back(at::rand(inputShapes[i]).to(at::dtype(inputTypes[i]))); } torch::jit::RemoveTensorMutation(graph); diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 35e67d8835d..a756d737139 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -37,13 +37,13 @@ std::vector getConstSizes(const BufPtr b) { } std::vector toInputSpecs( - const std::vector>& inputSizes) { + const std::vector>& inputSizes, + const std::vector& inputTypes) { std::vector specs; - for (const auto& sizes : inputSizes) { + for (int i = 0; i < inputSizes.size(); ++i) { mobile::nnc::InputSpec spec; - spec.sizes_ = sizes; - // TODO: Use user specified input type. For now using Long for BI model - spec.dtype_ = c10::ScalarType::Long; + spec.sizes_ = inputSizes[i]; + spec.dtype_ = inputTypes[i]; specs.emplace_back(std::move(spec)); } return specs; @@ -52,10 +52,11 @@ std::vector toInputSpecs( std::unique_ptr compileMethod( std::shared_ptr kernel, const std::string& method_name, - const std::vector>& sizes) { + const std::vector>& sizes, + const std::vector& types) { auto func = std::make_unique(); func->set_name(method_name); - func->set_input_specs(toInputSpecs(sizes)); + func->set_input_specs(toInputSpecs(sizes, types)); auto params = c10::impl::GenericList(c10::AnyType::get()); auto const_descriptors = kernel->getConstantDescriptors(); @@ -110,15 +111,21 @@ std::pair, const std::string> aotCompile( const std::string& method_name, std::shared_ptr& g, const std::vector>& sizes, + const std::vector& types, const std::string& kernel_func_name) { GRAPH_DEBUG("Input sizes ", sizes); + GRAPH_DEBUG("Input types ", types); GRAPH_DEBUG("Method name ", method_name); + GRAPH_DEBUG("Kernel func name ", kernel_func_name); + + CAFFE_ENFORCE( + sizes.size() == types.size(), + "Number of input sizes and input types should be the same"); std::vector example_values; std::vector> example_inputs; - for (const auto& size : sizes) { - // TODO: Use user specified input type. For now using Long for BI model - auto example_input = at::rand(size).to(at::dtype(at::kLong)); + for (int i = 0; i < sizes.size(); ++i) { + auto example_input = at::rand(sizes[i]).to(at::dtype(types[i])); example_values.emplace_back(example_input); example_inputs.emplace_back(example_input); } @@ -141,7 +148,7 @@ std::pair, const std::string> aotCompile( const std::string compiled_assembly = kernel->getCodeText(); - auto func = compileMethod(kernel, method_name, sizes); + auto func = compileMethod(kernel, method_name, sizes, types); return std::make_pair(std::move(func), compiled_assembly); } diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.h b/torch/csrc/jit/mobile/nnc/aot_compiler.h index b0c610439f8..4edac10542c 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.h +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.h @@ -15,6 +15,7 @@ TORCH_API std::pair, const std::string> aotCompile( const std::string& method_name, std::shared_ptr& subgraph, const std::vector>& sizes, + const std::vector& types, const std::string& kernel_func_name = "func"); } // namespace nnc