-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvllm_allocator_adaptor_c.cpp
304 lines (242 loc) · 10.3 KB
/
vllm_allocator_adaptor_c.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
// file: vllm_allocator_adaptor_c.cpp
//
// An adaptor to pass Python function to PyTorch's pluggable allocator.
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* need to be unsigned long long
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
#include <cuda.h>
#define CUDA_CHECK(condition) \
do { \
CUresult error = condition; \
if (error != 0) { \
char* error_string; \
cuGetErrorString(error, (const char**)&error_string); \
std::cerr << "[vllm_allocator_adaptor_c] CUDA Error: " << error_string << " at " << __FILE__ << ":" << __LINE__ << std::endl; \
} \
} while (0)
// Global references to Python callables
// NOTE: this is borrowed reference, so we don't need to DECREF them.
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
extern "C" {
void ensure_context(unsigned long long device)
{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
// Ensure device context.
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle)
{
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
// std::cout << "[vllm_allocator_adaptor_c] create_and_map: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
}
void unmap_and_release(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle)
{
// std::cout << "[vllm_allocator_adaptor_c] unmap_and_release: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
CUDA_CHECK(cuMemUnmap(d_mem, size));
CUDA_CHECK(cuMemRelease(*p_memHandle));
}
PyObject* create_tuple_from_c_integers(unsigned long long a, unsigned long long b, unsigned long long c, unsigned long long d) {
// Create a new tuple of size 4
PyObject *tuple = PyTuple_New(4);
if (!tuple) {
return NULL; // Return NULL on failure
}
// Convert integers to Python objects and set them in the tuple
PyTuple_SetItem(tuple, 0, PyLong_FromLong(a)); // Steals reference to the PyLong
PyTuple_SetItem(tuple, 1, PyLong_FromLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple; // Return the created tuple
}
void* my_malloc(ssize_t size, int device, cudaStream_t stream)
{
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate a CUmemGenericAllocationHandle
// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
// Check if the allocation is supported
size_t granularity;
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
CUdeviceptr d_mem;
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)malloc(sizeof(CUmemGenericAllocationHandle));
if (!g_python_malloc_callback) {
std::cerr << "[vllm_allocator_adaptor_c] ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(device, alignedSize, (unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream)
{
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "[vllm_allocator_adaptor_c] ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// recv_size == size
// recv_device == device
// Free memory
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
free(p_memHandle);
}
} // extern "C"
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args)
{
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{
"init_module",
(PyCFunction)py_init_module,
METH_VARARGS,
"Initialize module with python_malloc and python_free callables."
},
{
"python_create_and_map",
(PyCFunction)python_create_and_map,
METH_VARARGS,
"Create and map memory on the device."
},
{
"python_unmap_and_release",
(PyCFunction)python_unmap_and_release,
METH_VARARGS,
"Unmap and release memory on the device."
},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef vllm_allocator_adaptor_c_module = {
PyModuleDef_HEAD_INIT,
"vllm_allocator_adaptor_c",
"vLLM Allocator Adaptor",
-1,
module_methods
};
PyMODINIT_FUNC
PyInit_vllm_allocator_adaptor_c(void)
{
// Initialize the module
PyObject* module = PyModule_Create(&vllm_allocator_adaptor_c_module);
if (!module) {
return NULL;
}
return module;
}