Skip to content

Commit fd42422

Browse files
committed
Address PR review feedback: Revert global allocator, enhance planner
- Reverted global allocator 256KB bucketing to prevent memory bloat - Moved bucketing logic to AttentionMemoryPlanner (local scope) - Implemented thread safety (mutex) in AttentionMemoryPlanner - Switched to 'First Fit' reuse strategy to prevent metadata explosion - Fixed unit test logic for workspace prediction - Added unit tests for reuse and metadata stability
1 parent 557c9c5 commit fd42422

File tree

4 files changed

+67
-35
lines changed

4 files changed

+67
-35
lines changed

onnxruntime/core/providers/cuda/cuda_allocator.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,8 @@ void* CUDAAllocator::Alloc(size_t size) {
4343
CheckDevice(true);
4444
void* p = nullptr;
4545
if (size > 0) {
46-
// Heuristic H1: Bucket Allocations
47-
// Round up to 256 KB to reduce fragmentation
48-
constexpr size_t kBucketSize = 256 * 1024;
49-
size_t bucketed_size = ((size + kBucketSize - 1) / kBucketSize) * kBucketSize;
50-
5146
// BFCArena was updated recently to handle the exception and adjust the request size
52-
CUDA_CALL_THROW(cudaMalloc((void**)&p, bucketed_size));
47+
CUDA_CALL_THROW(cudaMalloc((void**)&p, size));
5348
}
5449
return p;
5550
}

onnxruntime/core/providers/cuda/transformers/attention_memory_planner.cc

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,31 @@ size_t AttentionMemoryPlanner::PredictWorkspaceSize(int64_t batch_size, int64_t
1515
return std::min(kLimit, static_cast<size_t>(predicted * 1.25));
1616
}
1717

