mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[xla:cpu:benchmarks] Add scripts to run Gemma2 Keras model.
PiperOrigin-RevId: 713674776
This commit is contained in:
committed by
TensorFlower Gardener
parent
7b49ba401a
commit
d7a41d2c45
32
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md
vendored
Normal file
32
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
# Gemma2 2B Keras model
|
||||
|
||||
Scripts to run Gemma2 2B Keras model on CPU.
|
||||
|
||||
Model link: https://www.kaggle.com/models/google/gemma-2/keras
|
||||
|
||||
Instructions:
|
||||
|
||||
* Set up your Kaggle API key by following
|
||||
[these instructions](https://www.kaggle.com/docs/api#authentication).
|
||||
* `$ bash setup.sh`
|
||||
* This only needs to be run once. It will create a virtual environment at
|
||||
a location read from `config.sh` and install the necessary dependencies.
|
||||
* Change the `VENV_BASE` variable in `config.sh` before running `setup.sh`
|
||||
if you want to use a different location.
|
||||
* `$ KERAS_BACKEND=jax bash run.sh`
|
||||
* This script activates the right virtual environment and runs the
|
||||
benchmark in `benchmark.py`.
|
||||
* Set `KERAS_BACKEND=tensorflow` or `torch` to run with TensorFlow or
|
||||
PyTorch backend.
|
||||
* (Optional) Delete the virtual environment: `$ bash cleanup.sh`
|
||||
|
||||
To try other model variations with different numbers of parameters, modify the
|
||||
following line in `benchmark.py`:
|
||||
|
||||
```
|
||||
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
|
||||
```
|
||||
|
||||
Replace "gemma2_2b_en" with other preset names, e.g.,
|
||||
"gemma2_instruct_2b_en","gemma2_9b_en", etc. See the full preset list
|
||||
[here](https://github.com/keras-team/keras-hub/blob/86607dc921999e33f5b8a0bcf81ec987b60c9dee/keras_hub/src/models/gemma/gemma_presets.py#L5-L200).
|
||||
107
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py
vendored
Normal file
107
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2025 The OpenXLA Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark Gemma2-2B Keras performance."""
|
||||
|
||||
import time
|
||||
import keras_nlp
|
||||
import numpy as np
|
||||
|
||||
_NUM_OUTPUT_TOKENS = 30
|
||||
_QUERY = "What is JAX in 3 bullet points?"
|
||||
_VERBOSE = True
|
||||
|
||||
|
||||
def compute_stats(array):
|
||||
"""Reports mean and ± range for the given array.
|
||||
|
||||
The range computation follows benchstat's.
|
||||
|
||||
Args:
|
||||
array: The array to compute stats for.
|
||||
|
||||
Returns:
|
||||
mean and ± %diff range.
|
||||
"""
|
||||
q1 = np.percentile(array, 25)
|
||||
q3 = np.percentile(array, 75)
|
||||
low = q1 - 1.5 * (q3 - q1)
|
||||
high = q3 + 1.5 * (q3 - q1)
|
||||
|
||||
# Remove outliers.
|
||||
filtered_array = list(filter(lambda x: low <= x and x <= high, array))
|
||||
|
||||
mean = np.mean(filtered_array)
|
||||
min_val = np.min(filtered_array)
|
||||
max_val = np.max(filtered_array)
|
||||
max_diff = max(max_val - mean, mean - min_val)
|
||||
diff = max_diff / mean * 100.0
|
||||
|
||||
return (mean, diff)
|
||||
|
||||
|
||||
def run(gemma_lm, max_len):
|
||||
"""Benchmarks inferences with at most `max_len` output tokens.
|
||||
|
||||
Args:
|
||||
gemma_lm: The Gemma2 Keras model.
|
||||
max_len: The maximum number of output tokens per one inference.
|
||||
|
||||
Returns:
|
||||
mean ± %diff and the actual number of output tokens generated per inference.
|
||||
"""
|
||||
# Warm up.
|
||||
start = time.time()
|
||||
output = gemma_lm.generate(_QUERY, max_length=max_len + 1)
|
||||
num_actual_output_tokens = len(output.split(" "))
|
||||
warmup_time = (time.time() - start) * 1000
|
||||
|
||||
if _VERBOSE:
|
||||
print("=== Max len: %d ===" % max_len)
|
||||
print("Warmup: %lf ms" % warmup_time)
|
||||
print("Output:\n%s\n" % output)
|
||||
|
||||
times = []
|
||||
for i in range(1, 6):
|
||||
start = time.time()
|
||||
output = gemma_lm.generate(_QUERY, max_length=max_len + 1)
|
||||
assert num_actual_output_tokens == len(output.split(" "))
|
||||
elapsed_time = (time.time() - start) * 1000
|
||||
times.append(elapsed_time)
|
||||
if _VERBOSE:
|
||||
print("%d: %lf ms" % (i, elapsed_time))
|
||||
|
||||
mean, diff = compute_stats(times)
|
||||
if _VERBOSE:
|
||||
print("Mean: %lf ± %d%% ms\n" % (mean, diff))
|
||||
|
||||
return (mean, diff, num_actual_output_tokens)
|
||||
|
||||
|
||||
def main():
|
||||
if _VERBOSE:
|
||||
print("Query: %s" % _QUERY)
|
||||
|
||||
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
|
||||
mean_1, diff_1, _ = run(gemma_lm, 1)
|
||||
mean_n, diff_n, num_output_tokens = run(gemma_lm, _NUM_OUTPUT_TOKENS)
|
||||
|
||||
print("Generated %d tokens", num_output_tokens)
|
||||
tpot = (mean_n - mean_1) / (num_output_tokens - 1)
|
||||
print("TTFT: %lf ± %d%% ms" % (mean_1, diff_1))
|
||||
print("TPOT: %lf ± %d%% ms" % (tpot, diff_n))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh
vendored
Normal file
22
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2025 The OpenXLA Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
source config.sh
|
||||
|
||||
rm -rf ${GEMMA2_VENV}
|
||||
21
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh
vendored
Normal file
21
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2025 The OpenXLA Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
export VENV_BASE=~/venv
|
||||
export GEMMA2_VENV=${VENV_BASE}/gemma2-keras
|
||||
5
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt
vendored
Normal file
5
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
keras==3.8.0
|
||||
keras_nlp==0.18.1
|
||||
tensorflow==2.18.0
|
||||
jax==0.4.38
|
||||
torch==2.5.1
|
||||
23
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh
vendored
Normal file
23
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2025 The OpenXLA Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
source config.sh
|
||||
source ${GEMMA2_VENV}/bin/activate
|
||||
|
||||
python benchmark.py
|
||||
25
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh
vendored
Normal file
25
third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2025 The OpenXLA Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
source config.sh
|
||||
|
||||
mkdir -p ${VENV_BASE}
|
||||
python3 -m venv ${GEMMA2_VENV}
|
||||
source ${GEMMA2_VENV}/bin/activate
|
||||
pip install -r requirements.txt
|
||||
Reference in New Issue
Block a user