Files
tensorflow/third_party/xla/MODULE.bazel
2025-12-26 15:20:26 -08:00

224 lines
8.3 KiB
Python

module(name = "xla")
##############################################################
# Bazel module dependencies
# go/keep-sorted start
bazel_dep(name = "abseil-cpp", version = "20250814.0", repo_name = "com_google_absl")
bazel_dep(name = "abseil-py", version = "2.1.0", repo_name = "absl_py")
bazel_dep(name = "bazel_features", version = "1.36.0")
bazel_dep(name = "bazel_skylib", version = "1.8.1")
bazel_dep(name = "boringssl", version = "0.20250818.0")
bazel_dep(name = "curl", version = "8.11.0")
bazel_dep(name = "google_benchmark", version = "1.8.5", repo_name = "com_google_benchmark")
bazel_dep(name = "googletest", version = "1.17.0", repo_name = "com_google_googletest")
bazel_dep(name = "grpc", version = "1.74.1", repo_name = "com_github_grpc_grpc")
bazel_dep(name = "gutil", version = "20250502.0", repo_name = "com_google_gutil")
bazel_dep(name = "jsoncpp", version = "1.9.6", repo_name = "jsoncpp_git")
bazel_dep(name = "or-tools", version = "9.12", repo_name = "com_google_ortools")
bazel_dep(name = "platforms", version = "1.0.0")
bazel_dep(name = "protobuf", version = "32.1", repo_name = "com_google_protobuf")
bazel_dep(name = "pybind11_abseil", version = "202402.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
bazel_dep(name = "pybind11_protobuf", version = "0.0.0-20250210-f02a2b7")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "riegeli", version = "0.0.0-20250822-9f2744d", repo_name = "com_google_riegeli")
bazel_dep(name = "rules_cc", version = "0.2.0")
bazel_dep(name = "rules_java", version = "8.16.1")
bazel_dep(name = "rules_license", version = "1.0.0")
bazel_dep(name = "rules_python", version = "1.6.0")
bazel_dep(name = "rules_shell", version = "0.6.1")
bazel_dep(name = "snappy", version = "1.2.1")
bazel_dep(name = "zlib", version = "1.3.1.bcr.5")
# go/keep-sorted end
# Only for compatibility, not directly used, change repo_name to None after upgrading Bazel to latest 7.x
bazel_dep(name = "eigen", version = "4.0.0-20241125.bcr.3", repo_name = "DO_NOT_USE_eigen")
bazel_dep(name = "grpc-java", version = "1.75.0", repo_name = "DO_NOT_USE_grpc_java")
# TODO: publish an official version of rules_ml_toolchain to BCR
bazel_dep(name = "rules_ml_toolchain")
# To calculate integrity:
# wget -O temp_module_archive.tar.gz <archive URL>
# HASH=$(openssl dgst -sha256 -binary temp_module_archive.tar.gz | openssl base64 -A)
# echo "sha256-${HASH}"
archive_override(
module_name = "rules_ml_toolchain",
integrity = "sha256-B4AvIZFqETvnj/IRCJEjm9UYOtCdjEL2+bBOTgv6VQU=",
strip_prefix = "rules_ml_toolchain-802e0dbbcc3cd82ac5b0accbff6f95b70106d0d1",
urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/802e0dbbcc3cd82ac5b0accbff6f95b70106d0d1.tar.gz"],
)
# TODO: Upstream the patch?
single_version_override(
module_name = "grpc",
patch_strip = 1,
patches = ["//third_party/grpc:grpc.patch"],
)
# TODO: Upstream those patch?
single_version_override(
module_name = "abseil-cpp",
patch_strip = 1,
patches = [
"//third_party/absl:btree.patch",
"//third_party/absl:build_dll.patch",
"//third_party/absl:endian.patch",
"//third_party/absl:rules_cc.patch",
"//third_party/absl:check_op.patch",
"//third_party/absl:check_op_2.patch",
],
)
# Use an unreleased version of googletest
archive_override(
module_name = "googletest",
patch_strip = 1,
patches = [
"//third_party/googletest:0001-Add-ASSERT_OK-EXPECT_OK-ASSERT_OK_AND_ASSIGN-macros.patch",
],
strip_prefix = "googletest-28e9d1f26771c6517c3b4be10254887673c94018",
urls = ["https://github.com/google/googletest/archive/28e9d1f26771c6517c3b4be10254887673c94018.zip"],
)
##############################################################
# C++ dependencies
# TODO: most of them can be Bazel modules, but we need a release strategy for them
third_party = use_extension("//third_party/extensions:third_party.bzl", "third_party_ext")
use_repo(
third_party,
"FXdiv",
"XNNPACK",
"cpuinfo",
"cudnn_frontend_archive",
"dlpack",
"ducc",
"eigen_archive",
"farmhash_archive",
"farmhash_gpu_archive",
"gloo",
"highwayhash",
"hwloc",
"implib_so",
"llvm-raw",
"llvm_openmp",
"ml_dtypes_py",
"mpitrampoline",
"nanobind",
"nvshmem",
"onednn",
"pthreadpool",
"rocm_device_libs",
"shardy",
"slinky",
"stablehlo",
"triton",
)
##############################################################
# Python toolchain and pypi dependencies
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.defaults(python_version = "3.11")
python.toolchain(python_version = "3.11")
use_repo(python, "pythons_hub")
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.whl_mods(
additive_build_content = """
load("@rules_cc//cc:cc_library.bzl", "cc_library")
cc_library(
name = "numpy_headers_2",
hdrs = glob(["site-packages/numpy/_core/include/**/*.h"]),
strip_include_prefix="site-packages/numpy/_core/include/",
)
cc_library(
name = "numpy_headers_1",
hdrs = glob(["site-packages/numpy/core/include/**/*.h"]),
strip_include_prefix="site-packages/numpy/core/include/",
)
cc_library(
name = "numpy_headers",
deps = [":numpy_headers_2", ":numpy_headers_1"],
)
""",
hub_name = "pypi_mods",
whl_name = "numpy",
)
pip.parse(
extra_hub_aliases = {
"numpy": ["numpy_headers"],
},
hub_name = "xla_pypi",
python_version = "3.11",
requirements_lock = "//:requirements_lock_3_11.txt",
whl_modifications = {
"@pypi_mods//:numpy.json": "numpy",
},
)
use_repo(pip, "pypi_mods", pypi = "xla_pypi")
python_version_ext = use_extension("//third_party/extensions:python_version.bzl", "python_version_ext")
use_repo(python_version_ext, "python_version_repo")
##############################################################
# Other dependencies via module extensions
### TSL
tsl_extension = use_extension("//third_party/extensions:tsl.bzl", "tsl_extension")
use_repo(tsl_extension, "tsl")
### LLVM
llvm = use_extension("//third_party/extensions:llvm.bzl", "llvm_extension")
use_repo(llvm, "llvm-project")
### pybind
pybind11_internal_configure = use_extension(
"@pybind11_bazel//:internal_configure.bzl",
"internal_configure_extension",
)
use_repo(pybind11_internal_configure, "pybind11")
### RBE
rbe_config = use_extension("//third_party/extensions:rbe_config.bzl", "rbe_config_ext")
use_repo(rbe_config, "ml_build_config_platform")
remote_execution_configure = use_extension("//third_party/extensions:remote_execution_configure.bzl", "remote_execution_configure_ext")
use_repo(remote_execution_configure, "local_config_remote_execution")
### From @rules_ml_toolchain
cuda_configure = use_extension("@rules_ml_toolchain//third_party/extensions:cuda_configure.bzl", "cuda_configure_ext")
use_repo(cuda_configure, "local_config_cuda")
nccl_configure = use_extension("@rules_ml_toolchain//third_party/extensions:nccl_configure.bzl", "nccl_configure_ext")
use_repo(nccl_configure, "local_config_nccl")
sycl_configure = use_extension("@rules_ml_toolchain//third_party/extensions:sycl_configure.bzl", "sycl_configure_ext")
use_repo(sycl_configure, "local_config_sycl")
cuda_redist_init_ext = use_extension("@rules_ml_toolchain//third_party/extensions:cuda_redist_init.bzl", "cuda_redist_init_ext")
use_repo(cuda_redist_init_ext, "cuda_cudart")
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")
register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")
register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64")
register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda")
### Other local config repos
rocm_configure = use_extension("//third_party/extensions:rocm_configure.bzl", "rocm_configure_ext")
use_repo(rocm_configure, "local_config_rocm")
tensorrt_configure = use_extension("//third_party/extensions:tensorrt_configure.bzl", "tensorrt_configure_ext")
use_repo(tensorrt_configure, "local_config_tensorrt")
pjrt_nightly_timestamp = use_extension("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo_bzlmod")
use_repo(pjrt_nightly_timestamp, "nightly_timestamp")
pjrt_rc_number = use_extension("//build_tools/pjrt_wheels:release_candidate.bzl", "rc_number_repo_bzlmod")
use_repo(pjrt_rc_number, "rc_number")