[XLA:CPU] Adds HLO-level tracing.

PiperOrigin-RevId: 241989294
This commit is contained in:
A. Unique TensorFlower
2019-04-04 13:00:30 -07:00
committed by TensorFlower Gardener
parent 65ee4a62b8
commit 359e02ef20
6 changed files with 125 additions and 0 deletions

View File

@@ -253,6 +253,7 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor/host:host_stream",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -508,6 +509,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",

View File

@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
@@ -86,7 +87,11 @@ extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
extern const char* const kKeyValueSortSymbolName =
"__xla_cpu_runtime_KeyValueSort";
extern const char* const kTracingStartSymbolName =
"__xla_cpu_runtime_TracingStart";
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
} // namespace cpu
} // namespace xla
@@ -104,6 +109,24 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
extern "C" {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
const char* name) {
VLOG(3) << "TracingStart " << name;
return tensorflow::profiler::TraceMe::ActivityStart(name);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
xla::int64 id) {
VLOG(3) << "TracingEnd " << id;
tensorflow::profiler::TraceMe::ActivityEnd(id);
}
} // extern "C"
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,

View File

@@ -66,6 +66,9 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
extern const char* const kKeyValueSortSymbolName;
extern const char* const kTracingStartSymbolName;
extern const char* const kTracingEndSymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
@@ -80,6 +83,13 @@ XfeedManager* GetXfeedManager(int device_ordinal);
extern "C" {
extern xla::int64 __xla_cpu_runtime_TracingStart(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
const char* name);
extern void __xla_cpu_runtime_TracingEnd(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
xla::int64 id);
// Some things common to all of the runtime entry points below:
//
// * The shape pointer and shape_length reflect values that can be deserialized

View File

@@ -176,6 +176,13 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
arch_type_ == llvm::Triple::ArchType::x86_64;
profiling_state_ = ProfilingState(use_rdtscp);
bool emit_tracing =
hlo_module_config_.hlo_profiling_enabled() &&
hlo_module_config_.debug_options().xla_backend_extra_options().count(
"xla_hlo_trace");
tracing_state_.set_enabled(emit_tracing);
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
llvm::Function* ir_function = compute_function_->function();
InsertOrDie(&emitted_functions_, computation, ir_function);
@@ -2883,9 +2890,70 @@ void IrEmitter::ProfilingState::RecordCompleteComputation(
}
}
void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
HloInstruction* hlo,
llvm::Value* run_options) {
if (!enabled_) {
return;
}
llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo();
llvm::FunctionType* fn_type =
llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
/*isVarArg=*/false);
llvm::Function* function = b->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
const char* fn_name = runtime::kTracingStartSymbolName;
llvm::FunctionCallee trace_func =
module->getOrInsertFunction(fn_name, fn_type);
if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
fn->setCallingConv(llvm::CallingConv::C);
fn->setDoesNotThrow();
fn->setOnlyAccessesArgMemory();
}
auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
auto* activity_id =
b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
b->CreateBitCast(hlo_name, int8_ptr_type)});
activity_id->setName(IrName(hlo, "activity_id"));
activity_ids_[hlo] = activity_id;
}
void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
HloInstruction* hlo,
llvm::Value* run_options) {
if (!enabled_) {
return;
}
llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo();
llvm::FunctionType* fn_type =
llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
/*isVarArg=*/false);
llvm::Function* function = b->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
const char* fn_name = runtime::kTracingEndSymbolName;
llvm::FunctionCallee trace_func =
module->getOrInsertFunction(fn_name, fn_type);
if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
fn->setCallingConv(llvm::CallingConv::C);
fn->setDoesNotThrow();
fn->setOnlyAccessesArgMemory();
}
auto* activity_id = activity_ids_.at(hlo);
b->CreateCall(trace_func,
{b->CreateBitCast(run_options, void_ptr_type), activity_id});
}
Status IrEmitter::Preprocess(HloInstruction* hlo) {
VLOG(3) << "Visiting: " << hlo->ToString();
if (instruction_to_profile_idx_.count(hlo)) {
// Only trace the same HLOs that the profiler does.
tracing_state_.EmitTracingStart(&b_, hlo,
GetExecutableRunOptionsArgument());
profiling_state_.RecordCycleStart(&b_, hlo);
}
return Status::OK();
@@ -2895,6 +2963,10 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) {
if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
}
// Only trace the same HLOs that the profiler does.
if (instruction_to_profile_idx_.count(hlo)) {
tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
}
return Status::OK();
}

View File

@@ -531,6 +531,22 @@ class IrEmitter : public DfsHloVisitorWithDefault,
ProfilingState profiling_state_;
class TracingState {
public:
TracingState() : enabled_(false) {}
void set_enabled(bool value) { enabled_ = value; }
void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo,
llvm::Value* run_options);
void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo,
llvm::Value* run_options);
private:
bool enabled_;
// Maps from HLO to the activity id returned by xprof::TraceMe.
std::unordered_map<const HloInstruction*, llvm::Value*> activity_ids_;
};
TracingState tracing_state_;
// Given a load instruction and a shape or buffer size, annotate the load's
// result with the alignment required by the shape or size.
void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);

View File

@@ -245,6 +245,8 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));