By Adam Ring
Using a pre-trained deep learning model from a framework such as Pytorch has myriad applications in robotics, from computer vision to speech recognition, and many places inbetween. Sometimes you have a model that you want to train on another system with more powerful hardware, and then deploy the model elsewhere on a less powerful system. For this task, it is extremely useful to be able to transfer the weights of your trained model into another system, such as a virtual machine running Ubuntu 18.04. These methods for model transfer will also run on any machine with pytorch installed.
It is extremely discouraged to mix versions of Pytorch between training and deployment. If you train your model on Pytorch 1.8.9, and then try to load it using Pytorch 1.4.0, you may encounter some errors due to differences in the modules between versions. For this reason it is encouraged that you load your Pytorch model using the same version that is was trained on.
Let's assume that you have your model fully trained and loaded with all of the necessary weights.
model = MyModel()
model.train()
For instructions on how to train a machine learning model, see this section on training a model in the lab notebook. There are multiple ways to save this model, and I will be covering just a few in this tutorial.
state_dict
This is reccommended as the best way to save the weights of your model as its state_dict
, however it does require some dependencies to work. Once you have your model, you must specify a PATH
to the directory in which you want to save your model. This is where you can name the file used to store your model.
PATH = "path/to/directory/my_model_state_dict.pt"
or
PATH = "path/to/directory/my_model_state_dict.pth"
You can either specify that the state_dict
be saved using .pt
or .pth
format.
Then, to save the model to a path, simply call this line of code.
torch.save(model.state_dict(), PATH)
state_dict
Download the my_model_state_dict.pt/pth
into the environment in which you plan on deploying it. Note the path that the state dict is placed in. In order to load the model weights from the state_dict
file, you must first initialize an untrained istance of your model.
loaded_model = MyModel()
Keep in mind that this step requires you to have your model architecture defined in the environment in which you are deploying your model.
Next, you can simply load your model weights from the state dict using this line of code.
loaded_model.load_state_dict(torch.load("path/to/state/dict/my_model_state_dict.pt/pth"))
The trained weights of the model are now loaded into the untrained model, and you are ready to use the model as if it is pre-trained.
TorchScript is a framework built into Pytorch which is used for model deployment in many different types of environments without having the model defined in the deployment environment. The effect of this is that you can save a model using tracing and load it from a file generated by tracing it.
What tracing does is follow the operations done on an input tensor that is run through your model. Note that if your model has conditionals such as if
statements or external dependencies, then the tracing will not record these. Your model must only work on tensors as well.
In order to trace your trained model and save the trace to a file, you may run the following lines of code.
PATH = "path/to/traced/model/traced_model.pt/pth"
dummy_input = torch.ones(typical_input_size, dtype=dype_of_typical_input)
traced_model = torch.jit.trace(model, dummy_input)
torch.jit.save(traced_model, PATH)
The dummy_input
can simply be a bare tensor that is the same size as a typical input for your model. You may also use one of the training or test inputs. The content of the dummy input does not matter, as long as it is the correct size.
In order to load the trace of a model, you must download the traced model .pt
or .pth
file into your deployment environment and note the path to it.
All you need to do to load a traced model for deployment in Pytorch is use the following line of code.
loaded_model = torch.jit.load("path/to/traced/model/traced_model.pt/pth")
Keep in mind that the traced version of your model will only work for torch tensors, and will not mimic the behavior of any conditional statements that you may have in your model.
Please see the full tutorial in the repo: https://github.com/campusrover/Robotics_Computer_Vision/tree/master/utils/labelImg