不灭的火

加密类型:SHA/AES/RSA下载Go
复合类型:数组(array)、切片(slice)、映射(map)、结构体(struct)、指针(pointer、函数(function)、接口(interface)、通道(channel) Go类型
引用类型:切片(slice)、映射(map)、指针(pointer、函数(function)、通道(channel) Go引用

作者:AlbertWen  添加时间:2020-09-19 23:55:43  修改时间:2025-11-10 19:30:53  分类:22.Python编程  编辑

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
 
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
 
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
 
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
 
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
 
def imgshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
 
dataiter=iter(trainloader)
images,labels=dataiter.next()
imgshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

抄的代码,运行报错:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.
 
        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:
 
            if __name__ == '__main__':
                freeze_support()
                ...
 
        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

原因:多进程需要在main函数中运行,

解决方法1:

加main函数,在main中调用

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
 
def imgshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
 
def main(argv=None):
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
    trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
    testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
    testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
    classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
    dataiter=iter(trainloader)
    images,labels=dataiter.next()
    imgshow(torchvision.utils.make_grid(images))
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
 
if __name__=='__main__':
    sys.exit(main())

解决方法2:

num_workers改为0,单进程加载:

num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

 

 

参考:

https://blog.csdn.net/yaoyutian/article/details/85086129

https://blog.csdn.net/jacke121/article/details/81456842