11import os
22
3- import cv2
43import numpy as np
54import torch
6- import torchvision . transforms .functional as TF
5+ import torch . nn .functional as F
76from loguru import logger
87from PIL import Image , ImageDraw
9- from torchvision .transforms import ToPILImage
108
119try :
1210 import matplotlib .pyplot as plt
11+ import seaborn as sns
1312except Exception :
1413 logger .warning (
1514 'Can not import matplotlib. '
@@ -30,22 +29,32 @@ def to_pil_image(
3029
3130def save_image (image_tensor , mean , std , save_path ):
3231 img = to_pil_image (image_tensor )
33- Image .fromarray (img ).save (save_path )
32+
33+ if not save_path .lower ().endswith (('.png' , '.jpg' , '.jpeg' , '.pdf' )):
34+ os .makedirs (save_path , exist_ok = True )
35+ base_path = os .path .join (save_path , '{:04d}_visprunerP.png' )
36+ idx = 0
37+ while os .path .exists (base_path .format (idx )):
38+ idx += 1
39+ save_path = base_path .format (idx )
40+
41+ else :
42+ os .makedirs (os .path .dirname (save_path ), exist_ok = True )
43+
44+ img .save (save_path )
3445
3546
3647def visualize_kept_patches (
3748 image ,
38- keep_idx ,
49+ keep_idx = None ,
3950 mean = [0.48145466 , 0.4578275 , 0.40821073 ],
4051 std = [0.26862954 , 0.26130258 , 0.27577711 ],
4152 patch_size = 14 ,
42- darken_ratio = 0.3 ,
53+ darken_ratio = 0.8 ,
4354 save_path = None ,
4455):
45- assert image .ndim == 3 and image .shape [0 ] == 3 , \
46- f'Expected image of shape [3, H, W], got { image .shape } '
47- # save_image(image,mean,std,save_path)
48-
56+ # save_image(image, mean, std, save_path)
57+ # return
4958 _ , H , W = image .shape # 3 336 336
5059 device = image .device
5160 num_patches_h = H // patch_size # 24
@@ -59,8 +68,11 @@ def visualize_kept_patches(
5968 mask = patch_mask .repeat_interleave (patch_size , dim = 0 ).repeat_interleave (patch_size , dim = 1 )
6069 mask = mask .unsqueeze (0 ) # shape [1, H, W]
6170
62- # Darken image
63- masked_image = image * (mask + (~ mask ) * darken_ratio )
71+ # white
72+ prune_mask = ~ mask
73+ white_tensor = torch .ones_like (image )
74+ masked_image = image * (1 - darken_ratio * prune_mask .float ()) + \
75+ white_tensor * (darken_ratio * prune_mask .float ())
6476
6577 save_image (masked_image , mean , std , save_path )
6678
@@ -85,14 +97,6 @@ def grid_show(to_shows, cols, save_path=None, dpi=100):
8597 plt .savefig (save_path , bbox_inches = 'tight' , dpi = dpi )
8698 plt .close ()
8799
88- # def visualize_head(att_map):
89- # ax = plt.gca()
90- # # Plot the heatmap
91- # im = ax.imshow(att_map)
92- # # Create colorbar
93- # cbar = ax.figure.colorbar(im, ax=ax)
94- # plt.show()
95-
96100
97101def visualize_heads (att_map , cols , save_path ):
98102 to_shows = []
@@ -215,6 +219,169 @@ def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6,
215219 plt .close ()
216220
217221
222+ def visualize_attention (attention , grid_size = 24 , save_path = None ):
223+
224+ if hasattr (attention , 'detach' ):
225+ attention = attention .detach ().cpu ().numpy ()
226+
227+ H , W = attention .shape
228+ new_H = H // grid_size * grid_size
229+ new_W = W // grid_size * grid_size
230+ attention = attention [:new_H , :new_W ]
231+
232+ blocks = attention .reshape (new_H // grid_size , grid_size , new_W // grid_size , grid_size )
233+ block_means = blocks .mean (axis = (1 , 3 ))
234+
235+ mask = np .triu (np .ones_like (block_means , dtype = bool ), k = 1 )
236+
237+ plt .figure (figsize = (10 , 10 ))
238+ sns .heatmap (block_means , mask = mask , cmap = 'viridis' , square = True , cbar = True )
239+
240+ ticks = np .arange (0 , block_means .shape [0 ], 1 )
241+ labels = ['' for i in ticks ]
242+ plt .xticks (ticks = ticks , labels = labels , rotation = 90 )
243+ plt .yticks (ticks = ticks , labels = labels )
244+
245+ plt .title ('Attention Map' )
246+ plt .tight_layout ()
247+ plt .savefig (save_path , bbox_inches = 'tight' )
248+ plt .close ()
249+
250+
251+ def visualize_attention_v2 (attention , grid_size = 24 , save_path = None ):
252+
253+ if hasattr (attention , 'detach' ):
254+ attention = attention .detach ().cpu ().numpy ()
255+
256+ # 分区
257+ block_ranges = []
258+
259+ # SYS: 2 blocks
260+
261+ sys_splits = [0 , 17 , 35 ]
262+ for i in range (len (sys_splits ) - 1 ):
263+ block_ranges .append ((sys_splits [i ], sys_splits [i + 1 ]))
264+ # IMG: 24 blocks of size 24
265+ for i in range (24 ):
266+ start = 35 + i * 24
267+ end = start + 24
268+ block_ranges .append ((start , end ))
269+
270+ # INS: 6 blocks
271+ ins_splits = [611 + i * 91 for i in range (7 )] # 611 + 6 * 91 = 1157 → crop to 1155
272+ ins_splits [- 1 ] = 1155
273+ for i in range (len (ins_splits ) - 1 ):
274+ block_ranges .append ((ins_splits [i ], ins_splits [i + 1 ]))
275+
276+ # 对每个 block pair 求平均
277+ num_blocks = len (block_ranges )
278+ block_attention = np .zeros ((num_blocks , num_blocks ))
279+ for i in range (num_blocks ):
280+ i_start , i_end = block_ranges [i ]
281+ for j in range (num_blocks ):
282+ j_start , j_end = block_ranges [j ]
283+ block = attention [i_start :i_end , j_start :j_end ]
284+ block_attention [31 - i , j ] = block .mean ()
285+
286+ mask = np .triu (np .ones_like (block_attention , dtype = bool ), k = 1 )
287+ plt .figure (figsize = (10 , 10 ))
288+ block_attention = block_attention / block_attention .max (axis = 1 , keepdims = True )
289+ sns .heatmap (block_attention , mask = mask , cmap = 'viridis' , square = True , cbar = True )
290+ # sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True)
291+
292+ section_labels = ['SYS' , 'IMG' , 'INS' ]
293+ section_boundaries = [2 , 26 , 32 ] # block_ranges 分别为2个SYS,24个IMG,6个INS
294+ ticks = np .arange (0 , num_blocks )
295+ plt .xticks (ticks = ticks , labels = ['' ] * num_blocks )
296+ plt .yticks (ticks = ticks , labels = ['' ] * num_blocks )
297+ plt .xticks (ticks = section_boundaries , labels = section_labels , fontsize = 12 )
298+ plt .yticks (ticks = section_boundaries , labels = section_labels , fontsize = 12 )
299+ plt .title ('Attention Map' )
300+ plt .tight_layout ()
301+ plt .savefig (save_path , bbox_inches = 'tight' )
302+ plt .close ()
303+
304+
305+ def visualize_cosin_token (token_embedding , save_path = None ):
306+
307+ plt .rcParams ['font.size' ] = 15
308+
309+ x = token_embedding [0 , 14 : 14 + 196 * 4 , :]
310+ x_norm = F .normalize (x , p = 2 , dim = 1 )
311+ similarity_matrix = x_norm @ x_norm .T
312+
313+ sim_np = similarity_matrix .cpu ().numpy ()
314+ sim_np = np .triu (sim_np , k = 1 )
315+ valid_sim = sim_np [sim_np > 0 ]
316+ vmin = np .percentile (valid_sim , 90 ) # 10% min
317+
318+ plt .subplots (figsize = (10 , 10 ))
319+ sns .heatmap (similarity_matrix .cpu ().numpy (), cmap = 'Reds' , vmin = vmin , vmax = 1 )
320+
321+ start = 0
322+ step = 196
323+ ticks = np .arange (start , 196 * 5 , step )
324+ plt .xticks (ticks , ticks )
325+ plt .yticks (ticks , ticks )
326+
327+ plt .title ('' )
328+ plt .xlabel ('' )
329+ plt .ylabel ('' )
330+ plt .tight_layout ()
331+ plt .savefig (save_path , format = 'pdf' )
332+ plt .rcdefaults ()
333+ plt .close ()
334+
335+
336+ def visualize_cosin_token_32p (token_embedding , save_path = None ):
337+
338+ plt .rcParams ['font.size' ] = 20
339+
340+ all_tokens = token_embedding [0 , 14 :14 + 196 * 32 , :]
341+ x_norm = F .normalize (all_tokens , p = 2 , dim = 1 )
342+ similarity_matrix = x_norm @ x_norm .T
343+ sim_np = similarity_matrix .cpu ().numpy ()
344+ sim_np = np .triu (sim_np , k = 1 )
345+ valid_sim = sim_np [sim_np > 0 ]
346+ vmin = np .percentile (valid_sim , 90 ) # 10% min
347+
348+ group_size = 4
349+ num_groups = 8
350+ tokens_per_group = 196 * group_size
351+ step = 196
352+
353+ fig , axs = plt .subplots (2 , 4 , figsize = (22 , 10 )) # 2x4排布
354+ axs = axs .flatten ()
355+
356+ for i in range (num_groups ):
357+ x = all_tokens [i * tokens_per_group : (i + 1 ) * tokens_per_group , :]
358+ x_norm = F .normalize (x , p = 2 , dim = 1 )
359+ similarity_matrix = x_norm @ x_norm .T
360+
361+ ax = axs [i ]
362+ sns .heatmap (
363+ similarity_matrix .cpu ().numpy (), cmap = 'Reds' ,
364+ vmin = vmin , vmax = 1 , ax = ax , cbar = False
365+ )
366+
367+ ticks = np .arange (0 , tokens_per_group , step )
368+ labels = np .arange (i * tokens_per_group , (i + 1 ) * tokens_per_group , step )
369+ ax .set_xticks (ticks )
370+ ax .set_yticks (ticks )
371+ ax .set_xticklabels (labels , rotation = 0 )
372+ ax .set_yticklabels (labels )
373+ start_frame = i * group_size
374+ end_frame = (i + 1 ) * group_size - 1
375+ ax .set_xlabel (f'Frame { start_frame } -{ end_frame } ' , fontsize = 17 , labelpad = 10 )
376+
377+ plt .tight_layout ()
378+ # plt.savefig(save_path, format='pdf')
379+ # plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight')
380+ plt .savefig (save_path , dpi = 300 )
381+ plt .rcdefaults ()
382+ plt .close ()
383+
384+
218385def highlight_grid (image , grid_indexes , grid_size = 14 ):
219386 if not isinstance (grid_size , tuple ):
220387 grid_size = (grid_size , grid_size )
0 commit comments