PyTorch框架:(4)如何去构建数据

news/2024/9/20 15:00:55

接PyTorch框架:(3)

1、最基本的方法

(1)使用模块

模块1:TensorDataset、模块2:DataLoader

自己去构造数据集,然后一个batch一个batch的取数据,自己去写构造数据太麻烦,可以自动让其把数据源给我们构建好,这两个模块就是来帮我们完成这个事的。

第一步把x_train和y_train传进去,使用TensorDataset自动的帮我们组件dataset即(train_ds);

DataLoader是得搭配一下,先把数据转化为TensorDataset所支持的格式,然后采用DataLoader读进来,DataLoader的意思就是你把数据交给我,然后你告诉我一个batch_size有多少,然后你要取数据的时候我就帮你一个batch一个batch的取数据,这样方便一些。shuffle=True表示要不要重新洗牌;

 (2)定义一个get_data方法,需要传进来当前的数据集,后边做了一个return,就是按照一个Batch取数据就完事了;

 (3)训练函数

 自己定义一个训练方法,def fit方法,实际的去执行训练的操作。传进来的参数:

steps:一共迭代多少次。

model:就是定义的model,就是自己写个类,把model传进来。

loss_func:使用的f.中的损失。

opt:优化器是什么。

train_dl:实际数据传进来。

valid_dl:实际数据传进来。

Batch Normalization和Dropout在训练的时候一般都会加这两项,让模型过拟合的更低;在测试的时候一般就不加这两个东西了。所以为了有这两个区分,如果此时是训练,那么在训练的时候加上model.train();下边不是训练就是走一次前向传播,看一下对于当前模型来说他的一个效果,他的损失等于多少,把损失拿过来,我也不需要进行参数更新,不需要计算梯度,也不需要训练的过程,所以这一块我再额外的指定一下,这块不需要加Batch Normalization和Dropout,他不是一个训练的过程,所以在前边加上model.eval()。

所以见到这两个就是表示:model.train()强调的是你的训练过程,把该加的加进去;model.eval()强调的是测试过程,只需要得到结果,不需要把没用的都加进去。

 loss_batch做的事情:如果你传进来一个优化器,优化器求梯度,求完梯度更新,更新完之后置0,然后返回结果。这里不光计算一个loss值还要去计算他实际的梯度值是多少,要进行参数的更新。

上述相当于把每个模块都准备好了,实际训练模型的时候不用把每个函数都也在一个sell当中,下面三行就搞定了:

 第一步:拿到数据getdata。

第二步:拿到模型和优化器。(模型就是自己的类Mnist_NN)

第三步:执行fit函数。(fit函数的第三个参数表示损失函数是如何计算的,在损失函数计算当中还加入了梯度的更新,第四个使用什么样的优化器去更新我当前的结果)

2、复杂的方法

暂定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pgtn.cn/news/17759.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

使用Python,OpenCV进行银行支票数字和符号的OCR

使用Python,OpenCV进行银行支票数字和符号的OCR(第一部分)1. 效果图2. 原理2.1 MICR E-13B字体2.2 从MICR E-13B参考图像中提取数字和符号3. 源码3.1 MICR E-13B符号和数字提取3.2 银行支票数字和符号OCR参考上一篇介绍了:使用Pyh…

Jquery php 点击td变成input,修改后失去焦点发送数据

html部分 <Td><?php echo $row[bigclassid]?></Td> <td height"25" width"241" class"bigclassname"><?php echo $row[bigclassname]?></a></td> Js部分 <script> /**//* * 说明&#xff1…

PyTorch框架:(5)使用PyTorch框架构建卷积神经网络

基于pytorch构建一个非常简单的卷积神经网络&#xff0c;以Mnist数据集为例演示基本的流程 1、导工具包 2、读取数据 &#xff08;把该写的超参数全部写出来&#xff09; PS&#xff1a;当前输入图像的大小&#xff0c;注意这里使用卷积网络处理Mnist数据他就不是一个一个像素…

使用Python,OpenCV进行基本的图像处理——提取红色圆圈轮廓并绘制

使用Python&#xff0c;OpenCV进行基本的图像处理——提取红色圆圈轮廓并绘制1. 效果图1.1 形态学图像处理效果图1.2 转换HSV色彩空间提取2. 源码2.1 形态学图像处理提取源码2.2 转换HSV色彩空间提取源码写这篇博客源于博友的提问&#xff0c;想提取图片中的红色圆圈坐标&#…

PyTorch框架:(6)图像识别实战常用模块解读

1、TorchVision 官网&#xff1a;torchvision — Torchvision 0.10.0 documentation 在torchvision这个模块当中&#xff0c;包含了很多后续需要的功能&#xff1a; 需要自己安装这个模块pip install torchvision。安装完之后我们就可以使用这里边的三大核心模块了。 &…

Java 使用itextPdf7操作pdf,写入照片这一篇就够了

Java 使用itextPdf7操作pdf&#xff0c;写入照片这一篇就够了1. 效果图1.1 M*N列图片&#xff08;无边界&有边界&#xff09;1.2 图片重叠1.3 文字背景图片1.4 图片与文字相邻 & 图片文字Rowspan样式1.5 一个单元格多图片 & 多图片文本内容1.6 单元格中文本图片位置…

Flash气泡回弹效果

好久没有碰过Flash了&#xff0c;今天温习一下AS3.0&#xff0c;做了一个回弹效果&#xff0c;气泡回弹本想着怎么可以定义气泡的不同颜色&#xff0c;这样可以做出更绚丽的效果&#xff0c;或者更进步一&#xff0c;气泡和气泡直接回弹&#xff0c;想了老半天没有想出来&#…

Computer Vision Tasks

Computer Vision Tasks: 图像分类、目标检测、语义分割、实例分割&#xff1b; 只有目标检测和实例分割是实现了实例级别的识别的&#xff0c;就是把每一个单独的物体拎出来识别的&#xff1b;目标检测是画框框&#xff0c;而实例分割是抠图。 实例识别&#xff1a;就是把图片…