メモの目的
- 以下の問いへの回答を書いておく
- $D_{\text{KL}}(P||Q)$ の、PとQはどっちが予測分布でどっちが正解分布か
- torch.nn.functional.kl_div の引数には何を渡せばよいのか
一般的な式
\[D_{KL}(P || Q) = \sum_{i} P(i) \log\frac{P(i)}{Q(i)}\]a measure of how much an approximating probability distribution Q is different from a true probability distribution P[Wikipedia (2025-12-15)]- (他の参考書を参照したほうが良いのだが、まぁ。)
- つまり、$P$ を真の分布 (正解分布)、$Q$ を予測分布とする。
torch.nn.functional.kl_div の引数
- 結論
inputは予測分布の対数を渡すtargetは正解分布を渡す
- 公式ドキュメント より (2025-12-15, v2.9.1):
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_target=False)input: Tensor of arbitrary shape in log-probabilitiestarget: Tensor of the same shape as input.size_average: Deprecatedreduce: Deprecatedreduction: Specifies the reduction to apply to the output. Default: ‘mean’'none': no reduction will be applied'batchmean': the sum of the output will be divided by batchsize'sum': the output will be summed'mean': the output will be divided by the number of elements in the output
上の結論に至った検証コード
import torch
p = torch.tensor([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
])
q = torch.tensor([
[0.5, 0.4, 0.1],
[0.1, 0.3, 0.6],
])
no_reduction = torch.tensor([
[-torch.log(torch.tensor(0.5)), 0.0, 0.0],
[0.0, -torch.log(torch.tensor(0.3)), 0.0],
])
引数の確認とbatchmeanの動作確認
def test_kl_div(p, q):
kl_div = torch.nn.functional.kl_div(input = q.log(), target=p, reduction='batchmean')
should_equal = no_reduction.sum(dim=1).mean(dim=0)
assert torch.allclose(kl_div, should_equal), f"{kl_div} != {should_equal}"
test_kl_div(p, q)
reduction='none' の動作確認
def test_kl_div_none(p, q):
kl_div = torch.nn.functional.kl_div(input = q.log(), target=p, reduction='none')
assert torch.allclose(kl_div, no_reduction), f"{kl_div} != {no_reduction}"
test_kl_div_none(p, q)
reduction='sum' の動作確認
def test_kl_div_sum(p, q):
kl_div = torch.nn.functional.kl_div(input = q.log(), target=p, reduction='sum')
should_equal = no_reduction.sum(dim=1).sum(dim=0)
assert torch.allclose(kl_div, should_equal), f"{kl_div} != {should_equal}"
test_kl_div_sum(p, q)
reduction='mean'
WARNING: doesn’t return the true kl divergence value, please use reduction = ‘batchmean’ which aligns with KL math definition.
def test_kl_div_mean(p, q):
kl_div = torch.nn.functional.kl_div(input = q.log(), target=p, reduction='mean')
should_equal = no_reduction.mean(dim=1).mean(dim=0)
assert torch.allclose(kl_div, should_equal), f"{kl_div} != {should_equal}"
test_kl_div_mean(p, q)