Skip to content

Commit 4683792

Browse files
Add MySQL support for delayed message publishing (#51)
Co-authored-by: adrian.zajkowski <adrian.zajkowski@nordsec.com>
1 parent a57c448 commit 4683792

File tree

5 files changed

+636
-0
lines changed

5 files changed

+636
-0
lines changed

‎pkg/sql/delayed_mysql.go‎

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package sql
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/ThreeDotsLabs/watermill"
9+
"github.com/ThreeDotsLabs/watermill/components/delay"
10+
"github.com/ThreeDotsLabs/watermill/message"
11+
)
12+
13+
type DelayedMySQLPublisherConfig struct {
14+
// DelayPublisherConfig is a configuration for the delay.Publisher.
15+
DelayPublisherConfig delay.PublisherConfig
16+
17+
// OverridePublisherConfig allows overriding the default PublisherConfig.
18+
OverridePublisherConfig func(config *PublisherConfig) error
19+
20+
Logger watermill.LoggerAdapter
21+
}
22+
23+
func (c *DelayedMySQLPublisherConfig) setDefaults() {
24+
if c.Logger == nil {
25+
c.Logger = watermill.NopLogger{}
26+
}
27+
}
28+
29+
// NewDelayedMySQLPublisher creates a new Publisher that stores messages in MySQL with a delay.
30+
// The delay can be set per message with the Watermill's components/delay metadata.
31+
func NewDelayedMySQLPublisher(db ContextExecutor, config DelayedMySQLPublisherConfig) (message.Publisher, error) {
32+
config.setDefaults()
33+
34+
publisherConfig := PublisherConfig{
35+
SchemaAdapter: delayedMySQLSchemaAdapter{
36+
MySQLQueueSchema: MySQLQueueSchema{},
37+
},
38+
AutoInitializeSchema: true,
39+
}
40+
41+
if config.OverridePublisherConfig != nil {
42+
err := config.OverridePublisherConfig(&publisherConfig)
43+
if err != nil {
44+
return nil, err
45+
}
46+
}
47+
48+
var publisher message.Publisher
49+
var err error
50+
51+
publisher, err = NewPublisher(db, publisherConfig, config.Logger)
52+
if err != nil {
53+
return nil, err
54+
}
55+
56+
publisher, err = delay.NewPublisher(publisher, config.DelayPublisherConfig)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
return publisher, nil
62+
}
63+
64+
type DelayedMySQLSubscriberConfig struct {
65+
// OverrideSubscriberConfig allows overriding the default SubscriberConfig.
66+
OverrideSubscriberConfig func(config *SubscriberConfig) error
67+
68+
// DeleteOnAck deletes the message from the queue when it's acknowledged.
69+
DeleteOnAck bool
70+
71+
// AllowNoDelay allows receiving messages without the delay metadata.
72+
// By default, such messages will be skipped.
73+
// If set to true, messages without delay metadata will be received immediately.
74+
AllowNoDelay bool
75+
76+
Logger watermill.LoggerAdapter
77+
}
78+
79+
func (c *DelayedMySQLSubscriberConfig) setDefaults() {
80+
if c.Logger == nil {
81+
c.Logger = watermill.NopLogger{}
82+
}
83+
}
84+
85+
// NewDelayedMySQLSubscriber creates a new Subscriber that reads messages from MySQL with a delay.
86+
// The delay can be set per message with the Watermill's components/delay metadata.
87+
func NewDelayedMySQLSubscriber(db Beginner, config DelayedMySQLSubscriberConfig) (message.Subscriber, error) {
88+
config.setDefaults()
89+
90+
where := "delayed_until <= NOW()"
91+
92+
if config.AllowNoDelay {
93+
where += " OR delayed_until IS NULL"
94+
}
95+
96+
schemaAdapter := delayedMySQLSchemaAdapter{
97+
MySQLQueueSchema: MySQLQueueSchema{
98+
GenerateWhereClause: func(params GenerateWhereClauseParams) (string, []any) {
99+
return where, nil
100+
},
101+
},
102+
}
103+
104+
subscriberConfig := SubscriberConfig{
105+
SchemaAdapter: schemaAdapter,
106+
OffsetsAdapter: MySQLQueueOffsetsAdapter{
107+
DeleteOnAck: config.DeleteOnAck,
108+
},
109+
InitializeSchema: true,
110+
}
111+
112+
if config.OverrideSubscriberConfig != nil {
113+
err := config.OverrideSubscriberConfig(&subscriberConfig)
114+
if err != nil {
115+
return nil, err
116+
}
117+
}
118+
119+
sub, err := NewSubscriber(db, subscriberConfig, config.Logger)
120+
if err != nil {
121+
return nil, err
122+
}
123+
124+
return sub, nil
125+
}
126+
127+
type delayedMySQLSchemaAdapter struct {
128+
MySQLQueueSchema
129+
}
130+
131+
func (a delayedMySQLSchemaAdapter) SchemaInitializingQueries(params SchemaInitializingQueriesParams) ([]Query, error) {
132+
createMessagesTable := `
133+
CREATE TABLE IF NOT EXISTS ` + a.MessagesTable(params.Topic) + ` (
134+
` + "`offset`" + ` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
135+
` + "`uuid`" + ` VARCHAR(36) NOT NULL,
136+
` + "`payload`" + ` ` + a.payloadColumnType(params.Topic) + ` DEFAULT NULL,
137+
` + "`metadata`" + ` JSON DEFAULT NULL,
138+
` + "`acked`" + ` BOOLEAN NOT NULL DEFAULT FALSE,
139+
` + "`created_at`" + ` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
140+
` + "`delayed_until`" + ` TIMESTAMP NULL DEFAULT NULL,
141+
INDEX ` + "`delayed_until_idx`" + ` (` + "`delayed_until`" + `)
142+
);
143+
`
144+
145+
return []Query{{Query: createMessagesTable}}, nil
146+
}
147+
148+
func (a delayedMySQLSchemaAdapter) InsertQuery(params InsertQueryParams) (Query, error) {
149+
insertQuery := fmt.Sprintf(
150+
`INSERT INTO %s (uuid, payload, metadata, delayed_until) VALUES %s`,
151+
a.MessagesTable(params.Topic),
152+
delayedMySQLInsertMarkers(len(params.Msgs)),
153+
)
154+
155+
args, err := delayedMySQLInsertArgs(params.Msgs)
156+
if err != nil {
157+
return Query{}, err
158+
}
159+
160+
return Query{insertQuery, args}, nil
161+
}
162+
163+
func delayedMySQLInsertMarkers(count int) string {
164+
result := strings.Builder{}
165+
166+
for range count {
167+
result.WriteString("(?,?,?,?),")
168+
}
169+
170+
return strings.TrimRight(result.String(), ",")
171+
}
172+
173+
func delayedMySQLInsertArgs(msgs message.Messages) ([]any, error) {
174+
var args []any
175+
176+
for _, msg := range msgs {
177+
metadata, err := json.Marshal(msg.Metadata)
178+
if err != nil {
179+
return nil, fmt.Errorf("could not marshal metadata into JSON for message %s: %w", msg.UUID, err)
180+
}
181+
182+
args = append(args, msg.UUID, msg.Payload, metadata)
183+
184+
// Extract delayed_until from metadata
185+
delayedUntilStr := msg.Metadata.Get(delay.DelayedUntilKey)
186+
if delayedUntilStr == "" {
187+
args = append(args, nil)
188+
} else {
189+
// Convert ISO 8601 to MySQL TIMESTAMP format: "2025-10-22T09:58:00Z" -> "2025-10-22 09:58:00"
190+
delayedUntilStr = strings.Replace(delayedUntilStr, "T", " ", 1)
191+
delayedUntilStr = strings.TrimSuffix(delayedUntilStr, "Z")
192+
193+
args = append(args, delayedUntilStr)
194+
}
195+
}
196+
197+
return args, nil
198+
}

