1+ import asyncio
2+ from typing import Any , Dict , List , Optional , Union
3+
4+ import httpx
5+
6+ class PlexeClient :
7+ def __init__ (self , api_key : Optional [str ] = None , timeout : float = 120.0 ):
8+ self .api_key = api_key
9+ if not api_key :
10+ import os
11+ self .api_key = os .environ .get ("PLEXE_API_KEY" )
12+ if not self .api_key :
13+ raise ValueError ("PLEXE_API_KEY must be provided or set as environment variable" )
14+
15+ self .base_url = "https://api.plexe.ai/v1"
16+ self .client = httpx .Client (timeout = timeout )
17+ self .async_client = httpx .AsyncClient (timeout = timeout )
18+
19+ def _headers (self ):
20+ return {
21+ "Authorization" : f"Bearer { self .api_key } " ,
22+ "Content-Type" : "application/json" ,
23+ }
24+
25+ def create (self , task_description : str ) -> tuple [str , int , str ]:
26+ """Create a new ML model from a task description.
27+
28+ Args:
29+ task_description: Description of what the model should do
30+
31+ Returns:
32+ Tuple of (model_id, version, description)
33+ """
34+ if not task_description :
35+ raise ValueError ("Task description must be provided" )
36+
37+ response = self .client .post (
38+ f"{ self .base_url } /create" ,
39+ json = {"description" : task_description },
40+ headers = self ._headers ()
41+ )
42+ response .raise_for_status ()
43+ data = response .json ()
44+ return data ["model_id" ], data ["version" ], data ["description" ]
45+
46+ async def acreate (self , task_description : str ) -> tuple [str , int , str ]:
47+ """Async version of create()"""
48+ if not task_description :
49+ raise ValueError ("Task description must be provided" )
50+
51+ response = await self .async_client .post (
52+ f"{ self .base_url } /create" ,
53+ json = {"description" : task_description },
54+ headers = self ._headers ()
55+ )
56+ response .raise_for_status ()
57+ data = response .json ()
58+ return data ["model_id" ], data ["version" ], data ["description" ]
59+
60+ def run (self , model_id : str , text_input : str = "" , version : int = - 1 ) -> Dict [str , Any ]:
61+ """Run predictions using a model.
62+
63+ Args:
64+ model_id: ID of the model to use
65+ text_input: Input text for the model
66+ version: Model version to use (-1 for latest)
67+
68+ Returns:
69+ Dictionary containing prediction results
70+ """
71+ response = self .client .post (
72+ f"{ self .base_url } /run" ,
73+ json = {
74+ "model_id" : model_id ,
75+ "text" : text_input ,
76+ "version" : version
77+ },
78+ headers = self ._headers ()
79+ )
80+ response .raise_for_status ()
81+ return response .json ()
82+
83+ async def arun (self , model_id : str , text_input : str = "" , version : int = - 1 ) -> Dict [str , Any ]:
84+ """Async version of run()"""
85+ response = await self .async_client .post (
86+ f"{ self .base_url } /run" ,
87+ json = {
88+ "model_id" : model_id ,
89+ "text" : text_input ,
90+ "version" : version
91+ },
92+ headers = self ._headers ()
93+ )
94+ response .raise_for_status ()
95+ return response .json ()
96+
97+ def batch_run (self , model_id : str , inputs : List [Dict [str , Any ]], version : int = - 1 ) -> List [Dict [str , Any ]]:
98+ """Run batch predictions.
99+
100+ Args:
101+ model_id: ID of the model to use
102+ inputs: List of input dictionaries
103+ version: Model version to use (-1 for latest)
104+
105+ Returns:
106+ List of prediction results
107+ """
108+ async def run_batch ():
109+ tasks = [
110+ self .arun (model_id = model_id , text_input = x .get ("text" , "" ), version = version )
111+ for x in inputs
112+ ]
113+ return await asyncio .gather (* tasks )
114+
115+ return asyncio .run (run_batch ())
0 commit comments