Skip to content
Prev Previous commit
Next Next commit
Refactor Query() concurrent error handling
To match the refactoring in Collection.AddDocuments
from the previous commit
  • Loading branch information
philippgille committed Mar 3, 2024
commit a4fd27914905f5b05a0043f306b12c310957ba0d
27 changes: 15 additions & 12 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chromem

import (
"context"
"fmt"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -109,14 +110,23 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
concurrency = numDocs
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

docChan := make(chan *Document, concurrency*2)
var globalErr error
globalErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setGlobalErr := func(err error) {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
}

wg := sync.WaitGroup{}
docChan := make(chan *Document, concurrency*2)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
Expand All @@ -129,14 +139,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu

sim, err := cosineSimilarity(queryVectors, doc.Embedding)
if err != nil {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
setGlobalErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
return
}

Expand Down