IRM(Invariant Risk Minimization)原理与最小实现

2024-07-06

Invariant Risk Minimization原理与最小实现

1、Invariant Risk Minimization原理1.1提出问题1.2提出模型

2、IRM最小实现参考文献

IRM(Invariant Risk Minimization)是2019年Martin Arjovsky等人提出的一种用于跨域图像分类的新方法,其提出的背景是当我们使用机器学习方法完成图片分类任务时,训练模型所使用的数据集与真实情况的数据集可能存在差别(数据集分布偏移),造成这种分布偏移的原因有很多,比如:数据选择的偏差(单一环境)、混淆因素等,该问题被称为跨域分类问题(注:跨域分类可能在其他的地方有其他的意思),目前大部分解决的方法是减小跨域分布偏差或者提取不变特征。而Martin提出的方法与之前很多跨域分类方法不同之处在于:为了提高机器学习的可解释性,并从根本上解决跨域分类问题,Martin考虑从数学方面推导出特征与标签预测的内在因果关系,即特征与标签之前存在与域无关的内在因果关系。

1、Invariant Risk Minimization原理

1.1提出问题

首先作者提出了一个问题,假设有一个SEM模型: 如上式所示,

X

1

X_1

X1​是一组服从正态分布的数据,

Y

Y

Y是由

X

1

X_1

X1​加上一个服从正态分布的白噪声构成,

X

2

X_2

X2​是

Y

Y

Y加上一个服从正态分布的白噪声构成。

当使用最小二乘方法由

(

X

1

,

X

2

)

(X_1,X_2)

(X1​,X2​)对

Y

Y

Y进行预测时,设其预测模型为:

Y

^

e

=

X

1

e

α

1

^

+

X

2

e

α

2

^

\hat{Y}^e=X_1^e\hat{\alpha_1}+X_2^e\hat{\alpha_2}

Y^e=X1e​α1​^​+X2e​α2​^​,因此若对

X

1

X_1

X1​与

Y

Y

Y的噪声乘以一个与环境有关的系数,那么当使用

X

1

X_1

X1​与

X

2

X_2

X2​预测

Y

Y

Y时,其根据算法是否能够识别出不变特征,回归系数会出现以下三种情况,因此作者的目标是得到第一种情况。

1.2提出模型

根据所总结的问题,作者做出如下定义,将模型分为两个部分,即数据表示

Φ

\Phi

Φ与分类器

ω

^

\hat{\omega}

ω^。 将定义转化为数学模型得IRM表达式, 但是由于上式是一个两层优化问题,因此将上式简化为单变量优化问题, 其中

Φ

\Phi

Φ成为不变预测器,其由两项组成,即经验风险最小项和不变风险最小项,而

λ

\lambda

λ作为平衡两项的一个超参数;由IRM到IRMv1的转变过程,作者还考虑了其他的因素,详细推导可看其论文第三章。

最终作者根据所提出的模型得到训练的损失函数表达式:

2、IRM最小实现

参照论文附录的基于Pytorch的IRM最小实现

import torch

from torch.autograd import grad

import numpy as np

import torchvision

def compute_penalty(losses, dummy_w):

# print(np.shape(losses[0::2]))

# print(dummy_w)

g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]

g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]

# print(g1*g2)

return (g1*g2).sum()

def example_1(n=10000, d=2, env=1):

x = torch.randn(n, d)*env

y = x + torch.randn(n, d)*env

z = y + torch.randn(n, d)

# z = y

# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])

return torch.cat((x, z), 1), y.sum(1, keepdim=True)

phi = torch.nn.Parameter(torch.ones(4, 1))

# print(phi)

dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))

# print(dummy_w)

opt = torch.optim.SGD([phi], lr=1e-3)

mse = torch.nn.MSELoss(reduction="none")

environments = [example_1(env=0.1), example_1(env=1.0)]

# s = [[1, 2], [3, 4]]

# for i, j in s:

# print(i)

# print(j)

for iteration in range(50000):

error = 0

penalty = 0

for x_e, y_e in environments:

# print(np.shape(x_e))

# print(np.shape(y_e))

p = torch.randperm(len(x_e))

error_e = mse(x_e[p]@phi*dummy_w, y_e[p])

# error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])

# print(np.shape(error_e))

penalty += compute_penalty(error_e, dummy_w)

error += error_e.mean()

# print(iteration)

# print(error_e.mean())

# print(error)

opt.zero_grad()

(1e-5 * error + penalty).backward()

opt.step()

if iteration % 1000 == 0:

print(phi)

参考文献

Arjovsky, M., et al. (2019). “Invariant Risk Minimization.”