Skip to content

Commit a46f4e0

Browse files
committed
Persist ACP sessions to default sqlite db unless specified with --session-db flag
Generally makes agent and session setup follow more the patterns used for the api and run commands Also forces tilde expansion on the run and api commands when passing the session-db as --session-db=~/somepath or --session-db "~/somepath" Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 3406abc commit a46f4e0

7 files changed

Lines changed: 370 additions & 23 deletions

File tree

‎cmd/root/acp.go‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
package root
22

33
import (
4+
"path/filepath"
5+
46
"github.com/spf13/cobra"
57

68
"github.com/docker/cagent/pkg/acp"
79
"github.com/docker/cagent/pkg/config"
10+
"github.com/docker/cagent/pkg/paths"
811
"github.com/docker/cagent/pkg/telemetry"
912
)
1013

1114
type acpFlags struct {
1215
runConfig config.RuntimeConfig
16+
sessionDB string
1317
}
1418

1519
func newACPCmd() *cobra.Command {
@@ -28,6 +32,7 @@ func newACPCmd() *cobra.Command {
2832
}
2933

3034
addRuntimeConfigFlags(cmd, &flags.runConfig)
35+
cmd.Flags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database")
3136

3237
return cmd
3338
}
@@ -38,5 +43,11 @@ func (f *acpFlags) runACPCommand(cmd *cobra.Command, args []string) error {
3843
ctx := cmd.Context()
3944
agentFilename := args[0]
4045

41-
return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig)
46+
// Expand tilde in session database path
47+
sessionDB, err := expandTilde(f.sessionDB)
48+
if err != nil {
49+
return err
50+
}
51+
52+
return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig, sessionDB)
4253
}

‎cmd/root/api.go‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,13 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
136136

137137
slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String())
138138

139-
sessionStore, err := session.NewSQLiteSessionStore(f.sessionDB)
139+
// Expand tilde in session database path
140+
sessionDB, err := expandTilde(f.sessionDB)
141+
if err != nil {
142+
return err
143+
}
144+
145+
sessionStore, err := session.NewSQLiteSessionStore(sessionDB)
140146
if err != nil {
141147
return fmt.Errorf("creating session store: %w", err)
142148
}

‎cmd/root/run.go‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,13 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes
334334
return nil, nil, err
335335
}
336336

337-
sessStore, err := session.NewSQLiteSessionStore(f.sessionDB)
337+
// Expand tilde in session database path
338+
sessionDB, err := expandTilde(f.sessionDB)
339+
if err != nil {
340+
return nil, nil, err
341+
}
342+
343+
sessStore, err := session.NewSQLiteSessionStore(sessionDB)
338344
if err != nil {
339345
return nil, nil, fmt.Errorf("creating session store: %w", err)
340346
}

‎cmd/root/tilde_test.go‎

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package root
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/docker/cagent/pkg/paths"
11+
)
12+
13+
func TestExpandTilde(t *testing.T) {
14+
t.Parallel()
15+
16+
homeDir := paths.GetHomeDir()
17+
require.NotEmpty(t, homeDir, "Home directory should be available for tests")
18+
19+
tests := []struct {
20+
name string
21+
input string
22+
expected string
23+
wantErr bool
24+
}{
25+
{
26+
name: "expands_tilde_prefix",
27+
input: "~/session.db",
28+
expected: filepath.Join(homeDir, "session.db"),
29+
},
30+
{
31+
name: "expands_tilde_with_nested_path",
32+
input: "~/.cagent/session.db",
33+
expected: filepath.Join(homeDir, ".cagent", "session.db"),
34+
},
35+
{
36+
name: "expands_tilde_with_deep_path",
37+
input: "~/path/to/some/file.db",
38+
expected: filepath.Join(homeDir, "path", "to", "some", "file.db"),
39+
},
40+
{
41+
name: "absolute_path_unchanged",
42+
input: "/absolute/path/session.db",
43+
expected: "/absolute/path/session.db",
44+
},
45+
{
46+
name: "relative_path_unchanged",
47+
input: "relative/path/session.db",
48+
expected: "relative/path/session.db",
49+
},
50+
{
51+
name: "tilde_in_middle_unchanged",
52+
input: "/some/~/path/session.db",
53+
expected: "/some/~/path/session.db",
54+
},
55+
{
56+
name: "tilde_without_slash_unchanged",
57+
input: "~something",
58+
expected: "~something",
59+
},
60+
{
61+
name: "just_tilde_slash_expands",
62+
input: "~/",
63+
expected: homeDir,
64+
},
65+
{
66+
name: "empty_string_unchanged",
67+
input: "",
68+
expected: "",
69+
},
70+
{
71+
name: "dot_path_unchanged",
72+
input: "./session.db",
73+
expected: "./session.db",
74+
},
75+
}
76+
77+
for _, tt := range tests {
78+
t.Run(tt.name, func(t *testing.T) {
79+
t.Parallel()
80+
81+
result, err := expandTilde(tt.input)
82+
83+
if tt.wantErr {
84+
require.Error(t, err)
85+
return
86+
}
87+
88+
require.NoError(t, err)
89+
assert.Equal(t, tt.expected, result)
90+
})
91+
}
92+
}

