pytorch(6)–深度置信网络「建议收藏」

pytorch(6)–深度置信网络「建议收藏」一、前言本文主要使用pytorch实现的DBN网络,用于对数据做回归,单个数据维度为(N,21),其中N为不定长,输出则为(N,1),对应N个值DBN网络结构:首层神经元数量输入为变量长度21,中间为RBM网络,如本篇使用的网络结构诶[128,64,32,16],为一个4层的RBM网络结构,训练时RBM需要逐层做训练;在RBM训练后,再接上BP神经网络,再对BP网络做微调,回归损失函数使用MSEloss。二、深度置信网络实现代码#DBN.pyimpor…

大家好,欢迎来到IT知识分享网。

一、前言

    本文主要使用pytorch 实现的DBN网络,用于对数据做回归,单个数据维度为(N,21),其中N为不定长,输出则为(N,1),对应N个值

DBN网络结构:

    pytorch(6)--深度置信网络「建议收藏」

    首层神经元数量输入为变量长度21,中间为RBM网络,如本篇使用的网络结构诶[128,64,32,16],为一个4层的RBM网络结构,训练时RBM需要逐层做训练;在RBM训练后,再接上BP神经网络,再对BP网络做微调,回归损失函数使用MSE loss。

二、深度置信网络实现代码

#DBN.py
import torch
import warnings
import torch.nn as nn
import numpy as np

from RBM import RBM
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.optim import Adam, SGD

from genCsvData import indefDataSet,DataLoader


class DBN(nn.Module

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/26551.html

(1)

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

关注微信