Managing Models & Releases (ResNet-50 Example with PyTorch)
The Bailo python client enables intuitive interaction with the Bailo service, from within a python environment. This example notebook will run through the following concepts:
Creating a new model on Bailo.
Creating and populating a model card.
Retrieving models from the service.
Making changes to the model, and model card.
Creating and managing specific releases, with files.
Prerequisites:
Python 3.8.1 or higher (including a notebook environment for this demo).
A local or remote Bailo service (see https://github.com/gchq/Bailo).
Introduction
The Bailo python client is split into two sub-packages: core and helper.
Core: For direct interactions with the service endpoints.
Helper: For more intuitive interactions with the service, using classes (e.g. Model) to handle operations.
In order to create helper classes, you will first need to instantiate a Client()
object from the core. By default, this object will not support any authentication. However, Bailo also supports PKI authentication, which you can use from Python by passing a PkiAgent()
object into the Client()
object when you instantiate it.
IMPORTANT: Select the relevant pip install command based on your environment.
[ ]:
# LINUX - CPU
! pip install bailo torch torchvision --index-url https://download.pytorch.org/whl/cpu
# MAC & WINDOWS - CPU
#! pip install bailo torch torchvision
[ ]:
# Necessary import statements
from bailo import Model, Client
import torch
from torchvision.models import resnet50, ResNet50_Weights
# Instantiating the PkiAgent(), if using.
# agent = PkiAgent(cert='', key='', auth='')
# Instantiating the Bailo client
client = Client("http://127.0.0.1:8080") # <- INSERT BAILO URL (if not hosting locally)
Creating a new ResNet-50 model in Bailo
Creating and updating the base model
In this section, we’ll create a new model representing ResNet-50 using the Model.create()
classmethod. On the Bailo service, a model must consist of at least 4 parameters upon creation. These are name, description, visibility and team_id. Other attributes like model cards, files, or releases are added later on. Below, we use the Client()
object created before when instantiating the new Model()
object.
NOTE: This creates the model on your Bailo service too! The model_id
is assigned by the backend, and we will use this later to retrieve the model.
[ ]:
model = Model.create(client=client, name="ResNet-50", description="ResNet-50 model for image classification.", team_id="uncategorised")
model_id = model.model_id
You may make changes to these attributes and then call the update()
method to relay the changes to the service, as below:
model.name = "New Name"
model.update()
Creating and populating a model card
When creating a model card, first we need to generate an empty one using the card_from_schema()
method. In this instance, we will use minimal-general-v10. You can manage custom schemas using the Schema()
helper class, but this is out of scope for this demo.
[ ]:
model.card_from_schema(schema_id='minimal-general-v10')
print(f"Model card version is {model.model_card_version}.")
Creating and populating a model card
When creating a model card, first we need to generate an empty one using the card_from_schema()
method. In this instance, we will use minimal-general-v10. You can manage custom schemas using the Schema()
helper class, but this is out of scope for this demo.
Creating and populated a new model card with a template
When creating a model card from a template, we need to use a preexisting model card as our template. First we create a new model, to create the new model card we use the card_from_template
method and pass our chosen template model’s ID.
[ ]:
model2 = Model.create(
client=client, name="ResNet-50", description="ResNet-50 model for image classification.", team_id="uncategorised"
)
model2_id = model2.model_id
model2.card_from_template(model.model_id)
print(f"Model name %s", model2.name)
If successful, the above will have created a new model card, and the model_card_version
attribute should be set to 1.
Next, we can populate the model card using the update_model_card()
method. This can be used any time you want to make changes, and the backend will create a new model card version each time. We’ll learn how to retrieve model cards later (either the latest, or a specific release).
NOTE: Your model card must match the schema, otherwise an error will be thrown.
[ ]:
new_card = {
'overview': {
'tags': [],
'modelSummary': 'ResNet-50 model for image classification.',
}
}
model.update_model_card(model_card=new_card)
print(f"Model card version is {model.model_card_version}.")
If successful, the model_card_version
will now be 2!
Retrieving an existing model
Using the .from_id() method
In this section, we’ll retrieve our previous model using the Model.from_id()
classmethod. This will create your Model()
object as before, but using existing information retrieved from the service.
[ ]:
model = Model.from_id(client=client, model_id=model_id)
print(f"Model description: {model.description}")
If successful, the model description we set earlier should be displayed above.
Creating and managing releases for models
On the Bailo service, different versions of the same model are managed using releases. Generally, this is for code changes and minor adjustments that don’t drastically change the behaviour of a model. In this section we will create a release and upload a file.
Creating a release
Release()
is a separate helper class in itself, but we can use our Model()
object to create and retrieve releases. Running the below code will create a new release of the model, and return an instantiated Release()
object which we will use to upload files with.
[ ]:
release_one = model.create_release(version='1.0.0', notes='Initial model weights.')
Preparing the model weights using PyTorch
In order to upload the ResNet50 model to Bailo, we must first retrieve the weights from PyTorch and save them to a BytesIO object. The Release.upload()
method takes a BytesIO object, and the torch.save()
method allows us to do this directly without the need to use up local disk space.
[ ]:
torch_model = resnet50(weights=ResNet50_Weights.DEFAULT)
torch.save(torch_model.state_dict(), 'resnet50_weights.pth')
Uploading weights to the release
To upload files for a release, we can use the release upload()
method which will take a file name, and a BytesIO
type containing the file contents. In this case, we’re using the resnet50_weights.pth we prepared in the last step.
NOTE: The upload()
method takes a BytesIO
type to allow for other integrations, such as with S3.
[ ]:
release_one.upload(path="resnet50_weights.pth")
Retrieving a release
We can retrieve the latest release for our ResNet-50 model using the model get_latest_release()
method. Alternatively, we can retrieve a specific release using the model get_release()
method. Both of these will return an instantiated Release()
object.
[ ]:
release_latest = model.get_latest_release()
release_one = model.get_release(version='1.0.0')
#To demonstrate this is the same release:
if release_latest == release_one:
print("Successfully retrieved identical releases!")
Downloading weights from the release
Similarly you can also download specific files from release using the download()
method. In this case, we’ll write them to a new file: bailo_resnet50_weights.pth. NOTE: filename
refers to the filename on Bailo, and path
is the local destination for your download.
In addition to this, you can also use the download_all()
method by providing a local directory path as path
. By default, this will download all files, but you can provide include
and exclude
lists, e.g. include=["*.txt", "*.json"]
to only include TXT or JSON files.
[ ]:
#release_latest.download(filename="resnet50_weights.pth", path="bailo_resnet50_weights.pth")
release_latest.download_all(path="downloads")
Loading the model using PyTorch
Finally, now we’ve retrieved the ResNet-50 weights from our Bailo release, we can load them in using the torch library.
[ ]:
weights = torch.load("downloads/resnet50_weights.pth")
torch_model = resnet50()
torch_model.load_state_dict(weights)
If the message “All keys matched successfully” is displayed, we have successfully initated our model.
Searching for models
In addition to fetching specific models, you can also use the Model.search()
method to return a list of Model()
objects that match your parameters. These parameters can be:
Task of the model (e.g. image classification).
Libraries used for the model (e.g. PyTorch).
Model card search (string to be found in model cards).
In the below example, we’ll just search for all models with no filters.
[ ]:
models = Model.search(client=client)
print(models)
We should now have a list of Model()
objects.