Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tidy & write test
  • Loading branch information
benclive committed Nov 29, 2024
commit 763a86181dfc68cfa71a9817439324ceeeab459e
67 changes: 27 additions & 40 deletions pkg/querier/ingester_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func NewPartitionContext(ctx context.Context) context.Context {
return context.WithValue(ctx, partitionCtxKey, &PartitionContext{})
}

func FromPartitionContext(ctx context.Context) *PartitionContext {
func ExtractPartitionContext(ctx context.Context) *PartitionContext {
v, ok := ctx.Value(partitionCtxKey).(*PartitionContext)
if !ok {
return &PartitionContext{}
Expand All @@ -147,10 +147,9 @@ func FromPartitionContext(ctx context.Context) *PartitionContext {
}

// forAllIngesters runs f, in parallel, for all ingesters
// waitForAllResponses param can be used to require results from all ingesters in the replication set. If this is set to false, the call will return as soon as we have a quorum by zone. Only valid for partition-ingesters.
func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllResponses bool, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) {
func (q *IngesterQuerier) forAllIngesters(ctx context.Context, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) {
if q.querierConfig.QueryPartitionIngesters {
FromPartitionContext(ctx).SetIsPartitioned(true)
ExtractPartitionContext(ctx).SetIsPartitioned(true)
tenantID, err := user.ExtractOrgID(ctx)
if err != nil {
return nil, err
Expand All @@ -164,7 +163,7 @@ func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllRespons
if err != nil {
return nil, err
}
return q.forGivenIngesterSets(ctx, waitForAllResponses, replicationSets, f)
return q.forGivenIngesterSets(ctx, replicationSets, f)
}

replicationSet, err := q.ring.GetReplicationSetForOperation(ring.Read)
Expand All @@ -176,19 +175,13 @@ func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllRespons
}

// forGivenIngesterSets runs f, in parallel, for given ingester sets
// waitForAllResponses param can be used to require results from all ingesters in all replication sets. If this is set to false, the call will return as soon as we have a quorum by zone.
func (q *IngesterQuerier) forGivenIngesterSets(ctx context.Context, waitForAllResponses bool, replicationSet []ring.ReplicationSet, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) {
func (q *IngesterQuerier) forGivenIngesterSets(ctx context.Context, replicationSet []ring.ReplicationSet, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) {
// Enable minimize requests if we can, so we initially query a single ingester per replication set, as each replication-set is one partition.
// Ingesters must supply zone information for this to have an effect.
config := ring.DoUntilQuorumConfig{
MinimizeRequests: !waitForAllResponses,
MinimizeRequests: true,
}
return concurrency.ForEachJobMergeResults[ring.ReplicationSet, responseFromIngesters](ctx, replicationSet, 0, func(ctx context.Context, set ring.ReplicationSet) ([]responseFromIngesters, error) {
if waitForAllResponses {
// Tell the ring we need to return all responses from all zones
set.MaxErrors = 0
set.MaxUnavailableZones = 0
}
return q.forGivenIngesters(ctx, set, config, f)
})
}
Expand All @@ -200,7 +193,7 @@ func (q *IngesterQuerier) forGivenIngesters(ctx context.Context, replicationSet
if err != nil {
return responseFromIngesters{addr: ingester.Addr}, err
}
FromPartitionContext(ctx).AddClient(client.(logproto.QuerierClient), ingester.Addr)
ExtractPartitionContext(ctx).AddClient(client.(logproto.QuerierClient), ingester.Addr)
resp, err := f(ctx, client.(logproto.QuerierClient))
if err != nil {
return responseFromIngesters{addr: ingester.Addr}, err
Expand All @@ -221,7 +214,7 @@ func (q *IngesterQuerier) forGivenIngesters(ctx context.Context, replicationSet
}

func (q *IngesterQuerier) SelectLogs(ctx context.Context, params logql.SelectLogParams) ([]iter.EntryIterator, error) {
resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
stats.FromContext(ctx).AddIngesterReached(1)
return client.Query(ctx, params.QueryRequest)
})
Expand All @@ -237,7 +230,7 @@ func (q *IngesterQuerier) SelectLogs(ctx context.Context, params logql.SelectLog
}

func (q *IngesterQuerier) SelectSample(ctx context.Context, params logql.SelectSampleParams) ([]iter.SampleIterator, error) {
resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
stats.FromContext(ctx).AddIngesterReached(1)
return client.QuerySample(ctx, params.SampleQueryRequest)
})
Expand All @@ -253,7 +246,7 @@ func (q *IngesterQuerier) SelectSample(ctx context.Context, params logql.SelectS
}

func (q *IngesterQuerier) Label(ctx context.Context, req *logproto.LabelRequest) ([][]string, error) {
resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
return client.Label(ctx, req)
})
if err != nil {
Expand All @@ -269,7 +262,7 @@ func (q *IngesterQuerier) Label(ctx context.Context, req *logproto.LabelRequest)
}

func (q *IngesterQuerier) Tail(ctx context.Context, req *logproto.TailRequest) (map[string]logproto.Querier_TailClient, error) {
resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) {
return client.Tail(ctx, req)
})
if err != nil {
Expand Down Expand Up @@ -334,7 +327,7 @@ func (q *IngesterQuerier) TailDisconnectedIngesters(ctx context.Context, req *lo
}

func (q *IngesterQuerier) Series(ctx context.Context, req *logproto.SeriesRequest) ([][]logproto.SeriesIdentifier, error) {
resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
return client.Series(ctx, req)
})
if err != nil {
Expand Down Expand Up @@ -389,28 +382,22 @@ func (q *IngesterQuerier) TailersCount(ctx context.Context) ([]uint32, error) {
}

func (q *IngesterQuerier) GetChunkIDs(ctx context.Context, from, through model.Time, matchers ...*labels.Matcher) ([]string, error) {
partitionCtx := FromPartitionContext(ctx)
var resps []responseFromIngesters
var err error
ingesterQueryFn := q.forAllIngesters

partitionCtx := ExtractPartitionContext(ctx)
if partitionCtx.IsPartitioned() {
// We need to query the same ingesters as the previous query
resps, err = partitionCtx.forQueriedIngesters(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
return querierClient.GetChunkIDs(ctx, &logproto.GetChunkIDsRequest{
Matchers: convertMatchersToString(matchers),
Start: from.Time(),
End: through.Time(),
})
})
} else {
resps, err = q.forAllIngesters(ctx, false, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
return querierClient.GetChunkIDs(ctx, &logproto.GetChunkIDsRequest{
Matchers: convertMatchersToString(matchers),
Start: from.Time(),
End: through.Time(),
})
})
ingesterQueryFn = partitionCtx.forQueriedIngesters
}

resps, err := ingesterQueryFn(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
return querierClient.GetChunkIDs(ctx, &logproto.GetChunkIDsRequest{
Matchers: convertMatchersToString(matchers),
Start: from.Time(),
End: through.Time(),
})
})

if err != nil {
return nil, err
}
Expand All @@ -424,7 +411,7 @@ func (q *IngesterQuerier) GetChunkIDs(ctx context.Context, from, through model.T
}

func (q *IngesterQuerier) Stats(ctx context.Context, _ string, from, through model.Time, matchers ...*labels.Matcher) (*index_stats.Stats, error) {
resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
return querierClient.GetStats(ctx, &logproto.IndexStatsRequest{
From: from,
Through: through,
Expand Down Expand Up @@ -454,7 +441,7 @@ func (q *IngesterQuerier) Volume(ctx context.Context, _ string, from, through mo
matcherString = syntax.MatchersString(matchers)
}

resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
resps, err := q.forAllIngesters(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) {
return querierClient.GetVolume(ctx, &logproto.VolumeRequest{
From: from,
Through: through,
Expand Down Expand Up @@ -482,7 +469,7 @@ func (q *IngesterQuerier) Volume(ctx context.Context, _ string, from, through mo
}

func (q *IngesterQuerier) DetectedLabel(ctx context.Context, req *logproto.DetectedLabelsRequest) (*logproto.LabelToValuesResponse, error) {
ingesterResponses, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
ingesterResponses, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) {
return client.GetDetectedLabels(ctx, req)
})
if err != nil {
Expand Down
153 changes: 140 additions & 13 deletions pkg/querier/ingester_querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/go-kit/log"
"github.com/grafana/dskit/ring/client"
"github.com/grafana/dskit/user"
"go.uber.org/atomic"

Expand Down Expand Up @@ -241,11 +242,10 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) {
}

tests := map[string]struct {
method string
testFn func(*IngesterQuerier) error
retVal interface{}
shards int
expectAllResponses bool
method string
testFn func(*IngesterQuerier) error
retVal interface{}
shards int
}{
"label": {
method: "Label",
Expand All @@ -269,8 +269,7 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) {
_, err := ingesterQuerier.GetChunkIDs(ctx, model.Time(0), model.Time(0))
return err
},
retVal: new(logproto.GetChunkIDsResponse),
expectAllResponses: true,
retVal: new(logproto.GetChunkIDsResponse),
},
"select_logs": {
method: "Query",
Expand Down Expand Up @@ -330,7 +329,7 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) {
ingestersPerPartition := len(ingesters) / partitions
assert.Greaterf(t, ingestersPerPartition, 1, "must have more than one ingester per partition")

ingesterQuerier, err := newTestPartitionIngesterQuerier(ingesterClient, instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards)
ingesterQuerier, err := newTestPartitionIngesterQuerier(newIngesterClientMockFactory(ingesterClient), instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards)
require.NoError(t, err)

ingesterQuerier.querierConfig.QueryPartitionIngesters = true
Expand All @@ -342,9 +341,6 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) {
testData.shards = partitions
}
expectedCalls := min(testData.shards, partitions)
if testData.expectAllResponses {
expectedCalls = expectedCalls * ingestersPerPartition
}
// Wait for responses: We expect one request per queried partition because we have request minimization enabled & ingesters are in multiple zones.
// If shuffle sharding is enabled, we expect one query per shard as we write to a subset of partitions.
require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected all ingesters to respond")
Expand All @@ -353,6 +349,137 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) {
}
}

func TestIngesterQuerier_QueriesSameIngestersWithPartitionContext(t *testing.T) {
t.Parallel()
userCtx := user.InjectOrgID(context.Background(), "test-user")
testCtx, cancel := context.WithTimeout(userCtx, time.Second*10)
defer cancel()

ingesters := []ring.InstanceDesc{
mockInstanceDescWithZone("1.1.1.1", ring.ACTIVE, "A"),
mockInstanceDescWithZone("2.2.2.2", ring.ACTIVE, "B"),
mockInstanceDescWithZone("3.3.3.3", ring.ACTIVE, "A"),
mockInstanceDescWithZone("4.4.4.4", ring.ACTIVE, "B"),
mockInstanceDescWithZone("5.5.5.5", ring.ACTIVE, "A"),
mockInstanceDescWithZone("6.6.6.6", ring.ACTIVE, "B"),
}

tests := map[string]struct {
method string
testFn func(context.Context, *IngesterQuerier) error
retVal interface{}
shards int
}{
"select_logs": {
method: "Query",
testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error {
_, err := ingesterQuerier.SelectLogs(ctx, logql.SelectLogParams{
QueryRequest: new(logproto.QueryRequest),
})
return err
},
retVal: newQueryClientMock(),
},
"select_sample": {
method: "QuerySample",
testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error {
_, err := ingesterQuerier.SelectSample(ctx, logql.SelectSampleParams{
SampleQueryRequest: new(logproto.SampleQueryRequest),
})
return err
},
retVal: newQuerySampleClientMock(),
},
"select_logs_shuffle_sharded": {
method: "Query",
testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error {
_, err := ingesterQuerier.SelectLogs(ctx, logql.SelectLogParams{
QueryRequest: new(logproto.QueryRequest),
})
return err
},
retVal: newQueryClientMock(),
shards: 2, // Must be less than number of partitions
},
}

for testName, testData := range tests {
cnt := atomic.NewInt32(0)
ctx := NewPartitionContext(testCtx)

t.Run(testName, func(t *testing.T) {
cnt.Store(0)
runFn := func(args mock.Arguments) {
ctx := args[0].(context.Context)

select {
case <-ctx.Done():
// should not be cancelled by the tracker
require.NoErrorf(t, ctx.Err(), "tracker should not cancel ctx: %v", context.Cause(ctx))
default:
cnt.Add(1)
}
}

instanceRing := newReadRingMock(ingesters, 0)
ingesterClient := newQuerierClientMock()
ingesterClient.On(testData.method, mock.Anything, mock.Anything, mock.Anything).Return(testData.retVal, nil).Run(runFn)
ingesterClient.On("GetChunkIDs", mock.Anything, mock.Anything, mock.Anything).Return(new(logproto.GetChunkIDsResponse), nil).Run(runFn)

partitions := 3
ingestersPerPartition := len(ingesters) / partitions
assert.Greaterf(t, ingestersPerPartition, 1, "must have more than one ingester per partition")

mockClientFactory := mockIngesterClientFactory{
requestedClients: make(map[string]int),
}

ingesterQuerier, err := newTestPartitionIngesterQuerier(mockClientFactory.newIngesterClientMockFactory(ingesterClient), instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards)
require.NoError(t, err)

ingesterQuerier.querierConfig.QueryPartitionIngesters = true

err = testData.testFn(ctx, ingesterQuerier)
require.NoError(t, err)

if testData.shards == 0 {
testData.shards = partitions
}
expectedCalls := min(testData.shards, partitions)
expectedIngesterCalls := expectedCalls
// Wait for responses: We expect one request per queried partition because we have request minimization enabled & ingesters are in multiple zones.
// If shuffle sharding is enabled, we expect one query per shard as we write to a subset of partitions.
require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected ingesters to respond")
ingesterClient.AssertNumberOfCalls(t, testData.method, expectedCalls)

partitionCtx := ExtractPartitionContext(ctx)
require.Equal(t, expectedIngesterCalls, len(partitionCtx.ingestersUsed))
require.Equal(t, expectedIngesterCalls, len(mockClientFactory.requestedClients))

for _, ingester := range partitionCtx.ingestersUsed {
count, ok := mockClientFactory.requestedClients[ingester.addr]
require.True(t, ok)
require.Equal(t, count, 1)
}

// Now call getChunkIDs to ensure we only call the same ingesters as before.
_, err = ingesterQuerier.GetChunkIDs(ctx, model.Time(0), model.Time(1))
require.NoError(t, err)

require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected ingesters to respond")
ingesterClient.AssertNumberOfCalls(t, "GetChunkIDs", expectedCalls)

// Finally, confirm we called the same ingesters again and didn't ask for any new clients
require.Equal(t, expectedIngesterCalls, len(mockClientFactory.requestedClients))
for _, ingester := range partitionCtx.ingestersUsed {
count, ok := mockClientFactory.requestedClients[ingester.addr]
require.True(t, ok)
require.Equal(t, count, 1)
}
})
}
}

func TestQuerier_tailDisconnectedIngesters(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -540,14 +667,14 @@ func newTestIngesterQuerier(readRingMock *readRingMock, ingesterClient *querierC
)
}

func newTestPartitionIngesterQuerier(ingesterClient *querierClientMock, instanceRing *readRingMock, partitionRing *ring.PartitionInstanceRing, tenantShards int) (*IngesterQuerier, error) {
func newTestPartitionIngesterQuerier(clientFactory client.PoolFactory, instanceRing *readRingMock, partitionRing *ring.PartitionInstanceRing, tenantShards int) (*IngesterQuerier, error) {
return newIngesterQuerier(
mockQuerierConfig(),
mockIngesterClientConfig(),
instanceRing,
partitionRing,
func(string) int { return tenantShards },
newIngesterClientMockFactory(ingesterClient),
clientFactory,
constants.Loki,
log.NewNopLogger(),
)
Expand Down
Loading