NIPS 2017
1. 网络结构
一个九层的DLN(Deep Lattice Networks)网络结构如下所示:
![[Pasted image 20231207230047.png|600]]
主要由 calibrators、linear embedding、calibrators、ensemble of lattices、calibrators、ensemble of lattices、calibrators、lattice、calibrator 组成。其中比较重要的层有 Linear Embedding Layer、Calibration Layer、Ensemble of Lattice Layer。
1.1 Linear Embedding Layer
1.1.1 结构
W_t^m[i,j]\geq 0 \ for\ all\ (i,j).
- 该网络的输出为:
x_{t+1}=[\begin{matrix}x_{t+1}^m \\ x_{t+1}^n \end{matrix}]=[\begin{matrix}W_t^m x_t^m \\ W_t^n x_t^n \end{matrix}]
1.1.2 单调性
只需要 W 大于零,则可以保证 m 部分为单调的。
1.2 Calibration Layer
1.2.1 结构
x_{t+1}:=[c_{t,1}(x_t[1])\ c_{t,2}(x_t[2]) \ ...\ c_{t,D_t}(x_t[D_t])]^T.
- 其中 c (x[d]; a, b) 为如下形式:
c (x[d]; a, b))=\sum_{k=1}^K \alpha[k]|ReLU(x-a[k])+b[1] - 其中的 \alpha[k] 表达为:
\alpha[k]:=\begin{cases} \frac{b[k+1]-b[k]}{a[k+1]-a[k]} - \frac{b[k]-b[k-1]}{a[k]-a[k-1]} & for\ k=2,...,K-1 \\ \frac{b[2]-b[1]}{a[2]-a[1]} & for\ k=1 \\ -\frac{b[K]-b[K-1]}{a[K]-a[K-1]} & for\ k=K \end{cases} - 为了保证单调,限制 b[k] 保持单调,添加约束:
b[k]\leq b[k+1]\ for\ k=1,...,K-1
1.2.1 单调性
a 与 b 分别是单调折线的横轴与纵轴,一共存在 K 折。当满足 a[1]=a[2]=…=a[K]时,可轻易证得单调性。令 x-a[k]=x_k,b[k+1]-b[k]=b_k,a[k+1]-a[k]=a_k,则有:
\begin{align}& \sum_{k=1}^K \alpha[k]|ReLU(x-a[k]) \\& =\frac{relu(x_1) b_1}{a_1} + \frac{relu(x_2) (b_2-b_1)}{a_1}+...+ \frac{relu(x_{k-1}) (b_{k-1}-b_{k-2})}{a_{1}} - \frac{relu(x_{k}) b_{k-1}}{a_{1}} \\&=\frac{[relu(x_2)-relu(x_1)]b_1+...-[relu(x_{k-1})-relu(x_k)]b_{k-1}}{a_1} \end{align}
因为 a_1 \geq relu (x_2)-relu (x_1)\geq 0 ,且 \sum_{k=1}^{K}b_k=0 ,因此得到上式大于 0 小于 1。又因 x 越大,relu(x_k)=relu(x-a[k]) 项大于 0 的越多,从而保证单调性。
如果 a[1]=a[2]=...=a[K] 条件不满足,可得到下式:
\begin{align}& \sum_{k=1}^K \alpha[k]|ReLU(x-a[k]) \\& =relu(x_1)\frac{b_1}{a_1} + relu(x_2)(\frac{b_2}{a_2}-\frac{b_1}{a_1}) + ... + relu(x_{k-1})(\frac{b_k}{a_k}-\frac{b_{k-1}}{a_{k-1}}) - relu(x_k)\frac{b_{k-1}}{a_{k-1}} \end{align}
relu 函数在大于 0 时为 x,小于 0 时为 0,假设 a[m-1] < x < a[m],得到:
\begin{align}& \sum_{k=1}^K \alpha[k]\ ReLU(x-a[k]) \\& =relu(x_1)\frac{b_1}{a_1} + relu(x_2)(\frac{b_2}{a_2}-\frac{b_1}{a_1}) + ... + relu(x_{m-1})(\frac{b_{m-1}}{a_{m-1}}-\frac{b_{m-2}}{a_{m-2}}) \\&=(x_1-x_2)\frac{b_1}{a_1} + ...+(x_{m-2}-x_{m-1})\frac{b_{m-2}}{a_{m-2}} +x_{m-1}\frac{b_{m-1}}{a_{m-1}} \end{align}
考虑到 x_1-x_2=x-a[1]-(x-a[2])=a[2]-a[1]=a_1,化简上式得到:
\begin{align}& \sum_{k=1}^K \alpha[k]\ ReLU(x-a[k]) \\& =b_1+b_2+...+b_{m-2}+x_{m-1}\frac{b_{m-1}}{a_{m-1}} \geq 0 \end{align}
并且 x 越大,m 越大,可知该函数单调。
1.3 Ensemble of Lattice Layer
1.3.1 结构
Monotonic Calibrated Interpolated Look-Up Tables
https://jmlr.org/papers/volume17/15-243/15-243.pdf
每一个 Ensemble of Lattice Layer 都包涵 G 个 lattices。每一个 Lattice 都是一个线性插值的多维 Lookup 表,作者采用多维非线性平滑插值方法进行实验,表示为 \phi(x)^T\theta ,其中 \phi(x) 是这样的变换:\phi(x):[0,1]^S->[0,1]^{2^S} 。如果将这个函数里的每一个元素表示出来,是这样的:
\phi(x)[j]=\prod_{d=0}^{S-1}x[d]^{v_j[d]}(1-x[d])^{1-v_j[d]}
其中 v_j[d] 是单位超立方体第 j 个顶点的坐标向量。来一个示例就清楚了,当 S=2 时,有 v_1=(0,0),v_2=(0,1),v_3=(1,0),v_4=(1,1) ,此时 \phi(x) 如下所示:
\phi(x)=((1-x[0])(1-x[1]),\ (1-x[0])x[1],\ (x[0](1-x[1])),\ (x[0]x[1]))
因为 S=2 ,所以其实是一个正方形,如图中间正方体所示,其坐标与示例相同:
![[Pasted image 20231208230142.png|600]]
每一个 Ensemble Layer 都会产生 M 个输出,创建 DLN 时,如果 t+1 个曾是一个 Ensemble Layer,将随机排列前一个 Layer 的输出,指派到 G_{t+1}S_{t+1} 的输入。如果一个 Lattice 存在一个单调的输入,那么这个 Lattice 的输出对于下一个 Layer 来说,也是限制为单调的。基于这样的方式,作者构建了整个 DLN 网络。
此时 \theta 为 (\theta_1,\ \theta_2,\ \theta_3,\ \theta_4) ,输出 size 为 1。
1.3.2 单调性
显然,\phi(x)[j] 可以写为下面这种形式,其中 bit(d,j)\in \{0, 1\}:
\phi_j(x)=\prod_{d=0}^{S-1}((1-bit[d,j](1-x[d])+bit(d,j)x[d])
令 f(x)=\theta^T\phi(x),其中 x\in [0,1]^D。
单调性约束如下:
- 当对于任意 k,\ k' 满足 v_k[d]=0,v_{k'}[d]=1 不相同,且 \theta_{k'}>\theta_k,则对于给定 d 有 \frac{\partial f(x)}{\partial x[d]}>0。
证明:
\begin{align}
f(x) &= \sum_{k,k'}\theta_k \phi_k(x)+\theta_{k'}\phi_{k'}(x) \\
&= \sum_{k,k'} \alpha_k (\theta_kx[d]^{v_k[d]}(1-x[d])^{1-v_k[d]} + \theta_{k'}x[d]^{v_{k'}[d]}(1-x[d])^{1-v_{k'}[d]}) \\
&= \sum_{k,k'} \alpha_k (\theta_k(1-x[d])+\theta_{k'}x[d])
\end{align}
对 x[d]求偏导得到:
\frac{\partial f(x)}{\partial x[d]}=\sum_{k,k'}\alpha_k(\theta_{k'}-\theta_k)
因为 \alpha_k \in [0,1],因此需要保证 \theta_{k'} > \theta_k ,才能保证对于 x[d] 而言单调。
1.4 超参设定
每一个模块的输入和输出应该保持一致,如果是 ensemble 模块,则 G_t \times S_t 应该和输入保持一致。
评论区