Model Saving and Loading under PyTorch Multiple GPU Notes on of Pitting

  • 2021-09-20 21:00:35
  • OfStack

These days, in the environment of one machine and multiple cards, many problems have been encountered in training the model with pytorch. This paper summarizes a practical way to do experiments:

Training under multiple GPU, the code for creating models is usually as follows:


os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
model = MyModel(args)
if torch.cuda.is_available() and args.use_gpu:
  model = torch.nn.DataParallel(model).cuda()

The officially recommended way to save the model is to save only parameters:


torch.save(model.module.state_dict(), "model.pkl")

In fact, this is very troublesome, so I suggest saving the model directly (parameters + diagram):


torch.save(model, "model.pkl")

This is very practical, especially when we need to model and debug repeatedly. In this case, the loading of the model is very convenient, because the diagram of the model and the parameters have been saved in 1, so we do not need to set the corresponding super parameters and replace the corresponding network structure according to different models, as follows:


 if not (args.pretrained_model_path is None):
    print('load model from %s ...' % args.pretrained_model_path)
    model = torch.load(args.pretrained_model_path)
    print('success!')

However, it should be noted that the model under multiple GPU is loaded in this way. If the server environment has not changed much, or it is the same GPU environment as the training time, there will be no problem.

If the system environment changes, or we just want to load model parameters, or we encounter the following problems:

AttributeError: 'model' object has no attribute 'copy'

Or

AttributeError: 'DataParallel' object has no attribute 'copy'

Or

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found

At this time, we can load the model in the following way, first build the model, and then load the parameters.


os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
#  Modeling 
model = MyModel(args)

if torch.cuda.is_available() and args.use_gpu:
  model = torch.nn.DataParallel(model).cuda()

if not (args.pretrained_model_path is None):
  print('load model from %s ...' % args.pretrained_model_path)
  #  Obtain model parameters 
  model_dict = torch.load(args.pretrained_model_path).module.state_dict()
  #  Load parameter 
  model.module.load_state_dict(model_dict)
  print('success!')


Related articles: