Relative Positional Bias — [Swin-transformer]

2023-09-25 6 0

论文中对于这一块的描述不是很清楚,特意记录一下学习过程。
这篇博客讲解的很清楚,请参考阅读https://blog.csdn.net/qq_37541097/article/details/121119988
在这里插入图片描述

以下通过代码形式运行一个demo帮助理解。

1.假设window的H,W均为2,首先构造一个二维坐标
x= torch.arange(2)
y= torch.arange(2)#输入为一维序列,输出两个二维网格,常用来生成坐标
ox,oy = torch.meshgrid([x,y])#按照某个维度拼接,输入序列shape必须一致,默认按照dim0
o2 = torch.stack((ox,oy))print(ox,oy)
print(o2,o2.shape)coords = torch.flatten(o2,1)
print(coords,coords.shape)

输出

tensor([[0, 0],[1, 1]]) tensor([[0, 1],[0, 1]])
tensor([[[0, 0],[1, 1]],[[0, 1],[0, 1]]]) torch.Size([2, 2, 2])
#得到2行序列,对应x,y轴的坐标    
tensor([[0, 0, 1, 1],[0, 1, 0, 1]]) torch.Size([2, 4])
计算相对坐标索引时,采用了一种我之前没见过的扩张维度的方法,简介高效
print(coords[:,:,None].shape) #相当于增加一个维度
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
#作用与unsqueeze()相同
coords.unsqueeze(1)==coords[:,None,:]

输出

torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],[[0, 1, 0, 1]]]) torch.Size([2, 1, 4])torch.Size([2, 1, 4, 1])tensor([[[True, True, True, True]],[[True, True, True, True]]])
print(coords[:,:,None]) #相当于增加一个维度
print(coords[:,None,:])

输出

tensor([[[0],[0],[1],[1]],[[0],[1],[0],[1]]])
tensor([[[0, 0, 1, 1]],[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],[[True, True, True, True]]])
2.计算相对索引
relative_coords=coords[:,:,None]-coords[:,None,:]  #(2,16,1)-(2,1,16)  #广播机制相减
print(f"relative_coords:{relative_coords.shape}={coords[:,:,None].shape}-{coords[:,None,:].shape }","\n",{relative_coords})

输出

#这里相减,应该是使用了广播机制,先扩展到相同shape后,再进行元素相减运算
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4]) {tensor([[[ 0,  0, -1, -1],[ 0,  0, -1, -1],[ 1,  1,  0,  0],[ 1,  1,  0,  0]],[[ 0, -1,  0, -1],[ 1,  0,  1,  0],[ 0, -1,  0, -1],[ 1,  0,  1,  0]]])}

转换为[4,4,2],相当于得到4个4*2的坐标对,一行横坐标,一行纵坐标

relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)

输出

torch.Size([4, 4, 2])
tensor([[[ 0,  0],[ 0, -1],[-1,  0],[-1, -1]],[[ 0,  1],[ 0,  0],[-1,  1],[-1,  0]],[[ 1,  0],[ 1, -1],[ 0,  0],[ 0, -1]],[[ 1,  1],[ 1,  0],[ 0,  1],[ 0,  0]]])
print(relative_coords[:,:,0])  #输出第一列元素对应输入中第一列的第1个元素集合 ,第二列对应输入第一列的第2个元素集合
print(relative_coords[:,:,1])

输出

tensor([[ 0,  0, -1, -1],[ 0,  0, -1, -1],[ 1,  1,  0,  0],[ 1,  1,  0,  0]])
tensor([[ 0, -1,  0, -1],[ 1,  0,  1,  0],[ 0, -1,  0, -1],[ 1,  0,  1,  0]])
window_size=(2,2)#行、列元素都加上M-1 ,这里M=2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

输出

