-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathcache_engine.rs
176 lines (153 loc) · 6.05 KB
/
cache_engine.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// based on https://github.com/vllm-project/vllm/blob/b9fe4616f98b77b4b9458bce203aa6544cb31ef2/vllm/worker/cache_engine.py
use super::super::{config::TchRllmConfig, kernels, tmodel::TModel};
use super::CacheIface;
use rllm::{config::RllmConfig, CacheSize, HashMap};
use std::sync::Arc;
use tch::{Device, Tensor};
#[cfg(not(feature = "cuda"))]
use super::cuda_stub::{CudaEvent, CudaStream};
#[cfg(feature = "cuda")]
use tch_cuda::{CudaEvent, CudaStream};
type KVCache = (Tensor, Tensor);
pub struct CacheEngine {
gpu_cache: Arc<Vec<KVCache>>,
cpu_cache: Vec<KVCache>,
cache_stream: CudaStream,
events: Arc<Vec<CudaEvent>>,
used_events: bool,
}
struct MyCacheAwaiter {
gpu_cache: Arc<Vec<KVCache>>,
events: Option<Arc<Vec<CudaEvent>>>,
stream: CudaStream,
}
impl CacheIface for MyCacheAwaiter {
fn get(&self, layer_no: usize) -> (Tensor, Tensor) {
let (key, value) = &self.gpu_cache[layer_no];
if let Some(events) = &self.events {
events[layer_no].wait(&self.stream);
}
(key.shallow_clone(), value.shallow_clone())
}
}
impl CacheEngine {
pub fn new(config: Arc<RllmConfig<TModel>>, num_blocks: &CacheSize) -> Self {
let num_layers = config.get_num_layers_parallel();
let (gpu_cache, cpu_cache) = Self::allocate_caches(&config, num_blocks);
Self {
gpu_cache: Arc::new(gpu_cache),
cpu_cache,
cache_stream: CudaStream::new(config.model.device),
events: Arc::new((0..num_layers).map(|_| CudaEvent::new()).collect()),
used_events: false,
}
}
pub fn get_cache_iface(&mut self) -> Box<dyn CacheIface> {
let d = self.gpu_cache[0].0.device();
let events = if self.used_events {
Some(self.events.clone())
} else {
None
};
Box::new(MyCacheAwaiter {
events,
stream: CudaStream::current(d),
gpu_cache: self.gpu_cache.clone(),
})
}
pub fn new_round(&mut self) {
self.used_events = false;
}
pub fn swap_in(&mut self, src_to_dst: &HashMap<usize, usize>) {
self.swap(&self.cpu_cache, &self.gpu_cache, src_to_dst);
self.used_events = true;
}
pub fn swap_out(&mut self, src_to_dst: &HashMap<usize, usize>) {
self.swap(&self.gpu_cache, &self.cpu_cache, src_to_dst);
self.used_events = true;
}
fn alloc_key_block(config: &RllmConfig<TModel>, num_bl: i64, device: Device) -> Tensor {
let head_size = config.get_head_size() as i64;
let num_heads = config.get_num_heads_parallel() as i64;
let block_size = config.model.cache.block_size as i64;
let x = 16 / (config.model.dtype.elt_size_in_bytes() as i64);
Tensor::empty(
&[num_bl, num_heads, head_size / x, block_size, x],
(config.model.dtype, device),
)
}
fn alloc_value_block(config: &RllmConfig<TModel>, num_bl: i64, device: Device) -> Tensor {
let head_size = config.get_head_size() as i64;
let num_heads = config.get_num_heads_parallel() as i64;
let block_size = config.model.cache.block_size as i64;
Tensor::empty(
&[num_bl, num_heads, head_size, block_size],
(config.model.dtype, device),
)
}
pub fn alloc_gpu_cache_layer(config: &RllmConfig<TModel>, num_bl: i64) -> (Tensor, Tensor) {
let device = config.model.device;
(
Self::alloc_key_block(config, num_bl, device),
Self::alloc_value_block(config, num_bl, device),
)
}
fn allocate_caches(
config: &RllmConfig<TModel>,
num_blocks: &CacheSize,
) -> (Vec<KVCache>, Vec<KVCache>) {
let num_layers = config.get_num_layers_parallel() as i64;
let gpu_cache = (0..num_layers)
.map(|_| Self::alloc_gpu_cache_layer(config, num_blocks.gpu as i64))
.collect();
let cpu_cache = (0..num_layers)
.map(|_| {
// TODO: vllm sets pin_memory=True here
let device = Device::Cpu;
(
Self::alloc_key_block(config, num_blocks.cpu as i64, device),
Self::alloc_value_block(config, num_blocks.cpu as i64, device),
)
})
.collect();
(gpu_cache, cpu_cache)
}
#[cfg(not(feature = "cuda"))]
fn swap(&self, _src: &[KVCache], _dst: &[KVCache], _src_to_dst: &HashMap<usize, usize>) {
let _ = self.cache_stream;
panic!("swap not implemented for CPU");
}
#[cfg(feature = "cuda")]
fn swap(&self, src: &[KVCache], dst: &[KVCache], src_to_dst: &HashMap<usize, usize>) {
let stream = &self.cache_stream;
for (i, (src_k_cache, src_v_cache)) in src.iter().enumerate() {
let (dst_k_cache, dst_v_cache) = &dst[i];
kernels::swap_blocks(src_k_cache, dst_k_cache, src_to_dst, &self.cache_stream);
kernels::swap_blocks(src_v_cache, dst_v_cache, src_to_dst, &self.cache_stream);
self.events[i].record(stream);
}
}
pub fn copy(&mut self, src_to_dsts: &HashMap<usize, Vec<usize>>) {
let mut key_caches: Vec<_> = self
.gpu_cache
.iter()
.map(|(key, _)| key.shallow_clone())
.collect();
let mut value_caches: Vec<_> = self
.gpu_cache
.iter()
.map(|(_, value)| value.shallow_clone())
.collect();
kernels::copy_blocks(&mut key_caches, &mut value_caches, &src_to_dsts);
}
pub fn get_cache_block_size(config: &RllmConfig<TModel>) -> usize {
let block_size = config.model.cache.block_size;
let head_size = config.get_head_size();
let num_heads = config.get_num_heads_parallel();
let num_layers = config.get_num_layers_parallel();
let key_cache_block = block_size * num_heads * head_size;
let value_cache_block = key_cache_block;
let total = num_layers * (key_cache_block + value_cache_block);
config.model.dtype.elt_size_in_bytes() * total
}
}