Skip to content

Commit 041e3c0

Browse files
authored
Merge pull request #72 from philippgille/import-from-reader
Import from io.Reader
2 parents 82f4efe + ff98d0a commit 041e3c0

File tree

3 files changed

+109
-19
lines changed

3 files changed

+109
-19
lines changed

‎db.go‎

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,21 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
186186
//
187187
// - filePath: Mandatory, must not be empty
188188
// - encryptionKey: Optional, must be 32 bytes long if provided
189+
//
190+
// Deprecated: Use [DB.ImportFromFile] instead.
189191
func (db *DB) Import(filePath string, encryptionKey string) error {
192+
return db.ImportFromFile(filePath, encryptionKey)
193+
}
194+
195+
// ImportFromFile imports the DB from a file at the given path. The file must be
196+
// encoded as gob and can optionally be compressed with flate (as gzip) and encrypted
197+
// with AES-GCM.
198+
// This works for both the in-memory and persistent DBs.
199+
// Existing collections are overwritten.
200+
//
201+
// - filePath: Mandatory, must not be empty
202+
// - encryptionKey: Optional, must be 32 bytes long if provided
203+
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
190204
if filePath == "" {
191205
return fmt.Errorf("file path is empty")
192206
}
@@ -246,6 +260,61 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
246260
return nil
247261
}
248262

263+
// ImportFromReader imports the DB from a reader. The stream must be encoded as
264+
// gob and can optionally be compressed with flate (as gzip) and encrypted with
265+
// AES-GCM.
266+
// This works for both the in-memory and persistent DBs.
267+
// Existing collections are overwritten.
268+
// If the writer has to be closed, it's the caller's responsibility.
269+
//
270+
// - reader: An implementation of [io.ReadSeeker]
271+
// - encryptionKey: Optional, must be 32 bytes long if provided
272+
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
273+
if encryptionKey != "" {
274+
// AES 256 requires a 32 byte key
275+
if len(encryptionKey) != 32 {
276+
return errors.New("encryption key must be 32 bytes long")
277+
}
278+
}
279+
280+
// Create persistence structs with exported fields so that they can be decoded
281+
// from gob.
282+
type persistenceCollection struct {
283+
Name string
284+
Metadata map[string]string
285+
Documents map[string]*Document
286+
}
287+
persistenceDB := struct {
288+
Collections map[string]*persistenceCollection
289+
}{
290+
Collections: make(map[string]*persistenceCollection, len(db.collections)),
291+
}
292+
293+
db.collectionsLock.Lock()
294+
defer db.collectionsLock.Unlock()
295+
296+
err := readFromReader(reader, &persistenceDB, encryptionKey)
297+
if err != nil {
298+
return fmt.Errorf("couldn't read stream: %w", err)
299+
}
300+
301+
for _, pc := range persistenceDB.Collections {
302+
c := &Collection{
303+
Name: pc.Name,
304+
305+
metadata: pc.Metadata,
306+
documents: pc.Documents,
307+
}
308+
if db.persistDirectory != "" {
309+
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
310+
c.compress = db.compress
311+
}
312+
db.collections[c.Name] = c
313+
}
314+
315+
return nil
316+
}
317+
249318
// Export exports the DB to a file at the given path. The file is encoded as gob,
250319
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
251320
// This works for both the in-memory and persistent DBs.

‎db_test.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestDB_ImportExport(t *testing.T) {
147147
new := NewDB()
148148

149149
// Import
150-
err = new.Import(tc.filePath, tc.encryptionKey)
150+
err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
151151
if err != nil {
152152
t.Fatal("expected no error, got", err)
153153
}

‎persistence.go‎

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,43 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
163163
}
164164
}
165165

166+
r, err := os.Open(filePath)
167+
if err != nil {
168+
return fmt.Errorf("couldn't open file: %w", err)
169+
}
170+
defer r.Close()
171+
172+
return readFromReader(r, obj, encryptionKey)
173+
}
174+
175+
// readFromReader reads an object from a Reader. The object is deserialized from gob.
176+
// `obj` must be a pointer to an instantiated object. The stream may optionally
177+
// be compressed as gzip and/or encrypted with AES-GCM. The encryption key must
178+
// be 32 bytes long.
179+
// If the reader has to be closed, it's the caller's responsibility.
180+
func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
181+
// AES 256 requires a 32 byte key
182+
if encryptionKey != "" {
183+
if len(encryptionKey) != 32 {
184+
return errors.New("encryption key must be 32 bytes long")
185+
}
186+
}
187+
166188
// We want to:
167-
// Read file -> decrypt with AES-GCM -> decompress with flate -> decode as gob
189+
// Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode
190+
// as gob.
168191
// To reduce memory usage we chain the readers instead of buffering, so we start
169192
// from the end. For the decryption there's no reader though.
170193

171-
var r io.Reader
194+
// For the chainedReader we don't declare it as ReadSeeker so we can reassign
195+
// the gzip reader to it.
196+
var chainedReader io.Reader
172197

173198
// Decrypt if an encryption key is provided
174199
if encryptionKey != "" {
175-
encrypted, err := os.ReadFile(filePath)
200+
encrypted, err := io.ReadAll(r)
176201
if err != nil {
177-
return fmt.Errorf("couldn't read file: %w", err)
202+
return fmt.Errorf("couldn't read from reader: %w", err)
178203
}
179204
block, err := aes.NewCipher([]byte(encryptionKey))
180205
if err != nil {
@@ -194,28 +219,24 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
194219
return fmt.Errorf("couldn't decrypt data: %w", err)
195220
}
196221

197-
r = bytes.NewReader(data)
222+
chainedReader = bytes.NewReader(data)
198223
} else {
199-
var err error
200-
r, err = os.Open(filePath)
201-
if err != nil {
202-
return fmt.Errorf("couldn't open file: %w", err)
203-
}
224+
chainedReader = r
204225
}
205226

206-
// Determine if the file is compressed
227+
// Determine if the stream is compressed
207228
magicNumber := make([]byte, 2)
208-
_, err := r.Read(magicNumber)
229+
_, err := chainedReader.Read(magicNumber)
209230
if err != nil {
210-
return fmt.Errorf("couldn't read magic number to determine whether the file is compressed: %w", err)
231+
return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
211232
}
212233
var compressed bool
213234
if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
214235
compressed = true
215236
}
216237

217-
// Reset reader. Both file and bytes.Reader support seeking.
218-
if s, ok := r.(io.Seeker); !ok {
238+
// Reset reader. Both the reader from the param and bytes.Reader support seeking.
239+
if s, ok := chainedReader.(io.Seeker); !ok {
219240
return fmt.Errorf("reader doesn't support seeking")
220241
} else {
221242
_, err := s.Seek(0, 0)
@@ -225,15 +246,15 @@ func readFromFile(filePath string, obj any, encryptionKey string) error {
225246
}
226247

227248
if compressed {
228-
gzr, err := gzip.NewReader(r)
249+
gzr, err := gzip.NewReader(chainedReader)
229250
if err != nil {
230251
return fmt.Errorf("couldn't create gzip reader: %w", err)
231252
}
232253
defer gzr.Close()
233-
r = gzr
254+
chainedReader = gzr
234255
}
235256

236-
dec := gob.NewDecoder(r)
257+
dec := gob.NewDecoder(chainedReader)
237258
err = dec.Decode(obj)
238259
if err != nil {
239260
return fmt.Errorf("couldn't decode object: %w", err)

0 commit comments

Comments
 (0)