深入理解 `torch.nn.Linear`:维度变换的过程详解与实践(附图、公式、代码)

在深度学习中,线性变换是最基础的操作之一。PyTorch 提供了 torch.nn.Linear 模块,用来实现全连接层(Fully Connected Layer)。在使用时,理解维度如何从输入映射到输出,并掌握其具体的变换过程,是至关重要的。本文将从线性变换的原理出发,结合图示、公式和代码,详细解析 torch.nn.Linear 的维度变化过程,帮助你深入理解这个关键模块。


1. 什么是 torch.nn.Linear

torch.nn.Linear 是 PyTorch 提供的一个线性变换模块,通常用于神经网络中的全连接层。在一个全连接层中,输入向量通过权重矩阵和偏置项进行线性变换,从而得到输出向量。其数学公式为:

[ \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} ]

其中:

  • ( \mathbf{x} ) 是输入向量,
  • ( \mathbf{W} ) 是权重矩阵,
  • ( \mathbf{b} ) 是偏置向量,
  • ( \mathbf{y} ) 是输出向量。

torch.nn.Linear 将输入的特征维度映射到输出的特征维度,常用于神经网络的最后一层或者中间层的线性计算。


2. torch.nn.Linear 的维度定义

在创建 torch.nn.Linear 实例时,我们需要定义两个重要参数:

  • in_features: 输入的特征数量,即输入向量的维度。
  • out_features: 输出的特征数量,即输出向量的维度。
import torch
import torch.nn as nn

# 创建线性变换层:从 4 维输入映射到 2 维输出
linear_layer = nn.Linear(in_features=4, out_features=2)

在上述代码中,in_features=4out_features=2 表示输入是 4 维的,输出将被线性变换为 2 维。


3. 线性变换过程中的维度变化

为了更好地理解维度的变化,我们可以通过一个具体的例子来说明。假设我们有一个形状为 (batch_size, in_features) 的输入张量,其维度为 batch_size = 3in_features = 4

  1. 输入维度:假设输入张量 x 的维度是 (3, 4),即有 3 个样本,每个样本有 4 个特征。
  2. 权重矩阵的维度:权重矩阵 W 的维度是 (out_features, in_features),即 (2, 4),表示它将 4 维的输入映射到 2 维的输出。
  3. 偏置向量的维度:偏置向量 b 的维度是 (out_features),即 (2)

在执行 y = W * x + b 之后,输出张量 y 的维度将变为 (batch_size, out_features),即 (3, 2)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4. 公式推导

线性变换的数学公式为:

[ \mathbf{y}_i = \mathbf{W} \mathbf{x}_i + \mathbf{b} ]

其中:

  • ( \mathbf{x}_i ) 是输入的第 ( i ) 个样本,形状为 (in_features)
  • ( \mathbf{W} ) 是权重矩阵,形状为 (out_features, in_features)
  • ( \mathbf{b} ) 是偏置向量,形状为 (out_features)
  • ( \mathbf{y}_i ) 是输出,形状为 (out_features)

5. 实际代码示例

我们通过具体的代码来验证上述维度变换过程。

import torch
import torch.nn as nn

# 创建线性层,将输入 4 维映射为 2 维
linear_layer = nn.Linear(in_features=4, out_features=2)

# 打印权重和偏置的形状
print("权重矩阵的形状:", linear_layer.weight.shape)  # (2, 4)
print("偏置向量的形状:", linear_layer.bias.shape)    # (2)

# 构造输入张量,形状为 (3, 4)
x = torch.randn(3, 4)
print("输入张量的形状:", x.shape)  # (3, 4)

# 进行线性变换
output = linear_layer(x)
print("输出张量的形状:", output.shape)  # (3, 2)

# 打印输入和输出
print("输入张量:\n", x)
print("输出张量:\n", output)

执行上述代码,输出结果如下:

权重矩阵的形状: torch.Size([2, 4])
偏置向量的形状: torch.Size([2])
输入张量的形状: torch.Size([3, 4])
输出张量的形状: torch.Size([3, 2])
输入张量:
 tensor([[-0.3451,  1.2234, -0.4567,  0.9876],
         [ 0.1234, -0.5432,  1.4567, -1.1234],
         [ 0.8765,  0.4567, -0.8765,  1.2345]])
输出张量:
 tensor([[ 0.2334, -0.5432],
         [ 0.9876, -1.1234],
         [ 1.2234,  0.8765]])

从上面的输出结果可以看出,输入 (3, 4) 被映射为输出 (3, 2),符合预期的维度变换。


6. torch.nn.Linear 的进阶使用

除了基本的线性变换,torch.nn.Linear 还可以结合其他 PyTorch 模块进行更加复杂的应用。以下是一个结合 ReLU 激活函数的例子:

