本章将使用上一章简单介绍过的Dense全连接层来开始JAX模型的搭建,并实现一个多层感知机案例。
多层感知机(Multilayer Perceptron,MLP),是一种前馈人工神经网络模型,将输入的多个数据集映射到单一输出的数据集上。
多层感知机一般分为三层,
- 输入层
- 隐藏层
- 输出层。
当引入非线性的隐藏层后,理论上只要隐藏节点足够多,就可以拟合任意函数;同时,隐藏层越多,越容易拟合更复杂的函数。
隐藏层两个属性,
- 节点数
- 层数
层数越多,每一层需要的节点数越少。
全连接层——多层感知机的隐藏层
多层感知机的核心是隐藏层,隐藏层实际上就是一个全连接层。全连接层的每一个节点都与上一层的所有节点相连,用来把前面提取到的特征综合起来,所以全连接层的的参数也最多的。
注意,此处x₁,x₂,x₃看作参数(比如权重weight),w₁₁ ,w₁₂ ,w₁₃ … 则为输入参数,推导过程如下,
w₁₁ x x₁ + w₁₂ x x₂ + w₁₃ x x₃ = a₁
w₂₁ x x₁ + w₂₂ x x₂ + w₂₃ x x₃ = a₂
w₃₁ x x₁ + w₃₂ x x₂ + w₃₃ x x₃ = a₃
采用矩阵乘法公式,如下,
其中,@是矩阵相乘,即f(x) = x@w。下面使用Python实现简单矩阵计算,
按公式推导,手动计算如下,
1.1 x 3 + 1.8 x 2 + 0.4 = 7.3
1.2 x 3 + 1.7 x 2 + 0.4 = 7.4
结果是一个新矩阵[7.3, 7.4]。
使用Python代码计算如下,
import jax
def multiplify(matrix, weights, bias):
result = jax.numpy.matmul(matrix, weights) + bias
return result
if __name__ == "__main__":
matrix = jax.numpy.array([[1.1, 1.8], [1.2, 1.7]])
weights = jax.numpy.array([[3], [2]])
bias = 0.4
result = multiplify(matrix, weights, bias)
print(result)
打印输出如下所示,
[[7.3]
[7.4]]
结果是一个维度为2维、形状为(2, 1)的矩阵。
JAX实现简单全连接层
从上一节的计算过程可知,全连接层的本质就是有一个特征空间线性变换到另外一个特征空间。目标空间的任一维都会受到源空间的每一维的影响。目标向量是源向量的加权和。
全连接层一般是接在特征提取网络之后,用于对特征的分类器。全连接层常出现在最后几层,用于对前面提取的特征做加权和计算。
JAX实现全连接层代码如下,
import jax
def Dense(input_shape = (2, 1)):
key = jax.random.PRNGKey(10)
weights = jax.random.normal(key = key, shape = input_shape)
biases = jax.random.normal(key = key, shape = (input_shape[-1],))
params = [weights, biases]
def apply_function(inputs):
weights, biases = params
dotted = jax.numpy.dot(inputs, weights) + biases
return dotted
return apply_function
def test():
array = [[1.1, 1.8], [1.2, 1.7]]
inputs = jax.numpy.array(array)
dense = Dense()(inputs)
print(dense)
if __name__ == "__main__":
test()
全连接层Dense依次完成了函数、参数初始化,并使用默认的内置函数apply_function将对传入矩阵进行计算。打印输出如下所示,
[[-3.601719]
[-4.189919]]
更多功能的全连接函数
使用外部参数的全连接函数。上一小节Dense函数内置的apply_function中,实际上调用了随机函数生成参数。如果要使用外部参数而非Dense函数内部生成的参数,则可以改进如下,
import jax
def Dense(inputs_shape = (2, 1)):
key = jax.random.PRNGKey(10)
weights = jax.random.normal(key = key, shape = inputs_shape)
biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))
params = [weights, biases]
def init_params_function():
return params
def apply_function(inputs, params = params):
weights, biases = params
dotted = jax.numpy.dot(inputs, weights) + biases
return dotted
return init_params_function, apply_function
def test():
key = jax.random.PRNGKey(15)
inputs_shape = (2, 1)
weights = jax.random.normal(key = key, shape = inputs_shape)
biases = jax.random.normal(key = key, shape = (inputs_shape[-1],))
params = [weights, biases]
array = [[1.1, 1.8], [1.2, 1.7]]
inputs = jax.numpy.array(array)
init_params_function, apply_function = Dense()
dense = apply_function(inputs, params)
print(dense)
if __name__ == "__main__":
test()
这里使用了外部参数,而不是Dense内部生成的参数。打印输出如下所示,
[[1.5110686 ]
[0.74590844]]
返回参数的全连接函数
前面学习了全连接函数和使用外部参数的全连接函数的方法,但有时候需要把生成的参数返回。代码如下,
import jax
def Dense(inputs_shape = (2, 1)):
def init_function(shape = inputs_shape):
key = jax.random.PRNGKey(10)
weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))
return (weights, biases)
def apply_function(inputs, params):
weights, biases = params
dotted = jax.numpy.dot(inputs, weights) + biases
return dotted
return init_function, apply_function
def test():
init_function, apply_function = Dense()
init_params = init_function()
array = [[1.1, 1.8], [1.2, 1.7]]
inputs = jax.numpy.array(array)
result = apply_function(inputs, init_params)
print(f"init_params = {init_params}, result = {result}")
if __name__ == "__main__":
test()
打印输出如下所示,
init_params = (Array([[-0.62187684],
[-1.2754321 ]], dtype=float32), Array([-1.3445405], dtype=float32)), result = [[-4.324383 ]
[-4.2590275]]
不同种子生成不同随机参数的全连接函数
import jax
def Dense(input_shape = (2, 1), seed = 10):
def init_function(shape = input_shape):
key = jax.random.PRNGKey(seed)
weights, biases = jax.random.normal(key = key, shape = shape), jax.random.normal(key = key, shape = (shape[-1],))
return (weights, biases)
def apply_function(inputs, params):
weights, biases = params
dotted = jax.numpy.dot(inputs, weights) + biases
return dotted
return init_function, apply_function
def test():
array = [[1.1, 1.8], [1.2, 1.7]]
inputs = jax.numpy.array(array)
init_function, apply_function = Dense(seed = 10)
init_params = init_function()
dense = apply_function(inputs, init_params)
print(f"dense1 = {dense}")
print("----------------------------------------")
init_function, apply_function = Dense(seed = 20)
init_params = init_function()
dense = apply_function(inputs, init_params)
print(f"dense2 = {dense}")
if __name__ == "__main__":
test()
打印输出如下所示,
dense1 = [[-4.324383 ]
[-4.2590275]]
----------------------------------------
dense2 = [[2.524846]
[2.44795 ]]