mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[XLA:CPU] Adds HLO-level tracing.
PiperOrigin-RevId: 241989294
This commit is contained in:
committed by
TensorFlower Gardener
parent
65ee4a62b8
commit
359e02ef20
@@ -253,6 +253,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/host:host_stream",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@@ -508,6 +509,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
||||
@@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace xla {
|
||||
@@ -86,7 +87,11 @@ extern const char* const kParallelForkJoinSymbolName =
|
||||
"__xla_cpu_runtime_ParallelForkJoin";
|
||||
extern const char* const kKeyValueSortSymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSort";
|
||||
extern const char* const kTracingStartSymbolName =
|
||||
"__xla_cpu_runtime_TracingStart";
|
||||
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
@@ -104,6 +109,24 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
const char* name) {
|
||||
VLOG(3) << "TracingStart " << name;
|
||||
return tensorflow::profiler::TraceMe::ActivityStart(name);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
xla::int64 id) {
|
||||
VLOG(3) << "TracingEnd " << id;
|
||||
tensorflow::profiler::TraceMe::ActivityEnd(id);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
|
||||
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
|
||||
const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
|
||||
|
||||
@@ -66,6 +66,9 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
|
||||
extern const char* const kParallelForkJoinSymbolName;
|
||||
extern const char* const kKeyValueSortSymbolName;
|
||||
|
||||
extern const char* const kTracingStartSymbolName;
|
||||
extern const char* const kTracingEndSymbolName;
|
||||
|
||||
// All symbol names for XLA CPU runtime functions need to start with this
|
||||
// prefix.
|
||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
|
||||
@@ -80,6 +83,13 @@ XfeedManager* GetXfeedManager(int device_ordinal);
|
||||
|
||||
extern "C" {
|
||||
|
||||
extern xla::int64 __xla_cpu_runtime_TracingStart(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
const char* name);
|
||||
extern void __xla_cpu_runtime_TracingEnd(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
|
||||
xla::int64 id);
|
||||
|
||||
// Some things common to all of the runtime entry points below:
|
||||
//
|
||||
// * The shape pointer and shape_length reflect values that can be deserialized
|
||||
|
||||
@@ -176,6 +176,13 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||
bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
|
||||
arch_type_ == llvm::Triple::ArchType::x86_64;
|
||||
profiling_state_ = ProfilingState(use_rdtscp);
|
||||
|
||||
bool emit_tracing =
|
||||
hlo_module_config_.hlo_profiling_enabled() &&
|
||||
hlo_module_config_.debug_options().xla_backend_extra_options().count(
|
||||
"xla_hlo_trace");
|
||||
tracing_state_.set_enabled(emit_tracing);
|
||||
|
||||
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
|
||||
llvm::Function* ir_function = compute_function_->function();
|
||||
InsertOrDie(&emitted_functions_, computation, ir_function);
|
||||
@@ -2883,9 +2890,70 @@ void IrEmitter::ProfilingState::RecordCompleteComputation(
|
||||
}
|
||||
}
|
||||
|
||||
void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
|
||||
HloInstruction* hlo,
|
||||
llvm::Value* run_options) {
|
||||
if (!enabled_) {
|
||||
return;
|
||||
}
|
||||
|
||||
llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
|
||||
llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo();
|
||||
llvm::FunctionType* fn_type =
|
||||
llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
|
||||
/*isVarArg=*/false);
|
||||
|
||||
llvm::Function* function = b->GetInsertBlock()->getParent();
|
||||
llvm::Module* module = function->getParent();
|
||||
const char* fn_name = runtime::kTracingStartSymbolName;
|
||||
llvm::FunctionCallee trace_func =
|
||||
module->getOrInsertFunction(fn_name, fn_type);
|
||||
if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
|
||||
fn->setCallingConv(llvm::CallingConv::C);
|
||||
fn->setDoesNotThrow();
|
||||
fn->setOnlyAccessesArgMemory();
|
||||
}
|
||||
auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
|
||||
auto* activity_id =
|
||||
b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
|
||||
b->CreateBitCast(hlo_name, int8_ptr_type)});
|
||||
activity_id->setName(IrName(hlo, "activity_id"));
|
||||
activity_ids_[hlo] = activity_id;
|
||||
}
|
||||
|
||||
void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
|
||||
HloInstruction* hlo,
|
||||
llvm::Value* run_options) {
|
||||
if (!enabled_) {
|
||||
return;
|
||||
}
|
||||
|
||||
llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo();
|
||||
llvm::FunctionType* fn_type =
|
||||
llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
|
||||
/*isVarArg=*/false);
|
||||
|
||||
llvm::Function* function = b->GetInsertBlock()->getParent();
|
||||
llvm::Module* module = function->getParent();
|
||||
const char* fn_name = runtime::kTracingEndSymbolName;
|
||||
llvm::FunctionCallee trace_func =
|
||||
module->getOrInsertFunction(fn_name, fn_type);
|
||||
if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
|
||||
fn->setCallingConv(llvm::CallingConv::C);
|
||||
fn->setDoesNotThrow();
|
||||
fn->setOnlyAccessesArgMemory();
|
||||
}
|
||||
auto* activity_id = activity_ids_.at(hlo);
|
||||
b->CreateCall(trace_func,
|
||||
{b->CreateBitCast(run_options, void_ptr_type), activity_id});
|
||||
}
|
||||
|
||||
Status IrEmitter::Preprocess(HloInstruction* hlo) {
|
||||
VLOG(3) << "Visiting: " << hlo->ToString();
|
||||
if (instruction_to_profile_idx_.count(hlo)) {
|
||||
// Only trace the same HLOs that the profiler does.
|
||||
tracing_state_.EmitTracingStart(&b_, hlo,
|
||||
GetExecutableRunOptionsArgument());
|
||||
profiling_state_.RecordCycleStart(&b_, hlo);
|
||||
}
|
||||
return Status::OK();
|
||||
@@ -2895,6 +2963,10 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) {
|
||||
if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
|
||||
profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
|
||||
}
|
||||
// Only trace the same HLOs that the profiler does.
|
||||
if (instruction_to_profile_idx_.count(hlo)) {
|
||||
tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@@ -531,6 +531,22 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
|
||||
ProfilingState profiling_state_;
|
||||
|
||||
class TracingState {
|
||||
public:
|
||||
TracingState() : enabled_(false) {}
|
||||
void set_enabled(bool value) { enabled_ = value; }
|
||||
void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo,
|
||||
llvm::Value* run_options);
|
||||
void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo,
|
||||
llvm::Value* run_options);
|
||||
|
||||
private:
|
||||
bool enabled_;
|
||||
// Maps from HLO to the activity id returned by xprof::TraceMe.
|
||||
std::unordered_map<const HloInstruction*, llvm::Value*> activity_ids_;
|
||||
};
|
||||
TracingState tracing_state_;
|
||||
|
||||
// Given a load instruction and a shape or buffer size, annotate the load's
|
||||
// result with the alignment required by the shape or size.
|
||||
void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
|
||||
|
||||
@@ -245,6 +245,8 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
||||
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
|
||||
|
||||
Reference in New Issue
Block a user