Google Gen AI SDK

Imports

from google import genai
from google.genai import types

Create a client

Please run one of the following code blocks to create a client for different services (Google AI or Vertex). Feel free to switch the client and run all the examples to see how it behaves under different APIs.

# Only run this block for Google AI API
client = genai.Client(api_key='YOUR_API_KEY')
# Only run this block for Vertex AI API
client = genai.Client(
    vertexai=True, project='your-project-id', location='us-central1'
)

Types

Parameter types can be specified as either dictionaries(TypedDict) or pydantic Models. Pydantic model types are available in the types module.

Models

The client.models modules exposes model inferencing and model getters.

Generate Content

response = client.models.generate_content(
    model='gemini-1.5-pro-002', contents='What is your name?'
)
print(response.text)

System Instructions and Other Configs

response = client.models.generate_content(
    model='gemini-1.5-pro-002',
    contents='high',
    config=types.GenerateContentConfig(
        system_instruction='I say high, you say low',
        temperature= 0.3,
    ),
)
print(response.text)

Typed Config

All API methods support pydantic types for parameters as well as dictionaries. You can get the type from google.genai.types.

response = client.models.generate_content(
    model='gemini-1.5-flash-002',
    contents=types.Part.from_text('Why is sky blue?'),
    config=types.GenerateContentConfig(
        temperature=0,
        top_p=0.95,
        top_k=20,
        candidate_count=1,
        seed=5,
        max_output_tokens=100,
        stop_sequences=["STOP!"],
        presence_penalty=0.0,
        frequency_penalty=0.0,
    )
)

response

Safety Settings

response = client.models.generate_content(
    model='gemini-1.5-flash',
    contents='Say something bad.',
    config=types.GenerateContentConfig(
        safety_settings= [types.SafetySetting(
            category='HARM_CATEGORY_HATE_SPEECH',
            threshold='BLOCK_ONLY_HIGH',
        )]
    ),
)
print(response.text)

Function Calling

Automatic Python function Support

You can pass a python function directly and it will be automatically called and responded.

def get_current_weather(location: str,) -> int:
  """Returns the current weather.

  Args:
    location: The city and state, e.g. San Francisco, CA
  """
  return 'sunny'

response = client.models.generate_content(
    model='gemini-1.5-pro-002',
    contents="What is the weather like in Boston?",
    config=types.GenerateContentConfig(tools=[get_current_weather],)
)

response.text
function = dict(
    name="get_current_weather",
    description="Get the current weather in a given location",
    parameters={
      "type": "OBJECT",
      "properties": {
          "location": {
              "type": "STRING",
              "description": "The city and state, e.g. San Francisco, CA",
          },
      },
      "required": ["location"],
    }
)

tool = types.Tool(function_declarations=[function])


response = client.models.generate_content(
    model='gemini-1.5-pro-002',
    contents="What is the weather like in Boston?",
    config=types.GenerateContentConfig(tools=[tool],)
)

response.candidates[0].content.parts[0].function_call
function_call_part = response.candidates[0].content.parts[0]

function_reponse = get_current_weather(**function_call_part.function_call.args)


function_response_part = types.Part.from_function_response(
    name=function_call_part.function_call.name,
    response={'result': function_reponse}
)

response = client.models.generate_content(
    model='gemini-1.5-pro-002',
    contents=[
        types.Part.from_text("What is the weather like in Boston?"),
        function_call_part,
        function_response_part,
    ])

response

JSON Response Schema

Pydantic Model Schema support

Schemas can be provided as Pydantic Models.

from pydantic import BaseModel

class CountryInfo(BaseModel):
  name: str
  population: int
  capital: str
  continent: str
  gdp: int
  official_language: str
  total_area_sq_mi: int


response = client.models.generate_content(
    model='gemini-1.5-flash',
    contents='Give me information of the United States.',
    config=types.GenerateContentConfig(
        response_mime_type= 'application/json',
        response_schema= CountryInfo,
    ),
)
print(response.text)
response = client.models.generate_content(
    model='gemini-1.5-flash',
    contents='Give me information of the United States.',
    config={
        'response_mime_type': 'application/json',
        'response_schema': {
            'required': [
                'name',
                'population',
                'capital',
                'continent',
                'gdp',
                'official_language',
                'total_area_sq_mi',
            ],
            'properties': {
                'name': {'type': 'STRING'},
                'population': {'type': 'INTEGER'},
                'capital': {'type': 'STRING'},
                'continent': {'type': 'STRING'},
                'gdp': {'type': 'INTEGER'},
                'official_language': {'type': 'STRING'},
                'total_area_sq_mi': {'type': 'INTEGER'},
            },
            'type': 'OBJECT',
        },
    },
)
print(response.text)

Streaming

