forked from openai/openai-agents-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_max_turns.py
More file actions
127 lines (108 loc) · 3.72 KB
/
test_max_turns.py
File metadata and controls
127 lines (108 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
import json
import pytest
from typing_extensions import TypedDict
from agents import Agent, MaxTurnsExceeded, Runner
from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
@pytest.mark.asyncio
async def test_non_streamed_max_turns():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
tools=[get_function_tool("some_function", "result")],
)
func_output = json.dumps({"a": "b"})
model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
[get_text_message("3"), get_function_tool_call("some_function", func_output)],
[get_text_message("4"), get_function_tool_call("some_function", func_output)],
[get_text_message("5"), get_function_tool_call("some_function", func_output)],
]
)
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="user_message", max_turns=3)
@pytest.mark.asyncio
async def test_streamed_max_turns():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
tools=[get_function_tool("some_function", "result")],
)
func_output = json.dumps({"a": "b"})
model.add_multiple_turn_outputs(
[
[
get_text_message("1"),
get_function_tool_call("some_function", func_output),
],
[
get_text_message("2"),
get_function_tool_call("some_function", func_output),
],
[
get_text_message("3"),
get_function_tool_call("some_function", func_output),
],
[
get_text_message("4"),
get_function_tool_call("some_function", func_output),
],
[
get_text_message("5"),
get_function_tool_call("some_function", func_output),
],
]
)
with pytest.raises(MaxTurnsExceeded):
output = Runner.run_streamed(agent, input="user_message", max_turns=3)
async for _ in output.stream_events():
pass
class Foo(TypedDict):
a: str
@pytest.mark.asyncio
async def test_structured_output_non_streamed_max_turns():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
output_type=Foo,
tools=[get_function_tool("tool_1", "result")],
)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
]
)
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="user_message", max_turns=3)
@pytest.mark.asyncio
async def test_structured_output_streamed_max_turns():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
output_type=Foo,
tools=[get_function_tool("tool_1", "result")],
)
model.add_multiple_turn_outputs(
[
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
[get_function_tool_call("tool_1")],
]
)
with pytest.raises(MaxTurnsExceeded):
output = Runner.run_streamed(agent, input="user_message", max_turns=3)
async for _ in output.stream_events():
pass