#第一列(行)加M-1
tensor([[[ 1,  0],[ 1, -1],[ 0,  0],[ 0, -1]],[[ 1,  1],[ 1,  0],[ 0,  1],[ 0,  0]],[[ 2,  0],[ 2, -1],[ 1,  0],[ 1, -1]],[[ 2,  1],[ 2,  0],[ 1,  1],[ 1,  0]]])
# 继续第2列 (列) 加M-1
tensor([[[1, 1],[1, 0],[0, 1],[0, 0]],[[1, 2],[1, 1],[0, 2],[0, 1]],[[2, 1],[2, 0],[1, 1],[1, 0]],[[2, 2],[2, 1],[1, 2],[1, 1]]])
#第一列 (行) 乘 2M-1(3)
tensor([[[3, 1],[3, 0],[0, 1],[0, 0]],[[3, 2],[3, 1],[0, 2],[0, 1]],[[6, 1],[6, 0],[3, 1],[3, 0]],[[6, 2],[6, 1],[3, 2],[3, 1]]])
#行列元素相加
tensor([[4, 3, 1, 0],[5, 4, 2, 1],[7, 6, 4, 3],[8, 7, 5, 4]])

这里就得到相对位置索引,这里对应的值需要到relative positional bias Table 中获取,一开始程序中就定一个了一个可学习的table,长度为[2M-1]*[2M-1], 这里M=2,也就是长度为9,正对应上边索引0-8

# define a parameter table of relative position bias#构造可学习的相对位置偏置table,长度为 (2H-1)*(2W-1)*(num_head)  self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

这里假设有两个attention头

from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2))  # 2*Wh-1 * 2*Ww-1, nH 假设有两个attn头
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) #初始化bias_table

输出

torch.Size([9, 2])  #两个attn头,每个头(2M-1)*(2M-1)个数Parameter containing:
tensor([[0., 0.],[0., 0.],[0., 0.],[0., 0.],[0., 0.],[0., 0.],[0., 0.],[0., 0.],[0., 0.]], requires_grad=True)
Parameter containing:  #初始化后的数据
tensor([[-0.0340,  0.0181],[-0.0033, -0.0055],[ 0.0045,  0.0193],[ 0.0412, -0.0031],[ 0.0004, -0.0032],[ 0.0201, -0.0161],[ 0.0067,  0.0079],[ 0.0241, -0.0279],[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table 根据索引取值后的数据:\n",relative_position_bias.shape,"\n",relative_position_bias)relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
print("维度变换:\n",relative_position_bias.shape,"\n",relative_position_bias) #转换为与attention shape一致
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 
index :torch.Size([16]) tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4])  #索引展开成一维
bias table 根据索引取值后的数据:torch.Size([16, 2]) tensor([[ 0.0004, -0.0032],[ 0.0412, -0.0031],[-0.0033, -0.0055],[-0.0340,  0.0181],[ 0.0201, -0.0161],[ 0.0004, -0.0032],[ 0.0045,  0.0193],[-0.0033, -0.0055],[ 0.0241, -0.0279],[ 0.0067,  0.0079],[ 0.0004, -0.0032],[ 0.0412, -0.0031],[-0.0125, -0.0291],[ 0.0241, -0.0279],[ 0.0201, -0.0161],[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
维度变换:torch.Size([4, 4, 2]) tensor([[[ 0.0004, -0.0032],[ 0.0412, -0.0031],[-0.0033, -0.0055],[-0.0340,  0.0181]],[[ 0.0201, -0.0161],[ 0.0004, -0.0032],[ 0.0045,  0.0193],[-0.0033, -0.0055]],[[ 0.0241, -0.0279],[ 0.0067,  0.0079],[ 0.0004, -0.0032],[ 0.0412, -0.0031]],[[-0.0125, -0.0291],[ 0.0241, -0.0279],[ 0.0201, -0.0161],[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

在这里插入图片描述

以上代码就是有关相对位置偏置的全部内容了。

代码编程
赞赏

相关文章

数据结构与算法 二维迷宫问题
数据结构与算法 约瑟夫环问题 Josephus问题
数据结构与算法 哈希表的特点
数据结构与算法 Farmer John 问题 农夫锯木板问题
数据结构与算法 前缀、中缀、后缀表达式求值和相互转换
链表(图文详解)