Skip to content

Commit e219a71

Browse files
authored
update visualizer (#434)
1 parent fb8f3c1 commit e219a71

File tree

4 files changed

+191
-23
lines changed

4 files changed

+191
-23
lines changed

configs/sparsification/methods/DART/dart.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ sparse:
1717
special:
1818
method: DART
1919
pruning_loc: 5
20-
reduction_ratio: 0.778
20+
reduction_ratio: 0.7778
2121
pivot_image_token: 4
2222
pivot_text_token : 4
2323
save:

configs/sparsification/methods/VisPruner/vispruner.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ sparse:
1818
method: TokenReduction
1919
special:
2020
method: VisPruner
21-
prune_ratio: 0.778 # 0.667 0.778 0.889
21+
prune_ratio: 0.7778 # 0.6667 0.7778 0.8889
2222
important_ratio: 0.5
2323
save:
2424
save_trans: False

llmc/compression/token_reduction/visualizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras):
4848
save_path=''
4949
)
5050
visualize_grid_to_grid(
51-
visual_attention_maps[0, 4, :, :],
51+
visual_attention_maps[0, 31, :, :],
5252
300,
5353
image,
5454
grid_size=24,
@@ -72,6 +72,7 @@ def visualizer_hook(module, inps, layer_outs, pruning_paras):
7272
functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras),
7373
)
7474
if idx == (len(self.blocks) - 1):
75+
# self.model.language_model.layers[-1]
7576
blk.register_forward_hook(
7677
functools.partial(visualizer_hook, pruning_paras=self.pruning_paras),
7778
)

llmc/utils/visualizer.py

Lines changed: 187 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import os
22

3-
import cv2
43
import numpy as np
54
import torch
6-
import torchvision.transforms.functional as TF
5+
import torch.nn.functional as F
76
from loguru import logger
87
from PIL import Image, ImageDraw
9-
from torchvision.transforms import ToPILImage
108

