不灭的焱

革命尚未成功,同志仍须努力

作者:php-note.com  发布于:2020-09-19 23:51  分类:Python/数据分析  编辑

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
 
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
 
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
 
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
 
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
 
def imgshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
 
dataiter=iter(trainloader)
images,labels=dataiter.next()
imgshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

抄的代码,运行报错:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.
 
        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:
 
            if __name__ == '__main__':
                freeze_support()
                ...
 
        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

原因:多进程需要在main函数中运行,

解决方法1:

加main函数,在main中调用

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
 
def imgshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
 
def main(argv=None):
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
    trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
    testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
    testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
    classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
    dataiter=iter(trainloader)
    images,labels=dataiter.next()
    imgshow(torchvision.utils.make_grid(images))
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
 
if __name__=='__main__':
    sys.exit(main())

解决方法2:

num_workers改为0,单进程加载:

num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

 

 

参考:

https://blog.csdn.net/yaoyutian/article/details/85086129

https://blog.csdn.net/jacke121/article/details/81456842