Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,47 @@
"""


_SUPPORTED_FILE_CONTENT_MIME_TYPES = frozenset({
# Images
'image/png',
'image/jpeg',
'image/webp',
'image/heic',
'image/heif',
# Documents & Text
'application/pdf',
'text/plain',
'text/csv',
'text/html',
'text/md',
'text/x-python',
'text/javascript',
# Audio
'audio/wav',
'audio/mp3',
'audio/aiff',
'audio/aac',
'audio/ogg',
'audio/flac',
'audio/mpeg',
'audio/mpga',
'audio/m4a',
'audio/pcm',
'audio/webm',
# Video
'video/mp4',
'video/mpeg',
'video/mov',
'video/quicktime',
'video/avi',
'video/x-flv',
'video/mpg',
'video/webm',
'video/wmv',
'video/3gpp',
})


class _ResourceExhaustedError(ClientError):
"""Represents a resources exhausted error received from the Model."""

Expand Down Expand Up @@ -455,9 +496,26 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
for part in content.parts:
# Create copies to avoid mutating the original objects
if part.inline_data:
mime_type = (part.inline_data.mime_type or '').lower()
if mime_type not in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
identifier = part.inline_data.display_name or 'inline_file'
part.text = (
part.text or ''
) + f'\n[File reference: "{identifier}"]'
part.inline_data = None
part.inline_data = copy.copy(part.inline_data)
_remove_display_name_if_present(part.inline_data)

if part.file_data:
mime_type = (part.file_data.mime_type or '').lower()
identifier = (
part.file_data.display_name or part.file_data.file_uri
)
if mime_type not in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
part.text = (
part.text or ''
) + f'\n[File reference: "{identifier}"]'
part.file_data = None
part.file_data = copy.copy(part.file_data)
_remove_display_name_if_present(part.file_data)

Expand Down
46 changes: 46 additions & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,3 +2167,49 @@ async def __aexit__(self, *args):
# Verify the final speech_config is still None
assert config_arg.speech_config is None
assert isinstance(connection, GeminiLlmConnection)


@pytest.mark.asyncio
async def test_preprocess_request_unsupported_mime_type(gemini_llm):
"""Verifies that MS Office files are escaped to a text reference."""
unsupported_part = types.Part(
file_data=types.FileData(
mime_type="application/vnd.ms-excel",
file_uri="gs://bucket/data.xls",
display_name="data.xls",
)
)
req = LlmRequest(
model="gemini-2.0-flash",
contents=[types.Content(parts=[unsupported_part])],
)

await gemini_llm._preprocess_request(req)

processed_part = req.contents[0].parts[0]
# File_data should be stripped to avoid the 400 error
assert processed_part.file_data is None
# Text fallback should be present
assert '[File reference: "data.xls"]' in processed_part.text


@pytest.mark.asyncio
async def test_preprocess_request_supported_mime_type(gemini_llm):
"""Verifies that PDF files are passed through without modification."""
supported_part = types.Part(
file_data=types.FileData(
mime_type="application/pdf",
file_uri="gs://bucket/doc.pdf",
display_name="doc.pdf",
)
)
req = LlmRequest(
model="gemini-2.0-flash", contents=[types.Content(parts=[supported_part])]
)

await gemini_llm._preprocess_request(req)

processed_part = req.contents[0].parts[0]
# file_data should still be intact
assert processed_part.file_data is not None
assert processed_part.file_data.mime_type == "application/pdf"
Loading