5. pytorch-项目流程

1. 数据准备

基本步骤: 生成Dataset(或其子类)对象->传入DataLoader(为可迭代对象,可以用for迭代)

1.1 Dataset类

Dataset为抽象类

  • 注意
    • 直接从Dataset中取出的数据是没有经过transform的,只有通过Dataloader加载才可以
      training_data=torchvision.datasets.MNIST(root="./mnist", train=True,
                                 transform=torchvision.transforms.ToTensor(), download=True)
      # 像素点的范围仍然是0-255, 不是0-1
      print(training_data.train_data[0])
      

1.1.1 Dataset子类:TensorDataset

  • 源码阅读
    class TensorDataset(Dataset):
      """Dataset wrapping tensors.
    
      Each sample will be retrieved by indexing tensors along the first dimension.
    
      Arguments:
          *tensors (Tensor): tensors that have the same size of the first dimension.
      代码示例:
          x = torch.linspace(1, 10, 10)
          y = torch.linspace(10, 1, 10)
          dataset = TensorDataset(x, y)
      """
      def __init__(self, *tensors):
          """
          &1
          tensors[0]为x; tensor[1]为y。因为x,y的batch_size要相同,所以要assert
          TensorDataset(x, y, z...)传入任意多参数都是可以的
          """
          assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
          self.tensors = tensors
          
    
      def __getitem__(self, index):
          """
          &2
          相当于重载[]运算符
          """
          return tuple(tensor[index] for tensor in self.tensors)
    
      def __len__(self):
          return self.tensors[0].size(0)
    
  • 示例代码
    import torch
    import torch.utils.data as Data
    
    if __name__ == "__main__":
      x = torch.linspace(1, 10, 10)
      y = torch.linspace(10, 1, 10)
      dataset = Data.TensorDataset(x, y)
      # &1
      # 当最后一个step不足5个(假设仅剩2个),则仅会返回2个
      # shuffle: 训练时为True则打乱数据集
      # num_workers为子进程数量
      dataloader = Data.DataLoader(dataset=dataset, batch_size=5,
                                   shuffle=True, num_workers=2)
      for epoch in range(3):
          for step,input_data in enumerate(dataloader):
              print(f"{epoch}-{step}:\n{input_data}")
    

2. 网络搭建

2.1 class模式

2.2 Sequential模式

net = torch.nn.Sequential(
        torch.nn.Linear(2, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 2)
    )
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 关于Mongodb的全面总结 MongoDB的内部构造《MongoDB The Definitive Guide》...
    中v中阅读 32,038评论 2 89
  • 1. Java基础部分 基础部分的顺序:基本语法,类相关的语法,内部类的语法,继承相关的语法,异常的语法,线程的语...
    子非鱼_t_阅读 31,767评论 18 399
  • 你有经历过这些吗? 从小到大,我们总会遇到一些“命运”眷顾的人。在学校,会有亲戚老师的格外帮助...
    浅若夏沫清阅读 648评论 0 1
  • 郑大主编那打了鸡血似的持续了一周的亢奋之心,被周五那场突如其来的热带暴风雨扑灭。我们从洛溪市新锐传媒大厦出来后,一...
    图革者阅读 940评论 2 45
  • 放纵自己、做个不喜欢自己的人。远比做个让别人认同的人简单的多。 何必委屈自己,何必为别人而活,即使当所有人都放弃了...
    以太_x阅读 138评论 1 0