[xla:cpu:benchmarks] Add scripts to run Gemma2 Keras model.

PiperOrigin-RevId: 713674776
This commit is contained in:
Penporn Koanantakool
2025-01-09 08:04:14 -08:00
committed by TensorFlower Gardener
parent 7b49ba401a
commit d7a41d2c45
7 changed files with 235 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
# Gemma2 2B Keras model
Scripts to run Gemma2 2B Keras model on CPU.
Model link: https://www.kaggle.com/models/google/gemma-2/keras
Instructions:
* Set up your Kaggle API key by following
[these instructions](https://www.kaggle.com/docs/api#authentication).
* `$ bash setup.sh`
* This only needs to be run once. It will create a virtual environment at
a location read from `config.sh` and install the necessary dependencies.
* Change the `VENV_BASE` variable in `config.sh` before running `setup.sh`
if you want to use a different location.
* `$ KERAS_BACKEND=jax bash run.sh`
* This script activates the right virtual environment and runs the
benchmark in `benchmark.py`.
* Set `KERAS_BACKEND=tensorflow` or `torch` to run with TensorFlow or
PyTorch backend.
* (Optional) Delete the virtual environment: `$ bash cleanup.sh`
To try other model variations with different numbers of parameters, modify the
following line in `benchmark.py`:
```
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
```
Replace "gemma2_2b_en" with other preset names, e.g.,
"gemma2_instruct_2b_en","gemma2_9b_en", etc. See the full preset list
[here](https://github.com/keras-team/keras-hub/blob/86607dc921999e33f5b8a0bcf81ec987b60c9dee/keras_hub/src/models/gemma/gemma_presets.py#L5-L200).

View File

@@ -0,0 +1,107 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark Gemma2-2B Keras performance."""
import time
import keras_nlp
import numpy as np
_NUM_OUTPUT_TOKENS = 30
_QUERY = "What is JAX in 3 bullet points?"
_VERBOSE = True
def compute_stats(array):
"""Reports mean and ± range for the given array.
The range computation follows benchstat's.
Args:
array: The array to compute stats for.
Returns:
mean and ± %diff range.
"""
q1 = np.percentile(array, 25)
q3 = np.percentile(array, 75)
low = q1 - 1.5 * (q3 - q1)
high = q3 + 1.5 * (q3 - q1)
# Remove outliers.
filtered_array = list(filter(lambda x: low <= x and x <= high, array))
mean = np.mean(filtered_array)
min_val = np.min(filtered_array)
max_val = np.max(filtered_array)
max_diff = max(max_val - mean, mean - min_val)
diff = max_diff / mean * 100.0
return (mean, diff)
def run(gemma_lm, max_len):
"""Benchmarks inferences with at most `max_len` output tokens.
Args:
gemma_lm: The Gemma2 Keras model.
max_len: The maximum number of output tokens per one inference.
Returns:
mean ± %diff and the actual number of output tokens generated per inference.
"""
# Warm up.
start = time.time()
output = gemma_lm.generate(_QUERY, max_length=max_len + 1)
num_actual_output_tokens = len(output.split(" "))
warmup_time = (time.time() - start) * 1000
if _VERBOSE:
print("=== Max len: %d ===" % max_len)
print("Warmup: %lf ms" % warmup_time)
print("Output:\n%s\n" % output)
times = []
for i in range(1, 6):
start = time.time()
output = gemma_lm.generate(_QUERY, max_length=max_len + 1)
assert num_actual_output_tokens == len(output.split(" "))
elapsed_time = (time.time() - start) * 1000
times.append(elapsed_time)
if _VERBOSE:
print("%d: %lf ms" % (i, elapsed_time))
mean, diff = compute_stats(times)
if _VERBOSE:
print("Mean: %lf ± %d%% ms\n" % (mean, diff))
return (mean, diff, num_actual_output_tokens)
def main():
if _VERBOSE:
print("Query: %s" % _QUERY)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
mean_1, diff_1, _ = run(gemma_lm, 1)
mean_n, diff_n, num_output_tokens = run(gemma_lm, _NUM_OUTPUT_TOKENS)
print("Generated %d tokens", num_output_tokens)
tpot = (mean_n - mean_1) / (num_output_tokens - 1)
print("TTFT: %lf ± %d%% ms" % (mean_1, diff_1))
print("TPOT: %lf ± %d%% ms" % (tpot, diff_n))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -x
set -e
source config.sh
rm -rf ${GEMMA2_VENV}

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -x
set -e
export VENV_BASE=~/venv
export GEMMA2_VENV=${VENV_BASE}/gemma2-keras

View File

@@ -0,0 +1,5 @@
keras==3.8.0
keras_nlp==0.18.1
tensorflow==2.18.0
jax==0.4.38
torch==2.5.1

View File

@@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -x
set -e
source config.sh
source ${GEMMA2_VENV}/bin/activate
python benchmark.py

View File

@@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -x
set -e
source config.sh
mkdir -p ${VENV_BASE}
python3 -m venv ${GEMMA2_VENV}
source ${GEMMA2_VENV}/bin/activate
pip install -r requirements.txt