mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Automated Code Change
PiperOrigin-RevId: 829434646
This commit is contained in:
committed by
TensorFlower Gardener
parent
3f34740ef5
commit
cc4a122a24
@@ -57,13 +57,13 @@ class GenTypeScriptOp {
|
||||
~GenTypeScriptOp();
|
||||
|
||||
// Returns the generated code as a string:
|
||||
string Code();
|
||||
std::string Code();
|
||||
|
||||
private:
|
||||
void ProcessArgs();
|
||||
void ProcessAttrs();
|
||||
void AddAttrForArg(const string& attr, int arg_index);
|
||||
string InputForAttr(const OpDef::AttrDef& op_def_attr);
|
||||
void AddAttrForArg(const std::string& attr, int arg_index);
|
||||
std::string InputForAttr(const OpDef::AttrDef& op_def_attr);
|
||||
|
||||
void AddMethodSignature();
|
||||
void AddOpAttrs();
|
||||
@@ -73,7 +73,7 @@ class GenTypeScriptOp {
|
||||
const ApiDef& api_def_;
|
||||
|
||||
// Placeholder string for all generated code:
|
||||
string result_;
|
||||
std::string result_;
|
||||
|
||||
// Holds in-order vector of Op inputs:
|
||||
std::vector<ArgDefs> input_op_args_;
|
||||
@@ -82,7 +82,7 @@ class GenTypeScriptOp {
|
||||
std::vector<OpAttrs> op_attrs_;
|
||||
|
||||
// Stores attributes-to-arguments by name:
|
||||
typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
|
||||
typedef std::unordered_map<std::string, std::vector<int>> AttrArgIdxMap;
|
||||
AttrArgIdxMap attr_arg_idx_map_;
|
||||
|
||||
// Holds number of outputs:
|
||||
@@ -94,7 +94,7 @@ GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
|
||||
|
||||
GenTypeScriptOp::~GenTypeScriptOp() = default;
|
||||
|
||||
string GenTypeScriptOp::Code() {
|
||||
std::string GenTypeScriptOp::Code() {
|
||||
ProcessArgs();
|
||||
ProcessAttrs();
|
||||
|
||||
@@ -144,7 +144,7 @@ void GenTypeScriptOp::ProcessAttrs() {
|
||||
}
|
||||
}
|
||||
|
||||
void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
|
||||
void GenTypeScriptOp::AddAttrForArg(const std::string& attr, int arg_index) {
|
||||
// Keep track of attributes-to-arguments by name. These will be used for
|
||||
// construction Op attributes that require information about the inputs.
|
||||
auto iter = attr_arg_idx_map_.find(attr);
|
||||
@@ -155,8 +155,8 @@ void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
|
||||
}
|
||||
}
|
||||
|
||||
string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
|
||||
string inputs;
|
||||
std::string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
|
||||
std::string inputs;
|
||||
auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
|
||||
if (arg_list != attr_arg_idx_map_.end()) {
|
||||
for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
|
||||
@@ -235,7 +235,7 @@ void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
|
||||
}
|
||||
|
||||
void StartFile(WritableFile* ts_file) {
|
||||
const string header =
|
||||
const std::string header =
|
||||
R"header(/**
|
||||
* @license
|
||||
* Copyright 2018 Google Inc. All Rights Reserved.
|
||||
@@ -266,7 +266,7 @@ import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
|
||||
} // namespace
|
||||
|
||||
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& ts_filename) {
|
||||
const std::string& ts_filename) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
std::unique_ptr<WritableFile> ts_file = nullptr;
|
||||
|
||||
@@ -24,7 +24,7 @@ namespace tensorflow {
|
||||
|
||||
// Generated code is written to the file ts_filename:
|
||||
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& ts_filename);
|
||||
const std::string& ts_filename);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@@ -81,8 +81,9 @@ op {
|
||||
)";
|
||||
|
||||
// Generate TypeScript code
|
||||
void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
|
||||
string* ts_file_text) {
|
||||
void GenerateTsOpFileText(const std::string& op_def_str,
|
||||
const std::string& api_def_str,
|
||||
std::string* ts_file_text) {
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(
|
||||
@@ -93,7 +94,7 @@ void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
|
||||
}
|
||||
|
||||
const string& tmpdir = testing::TmpDir();
|
||||
const std::string& tmpdir = testing::TmpDir();
|
||||
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
|
||||
|
||||
WriteTSOps(op_defs, api_def_map, ts_file_path);
|
||||
@@ -101,10 +102,10 @@ void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, TestImports) {
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText("", "", &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
const std::string expected = R"(
|
||||
import * as tfc from '@tensorflow/tfjs-core';
|
||||
import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
|
||||
)";
|
||||
@@ -112,38 +113,38 @@ import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, InputSingleAndList) {
|
||||
const string api_def = R"pb(
|
||||
const std::string api_def = R"pb(
|
||||
op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
|
||||
)pb";
|
||||
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText("", api_def, &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
const std::string expected = R"(
|
||||
export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
|
||||
)";
|
||||
ExpectContainsStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, TestVisibility) {
|
||||
const string api_def = R"(
|
||||
const std::string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
visibility: HIDDEN
|
||||
}
|
||||
)";
|
||||
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText("", api_def, &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
const std::string expected = R"(
|
||||
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
||||
)";
|
||||
ExpectDoesNotContainStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, SkipDeprecated) {
|
||||
const string op_def = R"(
|
||||
const std::string op_def = R"(
|
||||
op {
|
||||
name: "DeprecatedFoo"
|
||||
input_arg {
|
||||
@@ -172,14 +173,14 @@ op {
|
||||
}
|
||||
)";
|
||||
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText(op_def, "", &ts_file_text);
|
||||
|
||||
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, MultiOutput) {
|
||||
const string op_def = R"(
|
||||
const std::string op_def = R"(
|
||||
op {
|
||||
name: "MultiOutputFoo"
|
||||
input_arg {
|
||||
@@ -212,20 +213,20 @@ op {
|
||||
}
|
||||
)";
|
||||
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText(op_def, "", &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
const std::string expected = R"(
|
||||
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
|
||||
)";
|
||||
ExpectContainsStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, OpAttrs) {
|
||||
string ts_file_text;
|
||||
std::string ts_file_text;
|
||||
GenerateTsOpFileText("", "", &ts_file_text);
|
||||
|
||||
const string expectedFooAttrs = R"(
|
||||
const std::string expectedFooAttrs = R"(
|
||||
const opAttrs = [
|
||||
createTensorsTypeOpAttr('T', images),
|
||||
{name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
|
||||
|
||||
Reference in New Issue
Block a user