mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Reenable file upload in bazel remote config
This replaces the "transfer script via cli argument" hack by the upload support for remote configurations that landed in Bazel 3.1.0. PiperOrigin-RevId: 569079584
This commit is contained in:
committed by
TensorFlower Gardener
parent
354634b03f
commit
f2957eb767
@@ -250,7 +250,6 @@ tf_staging/third_party/gpus/cuda/build_defs.bzl.tpl:
|
||||
tf_staging/third_party/gpus/cuda/cuda_config.h.tpl:
|
||||
tf_staging/third_party/gpus/cuda/cuda_config.py.tpl:
|
||||
tf_staging/third_party/gpus/cuda_configure.bzl:
|
||||
tf_staging/third_party/gpus/find_cuda_config.py.gz.base64:
|
||||
tf_staging/third_party/gpus/find_cuda_config:.py
|
||||
tf_staging/third_party/gpus/rocm/BUILD.tpl:
|
||||
tf_staging/third_party/gpus/rocm/BUILD:
|
||||
|
||||
36
third_party/gpus/compress_find_cuda_config.py
vendored
36
third_party/gpus/compress_find_cuda_config.py
vendored
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_cuda_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_cuda_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_cuda_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
36
third_party/gpus/compress_find_rocm_config.py
vendored
36
third_party/gpus/compress_find_rocm_config.py
vendored
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_rocm_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_rocm_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_rocm_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
44
third_party/gpus/cuda_configure.bzl
vendored
44
third_party/gpus/cuda_configure.bzl
vendored
@@ -641,42 +641,19 @@ def _cudart_static_linkopt(cpu_value):
|
||||
"""Returns additional platform-specific linkopts for cudart."""
|
||||
return "" if cpu_value == "Darwin" else "\"-lrt\","
|
||||
|
||||
def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
|
||||
# and nccl_configure.bzl.
|
||||
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
def find_cuda_config(repository_ctx, cuda_libraries):
|
||||
"""Returns CUDA config dictionary from running find_cuda_config.py"""
|
||||
exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
def _get_cuda_config(repository_ctx):
|
||||
"""Detects and returns information about the CUDA installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -692,7 +669,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
compute_capabilities: A list of the system's CUDA compute capabilities.
|
||||
cpu_value: The name of the host operating system.
|
||||
"""
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
toolkit_path = config["cuda_toolkit_path"]
|
||||
|
||||
@@ -1008,9 +985,8 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"cuda:cuda_config.py",
|
||||
]}
|
||||
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
|
||||
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
|
||||
cuda_config = _get_cuda_config(repository_ctx)
|
||||
|
||||
cuda_include_path = cuda_config.config["cuda_include_dir"]
|
||||
cublas_include_path = cuda_config.config["cublas_include_dir"]
|
||||
@@ -1484,12 +1460,20 @@ remote_cuda_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cuda_configure = repository_rule(
|
||||
implementation = _cuda_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
6
third_party/gpus/find_cuda_config.py
vendored
6
third_party/gpus/find_cuda_config.py
vendored
@@ -53,12 +53,6 @@ tf_<library>_header_dir: ...
|
||||
tf_<library>_library_dir: ...
|
||||
"""
|
||||
|
||||
# You can use the following command to regenerate the base64 version of this
|
||||
# script:
|
||||
# cat third_party/tensorflow/third_party/gpus/find_cuda_config.oss.py |
|
||||
# pigz -z | base64 -w0 >
|
||||
# third_party/tensorflow/third_party/gpus/find_cuda_config.py.gz.base64.oss
|
||||
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
eJztW21v2zgS/q5fQSgoam9dJSlugUMOOcCbZlHftUlhZ7tYtIVB27TNrSz6SCppUPS/7wxJyRIt2YqjvX6JgTa2NPNwOC8Ph7Z4RC7E+l7yxVKTVyevTsjNkpEblighf43FHemneimkikg/jskQxRQZMsXkLZtFwVFwRN7yKYizGUmTGZNEg35/Tafwx93pkQ9MKi4S8io6IR0UCN2tsPsvQLgXKVnRe5IITVLFAIIrMucxI+zrlK014QmZitU65jSZMnLH9dIM40DADPKHgxATTUGagvwaPs2LcoRqYzC+llqvz46P7+7uImqMjYRcHMdWUB2/HVxcXo0uX4LBRuW3JGZKEcn+l3IJU53cE7oGe6Z0AlbG9I4ISehCMrinBdp7J7nmyaJHlJjrOyoZoMy40pJPUl1yVmYdzLkoAO6iCQn7IzIYheSX/mgw6gHG74ObN9e/3ZDf+8Nh/+pmcDki10NycX31enAzuL6CT7+S/tUf5L+Dq9c9wsBVMAz7upZoPxjJ0Y0mdGTEWMmAubAGqTWb8jmfwrySRUoXjCzELZMJTIesmVxxhcFUYN4MUGK+4ppqc2VrUjjMeauvIAzD95InmIbXFysYfiKpvEdjyJJRHH8GIZpqITkzNpJbm32QUgIMRMeaWd4rzVZREGDCq6nkkGeKUQm5oIwr6uAxMVUZpQcRR69pFcDFFabAjGl0VWJczGVmhAFaW/tRfyqSOV+k0jgQ9ZSeiVRHxqo1xUQXGThmiIsNptlSinSxxCRhyS2XIlmxRJNbKrlJyg7Y/278vn/zphsFgzkUF9yL+cwbkju39Ox0rB8yA405TEoTasl0Kk3YCVwCB03FjJX9p+kXZueVxeC+YDEUDd7K7aq0OyrixUJ8scGwvrfxzGJiA2GqfUnl7CXaM4MYaqj7QKWTYh7MpViRCVXOqY4YNrbl9kYEfLUxEdwDrBTkgsZNUJbHYq2PpZiuQhRJkf4o2KIh7nOaxjifOGUBZmsQQM0JCeET2TuhsnfAC+4dZFIQBNOYQp1emBBdopc7l4YCIVTds4CA9QrFYBQyXjA9dsON0ZQxTq1jxGysimYWlYwwJJWmcVxQAltfZ1lrPZ2F3IVtRZySTR3UjHCCMJ6PSM5r7QNxPidh7uMQQyhU5LIBLakG3Mh8LGh/BvmjOnljomQ0NkNvCXU3rtq6V3QYLj5Ccc3GrozHSbqaMNlZ0T+F7BHwGP4Btemy6P/TE3iRn4gRIy/wM35CafhkxIvD2CTPxuigFT2S0BXLwjN08QA2X4PBQOkABOomYODFiyJHudBg9cTIQ3Cfi0isWYYcyhBWiATKBpj9PEz1/OU/w671/wptAx9KFpm3HRke2YHIM0VedD7NXnRD8sxY1zP4XaMHkTXyFoXYMgAcczFaAGWtO6ddd9N5CbimY+S6AbqOcqjSYgU8z0YOv3035WZXtU/J8wgmB8gd4yPywsGWXyExfQUugMgz4AXTWXz7DuvmpyTMIEw61ENwoEM7F8sIyIJgN1sw9C4UDY3DbhZLMHZmM95ybUXineFUUdjcKmeV6jgRHBvyTnHgxVzGst45+ZibGvJkGqczZgodWFDad5lGtIQwQ4mYAv45elWtVyN9mi1YQhq9z+b/oi1gypVImLmOuTZHF1dbnSdF4aJfrH8KnuXnPMsTSKrsNvsKLZLqbCF0z/JpecZtieaCEyCGL1nWYkyLohnedjoWUiQ0XsqWdjMg4tg+A9wA+Zrll5dY1Q7qmgogji/OK1mhqNBzNPrhcjiC5m/8rv+f62Fo3WZZ5gCMwVWOsXYs8FAMIOaLNw7DFXkFU+JcKy4j61QVRQV7ZyuPLTNQ/ObcHFYghGcHkDngfQ+8VcKOVir2JV/LNNF8xRqUfEHYs+DAugdA/Le/4FGwmdJT3e+s+zeD92RoY/g3lP8RbMJhmSms5dDg2sb7K246oc3MlkT0PBiTVV6P3C0F2G27T4sFSJ2CiCUJaEOwGXlBSnew9LtODXoNyfit3SJmDYwlFdsavpzSeJpCLwjWMQkcoLjZnUPzL0zTbHHuKKy6UHe4XAoDNkkVNiwKu/A1hfnDvhK6f0gH3KeqhzLg1tweToBbTihz12mpiTO4GKitoi8SUG2RAw3VE0Blh1okoK0xSzS04tjeNaAgJ/i4vqPEQRbxuMQm7rVhIo9UchCnvAvriZkaMdPzd4Nr8F2ZlEzfnDNT3jUXwbvdh5YdDPT+8uqxleejHNZ8eCiPbT9qqqO6ASnVXJEBKlHCs0KwHt2OlMb2dx+TmKpmGxAj2d4eBOGOcXMkExpnF17ubU42mhapIcITH+zsVJzz2u1SHkgVEKlf3vZHLWxTyjAH71RKMC1sVioLqHa/UihMb8tSgdMuYZRH9xlDQiU1Ywwj2c7mxaEde6h+lZd4wih4ADWKJ0/0sIceZOE3kt30sJMPytnQuBSH/avXWSmWi7AM6GqnkKJe7VSkJLTaNbm6t88uD+UXynyum9UJCra3sAKa+9NoNUXxsvbTEnpgjYDvfugK6qWS0Xz4AuqjoObD108PxWg+bvmsqpLa1XNTex4BbKO0vnZuxva//WvGCE6wJUawaO7PfkZw4mXtJ0Y4hBGs734oI1gTPtiL7w7jgzLGYWxQwnj/SC6oqY9qLihVnffV29/NBaWx/e5ASzoFhzVqEJxsez2CBdy829khOOEt9Sc2eGB/YN32ozfZN8P+xeWwhW22D1TYaB9tsjZ/3kow+5v4kt4yQl01Z75w7XuRVk4e0yzUVExtv1CqRq9lqMRqvWsoWeA3Dgp/AWn4q6GTba99sICbd42aCKe0BfPUShzYSlj3/ehuwlrx6IaiDHNwT1GCaaGtqCmd2s6iVJZec1GJ1Xp/UbLAJw0R3zbqMDayLZKGAdx8PZ9fasYembb3Jf8+lNNdKPuUn74E3OKgjRcfTDxtko6x4PGkU4I5nHSKMG2QTnXp1ZNOsax90qnCap90ihb4+5rGrLORbW9f4yo+f9foG1CPJvYpP7Uq+/Y5BzJGm6wBsRpdv/3Qyj7HBzr4B0UPqIWfFB/EHX5xetuc/wN3+BY47th6Etp/kp2SGZ/iaQI80CHmthKNIQnDIyhuPjyZi31HDSpONQR5tnslVUVeewrg+Sg/LLI5HmKerMn3vwY8f7xmewwbK6bwfAhE6Pvm40f7wKoWIv7CtZEOP2cPwpZOJWQaUbqeUc064wYPm3drtJo8tVqnu+9Rszq9vQ+m7FDc/fv0DsWd3867LCkFwiuVz+Tf5+QfJ6cnJy5Nqp25d5ga8/Z9Ybgjfvu+Pmg+uZ/3TG7f6r9jfvtVC0RiIBx9rCisj6Y0tbw/y1fdL+y+l53GSIgSUrNZZ5tqIqCylep08+XSHGvrhM/UGXmm8PRKZ4Nk7HfHTQv1j6e63PKo7lVkj8NFeLiTdcJPyeVweD08g1L+lBTOkigtOwDYzdWAGDQeewkCiMV4jMdVxmNyfk7C8RjnOB4bNrbTDf4CEAHDYg==
|
||||
45
third_party/gpus/rocm_configure.bzl
vendored
45
third_party/gpus/rocm_configure.bzl
vendored
@@ -365,40 +365,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_
|
||||
libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True))
|
||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||
|
||||
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py');" % (python_bin)
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
def find_rocm_config(repository_ctx, script_path):
|
||||
def find_rocm_config(repository_ctx):
|
||||
"""Returns ROCm config dictionary from running find_rocm_config.py"""
|
||||
exec_result = _exec_find_rocm_config(repository_ctx, script_path)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
def _get_rocm_config(repository_ctx, bash_bin):
|
||||
"""Detects and returns information about the ROCm installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -413,7 +390,7 @@ def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
miopen_version_number: The version of MIOpen on the system.
|
||||
hipruntime_version_number: The version of HIP Runtime on the system.
|
||||
"""
|
||||
config = find_rocm_config(repository_ctx, find_rocm_config_script)
|
||||
config = find_rocm_config(repository_ctx)
|
||||
rocm_toolkit_path = config["rocm_toolkit_path"]
|
||||
rocm_version_number = config["rocm_version_number"]
|
||||
miopen_version_number = config["miopen_version_number"]
|
||||
@@ -565,10 +542,8 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"rocm:rocm_config.h",
|
||||
]}
|
||||
|
||||
find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64"))
|
||||
|
||||
bash_bin = get_bash_bin(repository_ctx)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin)
|
||||
|
||||
# For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft
|
||||
rocm_version_number = int(rocm_config.rocm_version_number)
|
||||
@@ -849,12 +824,20 @@ remote_rocm_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
rocm_configure = repository_rule(
|
||||
implementation = _rocm_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local ROCm toolchain.
|
||||
|
||||
|
||||
18
third_party/nccl/nccl_configure.bzl
vendored
18
third_party/nccl/nccl_configure.bzl
vendored
@@ -90,17 +90,11 @@ def _label(file):
|
||||
return Label("//third_party/nccl:{}".format(file))
|
||||
|
||||
def _create_local_nccl_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||
if nccl_version:
|
||||
nccl_version = nccl_version.split(".")[0]
|
||||
|
||||
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
||||
cuda_config = find_cuda_config(repository_ctx, ["cuda"])
|
||||
cuda_version = cuda_config["cuda_version"].split(".")
|
||||
|
||||
if nccl_version == "":
|
||||
@@ -120,7 +114,7 @@ def _create_local_nccl_repository(repository_ctx):
|
||||
)
|
||||
else:
|
||||
# Create target for locally installed NCCL.
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["nccl"])
|
||||
config = find_cuda_config(repository_ctx, ["nccl"])
|
||||
config_wrap = {
|
||||
"%{nccl_version}": config["nccl_version"],
|
||||
"%{nccl_header_dir}": config["nccl_include_dir"],
|
||||
@@ -170,12 +164,20 @@ remote_nccl_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
nccl_configure = repository_rule(
|
||||
implementation = _nccl_autoconf_impl,
|
||||
environ = _ENVIRONS,
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the NCCL configuration.
|
||||
|
||||
|
||||
11
third_party/tensorrt/tensorrt_configure.bzl
vendored
11
third_party/tensorrt/tensorrt_configure.bzl
vendored
@@ -148,11 +148,6 @@ def _get_tensorrt_full_version(repository_ctx):
|
||||
return get_host_environ(repository_ctx, _TF_TENSORRT_VERSION, None)
|
||||
|
||||
def _create_local_tensorrt_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
tpl_paths = {
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
@@ -161,7 +156,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
"plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"),
|
||||
}
|
||||
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda", "tensorrt"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "tensorrt"])
|
||||
cuda_version = config["cuda_version"]
|
||||
cuda_library_path = config["cuda_library_dir"] + "/"
|
||||
trt_version = config["tensorrt_version"]
|
||||
@@ -318,12 +313,16 @@ remote_tensorrt_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(default = "@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
|
||||
tensorrt_configure = repository_rule(
|
||||
implementation = _tensorrt_configure_impl,
|
||||
environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(default = "@org_tensorflow//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
1
third_party/xla/opensource_only.files
vendored
1
third_party/xla/opensource_only.files
vendored
@@ -48,7 +48,6 @@ third_party/gpus/cuda/build_defs.bzl.tpl:
|
||||
third_party/gpus/cuda/cuda_config.h.tpl:
|
||||
third_party/gpus/cuda/cuda_config.py.tpl:
|
||||
third_party/gpus/cuda_configure.bzl:
|
||||
third_party/gpus/find_cuda_config.py.gz.base64:
|
||||
third_party/gpus/find_cuda_config:.py
|
||||
third_party/gpus/rocm/BUILD.tpl:
|
||||
third_party/gpus/rocm/BUILD:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_cuda_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_cuda_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_cuda_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_rocm_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_rocm_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_rocm_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -641,42 +641,19 @@ def _cudart_static_linkopt(cpu_value):
|
||||
"""Returns additional platform-specific linkopts for cudart."""
|
||||
return "" if cpu_value == "Darwin" else "\"-lrt\","
|
||||
|
||||
def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
|
||||
# and nccl_configure.bzl.
|
||||
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
def find_cuda_config(repository_ctx, cuda_libraries):
|
||||
"""Returns CUDA config dictionary from running find_cuda_config.py"""
|
||||
exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
def _get_cuda_config(repository_ctx):
|
||||
"""Detects and returns information about the CUDA installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -692,7 +669,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
compute_capabilities: A list of the system's CUDA compute capabilities.
|
||||
cpu_value: The name of the host operating system.
|
||||
"""
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
toolkit_path = config["cuda_toolkit_path"]
|
||||
|
||||
@@ -1008,9 +985,8 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"cuda:cuda_config.py",
|
||||
]}
|
||||
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
|
||||
find_cuda_config_script = repository_ctx.path(Label("@local_xla//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
|
||||
cuda_config = _get_cuda_config(repository_ctx)
|
||||
|
||||
cuda_include_path = cuda_config.config["cuda_include_dir"]
|
||||
cublas_include_path = cuda_config.config["cublas_include_dir"]
|
||||
@@ -1484,12 +1460,20 @@ remote_cuda_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cuda_configure = repository_rule(
|
||||
implementation = _cuda_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
@@ -53,12 +53,6 @@ tf_<library>_header_dir: ...
|
||||
tf_<library>_library_dir: ...
|
||||
"""
|
||||
|
||||
# You can use the following command to regenerate the base64 version of this
|
||||
# script:
|
||||
# cat third_party/tensorflow/third_party/gpus/find_cuda_config.oss.py |
|
||||
# pigz -z | base64 -w0 >
|
||||
# third_party/tensorflow/third_party/gpus/find_cuda_config.py.gz.base64.oss
|
||||
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
eJztW21v2zgS/q5fQSgoam9dJSlugUMOOcCbZlHftUlhZ7tYtIVB27TNrSz6SCppUPS/7wxJyRIt2YqjvX6JgTa2NPNwOC8Ph7Z4RC7E+l7yxVKTVyevTsjNkpEblighf43FHemneimkikg/jskQxRQZMsXkLZtFwVFwRN7yKYizGUmTGZNEg35/Tafwx93pkQ9MKi4S8io6IR0UCN2tsPsvQLgXKVnRe5IITVLFAIIrMucxI+zrlK014QmZitU65jSZMnLH9dIM40DADPKHgxATTUGagvwaPs2LcoRqYzC+llqvz46P7+7uImqMjYRcHMdWUB2/HVxcXo0uX4LBRuW3JGZKEcn+l3IJU53cE7oGe6Z0AlbG9I4ISehCMrinBdp7J7nmyaJHlJjrOyoZoMy40pJPUl1yVmYdzLkoAO6iCQn7IzIYheSX/mgw6gHG74ObN9e/3ZDf+8Nh/+pmcDki10NycX31enAzuL6CT7+S/tUf5L+Dq9c9wsBVMAz7upZoPxjJ0Y0mdGTEWMmAubAGqTWb8jmfwrySRUoXjCzELZMJTIesmVxxhcFUYN4MUGK+4ppqc2VrUjjMeauvIAzD95InmIbXFysYfiKpvEdjyJJRHH8GIZpqITkzNpJbm32QUgIMRMeaWd4rzVZREGDCq6nkkGeKUQm5oIwr6uAxMVUZpQcRR69pFcDFFabAjGl0VWJczGVmhAFaW/tRfyqSOV+k0jgQ9ZSeiVRHxqo1xUQXGThmiIsNptlSinSxxCRhyS2XIlmxRJNbKrlJyg7Y/278vn/zphsFgzkUF9yL+cwbkju39Ox0rB8yA405TEoTasl0Kk3YCVwCB03FjJX9p+kXZueVxeC+YDEUDd7K7aq0OyrixUJ8scGwvrfxzGJiA2GqfUnl7CXaM4MYaqj7QKWTYh7MpViRCVXOqY4YNrbl9kYEfLUxEdwDrBTkgsZNUJbHYq2PpZiuQhRJkf4o2KIh7nOaxjifOGUBZmsQQM0JCeET2TuhsnfAC+4dZFIQBNOYQp1emBBdopc7l4YCIVTds4CA9QrFYBQyXjA9dsON0ZQxTq1jxGysimYWlYwwJJWmcVxQAltfZ1lrPZ2F3IVtRZySTR3UjHCCMJ6PSM5r7QNxPidh7uMQQyhU5LIBLakG3Mh8LGh/BvmjOnljomQ0NkNvCXU3rtq6V3QYLj5Ccc3GrozHSbqaMNlZ0T+F7BHwGP4Btemy6P/TE3iRn4gRIy/wM35CafhkxIvD2CTPxuigFT2S0BXLwjN08QA2X4PBQOkABOomYODFiyJHudBg9cTIQ3Cfi0isWYYcyhBWiATKBpj9PEz1/OU/w671/wptAx9KFpm3HRke2YHIM0VedD7NXnRD8sxY1zP4XaMHkTXyFoXYMgAcczFaAGWtO6ddd9N5CbimY+S6AbqOcqjSYgU8z0YOv3035WZXtU/J8wgmB8gd4yPywsGWXyExfQUugMgz4AXTWXz7DuvmpyTMIEw61ENwoEM7F8sIyIJgN1sw9C4UDY3DbhZLMHZmM95ybUXineFUUdjcKmeV6jgRHBvyTnHgxVzGst45+ZibGvJkGqczZgodWFDad5lGtIQwQ4mYAv45elWtVyN9mi1YQhq9z+b/oi1gypVImLmOuTZHF1dbnSdF4aJfrH8KnuXnPMsTSKrsNvsKLZLqbCF0z/JpecZtieaCEyCGL1nWYkyLohnedjoWUiQ0XsqWdjMg4tg+A9wA+Zrll5dY1Q7qmgogji/OK1mhqNBzNPrhcjiC5m/8rv+f62Fo3WZZ5gCMwVWOsXYs8FAMIOaLNw7DFXkFU+JcKy4j61QVRQV7ZyuPLTNQ/ObcHFYghGcHkDngfQ+8VcKOVir2JV/LNNF8xRqUfEHYs+DAugdA/Le/4FGwmdJT3e+s+zeD92RoY/g3lP8RbMJhmSms5dDg2sb7K246oc3MlkT0PBiTVV6P3C0F2G27T4sFSJ2CiCUJaEOwGXlBSnew9LtODXoNyfit3SJmDYwlFdsavpzSeJpCLwjWMQkcoLjZnUPzL0zTbHHuKKy6UHe4XAoDNkkVNiwKu/A1hfnDvhK6f0gH3KeqhzLg1tweToBbTihz12mpiTO4GKitoi8SUG2RAw3VE0Blh1okoK0xSzS04tjeNaAgJ/i4vqPEQRbxuMQm7rVhIo9UchCnvAvriZkaMdPzd4Nr8F2ZlEzfnDNT3jUXwbvdh5YdDPT+8uqxleejHNZ8eCiPbT9qqqO6ASnVXJEBKlHCs0KwHt2OlMb2dx+TmKpmGxAj2d4eBOGOcXMkExpnF17ubU42mhapIcITH+zsVJzz2u1SHkgVEKlf3vZHLWxTyjAH71RKMC1sVioLqHa/UihMb8tSgdMuYZRH9xlDQiU1Ywwj2c7mxaEde6h+lZd4wih4ADWKJ0/0sIceZOE3kt30sJMPytnQuBSH/avXWSmWi7AM6GqnkKJe7VSkJLTaNbm6t88uD+UXynyum9UJCra3sAKa+9NoNUXxsvbTEnpgjYDvfugK6qWS0Xz4AuqjoObD108PxWg+bvmsqpLa1XNTex4BbKO0vnZuxva//WvGCE6wJUawaO7PfkZw4mXtJ0Y4hBGs734oI1gTPtiL7w7jgzLGYWxQwnj/SC6oqY9qLihVnffV29/NBaWx/e5ASzoFhzVqEJxsez2CBdy829khOOEt9Sc2eGB/YN32ozfZN8P+xeWwhW22D1TYaB9tsjZ/3kow+5v4kt4yQl01Z75w7XuRVk4e0yzUVExtv1CqRq9lqMRqvWsoWeA3Dgp/AWn4q6GTba99sICbd42aCKe0BfPUShzYSlj3/ehuwlrx6IaiDHNwT1GCaaGtqCmd2s6iVJZec1GJ1Xp/UbLAJw0R3zbqMDayLZKGAdx8PZ9fasYembb3Jf8+lNNdKPuUn74E3OKgjRcfTDxtko6x4PGkU4I5nHSKMG2QTnXp1ZNOsax90qnCap90ihb4+5rGrLORbW9f4yo+f9foG1CPJvYpP7Uq+/Y5BzJGm6wBsRpdv/3Qyj7HBzr4B0UPqIWfFB/EHX5xetuc/wN3+BY47th6Etp/kp2SGZ/iaQI80CHmthKNIQnDIyhuPjyZi31HDSpONQR5tnslVUVeewrg+Sg/LLI5HmKerMn3vwY8f7xmewwbK6bwfAhE6Pvm40f7wKoWIv7CtZEOP2cPwpZOJWQaUbqeUc064wYPm3drtJo8tVqnu+9Rszq9vQ+m7FDc/fv0DsWd3867LCkFwiuVz+Tf5+QfJ6cnJy5Nqp25d5ga8/Z9Ybgjfvu+Pmg+uZ/3TG7f6r9jfvtVC0RiIBx9rCisj6Y0tbw/y1fdL+y+l53GSIgSUrNZZ5tqIqCylep08+XSHGvrhM/UGXmm8PRKZ4Nk7HfHTQv1j6e63PKo7lVkj8NFeLiTdcJPyeVweD08g1L+lBTOkigtOwDYzdWAGDQeewkCiMV4jMdVxmNyfk7C8RjnOB4bNrbTDf4CEAHDYg==
|
||||
@@ -365,40 +365,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_
|
||||
libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True))
|
||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||
|
||||
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py');" % (python_bin)
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
def find_rocm_config(repository_ctx, script_path):
|
||||
def find_rocm_config(repository_ctx):
|
||||
"""Returns ROCm config dictionary from running find_rocm_config.py"""
|
||||
exec_result = _exec_find_rocm_config(repository_ctx, script_path)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
def _get_rocm_config(repository_ctx, bash_bin):
|
||||
"""Detects and returns information about the ROCm installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -413,7 +390,7 @@ def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
miopen_version_number: The version of MIOpen on the system.
|
||||
hipruntime_version_number: The version of HIP Runtime on the system.
|
||||
"""
|
||||
config = find_rocm_config(repository_ctx, find_rocm_config_script)
|
||||
config = find_rocm_config(repository_ctx)
|
||||
rocm_toolkit_path = config["rocm_toolkit_path"]
|
||||
rocm_version_number = config["rocm_version_number"]
|
||||
miopen_version_number = config["miopen_version_number"]
|
||||
@@ -565,10 +542,8 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"rocm:rocm_config.h",
|
||||
]}
|
||||
|
||||
find_rocm_config_script = repository_ctx.path(Label("@local_xla//third_party/gpus:find_rocm_config.py.gz.base64"))
|
||||
|
||||
bash_bin = get_bash_bin(repository_ctx)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin)
|
||||
|
||||
# For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft
|
||||
rocm_version_number = int(rocm_config.rocm_version_number)
|
||||
@@ -849,12 +824,20 @@ remote_rocm_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
rocm_configure = repository_rule(
|
||||
implementation = _rocm_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local ROCm toolchain.
|
||||
|
||||
|
||||
@@ -90,17 +90,11 @@ def _label(file):
|
||||
return Label("//third_party/nccl:{}".format(file))
|
||||
|
||||
def _create_local_nccl_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@local_xla//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||
if nccl_version:
|
||||
nccl_version = nccl_version.split(".")[0]
|
||||
|
||||
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
||||
cuda_config = find_cuda_config(repository_ctx, ["cuda"])
|
||||
cuda_version = cuda_config["cuda_version"].split(".")
|
||||
|
||||
if nccl_version == "":
|
||||
@@ -120,7 +114,7 @@ def _create_local_nccl_repository(repository_ctx):
|
||||
)
|
||||
else:
|
||||
# Create target for locally installed NCCL.
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["nccl"])
|
||||
config = find_cuda_config(repository_ctx, ["nccl"])
|
||||
config_wrap = {
|
||||
"%{nccl_version}": config["nccl_version"],
|
||||
"%{nccl_header_dir}": config["nccl_include_dir"],
|
||||
@@ -170,12 +164,20 @@ remote_nccl_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
nccl_configure = repository_rule(
|
||||
implementation = _nccl_autoconf_impl,
|
||||
environ = _ENVIRONS,
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the NCCL configuration.
|
||||
|
||||
|
||||
@@ -148,11 +148,6 @@ def _get_tensorrt_full_version(repository_ctx):
|
||||
return get_host_environ(repository_ctx, _TF_TENSORRT_VERSION, None)
|
||||
|
||||
def _create_local_tensorrt_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@local_xla//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
tpl_paths = {
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
@@ -161,7 +156,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
"plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"),
|
||||
}
|
||||
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda", "tensorrt"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "tensorrt"])
|
||||
cuda_version = config["cuda_version"]
|
||||
cuda_library_path = config["cuda_library_dir"] + "/"
|
||||
trt_version = config["tensorrt_version"]
|
||||
@@ -318,12 +313,16 @@ remote_tensorrt_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(default = "@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
|
||||
tensorrt_configure = repository_rule(
|
||||
implementation = _tensorrt_configure_impl,
|
||||
environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(default = "@local_xla//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ third_party/gpus/cuda/build_defs.bzl.tpl:
|
||||
third_party/gpus/cuda/cuda_config.h.tpl:
|
||||
third_party/gpus/cuda/cuda_config.py.tpl:
|
||||
third_party/gpus/cuda_configure.bzl:
|
||||
third_party/gpus/find_cuda_config.py.gz.base64:
|
||||
third_party/gpus/find_cuda_config:.py
|
||||
third_party/gpus/rocm/BUILD.tpl:
|
||||
third_party/gpus/rocm/BUILD:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_cuda_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_cuda_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_cuda_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_rocm_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_rocm_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_rocm_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -641,42 +641,19 @@ def _cudart_static_linkopt(cpu_value):
|
||||
"""Returns additional platform-specific linkopts for cudart."""
|
||||
return "" if cpu_value == "Darwin" else "\"-lrt\","
|
||||
|
||||
def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
|
||||
# and nccl_configure.bzl.
|
||||
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
def find_cuda_config(repository_ctx, cuda_libraries):
|
||||
"""Returns CUDA config dictionary from running find_cuda_config.py"""
|
||||
exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
def _get_cuda_config(repository_ctx):
|
||||
"""Detects and returns information about the CUDA installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -692,7 +669,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
compute_capabilities: A list of the system's CUDA compute capabilities.
|
||||
cpu_value: The name of the host operating system.
|
||||
"""
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
toolkit_path = config["cuda_toolkit_path"]
|
||||
|
||||
@@ -1008,9 +985,8 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"cuda:cuda_config.py",
|
||||
]}
|
||||
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
|
||||
find_cuda_config_script = repository_ctx.path(Label("@local_tsl//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
|
||||
cuda_config = _get_cuda_config(repository_ctx)
|
||||
|
||||
cuda_include_path = cuda_config.config["cuda_include_dir"]
|
||||
cublas_include_path = cuda_config.config["cublas_include_dir"]
|
||||
@@ -1484,12 +1460,20 @@ remote_cuda_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
cuda_configure = repository_rule(
|
||||
implementation = _cuda_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
@@ -53,12 +53,6 @@ tf_<library>_header_dir: ...
|
||||
tf_<library>_library_dir: ...
|
||||
"""
|
||||
|
||||
# You can use the following command to regenerate the base64 version of this
|
||||
# script:
|
||||
# cat third_party/tensorflow/third_party/gpus/find_cuda_config.oss.py |
|
||||
# pigz -z | base64 -w0 >
|
||||
# third_party/tensorflow/third_party/gpus/find_cuda_config.py.gz.base64.oss
|
||||
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
eJztW21v2zgS/q5fQSgoam9dJSlugUMOOcCbZlHftUlhZ7tYtIVB27TNrSz6SCppUPS/7wxJyRIt2YqjvX6JgTa2NPNwOC8Ph7Z4RC7E+l7yxVKTVyevTsjNkpEblighf43FHemneimkikg/jskQxRQZMsXkLZtFwVFwRN7yKYizGUmTGZNEg35/Tafwx93pkQ9MKi4S8io6IR0UCN2tsPsvQLgXKVnRe5IITVLFAIIrMucxI+zrlK014QmZitU65jSZMnLH9dIM40DADPKHgxATTUGagvwaPs2LcoRqYzC+llqvz46P7+7uImqMjYRcHMdWUB2/HVxcXo0uX4LBRuW3JGZKEcn+l3IJU53cE7oGe6Z0AlbG9I4ISehCMrinBdp7J7nmyaJHlJjrOyoZoMy40pJPUl1yVmYdzLkoAO6iCQn7IzIYheSX/mgw6gHG74ObN9e/3ZDf+8Nh/+pmcDki10NycX31enAzuL6CT7+S/tUf5L+Dq9c9wsBVMAz7upZoPxjJ0Y0mdGTEWMmAubAGqTWb8jmfwrySRUoXjCzELZMJTIesmVxxhcFUYN4MUGK+4ppqc2VrUjjMeauvIAzD95InmIbXFysYfiKpvEdjyJJRHH8GIZpqITkzNpJbm32QUgIMRMeaWd4rzVZREGDCq6nkkGeKUQm5oIwr6uAxMVUZpQcRR69pFcDFFabAjGl0VWJczGVmhAFaW/tRfyqSOV+k0jgQ9ZSeiVRHxqo1xUQXGThmiIsNptlSinSxxCRhyS2XIlmxRJNbKrlJyg7Y/278vn/zphsFgzkUF9yL+cwbkju39Ox0rB8yA405TEoTasl0Kk3YCVwCB03FjJX9p+kXZueVxeC+YDEUDd7K7aq0OyrixUJ8scGwvrfxzGJiA2GqfUnl7CXaM4MYaqj7QKWTYh7MpViRCVXOqY4YNrbl9kYEfLUxEdwDrBTkgsZNUJbHYq2PpZiuQhRJkf4o2KIh7nOaxjifOGUBZmsQQM0JCeET2TuhsnfAC+4dZFIQBNOYQp1emBBdopc7l4YCIVTds4CA9QrFYBQyXjA9dsON0ZQxTq1jxGysimYWlYwwJJWmcVxQAltfZ1lrPZ2F3IVtRZySTR3UjHCCMJ6PSM5r7QNxPidh7uMQQyhU5LIBLakG3Mh8LGh/BvmjOnljomQ0NkNvCXU3rtq6V3QYLj5Ccc3GrozHSbqaMNlZ0T+F7BHwGP4Btemy6P/TE3iRn4gRIy/wM35CafhkxIvD2CTPxuigFT2S0BXLwjN08QA2X4PBQOkABOomYODFiyJHudBg9cTIQ3Cfi0isWYYcyhBWiATKBpj9PEz1/OU/w671/wptAx9KFpm3HRke2YHIM0VedD7NXnRD8sxY1zP4XaMHkTXyFoXYMgAcczFaAGWtO6ddd9N5CbimY+S6AbqOcqjSYgU8z0YOv3035WZXtU/J8wgmB8gd4yPywsGWXyExfQUugMgz4AXTWXz7DuvmpyTMIEw61ENwoEM7F8sIyIJgN1sw9C4UDY3DbhZLMHZmM95ybUXineFUUdjcKmeV6jgRHBvyTnHgxVzGst45+ZibGvJkGqczZgodWFDad5lGtIQwQ4mYAv45elWtVyN9mi1YQhq9z+b/oi1gypVImLmOuTZHF1dbnSdF4aJfrH8KnuXnPMsTSKrsNvsKLZLqbCF0z/JpecZtieaCEyCGL1nWYkyLohnedjoWUiQ0XsqWdjMg4tg+A9wA+Zrll5dY1Q7qmgogji/OK1mhqNBzNPrhcjiC5m/8rv+f62Fo3WZZ5gCMwVWOsXYs8FAMIOaLNw7DFXkFU+JcKy4j61QVRQV7ZyuPLTNQ/ObcHFYghGcHkDngfQ+8VcKOVir2JV/LNNF8xRqUfEHYs+DAugdA/Le/4FGwmdJT3e+s+zeD92RoY/g3lP8RbMJhmSms5dDg2sb7K246oc3MlkT0PBiTVV6P3C0F2G27T4sFSJ2CiCUJaEOwGXlBSnew9LtODXoNyfit3SJmDYwlFdsavpzSeJpCLwjWMQkcoLjZnUPzL0zTbHHuKKy6UHe4XAoDNkkVNiwKu/A1hfnDvhK6f0gH3KeqhzLg1tweToBbTihz12mpiTO4GKitoi8SUG2RAw3VE0Blh1okoK0xSzS04tjeNaAgJ/i4vqPEQRbxuMQm7rVhIo9UchCnvAvriZkaMdPzd4Nr8F2ZlEzfnDNT3jUXwbvdh5YdDPT+8uqxleejHNZ8eCiPbT9qqqO6ASnVXJEBKlHCs0KwHt2OlMb2dx+TmKpmGxAj2d4eBOGOcXMkExpnF17ubU42mhapIcITH+zsVJzz2u1SHkgVEKlf3vZHLWxTyjAH71RKMC1sVioLqHa/UihMb8tSgdMuYZRH9xlDQiU1Ywwj2c7mxaEde6h+lZd4wih4ADWKJ0/0sIceZOE3kt30sJMPytnQuBSH/avXWSmWi7AM6GqnkKJe7VSkJLTaNbm6t88uD+UXynyum9UJCra3sAKa+9NoNUXxsvbTEnpgjYDvfugK6qWS0Xz4AuqjoObD108PxWg+bvmsqpLa1XNTex4BbKO0vnZuxva//WvGCE6wJUawaO7PfkZw4mXtJ0Y4hBGs734oI1gTPtiL7w7jgzLGYWxQwnj/SC6oqY9qLihVnffV29/NBaWx/e5ASzoFhzVqEJxsez2CBdy829khOOEt9Sc2eGB/YN32ozfZN8P+xeWwhW22D1TYaB9tsjZ/3kow+5v4kt4yQl01Z75w7XuRVk4e0yzUVExtv1CqRq9lqMRqvWsoWeA3Dgp/AWn4q6GTba99sICbd42aCKe0BfPUShzYSlj3/ehuwlrx6IaiDHNwT1GCaaGtqCmd2s6iVJZec1GJ1Xp/UbLAJw0R3zbqMDayLZKGAdx8PZ9fasYembb3Jf8+lNNdKPuUn74E3OKgjRcfTDxtko6x4PGkU4I5nHSKMG2QTnXp1ZNOsax90qnCap90ihb4+5rGrLORbW9f4yo+f9foG1CPJvYpP7Uq+/Y5BzJGm6wBsRpdv/3Qyj7HBzr4B0UPqIWfFB/EHX5xetuc/wN3+BY47th6Etp/kp2SGZ/iaQI80CHmthKNIQnDIyhuPjyZi31HDSpONQR5tnslVUVeewrg+Sg/LLI5HmKerMn3vwY8f7xmewwbK6bwfAhE6Pvm40f7wKoWIv7CtZEOP2cPwpZOJWQaUbqeUc064wYPm3drtJo8tVqnu+9Rszq9vQ+m7FDc/fv0DsWd3867LCkFwiuVz+Tf5+QfJ6cnJy5Nqp25d5ga8/Z9Ybgjfvu+Pmg+uZ/3TG7f6r9jfvtVC0RiIBx9rCisj6Y0tbw/y1fdL+y+l53GSIgSUrNZZ5tqIqCylep08+XSHGvrhM/UGXmm8PRKZ4Nk7HfHTQv1j6e63PKo7lVkj8NFeLiTdcJPyeVweD08g1L+lBTOkigtOwDYzdWAGDQeewkCiMV4jMdVxmNyfk7C8RjnOB4bNrbTDf4CEAHDYg==
|
||||
@@ -365,40 +365,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_
|
||||
libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True))
|
||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||
|
||||
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py');" % (python_bin)
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
def find_rocm_config(repository_ctx, script_path):
|
||||
def find_rocm_config(repository_ctx):
|
||||
"""Returns ROCm config dictionary from running find_rocm_config.py"""
|
||||
exec_result = _exec_find_rocm_config(repository_ctx, script_path)
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
def _get_rocm_config(repository_ctx, bash_bin):
|
||||
"""Detects and returns information about the ROCm installation on the system.
|
||||
|
||||
Args:
|
||||
@@ -413,7 +390,7 @@ def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
miopen_version_number: The version of MIOpen on the system.
|
||||
hipruntime_version_number: The version of HIP Runtime on the system.
|
||||
"""
|
||||
config = find_rocm_config(repository_ctx, find_rocm_config_script)
|
||||
config = find_rocm_config(repository_ctx)
|
||||
rocm_toolkit_path = config["rocm_toolkit_path"]
|
||||
rocm_version_number = config["rocm_version_number"]
|
||||
miopen_version_number = config["miopen_version_number"]
|
||||
@@ -565,10 +542,8 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"rocm:rocm_config.h",
|
||||
]}
|
||||
|
||||
find_rocm_config_script = repository_ctx.path(Label("@local_tsl//third_party/gpus:find_rocm_config.py.gz.base64"))
|
||||
|
||||
bash_bin = get_bash_bin(repository_ctx)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin)
|
||||
|
||||
# For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft
|
||||
rocm_version_number = int(rocm_config.rocm_version_number)
|
||||
@@ -849,12 +824,20 @@ remote_rocm_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
rocm_configure = repository_rule(
|
||||
implementation = _rocm_autoconf_impl,
|
||||
environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_rocm_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_rocm_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local ROCm toolchain.
|
||||
|
||||
|
||||
@@ -90,17 +90,11 @@ def _label(file):
|
||||
return Label("//third_party/nccl:{}".format(file))
|
||||
|
||||
def _create_local_nccl_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@local_tsl//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||
if nccl_version:
|
||||
nccl_version = nccl_version.split(".")[0]
|
||||
|
||||
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
||||
cuda_config = find_cuda_config(repository_ctx, ["cuda"])
|
||||
cuda_version = cuda_config["cuda_version"].split(".")
|
||||
|
||||
if nccl_version == "":
|
||||
@@ -120,7 +114,7 @@ def _create_local_nccl_repository(repository_ctx):
|
||||
)
|
||||
else:
|
||||
# Create target for locally installed NCCL.
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["nccl"])
|
||||
config = find_cuda_config(repository_ctx, ["nccl"])
|
||||
config_wrap = {
|
||||
"%{nccl_version}": config["nccl_version"],
|
||||
"%{nccl_header_dir}": config["nccl_include_dir"],
|
||||
@@ -170,12 +164,20 @@ remote_nccl_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
nccl_configure = repository_rule(
|
||||
implementation = _nccl_autoconf_impl,
|
||||
environ = _ENVIRONS,
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(
|
||||
default = Label("@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the NCCL configuration.
|
||||
|
||||
|
||||
@@ -148,11 +148,6 @@ def _get_tensorrt_full_version(repository_ctx):
|
||||
return get_host_environ(repository_ctx, _TF_TENSORRT_VERSION, None)
|
||||
|
||||
def _create_local_tensorrt_repository(repository_ctx):
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@local_tsl//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
tpl_paths = {
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
@@ -161,7 +156,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
"plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"),
|
||||
}
|
||||
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda", "tensorrt"])
|
||||
config = find_cuda_config(repository_ctx, ["cuda", "tensorrt"])
|
||||
cuda_version = config["cuda_version"]
|
||||
cuda_library_path = config["cuda_library_dir"] + "/"
|
||||
trt_version = config["tensorrt_version"]
|
||||
@@ -318,12 +313,16 @@ remote_tensorrt_configure = repository_rule(
|
||||
remotable = True,
|
||||
attrs = {
|
||||
"environ": attr.string_dict(),
|
||||
"_find_cuda_config": attr.label(default = "@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
|
||||
tensorrt_configure = repository_rule(
|
||||
implementation = _tensorrt_configure_impl,
|
||||
environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
|
||||
attrs = {
|
||||
"_find_cuda_config": attr.label(default = "@local_tsl//third_party/gpus:find_cuda_config.py"),
|
||||
},
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user