diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc index 939e83c7b3f..2ff895c235b 100644 --- a/binaries/aot_model_compiler.cc +++ b/binaries/aot_model_compiler.cc @@ -1,6 +1,8 @@ #include #include +#include +#include #include #include #include @@ -25,9 +27,16 @@ C10_DEFINE_string(model_version, "", "The version of the model."); C10_DEFINE_string( input_dims, "", - "For input float TensorCPUs, specify the dimension using comma " - "separated numbers. If multiple inputs needed, use semicolon " - "to separate the dimension of different tensors."); + "The dimensions of input TensorCPUs using comma separated numbers." + "If multiple inputs needed, use semicolon to separate " + "the dimension of different tensors."); +C10_DEFINE_string( + input_types, + "float", + "The dtype of input TensorCPUs." + "If multiple inputs needed, use semicolon to separate " + "the dtype of different tensors." + "Supported dtypes: float, int64, uint8"); C10_DEFINE_string(method_name, "forward", "The name of the method."); C10_DEFINE_string( output_llvm, @@ -68,18 +77,39 @@ std::vector> parseInputShapes() { return inputs; } +std::vector parseInputTypes() { + std::vector inputTypes = split(';', FLAGS_input_types); + std::vector scalarTypes; + for (const auto& inputType : inputTypes) { + at::ScalarType scalarType; + if (inputType == "float") { + scalarType = at::ScalarType::Float; + } else if (inputType == "uint8") { + scalarType = at::ScalarType::Byte; + } else if (inputType == "int64") { + scalarType = at::ScalarType::Long; + } else { + CAFFE_THROW("Unsupported input type: ", inputType); + } + scalarTypes.push_back(scalarType); + } + return scalarTypes; +} + c10::Dict createCompileSpec() { c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); - auto input_shapes = parseInputShapes(); - method_spec.insert("sizes", input_shapes); + auto inputShapes = parseInputShapes(); + auto inputTypes = parseInputTypes(); + method_spec.insert("sizes", inputShapes); + method_spec.insert("types", inputTypes); compile_spec.insert(FLAGS_method_name, method_spec); return compile_spec; } -std::vector> getInputSizes ( +std::vector> getInputSizes( const c10::Dict& compile_spec) { auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList(); std::vector> inputSizes; @@ -90,6 +120,17 @@ std::vector> getInputSizes ( return inputSizes; } +std::vector getInputTypes( + const c10::Dict& compile_spec) { + auto inputTypesList = compile_spec.at(FLAGS_method_name).toGenericDict().at("types").toList(); + std::vector inputTypes; + for (const auto& inputType : inputTypesList) { + auto type = ((c10::IValue) inputType).toScalarType(); + inputTypes.emplace_back(type); + } + return inputTypes; +} + std::string getNncKernelId() { // TODO: calculate the version_token. const std::string version_token = "VERTOKEN"; @@ -122,10 +163,11 @@ c10::IValue preprocess( auto method = mod.get_method(FLAGS_method_name); auto graph = toGraphFunction(method.function()).graph()->copy(); auto sizes = getInputSizes(compile_spec); + auto types = getInputTypes(compile_spec); auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name); auto compiled = torch::jit::mobile::nnc::aotCompile( - FLAGS_method_name, graph, sizes, kernel_func_name); + FLAGS_method_name, graph, sizes, types, kernel_func_name); writeOutputLlvmAssembly(compiled.second); auto func = std::move(compiled.first); @@ -148,6 +190,7 @@ int main(int argc, char** argv) { " --model_name=" " --model_version=" " --input_dims=" + " --input_types=" " [--method_name=]" " [--output_llvm=]" " [--output_model=]"); @@ -162,6 +205,8 @@ int main(int argc, char** argv) { CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage()); + CAFFE_ENFORCE(split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(), + "Number of input_dims and input_types should be the same"); std::string output_model_name = FLAGS_output_model; if (output_model_name.empty()) { @@ -173,11 +218,12 @@ int main(int argc, char** argv) { m.eval(); auto frozen_m = torch::jit::freeze_module(m.clone()); auto graph = frozen_m.get_method(FLAGS_method_name).graph(); - auto input_shapes = parseInputShapes(); + auto inputShapes = parseInputShapes(); + auto inputTypes = parseInputTypes(); std::vector> example_inputs; - example_inputs.reserve(input_shapes.size()); - for (const auto& input_shape : input_shapes) { - example_inputs.emplace_back(at::rand(input_shape)); + example_inputs.reserve(inputShapes.size()); + for (int i = 0; i < inputShapes.size(); ++i) { + example_inputs.emplace_back(at::rand(inputShapes[i]).to(at::dtype(inputTypes[i]))); } torch::jit::RemoveTensorMutation(graph); diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 35e67d8835d..a756d737139 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -37,13 +37,13 @@ std::vector getConstSizes(const BufPtr b) { } std::vector toInputSpecs( - const std::vector>& inputSizes) { + const std::vector>& inputSizes, + const std::vector& inputTypes) { std::vector specs; - for (const auto& sizes : inputSizes) { + for (int i = 0; i < inputSizes.size(); ++i) { mobile::nnc::InputSpec spec; - spec.sizes_ = sizes; - // TODO: Use user specified input type. For now using Long for BI model - spec.dtype_ = c10::ScalarType::Long; + spec.sizes_ = inputSizes[i]; + spec.dtype_ = inputTypes[i]; specs.emplace_back(std::move(spec)); } return specs; @@ -52,10 +52,11 @@ std::vector toInputSpecs( std::unique_ptr compileMethod( std::shared_ptr kernel, const std::string& method_name, - const std::vector>& sizes) { + const std::vector>& sizes, + const std::vector& types) { auto func = std::make_unique(); func->set_name(method_name); - func->set_input_specs(toInputSpecs(sizes)); + func->set_input_specs(toInputSpecs(sizes, types)); auto params = c10::impl::GenericList(c10::AnyType::get()); auto const_descriptors = kernel->getConstantDescriptors(); @@ -110,15 +111,21 @@ std::pair, const std::string> aotCompile( const std::string& method_name, std::shared_ptr& g, const std::vector>& sizes, + const std::vector& types, const std::string& kernel_func_name) { GRAPH_DEBUG("Input sizes ", sizes); + GRAPH_DEBUG("Input types ", types); GRAPH_DEBUG("Method name ", method_name); + GRAPH_DEBUG("Kernel func name ", kernel_func_name); + + CAFFE_ENFORCE( + sizes.size() == types.size(), + "Number of input sizes and input types should be the same"); std::vector example_values; std::vector> example_inputs; - for (const auto& size : sizes) { - // TODO: Use user specified input type. For now using Long for BI model - auto example_input = at::rand(size).to(at::dtype(at::kLong)); + for (int i = 0; i < sizes.size(); ++i) { + auto example_input = at::rand(sizes[i]).to(at::dtype(types[i])); example_values.emplace_back(example_input); example_inputs.emplace_back(example_input); } @@ -141,7 +148,7 @@ std::pair, const std::string> aotCompile( const std::string compiled_assembly = kernel->getCodeText(); - auto func = compileMethod(kernel, method_name, sizes); + auto func = compileMethod(kernel, method_name, sizes, types); return std::make_pair(std::move(func), compiled_assembly); } diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.h b/torch/csrc/jit/mobile/nnc/aot_compiler.h index b0c610439f8..4edac10542c 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.h +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.h @@ -15,6 +15,7 @@ TORCH_API std::pair, const std::string> aotCompile( const std::string& method_name, std::shared_ptr& subgraph, const std::vector>& sizes, + const std::vector& types, const std::string& kernel_func_name = "func"); } // namespace nnc