Automated Code Change

PiperOrigin-RevId: 829434646
This commit is contained in:
A. Unique TensorFlower
2025-11-07 07:46:10 -08:00
committed by TensorFlower Gardener
parent 3f34740ef5
commit cc4a122a24
3 changed files with 31 additions and 30 deletions

View File

@@ -57,13 +57,13 @@ class GenTypeScriptOp {
~GenTypeScriptOp();
// Returns the generated code as a string:
string Code();
std::string Code();
private:
void ProcessArgs();
void ProcessAttrs();
void AddAttrForArg(const string& attr, int arg_index);
string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddAttrForArg(const std::string& attr, int arg_index);
std::string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddMethodSignature();
void AddOpAttrs();
@@ -73,7 +73,7 @@ class GenTypeScriptOp {
const ApiDef& api_def_;
// Placeholder string for all generated code:
string result_;
std::string result_;
// Holds in-order vector of Op inputs:
std::vector<ArgDefs> input_op_args_;
@@ -82,7 +82,7 @@ class GenTypeScriptOp {
std::vector<OpAttrs> op_attrs_;
// Stores attributes-to-arguments by name:
typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
typedef std::unordered_map<std::string, std::vector<int>> AttrArgIdxMap;
AttrArgIdxMap attr_arg_idx_map_;
// Holds number of outputs:
@@ -94,7 +94,7 @@ GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
GenTypeScriptOp::~GenTypeScriptOp() = default;
string GenTypeScriptOp::Code() {
std::string GenTypeScriptOp::Code() {
ProcessArgs();
ProcessAttrs();
@@ -144,7 +144,7 @@ void GenTypeScriptOp::ProcessAttrs() {
}
}
void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
void GenTypeScriptOp::AddAttrForArg(const std::string& attr, int arg_index) {
// Keep track of attributes-to-arguments by name. These will be used for
// construction Op attributes that require information about the inputs.
auto iter = attr_arg_idx_map_.find(attr);
@@ -155,8 +155,8 @@ void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
}
}
string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
string inputs;
std::string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
std::string inputs;
auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
if (arg_list != attr_arg_idx_map_.end()) {
for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
@@ -235,7 +235,7 @@ void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
}
void StartFile(WritableFile* ts_file) {
const string header =
const std::string header =
R"header(/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
@@ -266,7 +266,7 @@ import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
} // namespace
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& ts_filename) {
const std::string& ts_filename) {
Env* env = Env::Default();
std::unique_ptr<WritableFile> ts_file = nullptr;

View File

@@ -24,7 +24,7 @@ namespace tensorflow {
// Generated code is written to the file ts_filename:
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& ts_filename);
const std::string& ts_filename);
} // namespace tensorflow

View File

@@ -81,8 +81,9 @@ op {
)";
// Generate TypeScript code
void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
string* ts_file_text) {
void GenerateTsOpFileText(const std::string& op_def_str,
const std::string& api_def_str,
std::string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(
@@ -93,7 +94,7 @@ void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
}
const string& tmpdir = testing::TmpDir();
const std::string& tmpdir = testing::TmpDir();
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
WriteTSOps(op_defs, api_def_map, ts_file_path);
@@ -101,10 +102,10 @@ void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
}
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
const std::string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
@@ -112,38 +113,38 @@ import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
}
TEST(TsOpGenTest, InputSingleAndList) {
const string api_def = R"pb(
const std::string api_def = R"pb(
op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
)pb";
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
const std::string expected = R"(
export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
)";
ExpectContainsStr(ts_file_text, expected);
}
TEST(TsOpGenTest, TestVisibility) {
const string api_def = R"(
const std::string api_def = R"(
op {
graph_op_name: "Foo"
visibility: HIDDEN
}
)";
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
const std::string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
const string op_def = R"(
const std::string op_def = R"(
op {
name: "DeprecatedFoo"
input_arg {
@@ -172,14 +173,14 @@ op {
}
)";
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
const string op_def = R"(
const std::string op_def = R"(
op {
name: "MultiOutputFoo"
input_arg {
@@ -212,20 +213,20 @@ op {
}
)";
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
const std::string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
)";
ExpectContainsStr(ts_file_text, expected);
}
TEST(TsOpGenTest, OpAttrs) {
string ts_file_text;
std::string ts_file_text;
GenerateTsOpFileText("", "", &ts_file_text);
const string expectedFooAttrs = R"(
const std::string expectedFooAttrs = R"(
const opAttrs = [
createTensorsTypeOpAttr('T', images),
{name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}