1+ import os
2+ import time
3+ import pytest
4+ import asyncio
5+ from pathlib import Path
6+ from plexe import PlexeAI , build , abuild , infer , ainfer , batch_infer
7+
8+ API_KEY = os .getenv ("PLEXE_API_KEY" ) or ""
9+ if not API_KEY :
10+ pytest .skip ("PLEXE_API_KEY environment variable not set" , allow_module_level = True )
11+
12+ TEST_MODEL_NAME = "test_prediction_model"
13+ TEST_GOAL = "Predict the outcomes of english premier league games based on prior results using the attached dataset"
14+
15+ @pytest .fixture
16+ def client ():
17+ """Create a PlexeAI client instance for testing."""
18+ return PlexeAI (api_key = API_KEY )
19+
20+ @pytest .fixture
21+ def sample_data_file (tmp_path ):
22+ """Create a temporary sample data file for testing."""
23+ data_file = tmp_path / "test_data.csv"
24+ data_content = """text,sentiment
25+ This product is amazing!,positive
26+ I love this service,positive
27+ Terrible experience,negative
28+ Not worth the money,negative
29+ Pretty good overall,positive"""
30+ data_file .write_text (data_content )
31+ return data_file
32+
33+ @pytest .fixture
34+ def sample_input_data ():
35+ """Sample input data for inference testing."""
36+ return {"text" : "This is a great product!" }
37+
38+ def wait_for_model (client , model_name : str , model_version : str , timeout : int = 300 ):
39+ """Wait for model to be ready."""
40+ start_time = time .time ()
41+ while time .time () - start_time < timeout :
42+ status = client .get_status (model_name , model_version )
43+ if status ["status" ] == "completed" :
44+ return True
45+ elif status ["status" ] == "failed" :
46+ raise Exception (f"Model failed: { status .get ('error' , 'Unknown error' )} " )
47+ time .sleep (10 )
48+ raise TimeoutError (f"Model did not complete within { timeout } seconds" )
49+
50+ async def async_wait_for_model (client , model_name : str , model_version : str , timeout : int = 300 ):
51+ """Wait for model to be ready asynchronously."""
52+ start_time = time .time ()
53+ while time .time () - start_time < timeout :
54+ status = await client .aget_status (model_name , model_version )
55+ if status ["status" ] == "completed" :
56+ return True
57+ elif status ["status" ] == "failed" :
58+ raise Exception (f"Model failed: { status .get ('error' , 'Unknown error' )} " )
59+ await asyncio .sleep (10 )
60+ raise TimeoutError (f"Model did not complete within { timeout } seconds" )
61+
62+ class TestPlexeAIIntegration :
63+ """Integration tests for PlexeAI client."""
64+
65+ def test_client_initialization (self ):
66+ """Test client initialization with API key."""
67+ client = PlexeAI (api_key = API_KEY )
68+ assert client .api_key == API_KEY
69+ assert client .base_url == "https://api.plexe.ai/v0"
70+
71+ def test_build_and_inference_flow (self , client , sample_data_file , sample_input_data ):
72+ """Test full flow: build model with direct data files to avoid timing issues."""
73+ try :
74+ model_version = build (
75+ goal = TEST_GOAL ,
76+ model_name = TEST_MODEL_NAME ,
77+ upload_id = "2d4da8f9-aaf1-4262-a36c-5e9167ca4d5b" ,
78+ api_key = API_KEY
79+ )
80+ assert isinstance (model_version , str )
81+
82+ # Wait for model to be ready
83+ wait_for_model (client , TEST_MODEL_NAME , model_version )
84+
85+ # Run inference
86+ result = infer (
87+ model_name = TEST_MODEL_NAME ,
88+ model_version = model_version ,
89+ input_data = sample_input_data ,
90+ api_key = API_KEY
91+ )
92+ assert isinstance (result , dict )
93+ assert "prediction" in result
94+
95+ # Run batch inference
96+ batch_inputs = [
97+ {"text" : "Great service!" },
98+ {"text" : "Not satisfied with the product" }
99+ ]
100+ results = batch_infer (
101+ model_name = TEST_MODEL_NAME ,
102+ model_version = model_version ,
103+ inputs = batch_inputs ,
104+ api_key = API_KEY
105+ )
106+ assert isinstance (results , list )
107+ assert len (results ) == len (batch_inputs )
108+
109+ except Exception as e :
110+ raise e
111+
112+ @pytest .mark .asyncio
113+ async def test_async_build_and_inference_flow (self , client , sample_data_file , sample_input_data ):
114+ """Test full async flow: build model with direct data files to avoid timing issues."""
115+ try :
116+ # Build model asynchronously using data_files directly
117+ model_version = await abuild (
118+ goal = TEST_GOAL ,
119+ model_name = f"{ TEST_MODEL_NAME } _async" ,
120+ upload_id = "2d4da8f9-aaf1-4262-a36c-5e9167ca4d5b" ,
121+ api_key = API_KEY
122+ )
123+ assert isinstance (model_version , str )
124+
125+ # Wait for model to be ready
126+ await async_wait_for_model (client , f"{ TEST_MODEL_NAME } _async" , model_version )
127+
128+ # Run inference asynchronously
129+ result = await ainfer (
130+ model_name = f"{ TEST_MODEL_NAME } _async" ,
131+ model_version = model_version ,
132+ input_data = sample_input_data ,
133+ api_key = API_KEY
134+ )
135+ assert isinstance (result , dict )
136+ assert "prediction" in result
137+
138+ # Optional batch inference test
139+ batch_inputs = [
140+ {"text" : "Great service!" },
141+ {"text" : "Not satisfied with the product" }
142+ ]
143+ results = batch_infer (
144+ model_name = TEST_MODEL_NAME ,
145+ model_version = model_version ,
146+ inputs = batch_inputs ,
147+ api_key = API_KEY
148+ )
149+ assert isinstance (results , list )
150+ assert len (results ) == len (batch_inputs )
151+
152+ except Exception as e :
153+ raise e
154+
155+ def test_file_upload_and_cleanup (self , client , sample_data_file ):
156+ """Test file upload and cleanup."""
157+ upload_id = client .upload_files (sample_data_file )
158+ assert isinstance (upload_id , str )
159+
160+ # Wait a bit to ensure file is processed
161+ time .sleep (2 )
162+
163+ cleanup_result = client .cleanup_upload (upload_id )
164+ assert isinstance (cleanup_result , dict )
165+
166+ def test_error_handling (self , client ):
167+ """Test error handling for invalid requests."""
168+ with pytest .raises (ValueError ):
169+ build (
170+ goal = TEST_GOAL ,
171+ model_name = TEST_MODEL_NAME ,
172+ data_files = None ,
173+ upload_id = None ,
174+ api_key = API_KEY
175+ )
176+
177+ with pytest .raises (ValueError ):
178+ build (
179+ goal = TEST_GOAL ,
180+ model_name = TEST_MODEL_NAME ,
181+ data_files = "nonexistent.csv" ,
182+ api_key = API_KEY
183+ )
0 commit comments