Skip to content

Commit 1c16bc2

Browse files
authored
Support for adapter-specifc pre/post processing through the adapter management APIs (#2951)
1 parent c1ac412 commit 1c16bc2

File tree

17 files changed

+2032
-190
lines changed

17 files changed

+2032
-190
lines changed

.github/workflows/integration.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,33 @@ jobs:
9191
--fail \
9292
| jq '.token' | tr -d '"' )
9393
./start_instance.sh action_g6 $token djl-serving
94+
- name: Create new G6 instance
95+
id: create_g6_5
96+
run: |
97+
cd /home/ubuntu/djl_benchmark_script/scripts
98+
token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \
99+
https://api.github.com/repos/deepjavalibrary/djl-serving/actions/runners/registration-token \
100+
--fail \
101+
| jq '.token' | tr -d '"' )
102+
./start_instance.sh action_g6 $token djl-serving
103+
- name: Create new G6 instance
104+
id: create_g6_6
105+
run: |
106+
cd /home/ubuntu/djl_benchmark_script/scripts
107+
token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \
108+
https://api.github.com/repos/deepjavalibrary/djl-serving/actions/runners/registration-token \
109+
--fail \
110+
| jq '.token' | tr -d '"' )
111+
./start_instance.sh action_g6 $token djl-serving
112+
- name: Create new G6 instance
113+
id: create_g6_7
114+
run: |
115+
cd /home/ubuntu/djl_benchmark_script/scripts
116+
token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \
117+
https://api.github.com/repos/deepjavalibrary/djl-serving/actions/runners/registration-token \
118+
--fail \
119+
| jq '.token' | tr -d '"' )
120+
./start_instance.sh action_g6 $token djl-serving
94121
- name: Create new Graviton instance
95122
id: create_aarch64
96123
run: |
@@ -133,6 +160,9 @@ jobs:
133160
g6_instance_id_2: ${{ steps.create_g6_2.outputs.action_g6_instance_id }}
134161
g6_instance_id_3: ${{ steps.create_g6_3.outputs.action_g6_instance_id }}
135162
g6_instance_id_4: ${{ steps.create_g6_4.outputs.action_g6_instance_id }}
163+
g6_instance_id_5: ${{ steps.create_g6_5.outputs.action_g6_instance_id }}
164+
g6_instance_id_6: ${{ steps.create_g6_6.outputs.action_g6_instance_id }}
165+
g6_instance_id_7: ${{ steps.create_g6_7.outputs.action_g6_instance_id }}
136166
aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }}
137167
cpu_instance_id: ${{ steps.create_cpu.outputs.action_cpu_instance_id }}
138168
p4d_instance_id_1: ${{ steps.create_p4d_1.outputs.action_lmic_p4d_instance_id }}
@@ -191,6 +221,9 @@ jobs:
191221
- test: TestVllmAsyncLora_g6
192222
instance: g6
193223
failure-prefix: lmi
224+
- test: TestVllmAsyncLoraWithCustomCode_g6
225+
instance: g6
226+
failure-prefix: lmi
194227
- test: TestMultiModalVllm_g6
195228
instance: g6
196229
failure-prefix: lmi
@@ -310,6 +343,12 @@ jobs:
310343
./stop_instance.sh $instance_id
311344
instance_id=${{ needs.create-runners.outputs.g6_instance_id_4 }}
312345
./stop_instance.sh $instance_id
346+
instance_id=${{ needs.create-runners.outputs.g6_instance_id_5 }}
347+
./stop_instance.sh $instance_id
348+
instance_id=${{ needs.create-runners.outputs.g6_instance_id_6 }}
349+
./stop_instance.sh $instance_id
350+
instance_id=${{ needs.create-runners.outputs.g6_instance_id_7 }}
351+
./stop_instance.sh $instance_id
313352
instance_id=${{ needs.create-runners.outputs.aarch64_instance_id }}
314353
./stop_instance.sh $instance_id
315354
instance_id=${{ needs.create-runners.outputs.cpu_instance_id }}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@ dist/
4141
*.egg-info/
4242
*.pt
4343