119
try:
1210
import matplotlib.pyplot as plt
11+
import seaborn as sns
1312
except Exception:
1413
logger.warning(
1514
'Can not import matplotlib. '
@@ -30,22 +29,32 @@ def to_pil_image(
3029

3130
def save_image(image_tensor, mean, std, save_path):
3231
img = to_pil_image(image_tensor)
33-
Image.fromarray(img).save(save_path)
32+
33+
if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
34+
os.makedirs(save_path, exist_ok=True)
35+
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
36+
idx = 0
37+
while os.path.exists(base_path.format(idx)):
38+
idx += 1
39+
save_path = base_path.format(idx)
40+
41+
else:
42+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
43+
44+
img.save(save_path)
3445

3546

3647
def visualize_kept_patches(
3748
image,
38-
keep_idx,
49+
keep_idx=None,
3950
mean=[0.48145466, 0.4578275, 0.40821073],
4051
std=[0.26862954, 0.26130258, 0.27577711],
4152
patch_size=14,
42-
darken_ratio=0.3,
53+
darken_ratio=0.8,
4354
save_path=None,
4455
):
45-
assert image.ndim == 3 and image.shape[0] == 3, \
46-
f'Expected image of shape [3, H, W], got {image.shape}'
47-
# save_image(image,mean,std,save_path)
48-
56+
# save_image(image, mean, std, save_path)
57+
# return
4958
_, H, W = image.shape # 3 336 336
5059
device = image.device
5160
num_patches_h = H // patch_size # 24
@@ -59,8 +68,11 @@ def visualize_kept_patches(
5968
mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
6069
mask = mask.unsqueeze(0) # shape [1, H, W]
6170

62-
# Darken image
63-
masked_image = image * (mask + (~mask) * darken_ratio)
71+
# white
72+
prune_mask = ~mask
73+
white_tensor = torch.ones_like(image)
74+
masked_image = image * (1 - darken_ratio * prune_mask.float()) + \
75+
white_tensor * (darken_ratio * prune_mask.float())
6476

6577
save_image(masked_image, mean, std, save_path)
6678

@@ -85,14 +97,6 @@ def grid_show(to_shows, cols, save_path=None, dpi=100):
8597
plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
8698
plt.close()
8799

88-
# def visualize_head(att_map):
89-
# ax = plt.gca()
90-
# # Plot the heatmap
91-
# im = ax.imshow(att_map)
92-
# # Create colorbar
93-
# cbar = ax.figure.colorbar(im, ax=ax)
94-
# plt.show()
95-
96100

97101
def visualize_heads(att_map, cols, save_path):
98102
to_shows = []
@@ -215,6 +219,169 @@ def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6,
215219
plt.close()
216220

217221

222+
def visualize_attention(attention, grid_size=24, save_path=None):
223+
224+
if hasattr(attention, 'detach'):
225+
attention = attention.detach().cpu().numpy()
226+
227+
H, W = attention.shape
228+
new_H = H // grid_size * grid_size
229+
new_W = W // grid_size * grid_size
230+
attention = attention[:new_H, :new_W]
231+
232+
blocks = attention.reshape(new_H // grid_size, grid_size, new_W // grid_size, grid_size)
233+
block_means = blocks.mean(axis=(1, 3))
234+
235+
mask = np.triu(np.ones_like(block_means, dtype=bool), k=1)
236+
237+
plt.figure(figsize=(10, 10))
238+
sns.heatmap(block_means, mask=mask, cmap='viridis', square=True, cbar=True)
239+
240+
ticks = np.arange(0, block_means.shape[0], 1)
241+
labels = ['' for i in ticks]
242+
plt.xticks(ticks=ticks, labels=labels, rotation=90)
243+
plt.yticks(ticks=ticks, labels=labels)
244+
245+
plt.title('Attention Map')
246+
plt.tight_layout()
247+
plt.savefig(save_path, bbox_inches='tight')
248+
plt.close()
249+
250+
251+
def visualize_attention_v2(attention, grid_size=24, save_path=None):
252+
253+
if hasattr(attention, 'detach'):
254+
attention = attention.detach().cpu().numpy()
255+
256+
# 分区
257+
block_ranges = []
258+
259+
# SYS: 2 blocks
260+
261+
sys_splits = [0, 17, 35]
262+
for i in range(len(sys_splits) - 1):
263+
block_ranges.append((sys_splits[i], sys_splits[i + 1]))
264+
# IMG: 24 blocks of size 24
265+
for i in range(24):
266+
start = 35 + i * 24
267+
end = start + 24
268+
block_ranges.append((start, end))
269+
270+
# INS: 6 blocks
271+
ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155
272+
ins_splits[-1] = 1155
273+
for i in range(len(ins_splits) - 1):
274+
block_ranges.append((ins_splits[i], ins_splits[i + 1]))
275+
276+
# 对每个 block pair 求平均
277+
num_blocks = len(block_ranges)
278+
block_attention = np.zeros((num_blocks, num_blocks))
279+
for i in range(num_blocks):
280+
i_start, i_end = block_ranges[i]
281+
for j in range(num_blocks):
282+
j_start, j_end = block_ranges[j]
283+
block = attention[i_start:i_end, j_start:j_end]
284+
block_attention[31 - i, j] = block.mean()
285+
286+
mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1)
287+
plt.figure(figsize=(10, 10))
288+
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)
289+
sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True)
290+
# sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True)
291+
292+
section_labels = ['SYS', 'IMG', 'INS']
293+
section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS
294+
ticks = np.arange(0, num_blocks)
295+
plt.xticks(ticks=ticks, labels=[''] * num_blocks)
296+
plt.yticks(ticks=ticks, labels=[''] * num_blocks)
297+
plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
298+
plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
299+
plt.title('Attention Map')
300+
plt.tight_layout()
301+
plt.savefig(save_path, bbox_inches='tight')
302+
plt.close()
303+
304+
305+
def visualize_cosin_token(token_embedding, save_path=None):
306+
307+
plt.rcParams['font.size'] = 15
308+
309+
x = token_embedding[0, 14: 14 + 196 * 4, :]
310+
x_norm = F.normalize(x, p=2, dim=1)
311+
similarity_matrix = x_norm @ x_norm.T
312+
313+
sim_np = similarity_matrix.cpu().numpy()
314+
sim_np = np.triu(sim_np, k=1)
315+
valid_sim = sim_np[sim_np > 0]
316+
vmin = np.percentile(valid_sim, 90) # 10% min
317+
318+
plt.subplots(figsize=(10, 10))
319+
sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1)
320+
321+
start = 0
322+
step = 196
323+
ticks = np.arange(start, 196 * 5, step)
324+
plt.xticks(ticks, ticks)
325+
plt.yticks(ticks, ticks)
326+
327+
plt.title('')
328+
plt.xlabel('')
329+
plt.ylabel('')
330+
plt.tight_layout()
331+
plt.savefig(save_path, format='pdf')
332+
plt.rcdefaults()
333+
plt.close()
334+
335+
336+
def visualize_cosin_token_32p(token_embedding, save_path=None):
337+
338+
plt.rcParams['font.size'] = 20
339+
340+
all_tokens = token_embedding[0, 14:14 + 196 * 32, :]
341+
x_norm = F.normalize(all_tokens, p=2, dim=1)
342+
similarity_matrix = x_norm @ x_norm.T
343+
sim_np = similarity_matrix.cpu().numpy()
344+
sim_np = np.triu(sim_np, k=1)
345+
valid_sim = sim_np[sim_np > 0]
346+
vmin = np.percentile(valid_sim, 90) # 10% min
347+
348+
group_size = 4
349+
num_groups = 8
350+
tokens_per_group = 196 * group_size
351+
step = 196
352+
353+
fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布
354+
axs = axs.flatten()
355+
356+
for i in range(num_groups):
357+
x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :]
358+
x_norm = F.normalize(x, p=2, dim=1)
359+
similarity_matrix = x_norm @ x_norm.T
360+
361+
ax = axs[i]
362+
sns.heatmap(
363+
similarity_matrix.cpu().numpy(), cmap='Reds',
364+
vmin=vmin, vmax=1, ax=ax, cbar=False
365+
)
366+
367+
ticks = np.arange(0, tokens_per_group, step)
368+
labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step)
369+
ax.set_xticks(ticks)
370+
ax.set_yticks(ticks)
371+
ax.set_xticklabels(labels, rotation=0)
372+
ax.set_yticklabels(labels)
373+
start_frame = i * group_size
374+
end_frame = (i + 1) * group_size - 1
375+
ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10)
376+
377+
plt.tight_layout()
378+
# plt.savefig(save_path, format='pdf')
379+
# plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight')
380+
plt.savefig(save_path, dpi=300)
381+
plt.rcdefaults()
382+
plt.close()
383+
384+
218385
def highlight_grid(image, grid_indexes, grid_size=14):
219386
if not isinstance(grid_size, tuple):
220387
grid_size = (grid_size, grid_size)

0 commit comments

Comments
 (0)