Skip to content

Commit dcb3395

Browse files
committed
multi: Change all ocurrences to use NewMessage
This changes all ocurrences of message instantiation to use NewMessage(). This unifies all code for message init under a single code path. In the future, it may be possible to make all message fields unexported in order to better enforce message invariants.
1 parent 25dbad9 commit dcb3395

File tree

7 files changed

+145
-175
lines changed

7 files changed

+145
-175
lines changed

‎codec.go‎

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ func (d *Decoder) Decode() (*Message, error) {
5959
if err != nil {
6060
return nil, exc.WrapError("decode", err)
6161
}
62+
63+
// Special case an empty message to return a new MultiSegment message
64+
// ready for writing. This maintains compatibility to tests and older
65+
// implementation of message and arenas.
66+
if hdr.maxSegment() == 0 && total == 0 {
67+
msg, _ := NewMultiSegmentMessage(nil)
68+
return msg, nil
69+
}
70+
6271
// TODO(someday): if total size is greater than can fit in one buffer,
6372
// attempt to allocate buffer per segment.
6473
if total > maxSize-uint64(len(hdr)) || total > uint64(maxInt) {
@@ -77,7 +86,8 @@ func (d *Decoder) Decode() (*Message, error) {
7786
return nil, exc.WrapError("decode", err)
7887
}
7988

80-
return &Message{Arena: arena}, nil
89+
msg, _, err := NewMessage(arena)
90+
return msg, err
8191
}
8292

8393
func (d *Decoder) readHeader(maxSize uint64) (streamHeader, error) {
@@ -167,7 +177,8 @@ func Unmarshal(data []byte) (*Message, error) {
167177
return nil, exc.WrapError("unmarshal", err)
168178
}
169179

170-
return &Message{Arena: arena}, nil
180+
msg, _, err := NewMessage(arena)
181+
return msg, err
171182
}
172183

173184
// UnmarshalPacked reads a packed serialized stream into a message.

‎codec_test.go‎

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ func TestEncoder(t *testing.T) {
1313
t.Parallel()
1414

1515
for i, test := range serializeTests {
16-
if test.decodeFails {
16+
if test.decodeFails || test.newMessageFails {
1717
continue
1818
}
19-
msg := &Message{Arena: test.arena()}
19+
msg, _, err := NewMessage(test.arena())
20+
require.NoError(t, err)
2021
var buf bytes.Buffer
2122
enc := NewEncoder(&buf)
22-
err := enc.Encode(msg)
23+
err = enc.Encode(msg)
2324
out := buf.Bytes()
2425
if err != nil {
2526
if !test.encodeFails {
@@ -198,26 +199,26 @@ func TestDecoder_MaxMessageSize(t *testing.T) {
198199
func TestStreamHeaderPadding(t *testing.T) {
199200
t.Parallel()
200201

201-
msg := &Message{
202-
Arena: MultiSegment([][]byte{
202+
msg, _, err := NewMessage(
203+
MultiSegment([][]byte{
203204
incrementingData(8),
204205
incrementingData(8),
205206
incrementingData(8),
206-
}),
207-
}
207+
}))
208+
require.NoError(t, err)
208209
var buf bytes.Buffer
209210
enc := NewEncoder(&buf)
210-
err := enc.Encode(msg)
211+
err = enc.Encode(msg)
211212
buf.Reset()
212213
if err != nil {
213214
t.Fatalf("Encode error: %v", err)
214215
}
215-
msg = &Message{
216-
Arena: MultiSegment([][]byte{
216+
msg, _, err = NewMessage(
217+
MultiSegment([][]byte{
217218
incrementingData(8),
218219
incrementingData(8),
219-
}),
220-
}
220+
}))
221+
require.NoError(t, err)
221222
err = enc.Encode(msg)
222223
out := buf.Bytes()
223224
if err != nil {

‎integration_test.go‎

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,13 +1832,11 @@ func BenchmarkDecode(b *testing.B) {
18321832
func TestPointerTraverseDefense(t *testing.T) {
18331833
t.Parallel()
18341834
const limit = 128
1835-
msg := &capnp.Message{
1836-
Arena: capnp.SingleSegment([]byte{
1837-
0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word
1838-
0, 0, 0, 0, 0, 0, 0, 0, // struct's data
1839-
}),
1840-
TraverseLimit: limit * 8,
1841-
}
1835+
msg, _ := capnp.NewSingleSegmentMessage([]byte{
1836+
0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word
1837+
0, 0, 0, 0, 0, 0, 0, 0, // struct's data
1838+
})
1839+
msg.TraverseLimit = limit * 8
18421840

18431841
for i := 0; i < limit; i++ {
18441842
_, err := msg.Root()
@@ -1855,13 +1853,11 @@ func TestPointerTraverseDefense(t *testing.T) {
18551853
func TestPointerDepthDefense(t *testing.T) {
18561854
t.Parallel()
18571855
const limit = 64
1858-
msg := &capnp.Message{
1859-
Arena: capnp.SingleSegment([]byte{
1860-
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
1861-
0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself
1862-
}),
1863-
DepthLimit: limit,
1864-
}
1856+
msg, _ := capnp.NewSingleSegmentMessage([]byte{
1857+
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
1858+
0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself
1859+
})
1860+
msg.DepthLimit = limit
18651861
root, err := msg.Root()
18661862
if err != nil {
18671863
t.Fatal("Root:", err)
@@ -1894,14 +1890,12 @@ func TestPointerDepthDefense(t *testing.T) {
18941890
func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) {
18951891
t.Parallel()
18961892
const limit = 63
1897-
msg := &capnp.Message{
1898-
Arena: capnp.SingleSegment([]byte{
1899-
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
1900-
0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word)
1901-
0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word
1902-
}),
1903-
DepthLimit: limit,
1904-
}
1893+
msg, _ := capnp.NewSingleSegmentMessage([]byte{
1894+
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
1895+
0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word)
1896+
0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word
1897+
})
1898+
msg.DepthLimit = limit
19051899

19061900
toStruct := func(p capnp.Ptr, err error) (capnp.Struct, error) {
19071901
if err != nil {
@@ -2083,11 +2077,10 @@ func TestSetEmptyTextWithDefault(t *testing.T) {
20832077

20842078
func TestFuzzedListOutOfBounds(t *testing.T) {
20852079
t.Parallel()
2086-
msg := &capnp.Message{
2087-
Arena: capnp.SingleSegment([]byte(
2088-
"\x00\x00\x00\x00\x03\x00\x01\x00\x0f\x000000000000" +
2089-
"000000000000\x01\x00\x00\x00\x13\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00")),
2090-
}
2080+
msg, _ := capnp.NewSingleSegmentMessage([]byte(
2081+
"\x00\x00\x00\x00\x03\x00\x01\x00\x0f\x000000000000" +
2082+
"000000000000\x01\x00\x00\x00\x13\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00"))
2083+
20912084
z, err := air.ReadRootZ(msg)
20922085
if err != nil {
20932086
t.Fatal("ReadRootZ:", err)

‎internal/fuzztest/fuzztest.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func Fuzz(data []byte) int {
2121
data = append(data, 0)
2222
}
2323
}
24-
msg := &capnp.Message{Arena: capnp.SingleSegment(data)}
24+
msg, _ := capnp.NewSingleSegmentMessage(data)
2525
z, err := air.ReadRootZ(msg)
2626
if err != nil {
2727
return 0

‎list_test.go‎

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,10 @@ import (
88
)
99

1010
func TestToListDefault(t *testing.T) {
11-
msg := &Message{Arena: SingleSegment([]byte{
11+
_, seg := NewSingleSegmentMessage([]byte{
1212
0, 0, 0, 0, 0, 0, 0, 0,
1313
42, 0, 0, 0, 0, 0, 0, 0,
14-
})}
15-
seg, err := msg.Segment(0)
16-
if err != nil {
17-
t.Fatal(err)
18-
}
14+
})
1915
tests := []struct {
2016
ptr Ptr
2117
def []byte
@@ -56,15 +52,11 @@ func TestToListDefault(t *testing.T) {
5652
}
5753

5854
func TestTextListBytesAt(t *testing.T) {
59-
msg := &Message{Arena: SingleSegment([]byte{
55+
_, seg := NewSingleSegmentMessage([]byte{
6056
0, 0, 0, 0, 0, 0, 0, 0,
6157
0x01, 0, 0, 0, 0x22, 0, 0, 0,
6258
'f', 'o', 'o', 0, 0, 0, 0, 0,
63-
})}
64-
seg, err := msg.Segment(0)
65-
if err != nil {
66-
t.Fatal(err)
67-
}
59+
})
6860
list := TextList{
6961
seg: seg,
7062
off: 8,

‎message_test.go‎

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,11 @@ func TestAlloc(t *testing.T) {
9191
})
9292
}
9393
{
94-
msg := &Message{Arena: MultiSegment([][]byte{
94+
_, seg := NewMultiSegmentMessage([][]byte{
9595
incrementingData(24)[:8:8],
9696
incrementingData(24)[:8],
9797
incrementingData(24)[:8],
98-
})}
99-
seg, err := msg.Segment(1)
100-
if err != nil {
101-
t.Fatal(err)
102-
}
98+
})
10399
tests = append(tests, allocTest{
104100
name: "prefers given segment",
105101
seg: seg,
@@ -109,14 +105,10 @@ func TestAlloc(t *testing.T) {
109105
})
110106
}
111107
{
112-
msg := &Message{Arena: MultiSegment([][]byte{
108+
_, seg := NewMultiSegmentMessage([][]byte{
113109
incrementingData(24)[:8],
114110
incrementingData(24),
115-
})}
116-
seg, err := msg.Segment(1)
117-
if err != nil {
118-
t.Fatal(err)
119-
}
111+
})
120112
tests = append(tests, allocTest{
121113
name: "given segment full with another available",
122114
seg: seg,
@@ -126,14 +118,10 @@ func TestAlloc(t *testing.T) {
126118
})
127119
}
128120
{
129-
msg := &Message{Arena: MultiSegment([][]byte{
121+
msg, seg := NewMultiSegmentMessage([][]byte{
130122
incrementingData(24),
131123
incrementingData(24),
132-
})}
133-
seg, err := msg.Segment(1)
134-
if err != nil {
135-
t.Fatal(err)
136-
}
124+
})
137125

138126
// Make arena not read-only again.
139127
msg.Arena.(*MultiSegmentArena).bp = &bufferpool.Default
@@ -308,7 +296,14 @@ func TestMarshal(t *testing.T) {
308296
if test.decodeFails {
309297
continue
310298
}
311-
msg := &Message{Arena: test.arena()}
299+
msg, _, err := NewMessage(test.arena())
300+
if err != nil != test.newMessageFails {
301+
t.Errorf("serializeTests[%d] %s: NewMessage unexpected error: %v", i, test.name, err)
302+
continue
303+
}
304+
if err != nil {
305+
continue
306+
}
312307
out, err := msg.Marshal()
313308
if err != nil {
314309
if !test.encodeFails {
@@ -373,7 +368,8 @@ func TestWriteTo(t *testing.T) {
373368
continue
374369
}
375370

376-
msg := &Message{Arena: test.arena()}
371+
msg, _, err := NewMessage(test.arena())
372+
require.NoError(t, err)
377373
n, err := msg.WriteTo(&buf)
378374
if test.encodeFails {
379375
require.Error(t, err, test.name)
@@ -566,9 +562,7 @@ func TestTotalSize(t *testing.T) {
566562
}
567563
}
568564

569-
msg := &Message{
570-
Arena: MultiSegment(segs),
571-
}
565+
msg, _ := NewMultiSegmentMessage(segs)
572566

573567
size, err := msg.TotalSize()
574568
assert.Nil(t, err, "TotalSize() returned an error")

0 commit comments

Comments
 (0)