44+
.kiro
45+
venv
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import logging
15+
import os
16+
from typing import Optional, AsyncGenerator, Any, Dict
17+
18+
from djl_python.custom_formatter_handling import CustomFormatterHandler
19+
from djl_python.adapter_manager_mixin import AdapterManagerMixin
20+
from djl_python.inputs import Input
21+
from djl_python.outputs import Output
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class AdapterFormatterMixin(CustomFormatterHandler, AdapterManagerMixin):
27+
"""
28+
Combined mixin that provides both adapter management and adapter-aware formatter application.
29+
30+
This class inherits from both CustomFormatterHandler and AdapterManagerMixin, providing
31+
a single base class for services that need adapter support with custom formatters.
32+
33+
Responsibilities:
34+
- Base model formatter functionality (from CustomFormatterHandler)
35+
- Adapter weight registration/management (from AdapterManagerMixin)
36+
- Custom code management (loading/unloading adapter-specific formatters)
37+
- Adapter-aware formatter application (methods in this class)
38+
"""
39+
40+
def __init__(self):
41+
CustomFormatterHandler.__init__(self)
42+
AdapterManagerMixin.__init__(self)
43+
self.adapter_code_registry: Dict[str, CustomFormatterHandler] = {}
44+
45+
def get_adapter_formatter_handler(
46+
self, adapter_name: str) -> Optional[CustomFormatterHandler]:
47+
"""
48+
Retrieves the formatter handler for a specific adapter.
49+
50+
Args:
51+
adapter_name: Unique identifier for the adapter
52+
53+
Returns:
54+
CustomFormatterHandler if adapter has custom code, None otherwise
55+
"""
56+
return self.adapter_code_registry.get(adapter_name)
57+
58+
def apply_input_formatter(self,
59+
decoded_payload: Any,
60+
adapter_name: Optional[str] = None,
61+
**kwargs) -> Any:
62+
"""
63+
Override to apply input formatter, using adapter-specific formatter if available.
64+
65+
Args:
66+
decoded_payload: The decoded request payload
67+
adapter_name: Optional adapter name to use for custom formatter
68+
**kwargs: Additional arguments to pass to the formatter
69+
70+
Returns:
71+
Formatted input
72+
"""
73+
# Check if adapter has custom formatter
74+
if adapter_name:
75+
adapter_formatter = self.get_adapter_formatter_handler(
76+
adapter_name)
77+
if adapter_formatter and adapter_formatter.input_formatter:
78+
logger.debug(
79+
f"Using adapter-specific input formatter for adapter '{adapter_name}'"
80+
)
81+
return adapter_formatter.apply_input_formatter(
82+
decoded_payload, **kwargs)
83+
84+
# Use base model formatter
85+
logger.debug("Using base model input formatter")
86+
return super().apply_input_formatter(decoded_payload, **kwargs)
87+
88+
def apply_output_formatter(self,
89+
output: Any,
90+
adapter_name: Optional[str] = None,
91+
**kwargs) -> Any:
92+
"""
93+
Override to apply output formatter, using adapter-specific formatter if available.
94+
95+
Args:
96+
output: The model output
97+
adapter_name: Optional adapter name to use for custom formatter
98+
**kwargs: Additional arguments to pass to the formatter
99+
100+
Returns:
101+
Formatted output
102+
"""
103+
# Check if adapter has custom formatter
104+
if adapter_name:
105+
adapter_formatter = self.get_adapter_formatter_handler(
106+
adapter_name)
107+
if adapter_formatter and adapter_formatter.output_formatter:
108+
logger.debug(
109+
f"Using adapter-specific output formatter for adapter '{adapter_name}'"
110+
)
111+
return adapter_formatter.apply_output_formatter(
112+
output, **kwargs)
113+
114+
# Use base model formatter
115+
logger.debug("Using base model output formatter")
116+
return super().apply_output_formatter(output, **kwargs)
117+
118+
async def apply_output_formatter_streaming_raw(
119+
self,
120+
response: AsyncGenerator,
121+
adapter_name: Optional[str] = None,
122+
**kwargs) -> AsyncGenerator:
123+
"""
124+
Override to apply streaming output formatter, using adapter-specific formatter if available.
125+
126+
Args:
127+
response: The async generator producing model outputs
128+
adapter_name: Optional adapter name to use for custom formatter
129+
**kwargs: Additional arguments to pass to the formatter
130+
131+
Returns:
132+
Async generator with formatted outputs
133+
"""
134+
# Check if adapter has custom formatter
135+
if adapter_name:
136+
adapter_formatter = self.get_adapter_formatter_handler(
137+
adapter_name)
138+
if adapter_formatter and adapter_formatter.output_formatter:
139+
logger.debug(
140+
f"Using adapter-specific streaming output formatter for adapter '{adapter_name}'"
141+
)
142+
async for item in adapter_formatter.apply_output_formatter_streaming_raw(
143+
response, **kwargs):
144+
yield item
145+
return
146+
147+
# Use base model formatter
148+
logger.debug("Using base model streaming output formatter")
149+
async for item in super().apply_output_formatter_streaming_raw(
150+
response, **kwargs):
151+
yield item
152+
153+
def load_adapter_custom_code(self, adapter_name: str,
154+
adapter_path: str) -> CustomFormatterHandler:
155+
"""
156+
Load custom code (model.py) for an adapter.
157+
158+
Args:
159+
adapter_name: Unique identifier for the adapter
160+
adapter_path: Path to adapter directory containing model.py
161+
162+
Returns:
163+
CustomFormatterHandler instance with loaded formatters
164+
165+
Raises:
166+
FileNotFoundError: If model.py doesn't exist
167+
ValueError: If custom code loading fails
168+
"""
169+
model_py_path = os.path.join(adapter_path, "model.py")
170+
171+
if not os.path.isfile(model_py_path):
172+
raise FileNotFoundError(
173+
f"model.py not found in adapter directory: {adapter_path}")
174+
175+
logger.info(
176+
f"Loading custom code for adapter '{adapter_name}' from {model_py_path}"
177+
)
178+
179+
try:
180+
# Create a new CustomFormatterHandler and load formatters from model.py
181+
# Pass adapter_name as namespace for unique module naming
182+
formatter_handler = CustomFormatterHandler()
183+
formatter_handler.load_formatters(adapter_path,
184+
namespace=adapter_name)
185+
186+
# Store in registry
187+
self.adapter_code_registry[adapter_name] = formatter_handler
188+
189+
logger.info(
190+
f"Successfully loaded custom code for adapter '{adapter_name}'"
191+
)
192+
return formatter_handler
193+
194+
except Exception as e:
195+
logger.exception(
196+
f"Failed to load custom code for adapter '{adapter_name}'")
197+
raise ValueError(
198+
f"Failed to load custom code for adapter {adapter_name}: {str(e)}"
199+
)
200+
201+
def unload_adapter_custom_code(self, adapter_name: str) -> bool:
202+
"""
203+
Unload custom code for an adapter.
204+
205+
Args:
206+
adapter_name: Unique identifier for the adapter
207+
208+
Returns:
209+
True if custom code was unloaded, False if no custom code was loaded
210+
"""
211+
if adapter_name not in self.adapter_code_registry:
212+
logger.debug(
213+
f"Adapter '{adapter_name}' not found in code registry")
214+
return False
215+
216+
logger.info(f"Unloading custom code for adapter '{adapter_name}'")
217+
del self.adapter_code_registry[adapter_name]
218+
219+
return True
220+
221+
async def register_adapter(self, inputs: Input) -> Output:
222+
"""
223+
Override register_adapter to handle custom code loading.
224+
225+
This method extends the base AdapterManagerMixin.register_adapter to add
226+
custom code management before adapter weight loading.
227+
"""
228+
adapter_name = inputs.get_property("name")
229+
adapter_alias = inputs.get_property("alias") or adapter_name
230+
adapter_path = inputs.get_property("src")
231+
232+
# Check for custom code and load it BEFORE registering adapter weights
233+
model_py_path = os.path.join(adapter_path, "model.py")
234+
if os.path.isfile(model_py_path):
235+
try:
236+
self.load_adapter_custom_code(adapter_name, adapter_path)
237+
except Exception as e:
238+
# Fail fast - don't load adapter weights if custom code fails
239+
outputs = Output()
240+
err = {"data": "", "last": True, "code": 424, "error": str(e)}
241+
outputs.add(Output.binary_encode(err), key="data")
242+
return outputs
243+
244+
# Now register adapter weights using parent implementation
245+
return await super().register_adapter(inputs)
246+
247+
async def unregister_adapter(self, inputs: Input) -> Output:
248+
"""
249+
Override unregister_adapter to handle custom code unloading.
250+
251+
This method extends the base AdapterManagerMixin.unregister_adapter to add
252+
custom code cleanup after adapter weight unloading.
253+
"""
254+
adapter_name = inputs.get_property("name")
255+
256+
# First unregister adapter weights using parent implementation
257+
result = await super().unregister_adapter(inputs)
258+
259+
# Then unload custom code if present
260+
self.unload_adapter_custom_code(adapter_name)
261+
262+
return result

0 commit comments

Comments
 (0)