Skip to content

Commit 12060b1

Browse files
Hoder-zyfyou-n-g
andauthored
fix: improve the logic of json_schema and refine the reasoning extraction logic for reasoning model (#1044)
* fix: fix a small bug in response_schema * feat: support response_format parameter in chat completion * fix: fix between json_mode and response_format * Update base.py * Update deprec.py * add unittest and refine logic * fix the reasoning extraction logic and refine prompt for deepseek adaptation * refactor: introduce workflow_check and streamline task parsing * refine prompt --------- Co-authored-by: Young <afe.young@gmail.com>
1 parent 393002a commit 12060b1

7 files changed

Lines changed: 164 additions & 50 deletions

File tree

‎rdagent/oai/backend/base.py‎

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from copy import deepcopy
1212
from datetime import datetime
1313
from pathlib import Path
14-
from typing import Any, Callable, List, Optional, Tuple, cast
14+
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast
1515

1616
import pytz
1717
from pydantic import BaseModel, TypeAdapter
@@ -36,13 +36,14 @@
3636
class JSONParser:
3737
"""JSON parser supporting multiple strategies"""
3838

39-
def __init__(self) -> None:
39+
def __init__(self, add_json_in_prompt: bool = False) -> None:
4040
self.strategies: List[Callable[[str], str]] = [
4141
self._direct_parse,
4242
self._extract_from_code_block,
4343
self._fix_python_syntax,
4444
self._extract_with_fix_combined,
4545
]
46+
self.add_json_in_prompt = add_json_in_prompt
4647

4748
def parse(self, content: str) -> str:
4849
"""Parse JSON content, automatically trying multiple strategies"""
@@ -55,7 +56,16 @@ def parse(self, content: str) -> str:
5556
continue
5657

5758
# All strategies failed
58-
raise json.JSONDecodeError("Failed to parse JSON after all attempts", original_content, 0)
59+
if not self.add_json_in_prompt:
60+
error = json.JSONDecodeError(
61+
"Failed to parse JSON after all attempts, maybe because 'messages' must contain the word 'json' in some form",
62+
original_content,
63+
0,
64+
)
65+
error.message = "Failed to parse JSON after all attempts, maybe because 'messages' must contain the word 'json' in some form" # type: ignore[attr-defined]
66+
raise error
67+
else:
68+
raise json.JSONDecodeError("Failed to parse JSON after all attempts", original_content, 0)
5969

6070
def _direct_parse(self, content: str) -> str:
6171
"""Strategy 1: Direct parsing (including handling extra data)"""
@@ -528,12 +538,16 @@ def _create_chat_completion_auto_continue(
528538
seed: Optional[int] = None,
529539
json_target_type: Optional[str] = None,
530540
add_json_in_prompt: bool = False,
541+
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
531542
**kwargs: Any,
532543
) -> str:
533544
"""
534545
Call the chat completion function and automatically continue the conversation if the finish_reason is length.
535546
"""
536547

548+
if response_format is None and json_mode:
549+
response_format = {"type": "json_object"}
550+
537551
# 0) return directly if cache is hit
538552
if seed is None and LLM_SETTINGS.use_auto_chat_cache_seed_gen:
539553
seed = LLM_CACHE_SEED_GEN.get_next_seed()
@@ -555,11 +569,11 @@ def _create_chat_completion_auto_continue(
555569
# Loop to get a full response
556570
try_n = 6
557571
for _ in range(try_n): # for some long code, 3 times may not enough for reasoning models
558-
if json_mode and add_json_in_prompt:
572+
if response_format == {"type": "json_object"} and add_json_in_prompt:
559573
self._add_json_in_prompt(new_messages)
560574
response, finish_reason = self._create_chat_completion_inner_function(
561575
messages=new_messages,
562-
json_mode=json_mode,
576+
response_format=response_format,
563577
**kwargs,
564578
)
565579
all_response += response
@@ -571,21 +585,31 @@ def _create_chat_completion_auto_continue(
571585

572586
# 2) refine the response and return
573587
if LLM_SETTINGS.reasoning_think_rm:
588+
# Strategy 1: Try to match complete <think>...</think> pattern
574589
match = re.search(r"<think>(.*?)</think>(.*)", all_response, re.DOTALL)
575-
_, all_response = match.groups() if match else ("", all_response)
590+
if match:
591+
_, all_response = match.groups()
592+
else:
593+
# Strategy 2: If no complete match, try to match only </think>
594+
match = re.search(r"</think>(.*)", all_response, re.DOTALL)
595+
if match:
596+
all_response = match.group(1)
597+
# If no match at all, keep original content
576598

577599
# 3) format checking
578-
if json_mode or json_target_type:
579-
parser = JSONParser()
600+
if response_format == {"type": "json_object"} or json_target_type:
601+
parser = JSONParser(add_json_in_prompt=add_json_in_prompt)
580602
all_response = parser.parse(all_response)
581603
if json_target_type:
582604
# deepseek will enter this branch
583605
TypeAdapter(json_target_type).validate_json(all_response)
584606

585-
if (response_format := kwargs.get("response_format")) is not None:
607+
if response_format is not None:
586608
if not isinstance(response_format, dict) and issubclass(response_format, BaseModel):
587609
# It may raise TypeError if initialization fails
588610
response_format(**json.loads(all_response))
611+
elif response_format == {"type": "json_object"}:
612+
logger.info(f"Using OpenAI response format: {response_format}")
589613
else:
590614
logger.warning(f"Unknown response_format: {response_format}, skipping validation.")
591615
if self.dump_chat_cache:
@@ -642,7 +666,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
642666
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
643667
self,
644668
messages: list[dict[str, Any]],
645-
json_mode: bool = False,
669+
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
646670
*args,
647671
**kwargs,
648672
) -> tuple[str, str | None]:

‎rdagent/oai/backend/deprec.py‎

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import uuid
1313
from copy import deepcopy
1414
from pathlib import Path
15-
from typing import Any, Optional, cast
15+
from typing import Any, Optional, Type, Union, cast
1616

1717
import numpy as np
1818
import openai
1919
import tiktoken
2020
from openai.types.chat import ChatCompletion
21+
from pydantic import BaseModel
2122

2223
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass, import_class
2324
from rdagent.log import LogColors
@@ -294,7 +295,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
294295
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
295296
self,
296297
messages: list[dict[str, Any]],
297-
json_mode: bool = False,
298+
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
298299
add_json_in_prompt: bool = False,
299300
*args,
300301
**kwargs,
@@ -414,13 +415,14 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
414415
frequency_penalty=frequency_penalty,
415416
presence_penalty=presence_penalty,
416417
)
417-
if json_mode:
418-
if add_json_in_prompt:
419-
for message in messages[::-1]:
420-
message["content"] = message["content"] + "\nPlease respond in json format."
421-
if message["role"] == LLM_SETTINGS.system_prompt_role:
422-
# NOTE: assumption: systemprompt is always the first message
423-
break
418+
419+
# FIX what if the model does not support response_schema
420+
if response_format == {"type": "json_object"} and add_json_in_prompt:
421+
for message in messages[::-1]:
422+
message["content"] = message["content"] + "\nPlease respond in json format."
423+
if message["role"] == LLM_SETTINGS.system_prompt_role:
424+
# NOTE: assumption: systemprompt is always the first message
425+
break
424426
call_kwargs["response_format"] = {"type": "json_object"}
425427
response = self.chat_client.chat.completions.create(**call_kwargs)
426428

‎rdagent/oai/backend/litellm.py‎

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copyreg
2-
from typing import Any, Literal, cast
2+
from typing import Any, Literal, Optional, Type, Union, cast
33

44
import numpy as np
55
from litellm import (
@@ -11,6 +11,7 @@
1111
supports_response_schema,
1212
token_counter,
1313
)
14+
from pydantic import BaseModel
1415

1516
from rdagent.log import LogColors
1617
from rdagent.log import rdagent_logger as logger
@@ -86,23 +87,24 @@ def _create_embedding_inner_function(
8687
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
8788
self,
8889
messages: list[dict[str, Any]],
89-
json_mode: bool = False,
90+
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
9091
*args,
9192
**kwargs,
9293
) -> tuple[str, str | None]:
9394
"""
9495
Call the chat completion function
9596
"""
96-
if json_mode and supports_response_schema(model=LITELLM_SETTINGS.chat_model):
97-
kwargs["response_format"] = {"type": "json_object"}
9897

99-
elif not supports_response_schema(model=LITELLM_SETTINGS.chat_model) and "response_format" in kwargs:
98+
if response_format and not supports_response_schema(model=LITELLM_SETTINGS.chat_model):
10099
# Deepseek will enter this branch
101100
logger.warning(
102101
f"{LogColors.RED}Model {LITELLM_SETTINGS.chat_model} does not support response schema, ignoring response_format argument.{LogColors.END}",
103102
tag="llm_messages",
104103
)
105-
kwargs.pop("response_format")
104+
response_format = None
105+
106+
if response_format:
107+
kwargs["response_format"] = response_format
106108

107109
if LITELLM_SETTINGS.log_llm_chat_content:
108110
logger.info(self._build_log_messages(messages), tag="llm_messages")

‎rdagent/scenarios/data_science/proposal/exp_gen/prompts.yaml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,5 @@ output_format:
344344
Design a specific and detailed Pipeline task based on the given hypothesis. The output should be detailed enough to directly implement the corresponding code.
345345
The output should follow JSON format. The schema is as follows:
346346
{
347-
"description": "A precise and comprehensive description of the main workflow script (`main.py`)",
347+
"description": "A detailed, step-by-step implementation guide for `main.py` that synthesizes planned modifications and code structure into a comprehensive coding plan. Must be formatted in Markdown with level-3 headings (###) organizing logical sections, key decision points, and implementation steps. Should provide sufficient detail covering implementation flow, algorithms, data handling, and key logic points for unambiguous developer execution.",
348348
}

‎rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml‎

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,23 +355,36 @@ task_gen:
355355
If you are confident in a specific value based on strong evidence, prior experiments, or clear rationale, specify the value clearly.
356356
{% include "scenarios.data_science.share:spec.hyperparameter" %}
357357
358+
358359
{% if task_output_format is not none %}
359-
## [Partial Response Format 1] Task Output Format:
360+
361+
# Output Format
362+
363+
{% if not workflow_check %}
364+
365+
{{ task_output_format }}
366+
367+
{% else %}
368+
369+
There are two steps in the task. But you should adhere to the final output format.
370+
371+
## [Partial Response Format 1]
372+
### Step1: **Task Output Format** :
360373
{{ task_output_format }}
361374
362-
{% if workflow_check %}
363-
# Step 2: Workflow Update
375+
### Step 2: **Workflow Update** :
364376
Since components have dependencies, your second task is to update the workflow to reflect the changes made to the target component. Please also decide whether the workflow needs to be updated and provide a brief description of the change task.
365377
{{ component_desc }}
366-
[Partial Response Format 2] Your generated workflow description should be a simple text and the following agent will do the implementation. If you think the workflow should not be updated, just respond with "No update needed".
367-
{% endif %}
368378
369-
Your final output should strictly adhere to the following JSON format.
379+
## [Partial Response Format 2] Your generated workflow description should be a simple text and the following agent will do the implementation. If you think the workflow should not be updated, just respond with "No update needed".
380+
381+
At last, your final output should strictly adhere to the following JSON format.
370382
{
371-
"task_design": ---The dict corresponding to task output format---,
372-
{% if workflow_check %}"workflow_update": ---A string corresponding to workflow description--- {% endif %}
383+
"task_design": a dict which strictly adheres to the **Task Output Format** in Step 1,
384+
"workflow_update": "A string which is a precise and comprehensive description of the Workflow Update, or 'No update needed' if no changes are required."
373385
}
374386
{% endif %}
387+
{% endif %}
375388
376389
user: |-
377390
# Competition Scenario Description
@@ -489,3 +502,4 @@ output_format:
489502
}
490503
491504
505+

