![]() | ![]() | ![]() |
В этом руководстве показано, как запустить Gemma с использованием платформы PyTorch, в том числе как использовать данные изображения для запроса моделей Gemma Release 3 и более поздних версий. Более подробную информацию о реализации Gemma PyTorch можно найти в репозитории проекта README .
Настраивать
В следующих разделах объясняется, как настроить среду разработки, в том числе как получить доступ к моделям Gemma для загрузки из Kaggle, настроить переменные аутентификации, установить зависимости и импортировать пакеты.
Системные требования
Для этой библиотеки Gemma Pytorch требуются процессоры GPU или TPU для запуска модели Gemma. Стандартной среды выполнения Colab CPU Python и среды выполнения T4 GPU Python достаточно для запуска моделей Gemma размеров 1B, 2B и 4B. Дополнительные варианты использования других графических процессоров или TPU см. в README в репозитории Gemma PyTorch.
Получите доступ к Джемме на Kaggle
Чтобы выполнить это руководство, сначала необходимо следовать инструкциям по установке на странице Gemma setup , в которых показано, как сделать следующее:
- Получите доступ к Джемме на kaggle.com .
- Выберите среду выполнения Colab с достаточными ресурсами для запуска модели Gemma.
- Создайте и настройте имя пользователя Kaggle и ключ API.
После завершения настройки Gemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.
Установить переменные среды
Установите переменные среды для KAGGLE_USERNAME
и KAGGLE_KEY
. При появлении запроса «Предоставить доступ?» сообщения, согласитесь предоставить секретный доступ.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Установить зависимости
pip install -q -U torch immutabledict sentencepiece
Скачать вес модели
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '4b':
CONFIG = '4b-v1'
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Задайте пути токенизатора и контрольных точек для модели.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Настройка среды выполнения
В следующих разделах объясняется, как подготовить среду PyTorch для запуска Gemma.
Подготовьте среду запуска PyTorch.
Подготовьте среду выполнения модели PyTorch, клонировав репозиторий Gemma Pytorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Установите конфигурацию модели
Прежде чем запустить модель, необходимо установить некоторые параметры конфигурации, включая вариант Gemma, токенизатор и уровень квантования.
# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Настройте контекст устройства
Следующий код настраивает контекст устройства для запуска модели:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Создайте экземпляр и загрузите модель
Загрузите модель с ее весами, чтобы подготовиться к выполнению запросов.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Выполнить вывод
Ниже приведены примеры генерации в режиме чата и генерации с несколькими запросами.
Модели Gemma, настроенные с помощью инструкций, были обучены с помощью специального форматтера, который аннотирует примеры настройки инструкций дополнительной информацией как во время обучения, так и в процессе вывода. Аннотации (1) обозначают роли в разговоре, а (2) обозначают повороты в разговоре.
Соответствующие токены аннотаций:
-
user
: очередь пользователя -
model
: поворот модели -
<start_of_turn>
: начало поворота диалога. -
<start_of_image>
: тег для ввода данных изображения. -
<end_of_turn><eos>
: конец хода диалога.
Для получения дополнительной информации прочтите о форматировании подсказок для моделей Gemma, настроенных с помощью инструкций [здесь]( https://ai.google.dev/gemma/core/prompt-structure
Генерировать текст с текстом
Ниже приведен пример фрагмента кода, демонстрирующий, как отформатировать приглашение для модели Gemma, настроенной с помощью инструкций, с использованием шаблонов чата пользователя и модели в многоходовом разговоре.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Генерация текста с изображениями
В версии Gemma 3 и более поздних версиях вы можете использовать изображения в подсказках. В следующем примере показано, как включить визуальные данные в приглашение.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)
print(model.generate(
[['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
device=device,
output_len=OUTPUT_LEN,
))
Узнать больше
Теперь, когда вы узнали, как использовать Gemma в Pytorch, вы можете изучить множество других вещей, которые может делать Gemma, в ai.google.dev/gemma . См. также другие связанные ресурсы: