import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform) trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2) testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2) classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') def imgshow(img): img=img/2+0.5 npimg=img.numpy() plt.imshow(np.transpose(npimg,(1,2,0))) plt.show() dataiter=iter(trainloader) images,labels=dataiter.next() imgshow(torchvision.utils.make_grid(images)) print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
抄的代码,运行报错:
RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase. This probably means that you are not using fork to start your child processes and you have forgotten to use the proper idiom in the main module: if __name__ == '__main__': freeze_support() ... The "freeze_support()" line can be omitted if the program is not going to be frozen to produce an executable.
原因:多进程需要在main函数中运行,
解决方法1:
加main函数,在main中调用
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np import sys def imgshow(img): img=img/2+0.5 npimg=img.numpy() plt.imshow(np.transpose(npimg,(1,2,0))) plt.show() def main(argv=None): transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform) trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2) testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2) classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') dataiter=iter(trainloader) images,labels=dataiter.next() imgshow(torchvision.utils.make_grid(images)) print(' '.join('%5s' % classes[labels[j]] for j in range(4))) if __name__=='__main__': sys.exit(main())
解决方法2:
num_workers改为0,单进程加载:
num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
参考: