重参数化(Reparameterization)的原理

重参数化(Reparameterization)的原理

重参数化是变分自编码器(VAE)中用来解决可微分性问题的一种技术。在VAE中,我们的目标是最大化观测数据的边缘对数似然,这涉及到一个隐含变量 z z z的积分或求和。因为隐含变量是从某个分布中采样的,这直接导致了当我们尝试使用梯度下降方法优化VAE的参数时,由于采样操作的随机性,无法直接对其求导。

重参数化技巧通过将随机采样过程转换为确定性的操作来解决这一问题。具体来说,它将随机变量 z z z的采样过程分解为两步:

  1. 从一个固定的分布(通常是标准正态分布)中采样一个辅助噪声变量 ϵ \epsilon ϵ
  2. 通过一个可微的变换将 ϵ \epsilon ϵ映射到隐变量 z z z

这样,原本依赖于随机采样的模型输出现在变成了依赖于确定性函数的输出,使得整个模型关于其参数可微,从而可以通过标准的反向传播算法进行优化。

功能

  • 允许反向传播:通过使用重参数化技巧,VAE的训练过程可以利用基于梯度的优化算法,如SGD或Adam,因为所有操作都是可微的。
  • 改善训练稳定性:将随机性限制在输入端(噪声 ϵ \epsilon ϵ),而不是模型的中间,有助于提高模型训练的稳定性和收敛速度。
  • 支持更复杂的概率模型:这种技巧使得模型可以学习复杂的数据分布,同时保持模型的可训练性。

Python 示例

下面是使用PyTorch实现的VAE中应用重参数化技巧的简单示例:

import torch
from torch import nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)  # 输入特征到隐层
        self.fc21 = nn.Linear(400, 20)  # 隐层到均值
        self.fc22 = nn.Linear(400, 20)  # 隐层到log方差
        self.fc3 = nn.Linear(20, 400)   # 隐层到输出
        self.fc4 = nn.Linear(400, 784)  # 输出层

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 损失函数和训练代码在这里省略,只关注模型结构和重参数化部分。

在这个示例中,reparameterize 函数接收从编码器生成的均值和对数方差,然后生成一个随机样本 z,该样本符合由均值 mu 和方差 exp(logvar) 定义的正态分布。这个过程使得模型在训练过程中能够通过梯度下

降法进行优化。

其他参考:

漫谈重参数:从正态分布到Gumbel Softmax。
Categorical Reparameterization with Gumbel-Softmax

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/568240.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DSPy入门:告别指令提示,拥抱编程之旅!

原文:intro-to-dspy-goodbye-prompting-hello-programming 2024 年 2 月 27 日 DSPy框架如何通过用编程和编译代替提示来解决基于LLM的应用程序中的脆弱性问题。 目前,使用大型语言模型(LLMs)构建应用程序不仅复杂而且脆弱。典型的pipelines通常使用pr…

解决“找不到MSVCP120.dll”或“MSVCP120.dll丢失”的错误方法

在计算机使用过程中,遇到诸如“找不到MSVCP120.dll”或“MSVCP120.dll丢失”的错误提示并不罕见。这类问题往往会导致某些应用程序无法正常运行,给用户带来困扰。本文旨在详细阐述MSVCP120.dll文件的重要性、其丢失的可能原因,以及解决方法&a…

nginx开启basic认证

