diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc index a0b90387876..313137ae495 100644 --- a/tensorflow/core/kernels/summary_interface.cc +++ b/tensorflow/core/kernels/summary_interface.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/summary_interface.h" + +#include #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -19,12 +22,10 @@ limitations under the License. #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/png/png_io.h" #include "tensorflow/core/lib/wav/wav_io.h" -#include "tensorflow/core/util/event.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { @@ -250,28 +251,34 @@ class SummaryWriterImpl : public SummaryWriterInterface { Status WriteTensor(int64 global_step, Tensor t, const string& tag, const string& serialized_metadata) override { - Summary s; - Summary::Value* v = s.add_value(); + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + Summary::Value* v = e->mutable_summary()->add_value(); t.AsProtoTensorContent(v->mutable_tensor()); v->set_tag(tag); v->mutable_metadata()->ParseFromString(serialized_metadata); - return Enqueue(global_step, s); + return WriteEvent(std::move(e)); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { - Summary s; - Summary::Value* v = s.add_value(); + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + Summary::Value* v = e->mutable_summary()->add_value(); v->set_tag(tag); float value; TF_RETURN_IF_ERROR(TensorValueAt(t, 0, &value)); v->set_simple_value(value); - return Enqueue(global_step, s); + return WriteEvent(std::move(e)); } Status WriteHistogram(int64 global_step, Tensor t, const string& tag) override { - Summary s; - Summary::Value* v = s.add_value(); + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + Summary::Value* v = e->mutable_summary()->add_value(); v->set_tag(tag); histogram::Histogram histo; for (int64 i = 0; i < t.NumElements(); i++) { @@ -287,7 +294,7 @@ class SummaryWriterImpl : public SummaryWriterInterface { } histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); - return Enqueue(global_step, s); + return WriteEvent(std::move(e)); } Status WriteImage(int64 global_step, Tensor tensor, const string& tag, @@ -306,7 +313,10 @@ class SummaryWriterImpl : public SummaryWriterInterface { return errors::InvalidArgument("Tensor too large for summary ", tensor.shape().DebugString()); } - Summary s; + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + Summary* s = e->mutable_summary(); // The casts and h * w cannot overflow because of the limits above. const int batch_size = static_cast(tensor.dim_size(0)); const int h = static_cast(tensor.dim_size(1)); @@ -321,20 +331,20 @@ class SummaryWriterImpl : public SummaryWriterInterface { &values(i, 0, 0), Eigen::DSizes(hw, depth)); }; TF_RETURN_IF_ERROR( - AddImages(tag, max_images, batch_size, w, h, depth, ith_image, &s)); + AddImages(tag, max_images, batch_size, w, h, depth, ith_image, s)); } else if (tensor.dtype() == DT_HALF) { TF_RETURN_IF_ERROR(NormalizeAndAddImages( - tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); + tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s)); } else if (tensor.dtype() == DT_FLOAT) { TF_RETURN_IF_ERROR(NormalizeAndAddImages( - tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); + tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s)); } else { return errors::InvalidArgument( "Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ", DataTypeString(tensor.dtype())); } - return Enqueue(global_step, s); + return WriteEvent(std::move(e)); } Status WriteAudio(int64 global_step, Tensor tensor, const string& tag, @@ -346,10 +356,13 @@ class SummaryWriterImpl : public SummaryWriterInterface { const int64 length_frames = tensor.dim_size(1); const int64 num_channels = tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1); - Summary s; + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + Summary* s = e->mutable_summary(); const int N = std::min(max_outputs, batch_size); for (int i = 0; i < N; ++i) { - Summary::Value* v = s.add_value(); + Summary::Value* v = s->add_value(); if (max_outputs > 1) { v->set_tag(strings::StrCat(tag, "/audio/", i)); } else { @@ -375,16 +388,12 @@ class SummaryWriterImpl : public SummaryWriterInterface { channels_by_frames.data(), sample_rate_truncated, num_channels, length_frames, sa->mutable_encoded_audio_string())); } - - return Enqueue(global_step, s); + return WriteEvent(std::move(e)); } - string DebugString() override { return "SummaryWriterImpl"; } - - private: - Status Enqueue(int64 global_step, const Summary& summary) { + Status WriteEvent(std::unique_ptr event) override { mutex_lock ml(mu_); - queue_.emplace_back(global_step, summary, env_->NowMicros()); + queue_.emplace_back(std::move(event)); if (queue_.size() >= max_queue_ || env_->NowMicros() - last_flush_ > 1000 * flush_millis_) { return InternalFlush(); @@ -392,13 +401,16 @@ class SummaryWriterImpl : public SummaryWriterInterface { return Status::OK(); } + string DebugString() override { return "SummaryWriterImpl"; } + + private: + double GetWallTime() { + return static_cast(env_->NowMicros()) / 1.0e6; + } + Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - for (const EventInfo& e : queue_) { - Event event; - event.set_step(std::get<0>(e)); - *event.mutable_summary() = std::get<1>(e); - event.set_wall_time(static_cast(std::get<2>(e)) / 1.0e6); - events_writer_->WriteEvent(event); + for (const std::unique_ptr& e : queue_) { + events_writer_->WriteEvent(*e); } queue_.clear(); if (!events_writer_->Flush()) { @@ -413,9 +425,8 @@ class SummaryWriterImpl : public SummaryWriterInterface { const int flush_millis_; uint64 last_flush_; Env* env_; - using EventInfo = std::tuple; mutex mu_; - std::vector queue_ GUARDED_BY(mu_); + std::vector> queue_ GUARDED_BY(mu_); // A pointer to allow deferred construction. std::unique_ptr events_writer_ GUARDED_BY(mu_); std::vector> registered_summaries_ diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index 1b5d0b27485..ccf3459e56b 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -15,8 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ #define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ +#include #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { @@ -43,6 +45,8 @@ class SummaryWriterInterface : public ResourceBase { virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, int max_outputs_, float sample_rate) = 0; + + virtual Status WriteEvent(std::unique_ptr e) = 0; }; // Creates a SummaryWriterInterface instance which writes to a file. It will diff --git a/tensorflow/core/kernels/summary_interface_test.cc b/tensorflow/core/kernels/summary_interface_test.cc index 379e045ca3e..58e021a0b3e 100644 --- a/tensorflow/core/kernels/summary_interface_test.cc +++ b/tensorflow/core/kernels/summary_interface_test.cc @@ -12,11 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#include +#include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/framework/summary.pb.h" -#include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/path.h" @@ -43,8 +41,8 @@ class SummaryInterfaceTest : public ::testing::Test { protected: Status SummaryTestHelper( const string& test_name, - std::function writer_fn, - std::function test_fn) { + const std::function& writer_fn, + const std::function& test_fn) { static std::set* tests = new std::set(); CHECK(tests->insert(test_name).second) << ": " << test_name; @@ -182,6 +180,24 @@ TEST_F(SummaryInterfaceTest, WriteAudio) { })); } +TEST_F(SummaryInterfaceTest, WriteEvent) { + TF_CHECK_OK( + SummaryTestHelper("event_test", + [](SummaryWriterInterface* writer) { + std::unique_ptr e{new Event}; + e->set_step(7); + e->mutable_summary()->add_value()->set_tag("hi"); + TF_RETURN_IF_ERROR(writer->WriteEvent(std::move(e))); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 7); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "hi"); + })); +} + TEST_F(SummaryInterfaceTest, WallTime) { env_.AdvanceByMillis(7023); TF_CHECK_OK(SummaryTestHelper(