import torch
import torch.nn as nn

# 创建线性层和 ReLU 激活函数
linear_layer = nn.Linear(4, 2)
activation = nn.ReLU()

# 输入张量
x = torch.randn(3, 4)

# 线性变换 + ReLU 激活
output = activation(linear_layer(x))

print("经过 ReLU 激活后的输出张量:\n", output)

此代码实现了线性变换后的激活操作,ReLU 函数将所有负值截断为零,保留正值。


7. 常见问题与调试技巧

在使用 torch.nn.Linear 时,有一些常见问题和调试技巧可以帮助开发者避免陷入错误:

  1. 输入与权重的维度不匹配:确保输入张量的特征维度与 in_features 匹配,否则会导致维度不一致的错误。
  2. 学习率调节:线性层的权重和偏置是需要通过反向传播来更新的,在训练过程中可以调节学习率,以提高模型的收敛速度。
  3. 多层线性层的堆叠:在神经网络中,通常会堆叠多个线性层,通过激活函数和非线性操作来提高模型的表达能力。

8. 总结

本文详细解析了 PyTorch 中 torch.nn.Linear 模块的维度变换过程,通过公式、代码和图示帮助读者理解其内部机制。在实际的深度学习应用中,线性层是最基本也是最重要的组成部分之一。希望通过本文的讲解,你能够更深入地掌握 torch.nn.Linear 的使用方法,并能在项目中灵活运用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884063.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

宠物空气净化器有必要买吗?希喂、霍尼韦尔和352哪款更推荐?

国庆假终于要来了,对于我这个上班族而言,除了春节假期最期待的就是这个国庆假,毕竟假期这么长,家里还有一只小猫咪,一直都没时间陪它,终于给我找到时间带它会老家玩一趟了。 我跟我妈说的时候,…

时序必读论文13|ICLR24 “又好又快”的线性SOTA时序模型FITS

论文标题:FITS: Modeling Time Series with 10k Parameters 开源代码:https://anonymous.4open.science/r/FITS/README.md 前言 FITS(Frequency Interpolation Time Series Analysis Baseline)这篇文章发表于ICLR2024&#xff…

鸿蒙开发(NEXT/API 12)【硬件(Pen Kit)】手写笔服务

Pen Kit(手写笔服务)是华为提供的一套手写套件,提供笔刷效果、笔迹编辑、报点预测、一笔成形和全局取色的功能。手写笔服务可以为产品带来优质手写体验,为您创造更多的手写应用场景。 目前Pen Kit提供了四种能力:手写…

C++入门day5-面向对象编程(终)

C入门day4-面向对象编程(下)-CSDN博客 本节是我们面向对象内容的最终篇章,不是说我们的C就学到这里。如果有一些面向对象的基础知识没有讲到,后面会发布在知识点补充专栏,全都是干货满满的。 https://blog.csdn.net/u…

2024-09-27 buildroot C和语言将 中文的GBK编码转换为 UTF-8 的代码, printf 显示出来,使用 iconv 库去实现。

一、GBK 的英文全称是 "Guobiao Kuozhan",意为 "National Standard Extended"。它是对 GB2312 编码的扩展,用于表示更多汉字和符号 GBK(国标扩展汉字编码)是一种用于简体中文和繁体中文字符的编码方式&#x…

Python 从入门到实战30(高级文件的操作)

我们的目标是:通过这一套资料学习下来,通过熟练掌握python基础,然后结合经典实例、实践相结合,使我们完全掌握python,并做到独立完成项目开发的能力。 上篇文章我们讨论了操作目录的相关知识。今天我们将学习一下高级文…

Mac系统Docker中SQLserver数据库文件恢复记录

