diff --git a/arrow/flight/flight_test.go b/arrow/flight/flight_test.go index 98d1734c..8d75aac2 100644 --- a/arrow/flight/flight_test.go +++ b/arrow/flight/flight_test.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" "io" + "sync" + "sync/atomic" "testing" "github.com/apache/arrow-go/v18/arrow" @@ -484,3 +486,251 @@ type flightStreamWriter struct{} func (f *flightStreamWriter) Send(data *flight.FlightData) error { return nil } var _ flight.DataStreamWriter = (*flightStreamWriter)(nil) + +// callbackRecordReader wraps a record reader and invokes a callback on each Next() call. +// It tracks whether batches are properly released and the reader itself is released. +type callbackRecordReader struct { + mem memory.Allocator + schema *arrow.Schema + numBatches int + currentBatch atomic.Int32 + onNext func(batchIndex int) // callback invoked before returning from Next() + released atomic.Bool + batchesCreated atomic.Int32 + totalRetains atomic.Int32 + totalReleases atomic.Int32 + createdBatches []arrow.RecordBatch // track all created batches for cleanup + mu sync.Mutex +} + +func newCallbackRecordReader(mem memory.Allocator, schema *arrow.Schema, numBatches int, onNext func(int)) *callbackRecordReader { + return &callbackRecordReader{ + mem: mem, + schema: schema, + numBatches: numBatches, + onNext: onNext, + } +} + +func (r *callbackRecordReader) Schema() *arrow.Schema { + return r.schema +} + +func (r *callbackRecordReader) Next() bool { + current := r.currentBatch.Load() + if int(current) >= r.numBatches { + return false + } + r.currentBatch.Add(1) + + if r.onNext != nil { + r.onNext(int(current)) + } + + return true +} + +func (r *callbackRecordReader) RecordBatch() arrow.RecordBatch { + bldr := array.NewInt64Builder(r.mem) + defer bldr.Release() + + currentBatch := r.currentBatch.Load() + bldr.AppendValues([]int64{int64(currentBatch)}, nil) + arr := bldr.NewArray() + + rec := array.NewRecordBatch(r.schema, []arrow.Array{arr}, 1) + arr.Release() + + tracked := &trackedRecordBatch{ + RecordBatch: rec, + onRetain: func() { + r.totalRetains.Add(1) + }, + onRelease: func() { + r.totalReleases.Add(1) + }, + } + + r.mu.Lock() + r.createdBatches = append(r.createdBatches, tracked) + r.mu.Unlock() + + r.batchesCreated.Add(1) + return tracked +} + +func (r *callbackRecordReader) ReleaseAll() { + r.mu.Lock() + defer r.mu.Unlock() + for _, batch := range r.createdBatches { + batch.Release() + } + r.createdBatches = nil +} + +func (r *callbackRecordReader) Retain() {} + +func (r *callbackRecordReader) Release() { + r.released.Store(true) +} + +func (r *callbackRecordReader) Record() arrow.RecordBatch { + return r.RecordBatch() +} + +func (r *callbackRecordReader) Err() error { + return nil +} + +// trackedRecordBatch wraps a RecordBatch to track Retain/Release calls. +type trackedRecordBatch struct { + arrow.RecordBatch + onRetain func() + onRelease func() +} + +func (t *trackedRecordBatch) Retain() { + if t.onRetain != nil { + t.onRetain() + } + t.RecordBatch.Retain() +} + +func (t *trackedRecordBatch) Release() { + if t.onRelease != nil { + t.onRelease() + } + t.RecordBatch.Release() +} + +func TestStreamChunksFromReader_OK(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil) + + rdr := newCallbackRecordReader(mem, schema, 5, nil) + defer rdr.ReleaseAll() + + ch := make(chan flight.StreamChunk, 5) + + ctx := context.Background() + + go flight.StreamChunksFromReader(ctx, rdr, ch) + + var chunksReceived int + for chunk := range ch { + if chunk.Err != nil { + t.Errorf("unexpected error chunk: %v", chunk.Err) + continue + } + if chunk.Data != nil { + chunksReceived++ + chunk.Data.Release() + } + } + + require.Equal(t, 5, chunksReceived, "should receive all 5 batches") + require.True(t, rdr.released.Load(), "reader should be released") + +} + +// TestStreamChunksFromReader_HandlesCancellation verifies that context cancellation +// causes StreamChunksFromReader to exit cleanly and release the reader. +func TestStreamChunksFromReader_HandlesCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mem := memory.DefaultAllocator + schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil) + + rdr := newCallbackRecordReader(mem, schema, 10, nil) + defer rdr.ReleaseAll() + ch := make(chan flight.StreamChunk) // unbuffered channel + + go flight.StreamChunksFromReader(ctx, rdr, ch) + + chunksReceived := 0 + for chunk := range ch { + if chunk.Data != nil { + chunksReceived++ + chunk.Data.Release() + } + + // Cancel context after 2 batches (simulating server detecting client disconnect) + if chunksReceived == 2 { + cancel() + } + } + + // After canceling context, StreamChunksFromReader exits and closes the channel. + // The for-range loop above exits when the channel closes. + // By the time we reach here, the channel is closed, which means StreamChunksFromReader's + // defer stack has already executed, so the reader must be released. + + require.True(t, rdr.released.Load(), "reader must be released when context is canceled") + +} + +// TestStreamChunksFromReader_CancellationReleasesBatches verifies that batches are +// properly tracked and demonstrates memory leaks without cleanup, then proves cleanup fixes it. +func TestStreamChunksFromReader_CancellationReleasesBatches(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + + schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create reader that will produce 10 batches, but we'll cancel after 3 + reader := newCallbackRecordReader(mem, schema, 10, func(batchIndex int) { + if batchIndex == 2 { + cancel() + } + }) + + ch := make(chan flight.StreamChunk, 5) + + // Start streaming + go flight.StreamChunksFromReader(ctx, reader, ch) + + // Consume chunks until channel closes + var chunksReceived int + for chunk := range ch { + if chunk.Err != nil { + t.Errorf("unexpected error chunk: %v", chunk.Err) + continue + } + if chunk.Data != nil { + chunksReceived++ + chunk.Data.Release() + } + } + + // Verify the reader was released + require.True(t, reader.released.Load(), "reader should be released") + + // We should have received at most 3-4 chunks (depending on timing) + // The important part is we didn't receive all 10 + require.LessOrEqual(t, chunksReceived, 4, "should not receive all 10 chunks, got %d", chunksReceived) + require.Greater(t, chunksReceived, 0, "should receive at least 1 chunk") + + // Check that Retain and Release don't balance - proving there's a leak without manual cleanup + retains := reader.totalRetains.Load() + releases := reader.totalReleases.Load() + batchesCreated := reader.batchesCreated.Load() + + // Each batch starts with refcount=1, then StreamChunksFromReader calls Retain() (refcount=2) + // For sent batches: we call Release() (refcount=1), batch still has initial ref + // For unsent batches due to cancellation: they keep refcount=1 from creation + // So we expect: releases < retains + batchesCreated + require.Less(t, releases, retains+batchesCreated, + "without cleanup, releases should be less than retains+created: retains=%d, releases=%d, created=%d", + retains, releases, batchesCreated) + + // Now manually release all created batches to show proper cleanup fixes the leak + reader.ReleaseAll() + + // After cleanup, memory should be freed + mem.AssertSize(t, 0) +} diff --git a/arrow/flight/flightsql/example/sqlite_server.go b/arrow/flight/flightsql/example/sqlite_server.go index fc7d76a2..dca7b2d6 100644 --- a/arrow/flight/flightsql/example/sqlite_server.go +++ b/arrow/flight/flightsql/example/sqlite_server.go @@ -354,7 +354,7 @@ func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.G } schema := rdr.Schema() - go flight.StreamChunksFromReader(rdr, ch) + go flight.StreamChunksFromReader(ctx, rdr, ch) return schema, ch, nil } @@ -485,7 +485,7 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query } ch := make(chan flight.StreamChunk) - go flight.StreamChunksFromReader(rdr, ch) + go flight.StreamChunksFromReader(ctx, rdr, ch) return schema, ch, nil } diff --git a/arrow/flight/flightsql/server.go b/arrow/flight/flightsql/server.go index d5102a27..25c89bf5 100644 --- a/arrow/flight/flightsql/server.go +++ b/arrow/flight/flightsql/server.go @@ -381,7 +381,7 @@ func (b *BaseServer) GetFlightInfoSqlInfo(_ context.Context, _ GetSqlInfo, desc } // DoGetSqlInfo returns a flight stream containing the list of sqlinfo results -func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) { +func (b *BaseServer) DoGetSqlInfo(ctx context.Context, cmd GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) { if b.Alloc == nil { b.Alloc = memory.DefaultAllocator } @@ -430,7 +430,7 @@ func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) (*arrow.Sch } // StreamChunksFromReader will call release on the reader when done - go flight.StreamChunksFromReader(rdr, ch) + go flight.StreamChunksFromReader(ctx, rdr, ch) return schema_ref.SqlInfo, ch, nil } @@ -927,19 +927,24 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightServ wr := flight.NewRecordWriter(stream, ipc.WithSchema(sc)) defer wr.Close() - for chunk := range cc { - if chunk.Err != nil { - return chunk.Err - } - - wr.SetFlightDescriptor(chunk.Desc) - if err = wr.WriteWithAppMetadata(chunk.Data, chunk.AppMetadata); err != nil { - return err + for { + select { + case <-stream.Context().Done(): + return stream.Context().Err() + case chunk, ok := <-cc: + if !ok { + return nil + } + if chunk.Err != nil { + return chunk.Err + } + wr.SetFlightDescriptor(chunk.Desc) + if err := wr.WriteWithAppMetadata(chunk.Data, chunk.AppMetadata); err != nil { + return err + } + chunk.Data.Release() } - chunk.Data.Release() } - - return err } type putMetadataWriter struct { diff --git a/arrow/flight/record_batch_reader.go b/arrow/flight/record_batch_reader.go index 7b744075..e6990a57 100644 --- a/arrow/flight/record_batch_reader.go +++ b/arrow/flight/record_batch_reader.go @@ -18,6 +18,7 @@ package flight import ( "bytes" + "context" "errors" "fmt" "io" @@ -212,24 +213,38 @@ type haserr interface { // StreamChunksFromReader is a convenience function to populate a channel // from a record reader. It is intended to be run using a separate goroutine -// by calling `go flight.StreamChunksFromReader(rdr, ch)`. +// by calling `go flight.StreamChunksFromReader(ctx, rdr, ch)`. // // If the record reader panics, an error chunk will get sent on the channel. // // This will close the channel and release the reader when it completes. -func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) { +func StreamChunksFromReader(ctx context.Context, rdr array.RecordReader, ch chan<- StreamChunk) { defer close(ch) defer func() { if err := recover(); err != nil { - ch <- StreamChunk{Err: utils.FormatRecoveredError("panic while reading", err)} + select { + case ch <- StreamChunk{Err: utils.FormatRecoveredError("panic while reading", err)}: + case <-ctx.Done(): + } } }() defer rdr.Release() for rdr.Next() { + select { + case <-ctx.Done(): + return + default: + } + rec := rdr.RecordBatch() rec.Retain() - ch <- StreamChunk{Data: rec} + select { + case ch <- StreamChunk{Data: rec}: + case <-ctx.Done(): + rec.Release() + return + } } if e, ok := rdr.(haserr); ok {