basic认证也叫做http基本认证,防止恶意访问 首先用在线网站生成一个叫做htpasswd的账号密码文件。 将生成结果复制到/etc/nginx/htpasswd文件中 在server的location中配置 server { listen 80; server_name a.com;location / { root html;index index.…

2001-2021年上市公司制造业智能制造词频统计数据

2001-2021年上市公司制造业智能制造词频统计数据 1、时间:2001-2021年 2、来源:上市公司年报 3、指标:年份、股票代码、行业名称、行业代码、所属省份、所属城市、智能制造词频、智能制造占比(%) 4、范围:上市公司 5、样本量…

基于TSM模块的打架斗殴识别技术

目 录 1 引言.... 4 1.1 研究背景与意义.... 4 1.2 研究现状综述.... 5 1.3 研究内容.... 6 1.3.1 图像预处理的优化.... 6 1.3.2 TSM模块的应用.... 6 1.3.3 视频分类的设计与实现.... 6 2 关键技术与方法.... 8 2.1 TSM算法与模型选择.... 8 2.1.1 TSM算法原理.... 8 2.1.2 …

深度学习-数据预处理

目录 创建一个人工数据集处理缺失的数据插入对inputs中的类别值或离散值,将NaN视为一个类别对inputs和outputs中的数值类型转换为张量格式 创建一个人工数据集 import os import pandas as pd os.makedirs(os.path.join(.., data), exist_okTrue) data_file os.p…

基于Vue+ElementPlus自定义带历史记录的搜索框组件

前言 基于Vue2.5ElementPlus实现的一个自定义带历史记录的搜索框组件 效果如图: 基本样式: 获取焦点后: 这里的历史记录默认最大存储10条,同时右侧的清空按钮可以清空所有历史记录。 同时搜索记录也支持点击搜索,按…

.NET(C#)连接达梦数据库GUID字段被自动加横线的修复方法

因信创的原因项目需要兼容达梦数据库,今天遇到个比较坑爹的问题,简单记录下解决方案。 数据库存的是这样: 通过DataAdapter.Fill拿出来以后变成了这样 纳尼?谁让你加上这些横杠的?(掀桌)导致了…

100个实用电气知识

在当今社会,电力作为日常生活和工作中不可或缺的能源,扮演着越来越重要的角色。为了更好地利用电力资源,了解电气知识成为了越来越多人的需求。在电气领域,有很多实用的知识,这些知识对于从事电气工作的人来说是非常重…

Linux系统安全:从面临的攻击和风险到安全加固、安全维护策略(文末有福利)

1. Linux面临的攻击与风险 1.1. Linux系统架构 Linux系统架构解读: 用户之间隔离内核态与用户态之间隔离用户进程一般以低权限用户运行系统服务一般以特权服务运行用户态通过系统调用进入内核态内核对系统资源进行管理和分配 1.2. Linux系统常见安全威胁 1.2.1.…

OSPF认证方式,ISIS简介,ISIS路由器类型

OSPF:转发,泛洪,丢弃

Docker搭建代码托管Gitlab

文章目录 一、简介二、Docker部署三、管理员使用四、用户使用五、用户客户端 一、简介 GitLab是一个基于Git的代码托管和协作平台,类似于GitHub。 它提供了一个完整的工具集,包括代码仓库管理、问题跟踪、CI/CD集成、代码审查等功能。 GitLab的开源版本…

Go语言并发赋值的安全性

struct并发赋值 type Test struct {X intY int }func main() {var g Testfor i : 0; i < 1000000; i {var wg sync.WaitGroup// 协程 1wg.Add(1)go func() {defer wg.Done()g Test{1, 2}}()// 协程 2wg.Add(1)go func() {defer wg.Done()g Test{3, 4}}()wg.Wait()// 赋值…

2024新算法角蜥优化算法(HLOA)和经典灰狼优化器(GWO)进行无人机三维路径规划设计实验

简介&#xff1a; 2024新算法角蜥优化算法&#xff08;HLOA&#xff09;和经典灰狼优化器&#xff08;GWO&#xff09;进行无人机三维路径规划设计实验。 无人机三维路径规划的重要意义在于确保飞行安全、优化飞行路线以节省时间和能源消耗&#xff0c;并使无人机能够适应复杂…

国内首个48小时大模型极限挑战赛落幕,四位“天才程序员”共同夺冠

4月21日晚&#xff0c;第四届ATEC科技精英赛&#xff08;ATEC2023&#xff09;线下赛落幕。本届赛事以大模型为技术基座&#xff0c;围绕“科技助老”命题&#xff0c;是国内首个基于真实场景的大模型全链路应用竞赛。ATEC2023线下赛采用48小时极限挑战的形式&#xff0c;来自东…

Ts支持哪些类型和类型运算(上)

目录 1、元组 2、接口&#xff08;interface&#xff09; 3、枚举&#xff08;Enum&#xff09; 4、字面量类型 5、keyof 6、in keyof 7、类型的装饰 静态类型系统 就是把 类型检查从运行时提前到了编译时&#xff0c;所以ts类型系统中的许多类型与js并无区别 例如&am…

概率图模型在机器学习中的应用:贝叶斯网络与马尔可夫随机场

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

go语言并发实战——日志收集系统(七) etcd的介绍与简单使用

什么是etcd etcd是基于Go语言开发的一个开源且高可用的分布式key-value存储系统&#xff0c;我们可以在上面实现配置共享与服务的注册与发现。 和它比较相似的还有我们之间所提到的Zookeeper以及consul.(注:后面我们学习微服务的时候etcd和consul会有广泛的使用) etcd有以下几…

网络中其他协议

目录 DNS协议 域名简介 ICMP协议 ICMP功能 ICMP协议格式 ping命令 NAT技术 NATP NAT技术的限制 代理服务器 DNS协议 DNS&#xff08;Domain Name System&#xff0c;域名系统&#xff09;协议&#xff0c;是一个用来将域名转化为IP地址的应用层协议。 为什么有这个协…

W801学习笔记十二:掌机进阶V3版本之驱动(PSRAM/SD卡)

本次升级添加了两个模块&#xff0c;现在要把他们驱动起来。 一&#xff1a;PSRAM 使用SDK自带的驱动&#xff0c;我们只需要写一个初始化函数&#xff0c;并在其中添加一些自检代码。 void psram_heap_init(){wm_psram_config(0);//实际使用的psram管脚选择0或者1&#xff…