18-
void* AttentionMemoryPlanner::Allocate(size_t size, const std::vector<int64_t>& shape) {
18+
void* AttentionMemoryPlanner::Allocate(size_t size) {
19+
std::lock_guard<std::mutex> lock(mutex_);
1920
size_t bucketed_size = BucketSize(size);
2021

21-
// Heuristic H3: Tensor Lifetime Reuse
22-
// Try to find a free block with exact shape match (preferred)
23-
for (auto& alloc : allocations_) {
24-
if (alloc.free && alloc.size >= bucketed_size) {
25-
if (alloc.shape == shape) {
26-
alloc.free = false;
27-
return alloc.ptr;
28-
}
29-
}
30-
}
22+
// Heuristic H3: Tensor Lifetime Reuse (Best Fit / First Fit with size >= requested)
23+
// We look for a free block that is large enough.
24+
// Since we bucket, we are likely to find exact matches or slightly larger ones.
25+
// We pick the first one that fits to avoid scanning the whole list (First Fit).
26+
// Ideally we might want Best Fit, but First Fit is faster and usually sufficient with bucketing.
3127

32-
// Fallback: find any free block large enough
3328
for (auto& alloc : allocations_) {
3429
if (alloc.free && alloc.size >= bucketed_size) {
35-
alloc.free = false;
36-
alloc.shape = shape;
37-
return alloc.ptr;
30+
alloc.free = false;
31+
return alloc.ptr;
3832
}
3933
}
4034

4135
// Allocate new
4236
void* p = allocator_->Alloc(bucketed_size);
43-
allocations_.push_back({p, bucketed_size, false, shape});
37+
allocations_.push_back({p, bucketed_size, false});
4438
return p;
4539
}
4640

4741
void AttentionMemoryPlanner::Free(void* p) {
42+
std::lock_guard<std::mutex> lock(mutex_);
4843
for (auto& alloc : allocations_) {
4944
if (alloc.ptr == p) {
5045
alloc.free = true;

onnxruntime/core/providers/cuda/transformers/attention_memory_planner.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <vector>
33
#include <map>
4+
#include <mutex>
45
#include "core/common/common.h"
56
#include "core/framework/allocator.h"
67

@@ -13,25 +14,27 @@ class AttentionMemoryPlanner {
1314
AttentionMemoryPlanner(AllocatorPtr allocator, size_t stream_idx)
1415
: allocator_(allocator), stream_idx_(stream_idx) {}
1516

16-
void* Allocate(size_t size, const std::vector<int64_t>& shape);
17+
void* Allocate(size_t size);
1718
void Free(void* p);
1819

1920
static size_t PredictWorkspaceSize(int64_t batch_size, int64_t num_heads, int64_t seq_len, int64_t head_dim, size_t element_size);
2021

2122
private:
2223
struct Allocation {
2324
void* ptr;
24-
size_t size;
25+
size_t size; // Actual allocated size (bucketed)
2526
bool free;
26-
std::vector<int64_t> shape;
2727
};
2828

2929
AllocatorPtr allocator_;
3030
size_t stream_idx_;
3131
std::vector<Allocation> allocations_;
32+
std::mutex mutex_;
3233

3334
size_t BucketSize(size_t size) const {
3435
constexpr size_t kBucketSize = 256 * 1024; // 256 KB
36+
// Only bucket if size is large enough to matter, otherwise we waste too much on small tensors
37+
if (size < kBucketSize) return size;
3538
return ((size + kBucketSize - 1) / kBucketSize) * kBucketSize;
3639
}
3740
};

onnxruntime/test/providers/cuda/attention_mem_tests.cc

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,67 @@ class MockAllocator : public IAllocator {
1919
};
2020

2121
TEST(AttentionMemoryPlannerTest, PredictWorkspaceSize) {
22+
// predicted = 1 * 32 * 1024 * 128 * 4 = 16,777,216 bytes = 16 MB
23+
// limit = min(512MB, predicted * 1.25) = 16MB * 1.25 = 20 MB = 20,971,520 bytes
2224
size_t size = AttentionMemoryPlanner::PredictWorkspaceSize(1, 32, 1024, 128, 4);
23-
// 1 * 32 * 1024 * 128 * 4 = 16,777,216 bytes = 16 MB
24-
EXPECT_EQ(size, 16777216);
25+
EXPECT_EQ(size, 20971520);
2526
}
2627

27-
TEST(AttentionMemoryPlannerTest, AllocationReuse) {
28+
TEST(AttentionMemoryPlannerTest, AllocationReuse_BestFit) {
2829
auto allocator = std::make_shared<MockAllocator>();
2930
AttentionMemoryPlanner planner(allocator, 0);
3031

31-
std::vector<int64_t> shape1 = {1, 32, 1024, 128};
32-
void* p1 = planner.Allocate(100, shape1);
32+
// Allocate 1MB (will be bucketed to 1MB if bucket size is 256KB)
33+
size_t size1 = 1024 * 1024;
34+
void* p1 = planner.Allocate(size1);
3335

3436
planner.Free(p1);
3537

36-
void* p2 = planner.Allocate(100, shape1);
37-
EXPECT_EQ(p1, p2); // Should reuse exact shape
38+
// Allocate slightly smaller size, should reuse p1
39+
size_t size2 = size1 - 1024;
40+
void* p2 = planner.Allocate(size2);
41+
EXPECT_EQ(p1, p2); // Should reuse the same pointer
3842

3943
planner.Free(p2);
44+
}
45+
46+
TEST(AttentionMemoryPlannerTest, MetadataStability_Autoregressive) {
47+
auto allocator = std::make_shared<MockAllocator>();
48+
AttentionMemoryPlanner planner(allocator, 0);
49+
50+
// Simulate autoregressive generation: seq_len increases, so buffer size increases
51+
// We want to ensure we don't keep allocating new blocks without reusing old ones if they fit.
52+
// Note: In a real scenario, we'd likely free the old smaller buffer and allocate a new larger one.
53+
// If we free the old one, it becomes available.
54+
55+
void* p_prev = nullptr;
56+
57+
// Step 1: Allocate 100KB
58+
void* p1 = planner.Allocate(100 * 1024);
59+
p_prev = p1;
60+
61+
// Step 2: Free p1, Allocate 110KB
62+
// Since 100KB < 256KB bucket, it's not bucketed in our current logic (if size < kBucketSize return size).
63+
// Wait, let's check the logic: "if (size < kBucketSize) return size;"
64+
// So small allocations are exact.
65+
66+
planner.Free(p1);
67+
68+
// If we allocate larger, we can't reuse the smaller block.
69+
void* p2 = planner.Allocate(110 * 1024);
70+
EXPECT_NE(p1, p2); // Can't reuse smaller block for larger request
71+
72+
planner.Free(p2);
73+
74+
// Step 3: Large allocations (bucketed)
75+
// Allocate 1MB
76+
void* p3 = planner.Allocate(1024 * 1024);
77+
planner.Free(p3);
4078

41-
std::vector<int64_t> shape2 = {1, 32, 1024, 64};
42-
void* p3 = planner.Allocate(100, shape2);
43-
EXPECT_EQ(p3, p2); // Should reuse size-compatible buffer
79+
// Allocate 0.9MB (should reuse 1MB bucket)
80+
void* p4 = planner.Allocate(900 * 1024);
81+
EXPECT_EQ(p3, p4);
82+
planner.Free(p4);
4483
}
4584

4685
}

0 commit comments

Comments
 (0)