‎pkg/sql/delayed_mysql_test.go‎

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package sql_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/ThreeDotsLabs/watermill"
12+
"github.com/ThreeDotsLabs/watermill-sql/v4/pkg/sql"
13+
"github.com/ThreeDotsLabs/watermill/components/delay"
14+
"github.com/ThreeDotsLabs/watermill/message"
15+
)
16+
17+
func TestDelayedMySQL(t *testing.T) {
18+
t.Parallel()
19+
20+
db := newMySQL(t)
21+
22+
pub, err := sql.NewDelayedMySQLPublisher(db, sql.DelayedMySQLPublisherConfig{
23+
DelayPublisherConfig: delay.PublisherConfig{
24+
DefaultDelayGenerator: func(params delay.DefaultDelayGeneratorParams) (delay.Delay, error) {
25+
return delay.For(time.Second), nil
26+
},
27+
},
28+
Logger: logger,
29+
})
30+
require.NoError(t, err)
31+
32+
sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{
33+
DeleteOnAck: true,
34+
Logger: logger,
35+
})
36+
require.NoError(t, err)
37+
38+
topic := watermill.NewUUID()
39+
40+
messages, err := sub.Subscribe(context.Background(), topic)
41+
require.NoError(t, err)
42+
43+
msg := message.NewMessage(watermill.NewUUID(), []byte("{}"))
44+
45+
err = pub.Publish(topic, msg)
46+
require.NoError(t, err)
47+
48+
select {
49+
case <-messages:
50+
t.Errorf("message should not be received")
51+
case <-time.After(time.Millisecond * 200):
52+
}
53+
54+
assert.EventuallyWithT(t, func(t *assert.CollectT) {
55+
select {
56+
case received := <-messages:
57+
assert.Equal(t, msg.UUID, received.UUID)
58+
received.Ack()
59+
default:
60+
t.Errorf("message should be received")
61+
}
62+
}, time.Second, time.Millisecond*10)
63+
}
64+
65+
func TestDelayedMySQL_NoDelay(t *testing.T) {
66+
t.Parallel()
67+
68+
db := newMySQL(t)
69+
70+
pub, err := sql.NewDelayedMySQLPublisher(db, sql.DelayedMySQLPublisherConfig{
71+
DelayPublisherConfig: delay.PublisherConfig{
72+
AllowNoDelay: true,
73+
},
74+
Logger: logger,
75+
})
76+
require.NoError(t, err)
77+
78+
t.Run("skip_empty", func(t *testing.T) {
79+
t.Parallel()
80+
81+
sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{
82+
DeleteOnAck: true,
83+
Logger: logger,
84+
})
85+
require.NoError(t, err)
86+
87+
topic := watermill.NewUUID()
88+
89+
messages, err := sub.Subscribe(context.Background(), topic)
90+
require.NoError(t, err)
91+
92+
msg := message.NewMessage(watermill.NewUUID(), []byte("{}"))
93+
94+
err = pub.Publish(topic, msg)
95+
require.NoError(t, err)
96+
97+
select {
98+
case <-messages:
99+
t.Errorf("message should not be received")
100+
case <-time.After(time.Second * 2):
101+
}
102+
})
103+
104+
t.Run("allow_empty", func(t *testing.T) {
105+
t.Parallel()
106+
107+
sub, err := sql.NewDelayedMySQLSubscriber(db, sql.DelayedMySQLSubscriberConfig{
108+
DeleteOnAck: true,
109+
AllowNoDelay: true,
110+
Logger: logger,
111+
})
112+
require.NoError(t, err)
113+
114+
topic := watermill.NewUUID()
115+
116+
messages, err := sub.Subscribe(context.Background(), topic)
117+
require.NoError(t, err)
118+
119+
msg := message.NewMessage(watermill.NewUUID(), []byte("{}"))
120+
121+
err = pub.Publish(topic, msg)
122+
require.NoError(t, err)
123+
124+
assert.EventuallyWithT(t, func(t *assert.CollectT) {
125+
select {
126+
case received := <-messages:
127+
assert.Equal(t, msg.UUID, received.UUID)
128+
received.Ack()
129+
default:
130+
t.Errorf("message should be received")
131+
}
132+
}, time.Second*2, time.Millisecond*10)
133+
})
134+
}

0 commit comments

Comments
 (0)