Skip to content

Commit e50e0f2

Browse files
authored
feat: add TransactionalEnforcer (#1541)
1 parent 33e2e50 commit e50e0f2

6 files changed

Lines changed: 1332 additions & 0 deletions

File tree

‎enforcer_transactional.go‎

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright 2025 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package casbin
16+
17+
import (
18+
"context"
19+
"errors"
20+
"sync"
21+
22+
"github.com/casbin/casbin/v2/persist"
23+
)
24+
25+
// TransactionalEnforcer extends Enforcer with transaction support.
26+
// It provides atomic policy operations through transactions.
27+
type TransactionalEnforcer struct {
28+
*Enforcer // Embedded enforcer for all standard functionality
29+
currentTx *Transaction // Current active transaction (nil if none)
30+
txMutex sync.RWMutex // Protects transaction state
31+
}
32+
33+
// NewTransactionalEnforcer creates a new TransactionalEnforcer.
34+
// It accepts the same parameters as NewEnforcer.
35+
func NewTransactionalEnforcer(params ...interface{}) (*TransactionalEnforcer, error) {
36+
enforcer, err := NewEnforcer(params...)
37+
if err != nil {
38+
return nil, err
39+
}
40+
41+
return &TransactionalEnforcer{
42+
Enforcer: enforcer,
43+
}, nil
44+
}
45+
46+
// BeginTransaction starts a new transaction.
47+
// Returns an error if a transaction is already in progress or if the adapter doesn't support transactions.
48+
func (te *TransactionalEnforcer) BeginTransaction(ctx context.Context) (*Transaction, error) {
49+
te.txMutex.Lock()
50+
defer te.txMutex.Unlock()
51+
52+
if te.currentTx != nil {
53+
return nil, errors.New("transaction already in progress")
54+
}
55+
56+
// Check if adapter supports transactions.
57+
txAdapter, ok := te.adapter.(persist.TransactionalAdapter)
58+
if !ok {
59+
return nil, errors.New("adapter does not support transactions")
60+
}
61+
62+
// Start database transaction.
63+
txContext, err := txAdapter.BeginTransaction(ctx)
64+
if err != nil {
65+
return nil, err
66+
}
67+
68+
// Create transaction buffer with current model snapshot.
69+
buffer := NewTransactionBuffer(te.model)
70+
71+
tx := &Transaction{
72+
enforcer: te,
73+
buffer: buffer,
74+
txContext: txContext,
75+
ctx: ctx,
76+
}
77+
78+
te.currentTx = tx
79+
return tx, nil
80+
}
81+
82+
// GetCurrentTransaction returns the current active transaction, or nil if none.
83+
func (te *TransactionalEnforcer) GetCurrentTransaction() *Transaction {
84+
te.txMutex.RLock()
85+
defer te.txMutex.RUnlock()
86+
return te.currentTx
87+
}
88+
89+
// IsInTransaction returns true if there is an active transaction.
90+
func (te *TransactionalEnforcer) IsInTransaction() bool {
91+
te.txMutex.RLock()
92+
defer te.txMutex.RUnlock()
93+
return te.currentTx != nil
94+
}
95+
96+
// clearTransaction clears the current transaction (called internally).
97+
func (te *TransactionalEnforcer) clearTransaction() {
98+
te.txMutex.Lock()
99+
defer te.txMutex.Unlock()
100+
te.currentTx = nil
101+
}
102+
103+
// WithTransaction executes a function within a transaction.
104+
// If the function returns an error, the transaction is rolled back.
105+
// Otherwise, it's committed automatically.
106+
func (te *TransactionalEnforcer) WithTransaction(ctx context.Context, fn func(*Transaction) error) error {
107+
tx, err := te.BeginTransaction(ctx)
108+
if err != nil {
109+
return err
110+
}
111+
112+
defer func() {
113+
if r := recover(); r != nil {
114+
_ = tx.Rollback()
115+
panic(r)
116+
}
117+
}()
118+
119+
err = fn(tx)
120+
if err != nil {
121+
_ = tx.Rollback()
122+
return err
123+
}
124+
125+
return tx.Commit()
126+
}

‎persist/transaction.go‎

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2025 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package persist
16+
17+
import "context"
18+
19+
// TransactionalAdapter defines the interface for adapters that support transactions.
20+
// Adapters implementing this interface can participate in Casbin transactions.
21+
type TransactionalAdapter interface {
22+
Adapter
23+
// BeginTransaction starts a new transaction and returns a transaction context.
24+
BeginTransaction(ctx context.Context) (TransactionContext, error)
25+
}
26+
27+
// TransactionContext represents a database transaction context.
28+
// It provides methods to commit or rollback the transaction and get an adapter
29+
// that operates within this transaction.
30+
type TransactionContext interface {
31+
// Commit commits the transaction.
32+
Commit() error
33+
// Rollback rolls back the transaction.
34+
Rollback() error
35+
// GetAdapter returns an adapter that operates within this transaction.
36+
GetAdapter() Adapter
37+
}
38+
39+
// PolicyOperation represents a policy operation that can be buffered in a transaction.
40+
type PolicyOperation struct {
41+
Type OperationType // The type of operation (add, remove, update)
42+
Section string // The section of the policy (p, g)
43+
PolicyType string // The policy type (p, p2, g, g2, etc.)
44+
Rules [][]string // The policy rules to operate on
45+
OldRules [][]string // For update operations, the old rules to replace
46+
}
47+
48+
// OperationType represents the type of policy operation.
49+
type OperationType int
50+
51+
const (
52+
// OperationAdd represents adding policy rules.
53+
OperationAdd OperationType = iota
54+
// OperationRemove represents removing policy rules.
55+
OperationRemove
56+
// OperationUpdate represents updating policy rules.
57+
OperationUpdate
58+
)

0 commit comments

Comments
 (0)