Skip to content

Commit 77d0d4a

Browse files
authored
early-return: detect short deviated statements (#1396)
1 parent 456cbd0 commit 77d0d4a

File tree

5 files changed

+370
-0
lines changed

5 files changed

+370
-0
lines changed

‎internal/ifelse/branch.go‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ func (b Branch) IsShort() bool {
102102
return true
103103
case 1:
104104
return isShortStmt(b.block[0])
105+
case 2:
106+
return isShortStmt(b.block[1])
105107
}
106108
return false
107109
}

‎internal/ifelse/branch_test.go‎

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package ifelse
2+
3+
import (
4+
"go/ast"
5+
"go/token"
6+
"testing"
7+
)
8+
9+
func TestBlockBranch(t *testing.T) {
10+
t.Run("empty", func(t *testing.T) {
11+
block := &ast.BlockStmt{List: []ast.Stmt{}}
12+
b := BlockBranch(block)
13+
if b.BranchKind != Empty {
14+
t.Errorf("want Empty branch, got %v", b.BranchKind)
15+
}
16+
})
17+
t.Run("non empty", func(t *testing.T) {
18+
stmt := &ast.ReturnStmt{}
19+
block := &ast.BlockStmt{List: []ast.Stmt{stmt}}
20+
b := BlockBranch(block)
21+
if b.BranchKind != Return {
22+
t.Errorf("want Return branch, got %v", b.BranchKind)
23+
}
24+
})
25+
}
26+
27+
func TestStmtBranch(t *testing.T) {
28+
cases := []struct {
29+
name string
30+
stmt ast.Stmt
31+
kind BranchKind
32+
call *Call
33+
}{
34+
{
35+
name: "ReturnStmt",
36+
stmt: &ast.ReturnStmt{},
37+
kind: Return,
38+
},
39+
{
40+
name: "BreakStmt",
41+
stmt: &ast.BranchStmt{Tok: token.BREAK},
42+
kind: Break,
43+
},
44+
{
45+
name: "ContinueStmt",
46+
stmt: &ast.BranchStmt{Tok: token.CONTINUE},
47+
kind: Continue,
48+
},
49+
{
50+
name: "GotoStmt",
51+
stmt: &ast.BranchStmt{Tok: token.GOTO},
52+
kind: Goto,
53+
},
54+
{
55+
name: "EmptyStmt",
56+
stmt: &ast.EmptyStmt{},
57+
kind: Empty,
58+
},
59+
{
60+
name: "ExprStmt with DeviatingFunc (panic)",
61+
stmt: &ast.ExprStmt{
62+
X: &ast.CallExpr{
63+
Fun: &ast.Ident{Name: "panic"},
64+
},
65+
},
66+
kind: Panic,
67+
call: &Call{Name: "panic"},
68+
},
69+
{
70+
name: "ExprStmt with DeviatingFunc (os.Exit)",
71+
stmt: &ast.ExprStmt{
72+
X: &ast.CallExpr{
73+
Fun: &ast.SelectorExpr{
74+
X: &ast.Ident{Name: "os"},
75+
Sel: &ast.Ident{Name: "Exit"},
76+
},
77+
},
78+
},
79+
kind: Exit,
80+
call: &Call{Pkg: "os", Name: "Exit"},
81+
},
82+
{
83+
name: "ExprStmt with non-deviating func",
84+
stmt: &ast.ExprStmt{
85+
X: &ast.CallExpr{
86+
Fun: &ast.Ident{Name: "foo"},
87+
},
88+
},
89+
kind: Regular,
90+
},
91+
{
92+
name: "LabeledStmt wrapping ReturnStmt",
93+
stmt: &ast.LabeledStmt{
94+
Label: &ast.Ident{Name: "lbl"},
95+
Stmt: &ast.ReturnStmt{},
96+
},
97+
kind: Return,
98+
},
99+
{
100+
name: "LabeledStmt wrapping ExprStmt",
101+
stmt: &ast.LabeledStmt{
102+
Label: &ast.Ident{Name: "lbl"},
103+
Stmt: &ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.Ident{Name: "foo"}}},
104+
},
105+
kind: Regular,
106+
},
107+
{
108+
name: "BlockStmt with ReturnStmt",
109+
stmt: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{}}},
110+
kind: Return,
111+
},
112+
{
113+
name: "BlockStmt with ExprStmt",
114+
stmt: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.Ident{Name: "foo"}}}}},
115+
kind: Regular,
116+
},
117+
}
118+
for _, c := range cases {
119+
b := StmtBranch(c.stmt)
120+
if b.BranchKind != c.kind {
121+
t.Errorf("%s: want %v, got %v", c.name, c.kind, b.BranchKind)
122+
}
123+
if c.call != nil {
124+
if b.Call != *c.call {
125+
t.Errorf("%s: want Call %+v, got %+v", c.name, *c.call, b.Call)
126+
}
127+
}
128+
}
129+
}
130+
131+
func TestBranch_String_LongString(t *testing.T) {
132+
tests := []struct {
133+
name string
134+
branch Branch
135+
wantStr string
136+
wantLong string
137+
}{
138+
{
139+
name: "Return branch",
140+
branch: Branch{BranchKind: Return},
141+
wantStr: "{ ... return }",
142+
wantLong: "a return statement",
143+
},
144+
{
145+
name: "Panic branch with Call",
146+
branch: Branch{BranchKind: Panic, Call: Call{Name: "panic"}},
147+
wantStr: "{ ... panic() }",
148+
wantLong: "call to panic function",
149+
},
150+
{
151+
name: "Exit branch with Call",
152+
branch: Branch{BranchKind: Exit, Call: Call{Pkg: "os", Name: "Exit"}},
153+
wantStr: "{ ... os.Exit() }",
154+
wantLong: "call to os.Exit function",
155+
},
156+
{
157+
name: "Empty branch",
158+
branch: Branch{BranchKind: Empty},
159+
wantStr: "{ }",
160+
wantLong: "an empty block",
161+
},
162+
{
163+
name: "Regular branch",
164+
branch: Branch{BranchKind: Regular},
165+
wantStr: "{ ... }",
166+
wantLong: "a regular statement",
167+
},
168+
}
169+
for _, tt := range tests {
170+
if got := tt.branch.String(); got != tt.wantStr {
171+
t.Errorf("%s: String() = %q, want %q", tt.name, got, tt.wantStr)
172+
}
173+
if got := tt.branch.LongString(); got != tt.wantLong {
174+
t.Errorf("%s: LongString() = %q, want %q", tt.name, got, tt.wantLong)
175+
}
176+
}
177+
}
178+
179+
func TestBranch_HasDecls(t *testing.T) {
180+
tests := []struct {
181+
name string
182+
block []ast.Stmt
183+
want bool
184+
}{
185+
{
186+
name: "DeclStmt",
187+
block: []ast.Stmt{&ast.DeclStmt{}},
188+
want: true,
189+
},
190+
{
191+
name: "AssignStmt with :=",
192+
block: []ast.Stmt{&ast.AssignStmt{Tok: token.DEFINE}},
193+
want: true,
194+
},
195+
{
196+
name: "ExprStmt",
197+
block: []ast.Stmt{&ast.ExprStmt{}},
198+
want: false,
199+
},
200+
}
201+
for _, tt := range tests {
202+
b := Branch{block: tt.block}
203+
if got := b.HasDecls(); got != tt.want {
204+
t.Errorf("%s: want HasDecls to be %v, got %v", tt.name, tt.want, got)
205+
}
206+
}
207+
}
208+
209+
func TestBranch_IsShort(t *testing.T) {
210+
tests := []struct {
211+
name string
212+
block []ast.Stmt
213+
want bool
214+
}{
215+
{
216+
name: "nil block",
217+
block: nil,
218+
want: true,
219+
},
220+
{
221+
name: "single ExprStmt",
222+
block: []ast.Stmt{&ast.ExprStmt{}},
223+
want: true,
224+
},
225+
{
226+
name: "single BlockStmt",
227+
block: []ast.Stmt{&ast.BlockStmt{}},
228+
want: false,
229+
},
230+
{
231+
name: "two short statements",
232+
block: []ast.Stmt{&ast.ExprStmt{}, &ast.ExprStmt{}},
233+
want: true,
234+
},
235+
{
236+
name: "second non-short statement",
237+
block: []ast.Stmt{&ast.ExprStmt{}, &ast.BlockStmt{}},
238+
want: false,
239+
},
240+
{
241+
name: "three statements (should return false)",
242+
block: []ast.Stmt{&ast.ExprStmt{}, &ast.ExprStmt{}, &ast.ExprStmt{}},
243+
want: false,
244+
},
245+
}
246+
for _, tt := range tests {
247+
b := Branch{block: tt.block}
248+
if got := b.IsShort(); got != tt.want {
249+
t.Errorf("%s: want IsShort to be %v, got %v", tt.name, tt.want, got)
250+
}
251+
}
252+
}

