Given a batch of distributions, which represented as a pytorch tensor:
A = torch.tensor([[0., 0., 0., 0., 1., 5., 1., 2.],
[0., 0., 1., 0., 4., 2., 1., 1.],
[0., 0., 1., 1., 0., 5., 1., 1.],
[0., 1., 1., 0., 2., 3., 1., 1.],
[0., 0., 2., 1., 3., 1., 1., 0.],
[0., 0., 2., 0., 5., 0., 1., 0.],
[0., 2., 1., 4., 0., 0., 1., 0.],
[0., 0., 2., 4., 1., 0., 1., 0.]], device='cuda:0')
A
is a batch of distributions, which consists of eight distributions.
Now, given another batch of distributions B
:
B = torch.tensor([[0., 0., 1., 4., 2., 1., 1., 0.],
[0., 0., 0., 5., 1., 2., 1., 0.],
[0., 0., 0., 4., 2., 3., 0., 0.],
[0., 0., 1., 7., 0., 0., 1., 0.],
[0., 0., 1., 2., 4., 0., 1., 1.],
[0., 0., 1., 3., 1., 3., 0., 0.],
[0., 0., 1., 4., 1., 0., 2., 0.],
[1., 0., 1., 5., 0., 1., 0., 0.],
[0., 1., 5., 1., 0., 0., 1., 0.],
[0., 0., 3., 2., 2., 0., 1., 0.],
[0., 2., 4., 0., 1., 0., 1., 0.],
[1., 0., 4., 1., 1., 1., 0., 0.]], device='cuda:0')
B
has 12 distributions. I want to calculate the KL Divergence between each distribution in A
and each distribution in B
, and then obtain a KL Distance Matrix, of which shape is 12*8
. I know to use loop structure and torch.nn.functional.kl_div()
to reach it. Is there any other methods in pytorch to implement it without using for-loop?
Here is my implementation using for-loop:
p_1 = F.softmax(A, dim = -1)
p_2 = F.softmax(B, dim = -1)
C = torch.empty(size = (A.shape[0], B.shape[0]), dtype = torch.float)
for i,a in enumerate(p_1):
for j,b in enumerate(p_2):
C[i][j] = torch.nn.functional.kl_div(a.log(), b)
print(C)
Output is :
tensor([[0.4704, 0.5431, 0.3422, 0.6284, 0.3985, 0.2003, 0.4925, 0.5739, 0.5793,
0.3992, 0.5007, 0.4934],
[0.3416, 0.4518, 0.2950, 0.5263, 0.0218, 0.2254, 0.3786, 0.4747, 0.3626,
0.1823, 0.2960, 0.2937],
[0.3845, 0.4306, 0.2722, 0.5022, 0.4769, 0.1500, 0.3964, 0.4556, 0.4609,
0.3396, 0.4076, 0.3933],
[0.2862, 0.3752, 0.2116, 0.4520, 0.1307, 0.1116, 0.3102, 0.3990, 0.2869,
0.1464, 0.2164, 0.2225],
[0.1829, 0.2674, 0.1763, 0.3227, 0.0244, 0.1481, 0.2067, 0.2809, 0.1675,
0.0482, 0.1271, 0.1210],
[0.4359, 0.5615, 0.4427, 0.6268, 0.0325, 0.4160, 0.4749, 0.5774, 0.3492,
0.2093, 0.3015, 0.3014],
[0.0235, 0.0184, 0.0772, 0.0286, 0.3462, 0.1461, 0.0142, 0.0162, 0.3524,
0.1824, 0.2844, 0.2988],
[0.0097, 0.0171, 0.0680, 0.0284, 0.2517, 0.1374, 0.0082, 0.0148, 0.2403,
0.1058, 0.2100, 0.1978]], device='cuda:0')