for chunk in client.models.generate_content_stream(
    model='gemini-1.5-flash', contents='Tell me a story in 300 words.'
):
  print(chunk.text)

Async

client.aio exposes all the analogous async methods that are available on client

For example, client.aio.models.generate_content is the async version of client.models.generate_content

request = await client.aio.models.generate_content(
    model='gemini-1.5-flash', contents='Tell me a story in 300 words.'
)

print(response.text)

Streaming

async for response in client.aio.models.generate_content_stream(
    model='gemini-1.5-flash', contents='Tell me a story in 300 words.'
):
  print(response.text)

Count Tokens and Compute Tokens

response = client.models.count_tokens(
    model='gemini-1.5-flash',
    contents='What is your name?',
)
print(response)

Compute Tokens

Compute tokens is not supported by Google AI.

response = client.models.compute_tokens(
    model='gemini-1.5-flash',
    contents='What is your name?',
)
print(response)

Async

response = await client.aio.models.count_tokens(
    model='gemini-1.5-flash',
    contents='What is your name?',
)
print(response)

Embed Content

response = client.models.embed_content(
    model='text-embedding-004',
    contents='What is your name?',
)
response
# multiple contents with config
response = client.models.embed_content(
    model='text-embedding-004',
    contents=['What is your name?', 'What is your age?'],
    config=types.EmbedContentConfig(output_dimensionality= 10)
)

response

Imagen

Generate Image

Support for generate image in Google AI is behind an allowlist

# Generate Image
response1 = client.models.generate_image(
    model='imagen-3.0-generate-001',
    prompt='An umbrella in the foreground, and a rainy night sky in the background',
    config=types.GenerateImageConfig(
        negative_prompt= "human",
        number_of_images= 1,
        include_rai_reason= True,
        output_mime_type= "image/jpeg"
    )
)
response1.generated_images[0].image.show()

Upscale Image

# Upscale the generated image from above
response2 = client.models.upscale_image(
    model='imagen-3.0-generate-001',
    image=response1.generated_images[0].image,
    config=types.UpscaleImageConfig(upscale_factor="x2")
)
response2.generated_images[0].image.show()

Edit Image

# Edit the generated image from above
from google.genai.types import RawReferenceImage, MaskReferenceImage
raw_ref_image = RawReferenceImage(
    reference_id=1,
    reference_image=response1.generated_images[0].image,
)

# Model computes a mask of the background
mask_ref_image = MaskReferenceImage(
    reference_id=2,
    config=types.MaskReferenceConfig(
        mask_mode='MASK_MODE_BACKGROUND',
        mask_dilation=0,
    ),
)

response3 = client.models.edit_image(
    model='imagen-3.0-capability-preview-0930',
    prompt='Sunlight and clear sky',
    reference_images=[raw_ref_image, mask_ref_image],
    config=types.EditImageConfig(
        edit_mode= 'EDIT_MODE_INPAINT_INSERTION',
        number_of_images= 1,
        negative_prompt= 'human',
        include_rai_reason= True,
        output_mime_type= 'image/jpeg',
    ),
)
response3.generated_images[0].image.show()

Files (Only Google AI)

!gsutil cp gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf .
!gsutil cp gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf .

Upload

file1 = client.files.upload(path='2312.11805v3.pdf')
file2 = client.files.upload(path='2403.05530.pdf')

print(file1)
print(file2)

Delete

file3 = client.files.upload(path='2312.11805v3.pdf')

client.files.delete(name=file3.name)

Caches

client.caches contains the control plane APIs for cached content

Create

if client._api_client.vertexai:
  file_uris = [
      'gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf',
      'gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf'
  ]
else:
  file_uris = [file1.uri, file2.uri]

cached_content = client.caches.create(
      model='gemini-1.5-pro-002',
      contents=[
          types.Content(
              role='user',
              parts=[
                types.Part.from_uri(
                    file_uri=file_uris[0],
                    mime_type='application/pdf'),
                types.Part.from_uri(
                    file_uri=file_uris[1],
                    mime_type='application/pdf',)])
      ],
      config=types.CreateCachedContentConfig(
          display_name='test cache',
          system_instruction='What is the sum of the two pdfs?',
          ttl='3600s',
      ),
  )

Get

client.caches.get(name=cached_content.name)

Generate Content

client.models.generate_content(
    model='gemini-1.5-pro-002',
    contents='Summarize the pdfs',
    config=types.GenerateContentConfig(
        cached_content=cached_content.name,
    )
)

Tunings

client.tunings contains tuning job APIs and supports supervised fine tuning through tune and distiallation through distill

Tune

  • Vertex supports tuning from GCS source

  • Google AI supports tuning from inline examples

if client._api_client.vertexai:
  model = 'gemini-1.5-pro-002'
  training_dataset=types.TuningDataset(
        gcs_uri='gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl',
  )
