-
Notifications
You must be signed in to change notification settings - Fork 348
Expand file tree
/
Copy pathapi.go
More file actions
125 lines (103 loc) · 3.62 KB
/
api.go
File metadata and controls
125 lines (103 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package root
import (
"errors"
"fmt"
"log/slog"
"os"
"time"
"github.com/spf13/cobra"
"github.com/docker/docker-agent/pkg/cli"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/server"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/telemetry"
)
type apiFlags struct {
listenAddr string
sessionDB string
pullIntervalMins int
fakeResponses string
recordPath string
runConfig config.RuntimeConfig
}
func newAPICmd() *cobra.Command {
var flags apiFlags
cmd := &cobra.Command{
Use: "api <agent-file>|<agents-dir>",
Short: "Start the API server",
Args: cobra.ExactArgs(1),
RunE: flags.runAPICommand,
}
cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8080", "Address to listen on")
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", "session.db", "Path to the session database")
cmd.PersistentFlags().IntVar(&flags.pullIntervalMins, "pull-interval", 0, "Auto-pull OCI reference every N minutes (0 = disabled)")
cmd.PersistentFlags().StringVar(&flags.fakeResponses, "fake", "", "Replay AI responses from cassette file (for testing)")
cmd.PersistentFlags().StringVar(&flags.recordPath, "record", "", "Record AI API interactions to cassette file")
cmd.MarkFlagsMutuallyExclusive("fake", "record")
addRuntimeConfigFlags(cmd, &flags.runConfig)
return cmd
}
func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr error) {
ctx := cmd.Context()
telemetry.TrackCommand(ctx, "serve", append([]string{"api"}, args...))
defer func() { // do not inline this defer so that commandErr is not resolved early
telemetry.TrackCommandError(ctx, "serve", append([]string{"api"}, args...), commandErr)
}()
out := cli.NewPrinter(cmd.OutOrStdout())
agentsPath := args[0]
// Make sure no question is ever asked to the user in api mode.
os.Stdin = nil
// Start fake proxy if --fake is specified
cleanup, err := setupFakeProxy(f.fakeResponses, 0, &f.runConfig)
if err != nil {
return err
}
defer func() {
if err := cleanup(); err != nil {
slog.Error("Failed to cleanup fake proxy", "error", err)
}
}()
// Start recording proxy if --record is specified
_, recordCleanup, err := setupRecordingProxy(f.recordPath, &f.runConfig)
if err != nil {
return err
}
defer func() {
if err := recordCleanup(); err != nil {
slog.Error("Failed to cleanup recording proxy", "error", err)
}
}()
if f.pullIntervalMins > 0 && !config.IsOCIReference(agentsPath) && !config.IsURLReference(agentsPath) {
return errors.New("--pull-interval flag can only be used with OCI or URL references, not local files")
}
ln, lnCleanup, err := newListener(ctx, f.listenAddr)
if err != nil {
return err
}
defer lnCleanup()
out.Println("Listening on", ln.Addr().String())
slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String())
// Expand tilde in session database path
sessionDB, err := expandTilde(f.sessionDB)
if err != nil {
return err
}
sessionStore, err := session.NewSQLiteSessionStore(sessionDB)
if err != nil {
return fmt.Errorf("creating session store: %w", err)
}
defer func() {
if err := sessionStore.Close(); err != nil {
slog.Error("Failed to close session store", "error", err)
}
}()
sources, err := config.ResolveSources(agentsPath, f.runConfig.EnvProvider())
if err != nil {
return fmt.Errorf("resolving agent sources: %w", err)
}
s, err := server.New(ctx, sessionStore, &f.runConfig, time.Duration(f.pullIntervalMins)*time.Minute, sources)
if err != nil {
return fmt.Errorf("creating server: %w", err)
}
return s.Serve(ctx, ln)
}