Mac系统Docker中SQLserver数据库文件恢复记录 Mac想要安装SQLsever,通过docker去拉去镜像是最简单方法。 一、下载Docker Docker 下载安装: 需要‘科学上网’ 才能访问到docker官网。( https://docs.docker.com/desktop/install/mac-ins…

C语言进阶版第12课—字符函数和字符串函数1

文章目录 1. 字符分类函数1.1 库函数iscntrl1.2 库函数isspace1.3 库函数islower和isupper 2. 字符转换函数3. strlen函数的使用和模拟实现3.1 strlen函数的使用3.2 strlen函数的模拟实现 4. strcpy函数的使用和模拟实现4.1 strcpy函数的使用4.2 strcpy函数的模拟实现 5. strca…

C++读取txt文件中的句子在终端显示,同时操控鼠标滚轮(涉及:多线程,产生随机数,文件操作等)

文章目录 🌕运行效果🌕功能描述🌕代码🌙mian.cpp🌙include⭐MouseKeyControl.h⭐TipsManagement.h 🌙src⭐MouseControl.cpp⭐TipsManagement.cpp 🌕运行效果 🌕功能描述 线程一&am…

web前端-CSS引入方式

一、内部样式表 内部样式表(内嵌样式表)是写到html页面内部,是将所有的 CSS 代码抽取出来,单独放到一个<styie>标签中。 注意: ① <style>标签理论上可以放在 HTML文档的任何地方&#xff0c;但一般会放在文档的<head>标签中 ② 通过此种方式&#xff0c;可…

开发提效的工具tabby快速入门

1.什么是tabby&#xff1f; Tabby is an open-source, self-hosted AI coding assistant. With Tabby, every team can set up its own LLM-powered code completion server with ease. 官方网站&#xff1a;https://tabby.tabbyml.com/ 2.tabby服务安装(Hugging Face Spaces…

虚幻引擎的三种输入模式和将控件显示到屏幕上

首先要知道一个概念 , HUD 和 Input 都是由 PlayerController 来控制的 而虚幻的Input控制模式有三种 Set Input Mode Game Only (设置输入模式仅限游戏): 视角会跟着鼠标旋转 , 就是正常游戏的模式 , 这也是游戏默认输入模式 Set Input Mode UI Only (设置输入模式仅限UI): …

【C++】 vector 迭代器失效问题

【C】 vector 迭代器失效问题 一. 迭代器失效问题分析二. 对于vector可能会导致其迭代器失效的操作有&#xff1a;1. 会引起其底层空间改变的操作&#xff0c;都有可能是迭代器失效2. 指定位置元素的删除操作--erase3. Linux下&#xff0c;g编译器对迭代器失效的检测并不是非常…

通信工程学习:什么是FDD频分双工

FDD:频分双工 FDD(频分双工,Frequency Division Duplexing)是一种无线通信技术,它通过将频谱划分为上行和下行两个不重叠的频段来实现同时双向通信。以下是FDD频分双工的详细解释: 一、定义与原理 定义: FDD是一种无线通信系统的工作模式,其中上行链路(从移动…

每日OJ_牛客_OR59字符串中找出连续最长的数字串_双指针_C++_Java

目录 牛客_OR59字符串中找出连续最长的数字串 题目解析 C代码1 C代码2 C代码3 Java代码 牛客_OR59字符串中找出连续最长的数字串 字符串中找出连续最长的数字串_牛客题霸_牛客网 题目解析 双指针&#xff1a; 遍历整个字符串&#xff0c;遇到数字的时候&#xff0c;用双…

坚果N1 Air高亮版对比当贝D6X高亮版:谁是2000元预算的投影仪王者?

当贝D6X高亮版新品升级&#xff0c;对于那些计划在这个时间点购买投影仪的用户来说&#xff0c;现在是个绝佳的时机&#xff01;特别是那些预算在两千元左右的&#xff0c;目前两千元左右的投影仪&#xff0c;无外乎两款产品&#xff0c;当贝D6X高亮版和坚果N1 Air高亮版&#…

常见区块链数据模型介绍

除了加密技术和共识算法&#xff0c;区块链技术还依赖于一种数据模型&#xff0c;它决定了信息如何被结构化、验证和存储。数据模型定义了账户如何管理&#xff0c;状态转换如何发生&#xff0c;以及用户和开发者如何与系统交互。 在区块链技术的短暂历史中&#xff0c;数据…

13年408计算机考研-计算机网络

第一题&#xff1a; 解析&#xff1a;OSI体系结构 OSI参考模型&#xff0c;由下至上依次是&#xff1a;物理层-数据链路层-网络层-运输层-会话层-表示层-应用层。 A.对话管理显然属于会话层&#xff0c; B.数据格式转换&#xff0c;是表示层要解决的问题&#xff0c;很显然答案…

怎样用云手机进行TikTok矩阵运营?

在运营TikTok矩阵时&#xff0c;许多用户常常面临操作复杂、设备过多等问题。如果你也感到操作繁琐&#xff0c;不妨考虑使用云手机。云手机具备丰富的功能&#xff0c;能够帮助电商卖家快速打造高效的TikTok矩阵。接下来&#xff0c;我们将详细解析这些功能如何提升你的运营效…

智能化转型新篇章:EasyCVR引领大型连锁超市视频监控进入AI时代

随着科技的飞速发展&#xff0c;视频监控系统在各行各业中的应用日益广泛&#xff0c;大型连锁超市作为人员密集、商品繁多的公共场所&#xff0c;其安全监控显得尤为重要。为了提升超市的安全管理水平、减少损失、保障顾客和员工的安全&#xff0c;引入高效、全面的视频监控系…