else:
  model = 'models/gemini-1.0-pro-001'
  training_dataset=types.TuningDataset(
        examples=[
            types.TuningExample(
                text_input=f"Input text {i}",
                output=f"Output text {i}",
            )
            for i in range(5)
        ],
    )
tuning_job = client.tunings.tune(
    base_model=model,
    training_dataset=training_dataset,
    config=types.CreateTuningJobConfig(
        epoch_count= 1,
        tuned_model_display_name="test_dataset_examples model"
    )
)
tuning_job

Get Tuning Job

tuning_job = client.tunings.get(name=tuning_job.name)
tuning_job
import time

running_states = set([
    "JOB_STATE_PENDING",
    "JOB_STATE_RUNNING",
])

while tuning_job.state in running_states:
    print(tuning_job.state)
    tuning_job = client.tunings.get(name=tuning_job.name)
    time.sleep(10)

Use Tuned Model

response = client.models.generate_content(
    model=tuning_job.tuned_model.endpoint,
    contents='What is your name?',
)

response.text

Get Tuned Model

tuned_model = client.models.get(model=tuning_job.tuned_model.model)
tuned_model

List Tuned Models

for model in client.models.list(config={'page_size': 10}):
  print(model)
pager = client.models.list(config={'page_size': 10})
print(pager.page_size)
print(pager.page[0])
pager.next_page()
print(pager.page[0])

Async

async for job in await client.aio.models.list(config={'page_size': 10}):
  print(job)
async_pager = await client.aio.models.list(config={'page_size': 10})
print(async_pager.page_size)
print(async_pager.page[0])
await async_pager.next_page()
print(async_pager.page[0])

Update Tuned Model

model = pager.page[0]

model = client.models.update(
    model=model.name,
    config=types.UpdateModelConfig(
        display_name='my tuned model',
        description='my tuned model description'))

model

Distillation

Only supported on Vertex. Requires allowlist.

distillation_job = client.tunings.distill(
    student_model="gemma-2b-1.1-it",
    teacher_model="gemini-1.5-pro-002",
    training_dataset=genai.types.DistillationDataset(
        gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
    ),
    config=genai.types.CreateDistillationJobConfig(
        epoch_count=1,
        pipeline_root_directory=(
            "gs://vertex-sdk-dev-staging-us-central1/tmp/distillation_pipeline_root"
        ),
    ),
)
distillation_job
tcompleted_states = set([
    "JOB_STATE_SUCCEEDED",
    "JOB_STATE_FAILED",
    "JOB_STATE_CANCELLED",
    "JOB_STATE_PAUSED"
])

while distillation_job.state not in completed_states:
    print(distillation_job.state)
    distillation_job = client.tunings.get(name=distillation_job.name)
    time.sleep(10)
distillation_job

List Tuning Jobs

for job in client.tunings.list(config={'page_size': 10}):
  print(job)
pager = client.tunings.list(config={'page_size': 10})
print(pager.page_size)
print(pager.page[0])
pager.next_page()
print(pager.page[0])

Async

async for job in await client.aio.tunings.list(config={'page_size': 10}):
  print(job)
async_pager = await client.aio.tunings.list(config={'page_size': 10})
print(async_pager.page_size)
print(async_pager.page[0])
await async_pager.next_page()
print(async_pager.page[0])

Batch Prediction

Only supported in Vertex AI.

Create

# Specify model and source file only, destination and job display name will be auto-populated
job = client.batches.create(
    model='gemini-1.5-flash-002',
    src='bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_requests',
    # # optionally specify destination and display_name by yourself
    # config = {
    #     'dest': 'bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_responses',
    #     'display_name': 'create_batch_job_demo'
    # }
)

job
# Get a job by name
job = client.batches.get(name=job.name)

job.state
completed_states = set([
    "JOB_STATE_SUCCEEDED",
    "JOB_STATE_FAILED",
    "JOB_STATE_CANCELLED",
    "JOB_STATE_PAUSED"
])

while job.state not in completed_states:
    print(job.state)
    job = client.batches.get(name=job.name)
    time.sleep(30)

job

List

for job in client.batches.list(config={'page_size': 10}):
  print(job)
pager = client.batches.list(config={'page_size': 10})
print(pager.page_size)
print(pager.page[0])
pager.next_page()
print(pager.page[0])

Async

async for job in await client.aio.batches.list(config={'page_size': 10}):
  print(job)
async_pager = await client.aio.tunings.list(config={'page_size': 10})
print(async_pager.page_size)
print(async_pager.page[0])
await async_pager.next_page()
print(async_pager.page[0])

Delete

# Delete the job resource
delete_job = client.batches.delete(name=job.name)

delete_job

Reference