‎pkg/acp/agent.go‎

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ import (
66
"encoding/json"
77
"fmt"
88
"log/slog"
9+
"os"
910
"path/filepath"
1011
"slices"
1112
"strings"
1213
"sync"
1314

1415
"github.com/coder/acp-go-sdk"
15-
"github.com/google/uuid"
1616

1717
"github.com/docker/cagent/pkg/config"
1818
"github.com/docker/cagent/pkg/runtime"
@@ -26,9 +26,10 @@ import (
2626

2727
// Agent implements the ACP Agent interface for cagent
2828
type Agent struct {
29-
agentSource config.Source
30-
runConfig *config.RuntimeConfig
31-
sessions map[string]*Session
29+
agentSource config.Source
30+
runConfig *config.RuntimeConfig
31+
sessionStore session.Store
32+
sessions map[string]*Session
3233

3334
conn *acp.AgentSideConnection
3435
team *team.Team
@@ -47,11 +48,12 @@ type Session struct {
4748
}
4849

4950
// NewAgent creates a new ACP agent
50-
func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig) *Agent {
51+
func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig, sessionStore session.Store) *Agent {
5152
return &Agent{
52-
agentSource: agentSource,
53-
runConfig: runConfig,
54-
sessions: make(map[string]*Session),
53+
agentSource: agentSource,
54+
runConfig: runConfig,
55+
sessionStore: sessionStore,
56+
sessions: make(map[string]*Session),
5557
}
5658
}
5759

@@ -108,30 +110,75 @@ func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (a
108110
}
109111

110112
// NewSession implements [acp.Agent]
111-
func (a *Agent) NewSession(_ context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) {
112-
sid := uuid.New().String()
113-
slog.Debug("ACP NewSession called", "session_id", sid, "cwd", params.Cwd)
113+
func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) {
114+
slog.Debug("ACP NewSession called", "cwd", params.Cwd)
114115

115116
// Log warning if MCP servers are provided (not yet supported)
116117
if len(params.McpServers) > 0 {
117118
slog.Warn("MCP servers provided by client are not yet supported", "count", len(params.McpServers))
118119
}
119120

120-
rt, err := runtime.New(a.team, runtime.WithCurrentAgent("root"))
121+
// Validate and normalize working directory
122+
var workingDir string
123+
if wd := strings.TrimSpace(params.Cwd); wd != "" {
124+
absWd, err := filepath.Abs(wd)
125+
if err != nil {
126+
return acp.NewSessionResponse{}, fmt.Errorf("invalid working directory: %w", err)
127+
}
128+
info, err := os.Stat(absWd)
129+
if err != nil {
130+
return acp.NewSessionResponse{}, fmt.Errorf("working directory does not exist: %w", err)
131+
}
132+
if !info.IsDir() {
133+
return acp.NewSessionResponse{}, fmt.Errorf("working directory must be a directory")
134+
}
135+
workingDir = absWd
136+
}
137+
138+
rt, err := runtime.New(a.team,
139+
runtime.WithCurrentAgent("root"),
140+
runtime.WithSessionStore(a.sessionStore),
141+
)
121142
if err != nil {
122143
return acp.NewSessionResponse{}, fmt.Errorf("failed to create runtime: %w", err)
123144
}
124145

146+
// Get root agent config for session settings
147+
rootAgent, err := a.team.Agent("root")
148+
if err != nil {
149+
return acp.NewSessionResponse{}, fmt.Errorf("failed to get root agent: %w", err)
150+
}
151+
152+
// Build session options (title will be set after we have the session ID)
153+
sessOpts := []session.Opt{
154+
session.WithMaxIterations(rootAgent.MaxIterations()),
155+
session.WithThinking(rootAgent.ThinkingConfigured()),
156+
}
157+
if workingDir != "" {
158+
sessOpts = append(sessOpts, session.WithWorkingDir(workingDir))
159+
}
160+
161+
// Create session - use its auto-generated ID
162+
sess := session.New(sessOpts...)
163+
sess.Title = "ACP Session " + sess.ID
164+
165+
// Persist session to the store
166+
if err := a.sessionStore.AddSession(ctx, sess); err != nil {
167+
return acp.NewSessionResponse{}, fmt.Errorf("failed to persist session: %w", err)
168+
}
169+
170+
slog.Debug("ACP session created", "session_id", sess.ID)
171+
125172
a.mu.Lock()
126-
a.sessions[sid] = &Session{
127-
id: sid,
128-
sess: session.New(session.WithTitle("ACP Session " + sid)),
173+
a.sessions[sess.ID] = &Session{
174+
id: sess.ID,
175+
sess: sess,
129176
rt: rt,
130-
workingDir: params.Cwd,
177+
workingDir: workingDir,
131178
}
132179
a.mu.Unlock()
133180

134-
return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil
181+
return acp.NewSessionResponse{SessionId: acp.SessionId(sess.ID)}, nil
135182
}
136183

137184
// Authenticate implements [acp.Agent]

0 commit comments

Comments
 (0)