From d7a41d2c4581b0499e7cf08045a335f32b1fa438 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 9 Jan 2025 08:04:14 -0800 Subject: [PATCH] [xla:cpu:benchmarks] Add scripts to run Gemma2 Keras model. PiperOrigin-RevId: 713674776 --- .../cpu/benchmarks/e2e/gemma2/keras/README.md | 32 ++++++ .../benchmarks/e2e/gemma2/keras/benchmark.py | 107 ++++++++++++++++++ .../benchmarks/e2e/gemma2/keras/cleanup.sh | 22 ++++ .../cpu/benchmarks/e2e/gemma2/keras/config.sh | 21 ++++ .../e2e/gemma2/keras/requirements.txt | 5 + .../cpu/benchmarks/e2e/gemma2/keras/run.sh | 23 ++++ .../cpu/benchmarks/e2e/gemma2/keras/setup.sh | 25 ++++ 7 files changed, 235 insertions(+) create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh create mode 100644 third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md new file mode 100644 index 00000000000..35337b27d05 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md @@ -0,0 +1,32 @@ +# Gemma2 2B Keras model + +Scripts to run Gemma2 2B Keras model on CPU. + +Model link: https://www.kaggle.com/models/google/gemma-2/keras + +Instructions: + +* Set up your Kaggle API key by following + [these instructions](https://www.kaggle.com/docs/api#authentication). +* `$ bash setup.sh` + * This only needs to be run once. It will create a virtual environment at + a location read from `config.sh` and install the necessary dependencies. + * Change the `VENV_BASE` variable in `config.sh` before running `setup.sh` + if you want to use a different location. +* `$ KERAS_BACKEND=jax bash run.sh` + * This script activates the right virtual environment and runs the + benchmark in `benchmark.py`. + * Set `KERAS_BACKEND=tensorflow` or `torch` to run with TensorFlow or + PyTorch backend. +* (Optional) Delete the virtual environment: `$ bash cleanup.sh` + +To try other model variations with different numbers of parameters, modify the +following line in `benchmark.py`: + +``` +gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") +``` + +Replace "gemma2_2b_en" with other preset names, e.g., +"gemma2_instruct_2b_en","gemma2_9b_en", etc. See the full preset list +[here](https://github.com/keras-team/keras-hub/blob/86607dc921999e33f5b8a0bcf81ec987b60c9dee/keras_hub/src/models/gemma/gemma_presets.py#L5-L200). diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py new file mode 100644 index 00000000000..46d5e4355c1 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py @@ -0,0 +1,107 @@ +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark Gemma2-2B Keras performance.""" + +import time +import keras_nlp +import numpy as np + +_NUM_OUTPUT_TOKENS = 30 +_QUERY = "What is JAX in 3 bullet points?" +_VERBOSE = True + + +def compute_stats(array): + """Reports mean and ± range for the given array. + + The range computation follows benchstat's. + + Args: + array: The array to compute stats for. + + Returns: + mean and ± %diff range. + """ + q1 = np.percentile(array, 25) + q3 = np.percentile(array, 75) + low = q1 - 1.5 * (q3 - q1) + high = q3 + 1.5 * (q3 - q1) + + # Remove outliers. + filtered_array = list(filter(lambda x: low <= x and x <= high, array)) + + mean = np.mean(filtered_array) + min_val = np.min(filtered_array) + max_val = np.max(filtered_array) + max_diff = max(max_val - mean, mean - min_val) + diff = max_diff / mean * 100.0 + + return (mean, diff) + + +def run(gemma_lm, max_len): + """Benchmarks inferences with at most `max_len` output tokens. + + Args: + gemma_lm: The Gemma2 Keras model. + max_len: The maximum number of output tokens per one inference. + + Returns: + mean ± %diff and the actual number of output tokens generated per inference. + """ + # Warm up. + start = time.time() + output = gemma_lm.generate(_QUERY, max_length=max_len + 1) + num_actual_output_tokens = len(output.split(" ")) + warmup_time = (time.time() - start) * 1000 + + if _VERBOSE: + print("=== Max len: %d ===" % max_len) + print("Warmup: %lf ms" % warmup_time) + print("Output:\n%s\n" % output) + + times = [] + for i in range(1, 6): + start = time.time() + output = gemma_lm.generate(_QUERY, max_length=max_len + 1) + assert num_actual_output_tokens == len(output.split(" ")) + elapsed_time = (time.time() - start) * 1000 + times.append(elapsed_time) + if _VERBOSE: + print("%d: %lf ms" % (i, elapsed_time)) + + mean, diff = compute_stats(times) + if _VERBOSE: + print("Mean: %lf ± %d%% ms\n" % (mean, diff)) + + return (mean, diff, num_actual_output_tokens) + + +def main(): + if _VERBOSE: + print("Query: %s" % _QUERY) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") + mean_1, diff_1, _ = run(gemma_lm, 1) + mean_n, diff_n, num_output_tokens = run(gemma_lm, _NUM_OUTPUT_TOKENS) + + print("Generated %d tokens", num_output_tokens) + tpot = (mean_n - mean_1) / (num_output_tokens - 1) + print("TTFT: %lf ± %d%% ms" % (mean_1, diff_1)) + print("TPOT: %lf ± %d%% ms" % (tpot, diff_n)) + + +if __name__ == "__main__": + main() diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh new file mode 100644 index 00000000000..8cb893f5b1d --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh + +rm -rf ${GEMMA2_VENV} diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh new file mode 100644 index 00000000000..55f1139b818 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +export VENV_BASE=~/venv +export GEMMA2_VENV=${VENV_BASE}/gemma2-keras diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt new file mode 100644 index 00000000000..d9866bf65ba --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt @@ -0,0 +1,5 @@ +keras==3.8.0 +keras_nlp==0.18.1 +tensorflow==2.18.0 +jax==0.4.38 +torch==2.5.1 diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh new file mode 100644 index 00000000000..876625a6565 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh +source ${GEMMA2_VENV}/bin/activate + +python benchmark.py diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh new file mode 100644 index 00000000000..2258692d608 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh + +mkdir -p ${VENV_BASE} +python3 -m venv ${GEMMA2_VENV} +source ${GEMMA2_VENV}/bin/activate +pip install -r requirements.txt