2626logger = logging .getLogger (__name__ )
2727
2828
29+ def get_base_url (base_url : Union [str , None ]):
30+ return base_url or os .environ .get ("PROMPTLAYER_BASE_URL" , "https://api.promptlayer.com" )
31+
32+
2933def is_workflow_results_dict (obj : Any ) -> bool :
3034 if not isinstance (obj , dict ):
3135 return False
@@ -49,9 +53,7 @@ def is_workflow_results_dict(obj: Any) -> bool:
4953
5054class PromptLayer (PromptLayerMixin ):
5155 def __init__ (
52- self ,
53- api_key : str = None ,
54- enable_tracing : bool = False ,
56+ self , api_key : Union [str , None ] = None , enable_tracing : bool = False , base_url : Union [str , None ] = None
5557 ):
5658 if api_key is None :
5759 api_key = os .environ .get ("PROMPTLAYER_API_KEY" )
@@ -62,11 +64,12 @@ def __init__(
6264 "Please set the PROMPTLAYER_API_KEY environment variable or pass the api_key parameter."
6365 )
6466
67+ self .base_url = get_base_url (base_url )
6568 self .api_key = api_key
66- self .templates = TemplateManager (api_key )
67- self .group = GroupManager (api_key )
68- self .tracer_provider , self .tracer = self ._initialize_tracer (api_key , enable_tracing )
69- self .track = TrackManager (api_key )
69+ self .templates = TemplateManager (api_key , self . base_url )
70+ self .group = GroupManager (api_key , self . base_url )
71+ self .tracer_provider , self .tracer = self ._initialize_tracer (api_key , self . base_url , enable_tracing )
72+ self .track = TrackManager (api_key , self . base_url )
7073
7174 def __getattr__ (
7275 self ,
@@ -75,15 +78,18 @@ def __getattr__(
7578 if name == "openai" :
7679 import openai as openai_module
7780
78- return PromptLayerBase (openai_module , function_name = "openai" , api_key = self .api_key , tracer = self .tracer )
81+ return PromptLayerBase (
82+ self .api_key , self .base_url , openai_module , function_name = "openai" , tracer = self .tracer
83+ )
7984 elif name == "anthropic" :
8085 import anthropic as anthropic_module
8186
8287 return PromptLayerBase (
88+ self .api_key ,
89+ self .base_url ,
8390 anthropic_module ,
8491 function_name = "anthropic" ,
8592 provider_type = "anthropic" ,
86- api_key = self .api_key ,
8793 tracer = self .tracer ,
8894 )
8995 else :
@@ -212,7 +218,7 @@ def _track_request_log(
212218 metadata = metadata ,
213219 ** body ,
214220 )
215- return track_request (** track_request_kwargs )
221+ return track_request (self . base_url , ** track_request_kwargs )
216222
217223 def run (
218224 self ,
@@ -277,12 +283,13 @@ def run_workflow(
277283
278284 results = asyncio .run (
279285 arun_workflow_request (
286+ api_key = self .api_key ,
287+ base_url = self .base_url ,
280288 workflow_id_or_name = _get_workflow_workflow_id_or_name (workflow_id_or_name , workflow_name ),
281289 input_variables = input_variables or {},
282290 metadata = metadata ,
283291 workflow_label_name = workflow_label_name ,
284292 workflow_version_number = workflow_version ,
285- api_key = self .api_key ,
286293 return_all_outputs = return_all_outputs ,
287294 )
288295 )
@@ -330,6 +337,7 @@ def log_request(
330337 ):
331338 return util_log_request (
332339 self .api_key ,
340+ self .base_url ,
333341 provider = provider ,
334342 model = model ,
335343 input = input ,
@@ -354,9 +362,7 @@ def log_request(
354362
355363class AsyncPromptLayer (PromptLayerMixin ):
356364 def __init__ (
357- self ,
358- api_key : str = None ,
359- enable_tracing : bool = False ,
365+ self , api_key : Union [str , None ] = None , enable_tracing : bool = False , base_url : Union [str , None ] = None
360366 ):
361367 if api_key is None :
362368 api_key = os .environ .get ("PROMPTLAYER_API_KEY" )
@@ -367,31 +373,30 @@ def __init__(
367373 "Please set the PROMPTLAYER_API_KEY environment variable or pass the api_key parameter."
368374 )
369375
376+ self .base_url = get_base_url (base_url )
370377 self .api_key = api_key
371- self .templates = AsyncTemplateManager (api_key )
372- self .group = AsyncGroupManager (api_key )
373- self .tracer_provider , self .tracer = self ._initialize_tracer (api_key , enable_tracing )
374- self .track = AsyncTrackManager (api_key )
378+ self .templates = AsyncTemplateManager (api_key , self . base_url )
379+ self .group = AsyncGroupManager (api_key , self . base_url )
380+ self .tracer_provider , self .tracer = self ._initialize_tracer (api_key , self . base_url , enable_tracing )
381+ self .track = AsyncTrackManager (api_key , self . base_url )
375382
376383 def __getattr__ (self , name : Union [Literal ["openai" ], Literal ["anthropic" ], Literal ["prompts" ]]):
377384 if name == "openai" :
378385 import openai as openai_module
379386
380387 openai = PromptLayerBase (
381- openai_module ,
382- function_name = "openai" ,
383- api_key = self .api_key ,
384- tracer = self .tracer ,
388+ self .api_key , self .base_url , openai_module , function_name = "openai" , tracer = self .tracer
385389 )
386390 return openai
387391 elif name == "anthropic" :
388392 import anthropic as anthropic_module
389393
390394 anthropic = PromptLayerBase (
395+ self .api_key ,
396+ self .base_url ,
391397 anthropic_module ,
392398 function_name = "anthropic" ,
393399 provider_type = "anthropic" ,
394- api_key = self .api_key ,
395400 tracer = self .tracer ,
396401 )
397402 return anthropic
@@ -413,12 +418,13 @@ async def run_workflow(
413418 ) -> Union [Dict [str , Any ], Any ]:
414419 try :
415420 return await arun_workflow_request (
421+ api_key = self .api_key ,
422+ base_url = self .base_url ,
416423 workflow_id_or_name = _get_workflow_workflow_id_or_name (workflow_id_or_name , workflow_name ),
417424 input_variables = input_variables or {},
418425 metadata = metadata ,
419426 workflow_label_name = workflow_label_name ,
420427 workflow_version_number = workflow_version ,
421- api_key = self .api_key ,
422428 return_all_outputs = return_all_outputs ,
423429 )
424430 except Exception as ex :
@@ -491,6 +497,7 @@ async def log_request(
491497 ):
492498 return await autil_log_request (
493499 self .api_key ,
500+ self .base_url ,
494501 provider = provider ,
495502 model = model ,
496503 input = input ,
@@ -530,7 +537,7 @@ async def _track_request(**body):
530537 pl_run_span_id ,
531538 ** body ,
532539 )
533- return await atrack_request (** track_request_kwargs )
540+ return await atrack_request (self . base_url , ** track_request_kwargs )
534541
535542 return _track_request
536543
@@ -554,7 +561,7 @@ async def _track_request_log(
554561 metadata = metadata ,
555562 ** body ,
556563 )
557- return await atrack_request (** track_request_kwargs )
564+ return await atrack_request (self . base_url , ** track_request_kwargs )
558565
559566 async def _run_internal (
560567 self ,
@@ -631,6 +638,6 @@ async def _run_internal(
631638
632639 return {
633640 "request_id" : request_log .get ("request_id" , None ),
634- "raw_response" : request_response ,
641+ "raw_response" : response ,
635642 "prompt_blueprint" : request_log .get ("prompt_blueprint" , None ),
636643 }
0 commit comments