‎test/early_return_test.go‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ func TestEarlyReturn(t *testing.T) {
1313
testRule(t, "early_return_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"preserve-scope"}})
1414
testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allowJump"}})
1515
testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allow-jump"}})
16+
testRule(t, "early_return_jump_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allow-jump", "preserve-scope"}})
1617
}

‎testdata/early_return_jump.go‎

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
package fixtures
44

5+
import (
6+
"fmt"
7+
"log"
8+
"log/slog"
9+
"net/http"
10+
"os"
11+
)
12+
513
func fn1() {
614
if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/
715
println()
@@ -113,3 +121,79 @@ func fn10() {
113121
}
114122
}
115123
}
124+
125+
func fn11() {
126+
if a() {
127+
println()
128+
os.Exit(1)
129+
}
130+
}
131+
132+
func fn12() {
133+
if a() {
134+
println()
135+
return
136+
}
137+
}
138+
139+
func fn13() {
140+
if err := a(); err != nil {
141+
println()
142+
panic(err)
143+
}
144+
}
145+
146+
func fn14() {
147+
if err := a(); err != nil {
148+
println()
149+
log.Fatal(err)
150+
}
151+
}
152+
153+
func fn15() {
154+
if err := a(); err != nil {
155+
println()
156+
log.Panic(err)
157+
}
158+
}
159+
160+
func fn16() {
161+
if err := a(); err != nil { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting (move short variable declaration to its own line if necessary)/
162+
println()
163+
println()
164+
log.Panic(err)
165+
}
166+
}
167+
168+
func fn17() {
169+
if err := a(); err != nil { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting (move short variable declaration to its own line if necessary)/
170+
println()
171+
println()
172+
println()
173+
panic(err)
174+
}
175+
}
176+
177+
func MustEncode[T any](w http.ResponseWriter, status int, v T) {
178+
if err := Encode(w, status, v); err != nil {
179+
slog.Error("Error encoding response", "err", err)
180+
return
181+
}
182+
}
183+
184+
func (c *client) renewAuthInfo() {
185+
err := RenewLease(c.ctx, c, "auth", c.authCreds, func() (*hashiVault.Secret, error) {
186+
authInfo, err := c.auth(c.v)
187+
if err != nil {
188+
return nil, fmt.Errorf("unable to renew auth info: %w", err)
189+
}
190+
191+
c.authCreds = authInfo
192+
193+
return authInfo, nil
194+
})
195+
if err != nil {
196+
slog.Error("unable to renew auth info", slog.String(loggingKeyError, err.Error()))
197+
os.Exit(1)
198+
}
199+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Test data for the early-return rule with allowJump option enabled
2+
3+
package fixtures
4+
5+
import (
6+
"os"
7+
)
8+
9+
func fn1() {
10+
if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/
11+
println()
12+
println()
13+
println()
14+
}
15+
}
16+
17+
func fn3() {
18+
if a() {
19+
println()
20+
os.Exit(1)
21+
}
22+
}
23+
24+
func fn4() {
25+
// No initializer, match as normal
26+
if cond { // MATCH /if c { ... } else { ... return } can be simplified to if !c { ... return } .../
27+
fn2()
28+
} else {
29+
return
30+
}
31+
}

0 commit comments

Comments
 (0)