‎rdagent/scenarios/data_science/proposal/exp_gen/proposal.py‎

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -729,11 +729,11 @@ def task_gen(
729729
else:
730730
component_info = get_component(hypotheses[0].component)
731731
data_folder_info = self.scen.processed_data_folder_description
732+
workflow_check = not pipeline and hypotheses[0].component != "Workflow"
732733
sys_prompt = T(".prompts_v2:task_gen.system").r(
733734
task_output_format=component_info["task_output_format"] if not self.supports_response_schema else None,
734-
# task_output_format=component_info["task_output_format"],
735735
component_desc=component_desc,
736-
workflow_check=not pipeline and hypotheses[0].component != "Workflow",
736+
workflow_check=workflow_check,
737737
)
738738
user_prompt = T(".prompts_v2:task_gen.user").r(
739739
scenario_desc=scenario_desc,
@@ -743,37 +743,47 @@ def task_gen(
743743
failed_exp_and_feedback_list_desc=failed_exp_feedback_list_desc,
744744
eda_improvement=fb_to_sota_exp.eda_improvement if fb_to_sota_exp else None,
745745
)
746+
746747
response = APIBackend().build_messages_and_create_chat_completion(
747748
user_prompt=user_prompt,
748749
system_prompt=sys_prompt,
749750
response_format=CodingSketch if self.supports_response_schema else {"type": "json_object"},
750751
json_target_type=Dict[str, str | Dict[str, str]] if not self.supports_response_schema else None,
751752
)
753+
752754
task_dict = json.loads(response)
753-
task_design = (
754-
task_dict.get("task_design", {}) if not self.supports_response_schema else task_dict.get("sketch", {})
755-
)
756-
logger.info(f"Task design:\n{task_design}")
755+
756+
# 1) explain the response and get main task_description
757+
not_found_str = f"{component_info['target_name']} description not provided"
758+
if self.supports_response_schema:
759+
# task_dict: {"sketch": str, ...}
760+
task_desc = task_dict.get("sketch", not_found_str)
761+
else:
762+
if workflow_check:
763+
# task_dict: {"task_design": ...., "workflow_update": ....}
764+
task_desc = task_dict.get("task_design", {}).get("description", not_found_str)
765+
else:
766+
# task_dict: {"description": ....}
767+
task_desc = task_dict.get("description", not_found_str)
768+
# task_desc: str, a description of the task
769+
770+
# 2) create the main task
771+
logger.info(f"Task design:\n{task_desc}")
757772
task_name = hypotheses[0].component
758-
description = (
759-
task_design
760-
if isinstance(task_design, str)
761-
else task_design.get("description", f"{component_info['target_name']} description not provided")
762-
)
763773
task_class = component_info["task_class"]
764774
task = task_class(
765775
name=task_name,
766-
description=description,
776+
description=task_desc,
767777
)
768-
new_workflow_desc = task_dict.get("workflow_update", "No update needed")
769778
exp = DSExperiment(pending_tasks_list=[[task]], hypothesis=hypotheses[0])
770-
# exp.experiment_workspace.inject_code_from_folder(sota_exp.experiment_workspace.workspace_path)
771779
if sota_exp is not None:
772780
exp.experiment_workspace.inject_code_from_file_dict(sota_exp.experiment_workspace)
773-
if not pipeline and new_workflow_desc != "No update needed":
781+
782+
# 3) create the workflow update task
783+
if workflow_check:
774784
workflow_task = WorkflowTask(
775785
name="Workflow",
776-
description=new_workflow_desc,
786+
description=task_dict.get("workflow_update", "No update needed"),
777787
)
778788
exp.pending_tasks_list.append([workflow_task])
779789
return exp

0 commit comments

Comments
 (0)