Skip to content

Commit 0ce30ed

Browse files
committed
Add tests
1 parent 4b2dc59 commit 0ce30ed

3 files changed

Lines changed: 183 additions & 1 deletion

File tree

‎plexe/tests/__init__.py‎

Whitespace-only changes.

‎plexe/tests/test_integration.py‎

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import os
2+
import time
3+
import pytest
4+
import asyncio
5+
from pathlib import Path
6+
from plexe import PlexeAI, build, abuild, infer, ainfer, batch_infer
7+
8+
API_KEY = os.getenv("PLEXE_API_KEY") or ""
9+
if not API_KEY:
10+
pytest.skip("PLEXE_API_KEY environment variable not set", allow_module_level=True)
11+
12+
TEST_MODEL_NAME = "test_prediction_model"
13+
TEST_GOAL = "Predict the outcomes of english premier league games based on prior results using the attached dataset"
14+
15+
@pytest.fixture
16+
def client():
17+
"""Create a PlexeAI client instance for testing."""
18+
return PlexeAI(api_key=API_KEY)
19+
20+
@pytest.fixture
21+
def sample_data_file(tmp_path):
22+
"""Create a temporary sample data file for testing."""
23+
data_file = tmp_path / "test_data.csv"
24+
data_content = """text,sentiment
25+
This product is amazing!,positive
26+
I love this service,positive
27+
Terrible experience,negative
28+
Not worth the money,negative
29+
Pretty good overall,positive"""
30+
data_file.write_text(data_content)
31+
return data_file
32+
33+
@pytest.fixture
34+
def sample_input_data():
35+
"""Sample input data for inference testing."""
36+
return {"text": "This is a great product!"}
37+
38+
def wait_for_model(client, model_name: str, model_version: str, timeout: int = 300):
39+
"""Wait for model to be ready."""
40+
start_time = time.time()
41+
while time.time() - start_time < timeout:
42+
status = client.get_status(model_name, model_version)
43+
if status["status"] == "completed":
44+
return True
45+
elif status["status"] == "failed":
46+
raise Exception(f"Model failed: {status.get('error', 'Unknown error')}")
47+
time.sleep(10)
48+
raise TimeoutError(f"Model did not complete within {timeout} seconds")
49+
50+
async def async_wait_for_model(client, model_name: str, model_version: str, timeout: int = 300):
51+
"""Wait for model to be ready asynchronously."""
52+
start_time = time.time()
53+
while time.time() - start_time < timeout:
54+
status = await client.aget_status(model_name, model_version)
55+
if status["status"] == "completed":
56+
return True
57+
elif status["status"] == "failed":
58+
raise Exception(f"Model failed: {status.get('error', 'Unknown error')}")
59+
await asyncio.sleep(10)
60+
raise TimeoutError(f"Model did not complete within {timeout} seconds")
61+
62+
class TestPlexeAIIntegration:
63+
"""Integration tests for PlexeAI client."""
64+
65+
def test_client_initialization(self):
66+
"""Test client initialization with API key."""
67+
client = PlexeAI(api_key=API_KEY)
68+
assert client.api_key == API_KEY
69+
assert client.base_url == "https://api.plexe.ai/v0"
70+
71+
def test_build_and_inference_flow(self, client, sample_data_file, sample_input_data):
72+
"""Test full flow: build model with direct data files to avoid timing issues."""
73+
try:
74+
model_version = build(
75+
goal=TEST_GOAL,
76+
model_name=TEST_MODEL_NAME,
77+
upload_id="2d4da8f9-aaf1-4262-a36c-5e9167ca4d5b",
78+
api_key=API_KEY
79+
)
80+
assert isinstance(model_version, str)
81+
82+
# Wait for model to be ready
83+
wait_for_model(client, TEST_MODEL_NAME, model_version)
84+
85+
# Run inference
86+
result = infer(
87+
model_name=TEST_MODEL_NAME,
88+
model_version=model_version,
89+
input_data=sample_input_data,
90+
api_key=API_KEY
91+
)
92+
assert isinstance(result, dict)
93+
assert "prediction" in result
94+
95+
# Run batch inference
96+
batch_inputs = [
97+
{"text": "Great service!"},
98+
{"text": "Not satisfied with the product"}
99+
]
100+
results = batch_infer(
101+
model_name=TEST_MODEL_NAME,
102+
model_version=model_version,
103+
inputs=batch_inputs,
104+
api_key=API_KEY
105+
)
106+
assert isinstance(results, list)
107+
assert len(results) == len(batch_inputs)
108+
109+
except Exception as e:
110+
raise e
111+
112+
@pytest.mark.asyncio
113+
async def test_async_build_and_inference_flow(self, client, sample_data_file, sample_input_data):
114+
"""Test full async flow: build model with direct data files to avoid timing issues."""
115+
try:
116+
# Build model asynchronously using data_files directly
117+
model_version = await abuild(
118+
goal=TEST_GOAL,
119+
model_name=f"{TEST_MODEL_NAME}_async",
120+
upload_id="2d4da8f9-aaf1-4262-a36c-5e9167ca4d5b",
121+
api_key=API_KEY
122+
)
123+
assert isinstance(model_version, str)
124+
125+
# Wait for model to be ready
126+
await async_wait_for_model(client, f"{TEST_MODEL_NAME}_async", model_version)
127+
128+
# Run inference asynchronously
129+
result = await ainfer(
130+
model_name=f"{TEST_MODEL_NAME}_async",
131+
model_version=model_version,
132+
input_data=sample_input_data,
133+
api_key=API_KEY
134+
)
135+
assert isinstance(result, dict)
136+
assert "prediction" in result
137+
138+
# Optional batch inference test
139+
batch_inputs = [
140+
{"text": "Great service!"},
141+
{"text": "Not satisfied with the product"}
142+
]
143+
results = batch_infer(
144+
model_name=TEST_MODEL_NAME,
145+
model_version=model_version,
146+
inputs=batch_inputs,
147+
api_key=API_KEY
148+
)
149+
assert isinstance(results, list)
150+
assert len(results) == len(batch_inputs)
151+
152+
except Exception as e:
153+
raise e
154+
155+
def test_file_upload_and_cleanup(self, client, sample_data_file):
156+
"""Test file upload and cleanup."""
157+
upload_id = client.upload_files(sample_data_file)
158+
assert isinstance(upload_id, str)
159+
160+
# Wait a bit to ensure file is processed
161+
time.sleep(2)
162+
163+
cleanup_result = client.cleanup_upload(upload_id)
164+
assert isinstance(cleanup_result, dict)
165+
166+
def test_error_handling(self, client):
167+
"""Test error handling for invalid requests."""
168+
with pytest.raises(ValueError):
169+
build(
170+
goal=TEST_GOAL,
171+
model_name=TEST_MODEL_NAME,
172+
data_files=None,
173+
upload_id=None,
174+
api_key=API_KEY
175+
)
176+
177+
with pytest.raises(ValueError):
178+
build(
179+
goal=TEST_GOAL,
180+
model_name=TEST_MODEL_NAME,
181+
data_files="nonexistent.csv",
182+
api_key=API_KEY
183+
)

‎pyproject.toml‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ dev = [
2727
[tool.pytest.ini_options]
2828
testpaths = ["tests"]
2929
python_files = ["test_*.py"]
30-
addopts = "-v --cov=plexeai --cov-report=term-missing"
3130

3231
[tool.coverage.run]
3332
source = ["plexe"]

0 commit comments

Comments
 (0)