Skip to content

Commit 273b3a3

Browse files
committed
Improve errors
Fail with suggestions for sub commands that has sub commands Cobra doesn't do this out of the box, fixing it upstream would be a breaking change, so do it here.
1 parent 63f58e9 commit 273b3a3

File tree

2 files changed

+158
-29
lines changed

2 files changed

+158
-29
lines changed

‎cobrakai.go‎

Lines changed: 114 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package cobrakai
22

33
import (
44
"context"
5+
"errors"
6+
"fmt"
7+
"strings"
58

69
"github.com/spf13/cobra"
710
)
@@ -26,14 +29,8 @@ type Commander interface {
2629
Commands() []Commander
2730
}
2831

29-
// Executer is the execution entry point.
30-
// The args are usually filled with os.Args[1:].
31-
type Executer interface {
32-
Execute(ctx context.Context, args []string) (*Commandeer, error)
33-
}
34-
3532
// New creates a new Executer from the command tree in Commander.
36-
func New(rootCmd Commander) (Executer, error) {
33+
func New(rootCmd Commander) (*Exec, error) {
3734
rootCd := &Commandeer{
3835
Command: rootCmd,
3936
}
@@ -62,7 +59,7 @@ func New(rootCmd Commander) (Executer, error) {
6259
return nil, err
6360
}
6461

65-
return &root{c: rootCd}, nil
62+
return &Exec{c: rootCd}, nil
6663

6764
}
6865

@@ -95,15 +92,29 @@ func (c *Commandeer) init() error {
9592
return initc(cd)
9693
}
9794

95+
type runErr struct {
96+
err error
97+
}
98+
99+
func (r *runErr) Error() string {
100+
return fmt.Sprintf("run error: %v", r.err)
101+
}
102+
98103
func (c *Commandeer) compile() error {
99104
c.CobraCommand = &cobra.Command{
100105
Use: c.Command.Name(),
101106
RunE: func(cmd *cobra.Command, args []string) error {
102-
return c.Command.Run(cmd.Context(), args)
107+
if err := c.Command.Run(cmd.Context(), args); err != nil {
108+
return &runErr{err: err}
109+
}
110+
return nil
103111
},
104112
PreRunE: func(cmd *cobra.Command, args []string) error {
105113
return c.init()
106114
},
115+
SilenceErrors: true,
116+
SilenceUsage: true,
117+
SuggestionsMinimumDistance: 2,
107118
}
108119

109120
// This is where the flags, short and long description etc. are added
@@ -120,28 +131,108 @@ func (c *Commandeer) compile() error {
120131
return nil
121132
}
122133

123-
type root struct {
134+
// Exec provides methods to execute the command tree.
135+
type Exec struct {
124136
c *Commandeer
125137
}
126138

127-
func (r *root) Execute(ctx context.Context, args []string) (*Commandeer, error) {
139+
// Execute executes the command tree starting from the root command.
140+
// The args are usually filled with os.Args[1:].
141+
func (r *Exec) Execute(ctx context.Context, args []string) (*Commandeer, error) {
128142
r.c.CobraCommand.SetArgs(args)
129143
cobraCommand, err := r.c.CobraCommand.ExecuteContextC(ctx)
130-
if err != nil {
131-
return nil, err
132-
}
133-
// Find the commandeer that was executed.
134-
var find func(*cobra.Command, *Commandeer) *Commandeer
135-
find = func(what *cobra.Command, in *Commandeer) *Commandeer {
136-
if in.CobraCommand == what {
137-
return in
144+
var cd *Commandeer
145+
if cobraCommand != nil {
146+
if err == nil {
147+
err = checkArgs(cobraCommand, args)
138148
}
139-
for _, in2 := range in.commandeers {
140-
if found := find(what, in2); found != nil {
141-
return found
149+
150+
// Find the commandeer that was executed.
151+
var find func(*cobra.Command, *Commandeer) *Commandeer
152+
find = func(what *cobra.Command, in *Commandeer) *Commandeer {
153+
if in.CobraCommand == what {
154+
return in
155+
}
156+
for _, in2 := range in.commandeers {
157+
if found := find(what, in2); found != nil {
158+
return found
159+
}
142160
}
161+
return nil
162+
}
163+
cd = find(cobraCommand, r.c)
164+
}
165+
166+
return cd, wrapErr(err)
167+
}
168+
169+
// CommandError is returned when a command fails because of a user error (unknown command, invalid flag etc.).
170+
type CommandError struct {
171+
Err error
172+
}
173+
174+
func (e *CommandError) Error() string {
175+
return fmt.Sprintf("command error: %v", e.Err)
176+
}
177+
178+
// IsCommandError reports whether any error in err's tree matches CommandError.
179+
func IsCommandError(err error) bool {
180+
switch err.(type) {
181+
case *CommandError:
182+
return true
183+
}
184+
return errors.Is(err, &CommandError{})
185+
}
186+
187+
func wrapErr(err error) error {
188+
if err == nil {
189+
return nil
190+
}
191+
192+
if rerr, ok := err.(*runErr); ok {
193+
err = rerr.err
194+
}
195+
196+
// All other errors are coming from Cobra.
197+
return &CommandError{Err: err}
198+
}
199+
200+
// Cobra only does suggestions for the root command.
201+
// See https://github.com/spf13/cobra/pull/1500
202+
func checkArgs(cmd *cobra.Command, args []string) error {
203+
// no subcommand, always take args.
204+
if !cmd.HasSubCommands() {
205+
return nil
206+
}
207+
208+
var commandName string
209+
for _, arg := range args {
210+
if strings.HasPrefix(arg, "-") {
211+
break
143212
}
213+
commandName = arg
214+
}
215+
216+
if commandName == "" || cmd.Name() == commandName {
144217
return nil
145218
}
146-
return find(cobraCommand, r.c), nil
219+
220+
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), findSuggestions(cmd, commandName))
221+
}
222+
223+
func findSuggestions(cmd *cobra.Command, arg string) string {
224+
if cmd.DisableSuggestions {
225+
return ""
226+
}
227+
if cmd.SuggestionsMinimumDistance <= 0 {
228+
cmd.SuggestionsMinimumDistance = 2
229+
}
230+
suggestionsString := ""
231+
if suggestions := cmd.SuggestionsFor(arg); len(suggestions) > 0 {
232+
suggestionsString += "\n\nDid you mean this?\n"
233+
for _, s := range suggestions {
234+
suggestionsString += fmt.Sprintf("\t%v\n", s)
235+
}
236+
}
237+
return suggestionsString
147238
}

‎cobrakai_test.go‎

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@ import (
1111
"github.com/spf13/cobra"
1212
)
1313

14-
func TestCobraKai(t *testing.T) {
15-
c := qt.New(t)
16-
17-
rootCmd := &rootCommand{name: "root",
14+
func testCommands() *rootCommand {
15+
return &rootCommand{name: "root",
1816
commands: []cobrakai.Commander{
1917
&lvl1Command{name: "foo"},
2018
&lvl1Command{name: "bar",
@@ -25,6 +23,13 @@ func TestCobraKai(t *testing.T) {
2523
},
2624
}
2725

26+
}
27+
28+
func TestCobraKai(t *testing.T) {
29+
c := qt.New(t)
30+
31+
rootCmd := testCommands()
32+
2833
x, err := cobrakai.New(rootCmd)
2934
c.Assert(err, qt.IsNil)
3035
// This can be anything, just used to make sure the same context is passed all the way.
@@ -71,6 +76,40 @@ func TestCobraKai(t *testing.T) {
7176

7277
}
7378

79+
func TestErrors(t *testing.T) {
80+
c := qt.New(t)
81+
82+
c.Run("unknown similar command", func(c *qt.C) {
83+
x, err := cobrakai.New(testCommands())
84+
c.Assert(err, qt.IsNil)
85+
_, err = x.Execute(context.Background(), []string{"fooo"})
86+
c.Assert(err, qt.Not(qt.IsNil))
87+
c.Assert(err.Error(), qt.Contains, "unknown")
88+
c.Assert(err.Error(), qt.Contains, "Did you mean this?")
89+
c.Assert(cobrakai.IsCommandError(err), qt.Equals, true)
90+
})
91+
92+
c.Run("unknown similar sub command", func(c *qt.C) {
93+
x, err := cobrakai.New(testCommands())
94+
c.Assert(err, qt.IsNil)
95+
_, err = x.Execute(context.Background(), []string{"bar", "bazz"})
96+
c.Assert(err, qt.Not(qt.IsNil))
97+
c.Assert(err.Error(), qt.Contains, "unknown")
98+
c.Assert(err.Error(), qt.Contains, "Did you mean this?")
99+
c.Assert(cobrakai.IsCommandError(err), qt.Equals, true)
100+
})
101+
102+
c.Run("unknown flag", func(c *qt.C) {
103+
x, err := cobrakai.New(testCommands())
104+
c.Assert(err, qt.IsNil)
105+
_, err = x.Execute(context.Background(), []string{"bar", "--unknown"})
106+
c.Assert(err, qt.Not(qt.IsNil))
107+
c.Assert(err.Error(), qt.Contains, "unknown")
108+
c.Assert(cobrakai.IsCommandError(err), qt.Equals, true)
109+
})
110+
111+
}
112+
74113
func Example() {
75114
rootCmd := &rootCommand{name: "root",
76115
commands: []cobrakai.Commander{
@@ -83,12 +122,11 @@ func Example() {
83122
},
84123
}
85124

86-
args := []string{"bar", "baz", "--localFlagName", "baz_local", "--persistentFlagName", "baz_persistent"}
87125
x, err := cobrakai.New(rootCmd)
88126
if err != nil {
89127
log.Fatal(err)
90128
}
91-
cd, err := x.Execute(context.Background(), args)
129+
cd, err := x.Execute(context.Background(), []string{"bar", "baz", "--localFlagName", "baz_local", "--persistentFlagName", "baz_persistent"})
92130
if err != nil {
93131
log.Fatal(err)
94132
}

0 commit comments

Comments
 (0)