Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
Name: name,
Metadata: m,
}
err := persist(metadataPath, pc, compress, "")
err := persistToFile(metadataPath, pc, compress, "")
if err != nil {
return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// Persist the document
if c.persistDirectory != "" {
docPath := c.getDocPath(doc.ID)
err := persist(docPath, doc, c.compress, "")
err := persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
Expand All @@ -252,7 +252,6 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// - whereDocument: Conditional filtering on documents. Optional.
// - ids: The ids of the documents to delete. If empty, all documents are deleted.
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {

// must have at least one of where, whereDocument or ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
return fmt.Errorf("must have at least one of where, whereDocument or ids")
Expand Down Expand Up @@ -294,15 +293,14 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s
// Remove the document from disk
if c.persistDirectory != "" {
docPath := c.getDocPath(docID)
err := remove(docPath)
err := removeFile(docPath)
if err != nil {
return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
}
}
}

return nil

}

// Count returns the number of documents in the collection.
Expand Down
73 changes: 69 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
Expand Down Expand Up @@ -142,7 +143,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
Name string
Metadata map[string]string
}{}
err := read(fPath, &pc, "")
err := readFromFile(fPath, &pc, "")
if err != nil {
return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
}
Expand All @@ -151,7 +152,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
// Read document
d := &Document{}
err := read(fPath, d, "")
err := readFromFile(fPath, d, "")
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
}
Expand Down Expand Up @@ -223,7 +224,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err = read(filePath, &persistenceDB, encryptionKey)
err = readFromFile(filePath, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
}
Expand Down Expand Up @@ -254,7 +255,22 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
//
// Deprecated: Use [DB.ExportToFile] instead.
func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
return db.ExportToFile(filePath, compress, encryptionKey)
}

// ExportToFile exports the DB to a file at the given path. The file is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
// If the file exists, it's overwritten, otherwise created.
//
// - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc")
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
if compress {
Expand Down Expand Up @@ -295,7 +311,56 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error
}
}

err := persist(filePath, persistenceDB, compress, encryptionKey)
err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't export DB: %w", err)
}

return nil
}

// ExportToWriter exports the DB to a writer. The stream is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
// If the writer has to be closed, it's the caller's responsibility.
//
// - writer: An implementation of [io.Writer]
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// Create persistence structs with exported fields so that they can be encoded
// as gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}

err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't export DB: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func TestDB_ImportExport(t *testing.T) {
}

// Export
err = orig.Export(tc.filePath, tc.compress, tc.encryptionKey)
err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down
56 changes: 36 additions & 20 deletions persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ func hash2hex(name string) string {
return hex.EncodeToString(hash[:4])
}

// persist persists an object to a file at the given path. The object is serialized
// persistToFile persists an object to a file at the given path. The object is serialized
// as gob, optionally compressed with flate (as gzip) and optionally encrypted with
// AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's
// overwritten, otherwise created.
func persist(filePath string, obj any, compress bool, encryptionKey string) error {
func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -66,25 +66,41 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
}
defer f.Close()

return persistToWriter(f, obj, compress, encryptionKey)
}

// persistToWriter persists an object to a writer. The object is serialized
// as gob, optionally compressed with flate (as gzip) and optionally encrypted with
// AES-GCM. The encryption key must be 32 bytes long.
// If the writer has to be closed, it's the caller's responsibility.
func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error {
// AES 256 requires a 32 byte key
if encryptionKey != "" {
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// We want to:
// Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file.
// Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to
// passed writer.
// To reduce memory usage we chain the writers instead of buffering, so we start
// from the end. For AES GCM sealing the stdlib doesn't provide a writer though.

var w io.Writer
var chainedWriter io.Writer
if encryptionKey == "" {
w = f
chainedWriter = w
} else {
w = &bytes.Buffer{}
chainedWriter = &bytes.Buffer{}
}

var gzw *gzip.Writer
var enc *gob.Encoder
if compress {
gzw = gzip.NewWriter(w)
gzw = gzip.NewWriter(chainedWriter)
enc = gob.NewEncoder(gzw)
} else {
enc = gob.NewEncoder(w)
enc = gob.NewEncoder(chainedWriter)
}

// Start encoding, it will write to the chain of writers.
Expand All @@ -93,22 +109,22 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
}

// If compressing, close the gzip writer. Otherwise the gzip footer won't be
// written yet. When using encryption (and w is a buffer) then we'll encrypt
// an incomplete file. Without encryption when we return here and having
// written yet. When using encryption (and chainedWriter is a buffer) then
// we'll encrypt an incomplete stream. Without encryption when we return here and having
// a deferred Close(), there might be a silenced error.
if compress {
err = gzw.Close()
err := gzw.Close()
if err != nil {
return fmt.Errorf("couldn't close gzip writer: %w", err)
}
}

// Without encyrption, the chain is done and the file is written.
// Without encyrption, the chain is done and the writing is finished.
if encryptionKey == "" {
return nil
}

// Otherwise, encrypt and then write to the file
// Otherwise, encrypt and then write to the unchained target writer.
block, err := aes.NewCipher([]byte(encryptionKey))
if err != nil {
return fmt.Errorf("couldn't create new AES cipher: %w", err)
Expand All @@ -121,22 +137,22 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return fmt.Errorf("couldn't read random bytes for nonce: %w", err)
}
// w is a *bytes.Buffer
buf := w.(*bytes.Buffer)
// chainedWriter is a *bytes.Buffer
buf := chainedWriter.(*bytes.Buffer)
encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil)
_, err = f.Write(encrypted)
_, err = w.Write(encrypted)
if err != nil {
return fmt.Errorf("couldn't write encrypted data: %w", err)
}

return nil
}

// read reads an object from a file at the given path. The object is deserialized
// readFromFile reads an object from a file at the given path. The object is deserialized
// from gob. `obj` must be a pointer to an instantiated object. The file may
// optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption
// key must be 32 bytes long.
func read(filePath string, obj any, encryptionKey string) error {
func readFromFile(filePath string, obj any, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -226,8 +242,8 @@ func read(filePath string, obj any, encryptionKey string) error {
return nil
}

// remove removes a file at the given path. If the file doesn't exist, it's a no-op.
func remove(filePath string) error {
// removeFile removes a file at the given path. If the file doesn't exist, it's a no-op.
func removeFile(filePath string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down
12 changes: 6 additions & 6 deletions persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestPersistenceWrite(t *testing.T) {

t.Run("gob", func(t *testing.T) {
tempFilePath := tempDir + ".gob"
persist(tempFilePath, obj, false, "")
persistToFile(tempFilePath, obj, false, "")

// Check if the file exists.
_, err = os.Stat(tempFilePath)
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestPersistenceWrite(t *testing.T) {

t.Run("gob gzipped", func(t *testing.T) {
tempFilePath := tempDir + ".gob.gz"
persist(tempFilePath, obj, true, "")
persistToFile(tempFilePath, obj, true, "")

// Check if the file exists.
_, err = os.Stat(tempFilePath)
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestPersistenceEncryption(t *testing.T) {

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := persist(tc.filePath, obj, tc.compress, encryptionKey)
err := persistToFile(tc.filePath, obj, tc.compress, encryptionKey)
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand All @@ -220,7 +220,7 @@ func TestPersistenceEncryption(t *testing.T) {

// Read the file.
var res s
err = read(tc.filePath, &res, encryptionKey)
err = readFromFile(tc.filePath, &res, encryptionKey)
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down