Skip to content

Commit d52cc4e

Browse files
authored
fix(generate_image): updating generate_image tool to support additional models in Amazon Bedrock (#89)
* switching region to us-east-1 as sdxl model is no longer available in us-west-2 * fix(image-gen): add support for Amazon Nova Canvas Fix image generation functionality by adding support for Amazon Nova Canvas model in addition to existing Stable Diffusion models. This enhancement allows users to generate images using both model types with appropriate parameters. * fix(gen_image): region selection and number of images parmaters * fix: updating toolspec to include region id and remove number of images parameter (defaulting the tool to generate one image always * fix(generate_image): update the tool to initial state where it only support stable diffusion models * Update src/strands_tools/generate_image.py Co-authored-by: Mackenzie Zastrow <[email protected]> * fix(generate_image): removing conditional for model specific * Update src/strands_tools/generate_image.py
1 parent 3d8adc2 commit d52cc4e

File tree

2 files changed

+136
-96
lines changed

2 files changed

+136
-96
lines changed

src/strands_tools/generate_image.py

Lines changed: 125 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
Key Features:
1010
1111
1. Image Generation:
12-
• Text-to-image conversion using Stable Diffusion
13-
• Support for multiple model variants (primarily stable-diffusion-xl-v1)
14-
• Customizable generation parameters (seed, steps, cfg_scale)
15-
• Style preset selection for consistent aesthetics
12+
• Text-to-image conversion using Stable Diffusion models
13+
• Support for the following models:
14+
• stability.sd3-5-large-v1:0
15+
• stability.stable-image-core-v1:1
16+
• stability.stable-image-ultra-v1:1
17+
• Customizable generation parameters (seed, aspect_ratio, output_format, negative_prompt)
1618
1719
2. Output Management:
1820
• Automatic local saving with intelligent filename generation
@@ -36,14 +38,22 @@
3638
# Basic usage with default parameters
3739
agent.tool.generate_image(prompt="A steampunk robot playing chess")
3840
39-
# Advanced usage with custom parameters
41+
# Advanced usage with Stable Diffusion
4042
agent.tool.generate_image(
4143
prompt="A futuristic city with flying cars",
42-
model_id="stability.stable-diffusion-xl-v1",
43-
seed=42,
44-
steps=50,
45-
cfg_scale=12,
46-
style_preset="cinematic"
44+
model_id="stability.sd3-5-large-v1:0",
45+
aspect_ratio="5:4",
46+
output_format="jpeg",
47+
negative_prompt="bad lighting, harsh lighting, abstract, surreal, twisted, multiple levels",
48+
)
49+
50+
# Using another Stable Diffusion model
51+
agent.tool.generate_image(
52+
prompt="A photograph of a cup of coffee from the side",
53+
model_id="stability.stable-image-ultra-v1:1",
54+
aspect_ratio="1:1",
55+
output_format="png",
56+
negative_prompt="blurry, distorted",
4757
)
4858
```
4959
@@ -60,9 +70,16 @@
6070
import boto3
6171
from strands.types.tools import ToolResult, ToolUse
6272

73+
STABLE_DIFFUSION_MODEL_ID = [
74+
"stability.sd3-5-large-v1:0",
75+
"stability.stable-image-core-v1:1",
76+
"stability.stable-image-ultra-v1:1",
77+
]
78+
79+
6380
TOOL_SPEC = {
6481
"name": "generate_image",
65-
"description": "Generates an image using Stable Diffusion based on a given prompt",
82+
"description": "Generates an image using Stable Diffusion models based on a given prompt",
6683
"inputSchema": {
6784
"json": {
6885
"type": "object",
@@ -73,23 +90,32 @@
7390
},
7491
"model_id": {
7592
"type": "string",
76-
"description": "Model id for image model, stability.stable-diffusion-xl-v1.",
93+
"description": "Model id for image model, stability.sd3-5-large-v1:0, \
94+
stability.stable-image-core-v1:1, or stability.stable-image-ultra-v1:1",
95+
},
96+
"region": {
97+
"type": "string",
98+
"description": "AWS region for the image generation model (default: us-west-2)",
7799
},
78100
"seed": {
79101
"type": "integer",
80102
"description": "Optional: Seed for random number generation (default: random)",
81103
},
82-
"steps": {
83-
"type": "integer",
84-
"description": "Optional: Number of steps for image generation (default: 30)",
104+
"aspect_ratio": {
105+
"type": "string",
106+
"description": "Optional: Controls the aspect ratio of the generated image for \
107+
Stable Diffusion models. Default 1:1. Enum: 16:9, 1:1, 21:9, 2:3, 3:2, 4:5, 5:4, 9:16, 9:21",
85108
},
86-
"cfg_scale": {
87-
"type": "number",
88-
"description": "Optional: CFG scale for image generation (default: 10)",
109+
"output_format": {
110+
"type": "string",
111+
"description": "Optional: Specifies the format of the output image for Stable Diffusion models. \
112+
Supported formats: JPEG, PNG.",
89113
},
90-
"style_preset": {
114+
"negative_prompt": {
91115
"type": "string",
92-
"description": "Optional: Style preset for image generation (default: 'photographic')",
116+
"description": "Optional: Keywords of what you do not wish to see in the output image. \
117+
Default: bad lighting, harsh lighting. \
118+
Max: 10.000 characters.",
93119
},
94120
},
95121
"required": ["prompt"],
@@ -98,19 +124,28 @@
98124
}
99125

100126

127+
# Create a filename based on the prompt
128+
def create_filename(prompt: str) -> str:
129+
"""Generate a filename from the prompt text."""
130+
words = re.findall(r"\w+", prompt.lower())[:5]
131+
filename = "_".join(words)
132+
filename = re.sub(r"[^\w\-_\.]", "_", filename)
133+
return filename[:100] # Limit filename length
134+
135+
101136
def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
102137
"""
103-
Generate images from text prompts using Stable Diffusion via Amazon Bedrock.
138+
Generate images from text prompts using Stable Diffusion models via Amazon Bedrock.
104139
105140
This function transforms textual descriptions into high-quality images using
106-
Stable Diffusion models available through Amazon Bedrock. It provides extensive
141+
image generation models available through Amazon Bedrock. It provides extensive
107142
customization options and handles the complete process from API interaction to
108143
image storage and result formatting.
109144
110145
How It Works:
111146
------------
112147
1. Extracts and validates parameters from the tool input
113-
2. Configures the request payload with appropriate parameters
148+
2. Configures the request payload with appropriate parameters based on model type
114149
3. Invokes the Bedrock image generation model through AWS SDK
115150
4. Processes the response to extract the base64-encoded image
116151
5. Creates an appropriate filename based on the prompt content
@@ -120,11 +155,13 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
120155
Generation Parameters:
121156
--------------------
122157
- prompt: The textual description of the desired image
123-
- model_id: Specific model to use (defaults to stable-diffusion-xl-v1)
158+
- model_id: Specific model to use (defaults to stability.stable-image-core-v1:1)
124159
- seed: Controls randomness for reproducible results
125-
- style_preset: Artistic style to apply (e.g., photographic, cinematic)
126-
- cfg_scale: Controls how closely the image follows the prompt
127-
- steps: Number of diffusion steps (higher = more refined but slower)
160+
- aspect_ratio: Controls the aspect ratio of the generated image
161+
- output_format: Specifies the format of the output image (e.g., png or jpeg)
162+
- negative_prompt: Keywords of what you do not wish to see in the output image
163+
164+
128165
129166
Common Usage Scenarios:
130167
---------------------
@@ -137,11 +174,8 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
137174
Args:
138175
tool: ToolUse object containing the parameters for image generation.
139176
- prompt: The text prompt describing the desired image.
140-
- model_id: Optional model identifier (default: "stability.stable-diffusion-xl-v1").
141-
- seed: Optional random seed (default: random integer).
142-
- style_preset: Optional style preset name (default: "photographic").
143-
- cfg_scale: Optional CFG scale value (default: 10).
144-
- steps: Optional number of diffusion steps (default: 30).
177+
- model_id: Optional model identifier.
178+
- Additional parameters specific to the chosen model type.
145179
**kwargs: Additional keyword arguments (unused).
146180
147181
Returns:
@@ -161,78 +195,84 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult:
161195
tool_use_id = tool["toolUseId"]
162196
tool_input = tool["input"]
163197

164-
# Extract input parameters
198+
# Extract common and Stable Diffusion input parameters
199+
aspect_ratio = tool_input.get("aspect_ratio", "1:1")
200+
output_format = tool_input.get("output_format", "jpeg")
165201
prompt = tool_input.get("prompt", "A stylized picture of a cute old steampunk robot.")
166-
model_id = tool_input.get("model_id", "stability.stable-diffusion-xl-v1")
202+
model_id = tool_input.get("model_id", "stability.stable-image-core-v1:1")
203+
region = tool_input.get("region", "us-west-2")
167204
seed = tool_input.get("seed", random.randint(0, 4294967295))
168-
style_preset = tool_input.get("style_preset", "photographic")
169-
cfg_scale = tool_input.get("cfg_scale", 10)
170-
steps = tool_input.get("steps", 30)
205+
negative_prompt = tool_input.get("negative_prompt", "bad lighting, harsh lighting")
171206

172207
# Create a Bedrock Runtime client
173-
client = boto3.client("bedrock-runtime", region_name="us-west-2")
208+
client = boto3.client("bedrock-runtime", region_name=region)
209+
210+
# Initialize variables for later use
211+
base64_image_data = None
174212

175-
# Format the request payload
213+
# create the request body
176214
native_request = {
177-
"text_prompts": [{"text": prompt}],
178-
"style_preset": style_preset,
215+
"prompt": prompt,
216+
"aspect_ratio": aspect_ratio,
179217
"seed": seed,
180-
"cfg_scale": cfg_scale,
181-
"steps": steps,
218+
"output_format": output_format,
219+
"negative_prompt": negative_prompt,
182220
}
183221
request = json.dumps(native_request)
184222

185223
# Invoke the model
186224
response = client.invoke_model(modelId=model_id, body=request)
187225

188226
# Decode the response body
189-
model_response = json.loads(response["body"].read())
227+
model_response = json.loads(response["body"].read().decode("utf-8"))
190228

191229
# Extract the image data
192-
base64_image_data = model_response["artifacts"][0]["base64"]
193-
194-
# Create a filename based on the prompt
195-
def create_filename(prompt: str) -> str:
196-
"""Generate a filename from the prompt text."""
197-
words = re.findall(r"\w+", prompt.lower())[:5]
198-
filename = "_".join(words)
199-
filename = re.sub(r"[^\w\-_\.]", "_", filename)
200-
return filename[:100] # Limit filename length
201-
202-
filename = create_filename(prompt)
203-
204-
# Save the generated image to a local folder
205-
output_dir = "output"
206-
if not os.path.exists(output_dir):
207-
os.makedirs(output_dir)
208-
209-
i = 1
210-
base_image_path = os.path.join(output_dir, f"{filename}.png")
211-
image_path = base_image_path
212-
while os.path.exists(image_path):
213-
image_path = os.path.join(output_dir, f"{filename}_{i}.png")
214-
i += 1
215-
216-
with open(image_path, "wb") as file:
217-
file.write(base64.b64decode(base64_image_data))
218-
230+
base64_image_data = model_response["images"][0]
231+
232+
# If we have image data, process and save it
233+
if base64_image_data:
234+
filename = create_filename(prompt)
235+
236+
# Save the generated image to a local folder
237+
output_dir = "output"
238+
if not os.path.exists(output_dir):
239+
os.makedirs(output_dir)
240+
241+
i = 1
242+
base_image_path = os.path.join(output_dir, f"{filename}.png")
243+
image_path = base_image_path
244+
while os.path.exists(image_path):
245+
image_path = os.path.join(output_dir, f"{filename}_{i}.png")
246+
i += 1
247+
248+
with open(image_path, "wb") as file:
249+
file.write(base64.b64decode(base64_image_data))
250+
251+
return {
252+
"toolUseId": tool_use_id,
253+
"status": "success",
254+
"content": [
255+
{"text": f"The generated image has been saved locally to {image_path}. "},
256+
{
257+
"image": {
258+
"format": output_format,
259+
"source": {"bytes": base64.b64decode(base64_image_data)},
260+
}
261+
},
262+
],
263+
}
264+
else:
265+
raise Exception("No image data found in the response.")
266+
except Exception as e:
219267
return {
220268
"toolUseId": tool_use_id,
221-
"status": "success",
269+
"status": "error",
222270
"content": [
223-
{"text": f"The generated image has been saved locally to {image_path}. "},
224271
{
225-
"image": {
226-
"format": "png",
227-
"source": {"bytes": base64.b64decode(base64_image_data)},
228-
}
229-
},
272+
"text": f"Error generating image: {str(e)} \n Try other supported models for this tool are: \n \
273+
1. stability.sd3-5-large-v1:0 \n \
274+
2. stability.stable-image-core-v1:1 \n \
275+
3. stability.stable-image-ultra-v1:1"
276+
}
230277
],
231278
}
232-
233-
except Exception as e:
234-
return {
235-
"toolUseId": tool_use_id,
236-
"status": "error",
237-
"content": [{"text": f"Error generating image: {str(e)}"}],
238-
}

tests/test_generate_image.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def mock_boto3_client():
3131
# Set up mock response
3232
mock_body = MagicMock()
3333
mock_body.read.return_value = json.dumps(
34-
{"artifacts": [{"base64": base64.b64encode(b"mock_image_data").decode("utf-8")}]}
34+
{"images": [base64.b64encode(b"mock_image_data").decode("utf-8")]}
3535
).encode("utf-8")
3636

3737
mock_client_instance = MagicMock()
@@ -76,9 +76,9 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m
7676
"input": {
7777
"prompt": "A cute robot",
7878
"seed": 123,
79-
"steps": 30,
80-
"cfg_scale": 10,
81-
"style_preset": "photographic",
79+
"aspect_ratio": "5:4",
80+
"output_format": "png",
81+
"negative_prompt": "blurry, low resolution, pixelated, grainy, unrealistic",
8282
},
8383
}
8484

@@ -94,11 +94,11 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m
9494
args, kwargs = mock_client_instance.invoke_model.call_args
9595
request_body = json.loads(kwargs["body"])
9696

97-
assert request_body["text_prompts"][0]["text"] == "A cute robot"
97+
assert request_body["prompt"] == "A cute robot"
9898
assert request_body["seed"] == 123
99-
assert request_body["steps"] == 30
100-
assert request_body["cfg_scale"] == 10
101-
assert request_body["style_preset"] == "photographic"
99+
assert request_body["aspect_ratio"] == "5:4"
100+
assert request_body["output_format"] == "png"
101+
assert request_body["negative_prompt"] == "blurry, low resolution, pixelated, grainy, unrealistic"
102102

103103
# Verify directory creation
104104
mock_os_makedirs.assert_called_once()
@@ -128,9 +128,9 @@ def test_generate_image_default_params(mock_boto3_client, mock_os_path_exists, m
128128
request_body = json.loads(kwargs["body"])
129129

130130
assert request_body["seed"] == 42 # From our mocked random.randint
131-
assert request_body["steps"] == 30
132-
assert request_body["cfg_scale"] == 10
133-
assert request_body["style_preset"] == "photographic"
131+
assert request_body["aspect_ratio"] == "1:1"
132+
assert request_body["output_format"] == "jpeg"
133+
assert request_body["negative_prompt"] == "bad lighting, harsh lighting"
134134

135135
assert result["status"] == "success"
136136

0 commit comments

Comments
 (0)