from __future__ import annotations
import logging
import os
import shutil
import tempfile
import warnings
from typing import Any
from semantic_version import Version
# isort: split
from bailo.core.client import Client
from bailo.core.enums import EntryKind, MinimalSchema, ModelVisibility
from bailo.core.exceptions import BailoException
from bailo.core.utils import NestedDict
from bailo.helper.entry import Entry
from bailo.helper.release import Release
try:
import mlflow
ML_FLOW = True
except ImportError:
ML_FLOW = False
logger = logging.getLogger(__name__)
[docs]
class Model(Entry):
"""Represent a model within Bailo.
:param client: A client object used to interact with Bailo
:param model_id: A unique ID for the model
:param name: Name of model
:param description: Description of model
:param organisation: Organisation responsible for the model, defaults to None
:param state: Development readiness of the model, defaults to None
:param visibility: Visibility of model, using ModelVisibility enum (e.g Public or Private), defaults to None
"""
def __init__(
self,
client: Client,
model_id: str,
name: str,
description: str,
organisation: str | None = None,
state: str | None = None,
visibility: ModelVisibility | None = None,
) -> None:
super().__init__(
client=client,
id=model_id,
name=name,
description=description,
kind=EntryKind.MODEL,
visibility=visibility,
organisation=organisation,
state=state,
)
self.model_id = model_id
[docs]
@classmethod
def create(
cls,
client: Client,
name: str,
description: str,
organisation: str | None = None,
state: str | None = None,
visibility: ModelVisibility | None = None,
) -> Model:
"""Build a model from Bailo and upload it.
:param client: A client object used to interact with Bailo
:param name: Name of model
:param description: Description of model
:param organisation: Organisation responsible for the model, defaults to None
:param state: Development readiness of the model, defaults to None
:param visibility: Visibility of model, using ModelVisibility enum (e.g Public or Private), defaults to None
:return: Model object
"""
res = client.post_model(
name=name,
kind=EntryKind.MODEL,
description=description,
visibility=visibility,
organisation=organisation,
state=state,
)
model_id = res["model"]["id"]
logger.info(f"Model successfully created on server with ID %s.", model_id)
model = cls(
client=client,
model_id=model_id,
name=name,
description=description,
visibility=visibility,
organisation=organisation,
state=state,
)
model._unpack(res["model"])
return model
[docs]
@classmethod
def from_id(cls, client: Client, model_id: str) -> Model:
"""Return an existing model from Bailo.
:param client: A client object used to interact with Bailo
:param model_id: A unique model ID
:return: A model object
"""
res = client.get_model(model_id=model_id)["model"]
if res["kind"] != "model":
raise BailoException(f"ID {model_id} does not belong to a model. Did you mean to use Datacard.from_id()?")
logger.info(f"Model %s successfully retrieved from server.", model_id)
model = cls(
client=client,
model_id=model_id,
name=res["name"],
description=res["description"],
organisation=res.get("organisation"),
state=res.get("state"),
)
model._unpack(res)
model.get_card_latest()
return model
[docs]
@classmethod
def search(
cls,
client: Client,
task: str | None = None,
libraries: list[str] | None = None,
filters: list[str] | None = None,
search: str = "",
) -> list[Model]:
"""Return a list of model objects from Bailo, based on search parameters.
:param client: A client object used to interact with Bailo
:param task: Model task (e.g. image classification), defaults to None
:param libraries: Model library (e.g. TensorFlow), defaults to None
:param filters: Custom filters, defaults to None
:param search: String to be located in model cards, defaults to ""
:return: List of model objects
"""
res = client.get_models(task=task, libraries=libraries, filters=filters, search=search)
models = []
for model in res["models"]:
res_model = client.get_model(model_id=model["id"])["model"]
model_obj = cls(
client=client,
model_id=model["id"],
name=model["name"],
description=model["description"],
organisation=res.get("organisation"),
state=res.get("state"),
)
model_obj._unpack(res_model)
model_obj.get_card_latest()
models.append(model_obj)
return models
[docs]
@classmethod
def from_mlflow(
cls,
client: Client,
mlflow_uri: str,
name: str,
schema_id: str = MinimalSchema.MODEL,
version: str | None = None,
files: bool = True,
visibility: ModelVisibility | None = None,
organisation: str | None = None,
state: str | None = None,
) -> Model:
"""Import an MLFlow Model into Bailo.
:param client: A client object used to interact with Bailo
:param mlflow_uri: MLFlow server URI
:param name: Name of model (on MLFlow). Same name will be used on Bailo
:param schema_id: A unique schema ID, only required when files is True, defaults to minimal-general-v10
:param version: Specific MLFlow model version to import, defaults to None
:param files: Import files?, defaults to True
:param visibility: Visibility of model on Bailo, using ModelVisibility enum (e.g Public or Private), defaults to None
:param organisation: Organisation responsible for the model, defaults to None
:param state: Development readiness of the model, defaults to None
:return: A model object
"""
if not ML_FLOW:
raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.")
mlflow_client = mlflow.tracking.MlflowClient(tracking_uri=mlflow_uri) # type: ignore[reportPrivateImportUsage]
mlflow.set_tracking_uri(mlflow_uri)
filter_string = f"name = '{name}'"
res = mlflow_client.search_model_versions(filter_string=filter_string, order_by=["version_number DESC"])
if not res:
raise BailoException("No MLFlow models found. Are you sure the name/alias/version provided is correct?")
sel_model = None
if version:
for model in res:
if model.version == version:
sel_model = model
else:
sel_model = res[0]
if sel_model is None:
raise BailoException("No MLFlow model found. Are you sure the name/alias/version provided is correct?")
name = sel_model.name
description = str(sel_model.description) + " Imported from MLFlow."
bailo_res = client.post_model(
name=name,
kind=EntryKind.MODEL,
description=description,
visibility=visibility,
organisation=organisation,
state=state,
)
model_id = bailo_res["model"]["id"]
logger.info(f"MLFlow model successfully imported to Bailo with ID %s", model_id)
model = cls(
client=client,
model_id=model_id,
name=name,
description=description,
visibility=visibility,
organisation=organisation,
state=state,
)
model._unpack(bailo_res["model"])
if files:
model.card_from_schema(schema_id=schema_id)
release = model.create_release(version=Version.coerce(str(sel_model.version)), notes=" ")
run_id = sel_model.run_id
if run_id is None:
raise BailoException(
"MLFlow model does not have an associated run_id, therefore artifacts cannot be transferred."
)
mlflow_run = mlflow_client.get_run(run_id)
artifact_uri: str = str(mlflow_run.info.artifact_uri)
if artifact_uri is None:
raise BailoException("Artifact URI could not be found, therefore artifacts cannot be transferred.")
if mlflow.artifacts.list_artifacts(artifact_uri=artifact_uri) is not None: # type: ignore[reportPrivateImportUsage]
temp_dir = os.path.join(tempfile.gettempdir(), "mlflow_model")
mlflow_dir = os.path.join(temp_dir, f"mlflow_{run_id}")
mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=mlflow_dir) # type: ignore[reportPrivateImportUsage]
release.upload(mlflow_dir)
return model
[docs]
def update_model_card(self, model_card: dict[str, Any] | None = None) -> None:
"""Upload and retrieve any changes to the model card on Bailo.
:param model_card: Model card dictionary, defaults to None
..note:: If a model card is not provided, the current model card attribute value is used
"""
self._update_card(card=model_card)
[docs]
def create_experiment(
self,
) -> Experiment:
"""Create an experiment locally
:return: An experiment object
"""
return Experiment.create(model=self)
[docs]
def create_release(
self,
version: Version | str,
notes: str,
files: list[str] | None = None,
images: list[str] | None = None,
minor: bool = False,
draft: bool = True,
) -> Release:
"""Call the Release.create method to build a release from Bailo and upload it.
:param version: A semantic version for the release
:param notes: Notes on release, defaults to ""
:param files: A list of files for release, defaults to []
:param images: A list of images for release, defaults to []
:param minor: Is a minor release?, defaults to False
:param draft: Is a draft release?, defaults to True
:return: Release object
"""
if self.model_card_schema:
return Release.create(
client=self.client,
model_id=self.model_id,
version=version,
notes=notes,
model_card_version=self.model_card_version,
files=files,
images=images,
minor=minor,
draft=draft,
)
raise BailoException("Create a model card before creating a release")
[docs]
def get_releases(self) -> list[Release]:
"""Get all releases for the model.
:return: List of Release objects
"""
res = self.client.get_all_releases(model_id=self.model_id)
releases = []
for release in res["releases"]:
releases.append(self.get_release(version=release["semver"]))
logger.info(f"Successfully retrieved all releases for model %s.", self.model_id)
return releases
[docs]
def get_release(self, version: Version | str) -> Release:
"""Call the Release.from_version method to return an existing release from Bailo.
:param version: A semantic version for the release
:return: Release object
"""
return Release.from_version(self.client, self.model_id, version)
[docs]
def get_latest_release(self):
"""Get the latest release for the model from Bailo.
:return: Release object
"""
releases = self.get_releases()
if not releases:
raise BailoException("This model has no releases.")
latest_release = max(releases)
logger.info(
f"latest_release (%s) for %s retrieved successfully.",
str(latest_release.version),
self.model_id,
)
return max(releases)
[docs]
def get_images(self):
"""Get all model image references for the model.
:return: List of images
"""
res = self.client.get_all_images(model_id=self.model_id)
logger.info(f"Images for %s retrieved successfully.", self.model_id)
return res["images"]
[docs]
def get_image(self):
"""Get a model image reference.
:raises NotImplementedError: Not implemented error.
"""
raise NotImplementedError
@property
def model_card(self):
return self._card
@model_card.setter
def model_card(self, value):
self._card = value
@property
def model_card_version(self):
return self._card_version
@model_card_version.setter
def model_card_version(self, value):
self._card_version = value
@property
def model_card_schema(self):
return self._card_schema
@model_card_schema.setter
def model_card_schema(self, value):
self._card_schema = value
def __repr__(self) -> str:
return f"{self.__class__.__name__}({str(self)})"
def __str__(self) -> str:
return f"{self.model_id}"
[docs]
class Experiment:
"""Represent an experiment locally.
:param model: A Bailo model object which the experiment is being run on
:param raw: Raw information about the experiment runs
.. code-block:: python
experiment = model.create_experiment()
for x in range(5):
experiment.start_run()
experiment.log_params({"lr": 0.01})
### INSERT MODEL TRAINING HERE ###
experiment.log_metrics("accuracy": 0.86)
experiment.log_artifacts(["weights.pth"])
experiment.publish(mc_loc="performance.performanceMetrics", run_id=1)
"""
def __init__(
self,
model: Model,
):
self.model = model
self.raw = []
self.run = -1
self.temp_dir = os.path.join(tempfile.gettempdir(), "bailo_runs")
self.published = False
self.run_data = {}
[docs]
@classmethod
def create(
cls,
model: Model,
) -> Experiment:
"""Create an experiment locally.
:param model: A Bailo model object which the experiment is being run on
:return: Experiment object
"""
return cls(model=model)
[docs]
def start_run(self, is_mlflow: bool = False):
"""Starts a new experiment run.
:param is_mlflow: Marks a run as MLFlow
"""
self.run += 1
self.run_data = {
"run": self.run,
"params": [],
"metrics": [],
"artifacts": [],
"dataset": "",
}
self.raw.append(self.run_data)
if not is_mlflow:
logger.info(f"Bailo tracking run %s.", self.run)
[docs]
def log_params(self, params: dict[str, Any]):
"""Logs parameters to the current run.
:param params: Dictionary of parameters to be logged
"""
self.run_data["params"].append(params)
[docs]
def log_metrics(self, metrics: dict[str, Any]):
"""Logs metrics to the current run.
:param metrics: Dictionary of metrics to be logged
"""
self.run_data["metrics"].append(metrics)
[docs]
def log_artifacts(self, artifacts: list):
"""Logs artifacts to the current run.
:param artifacts: A list of artifact paths to be logged
"""
self.run_data["artifacts"].extend(artifacts)
[docs]
def log_dataset(self, dataset: str):
"""Logs a dataset to the current run.
:param dataset: Arbitrary title of dataset
"""
self.run_data["dataset"] = dataset
[docs]
def from_mlflow(self, tracking_uri: str, experiment_id: str):
"""Imports information from an MLFlow Tracking experiment.
:param tracking_uri: MLFlow Tracking server URI
:param experiment_id: MLFlow Tracking experiment ID
:raises ImportError: Import error if MLFlow not installed
"""
if not ML_FLOW:
raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.")
client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri) # type: ignore[reportPrivateImportUsage]
runs = client.search_runs([experiment_id])
if len(runs):
logger.info(
f"Successfully retrieved MLFlow experiment %s from tracking server. %d were found.",
experiment_id,
len(runs),
)
else:
warnings.warn(
f"MLFlow experiment {experiment_id} does not have any runs and publishing requires at least one valid run. Are you sure the ID is correct?"
)
for run in runs:
data = run.data
info = run.info
inputs = run.inputs
artifact_uri: str = str(info.artifact_uri)
run_id = info.run_id
status = info.status
datasets = inputs.dataset_inputs
datasets_str = [dataset.dataset.name for dataset in datasets]
artifacts = []
# MLFlow run must be status FINISHED
if status != "FINISHED":
continue
if mlflow.artifacts.list_artifacts(artifact_uri=artifact_uri) is not None: # type: ignore[reportPrivateImportUsage]
mlflow_dir = os.path.join(self.temp_dir, f"mlflow_{run_id}")
mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=mlflow_dir) # type: ignore[reportPrivateImportUsage]
artifacts.append(mlflow_dir)
logger.info(
f"Successfully downloaded artifacts for MLFlow experiment %s to %s.",
experiment_id,
mlflow_dir,
)
self.start_run(is_mlflow=True)
self.log_params(data.params)
self.log_metrics(data.metrics)
self.log_artifacts(artifacts)
self.log_dataset("".join(datasets_str))
self.run_data["run"] = info.run_id
logger.info(f"Successfully imported MLFlow experiment %s.", experiment_id)
[docs]
def publish(
self,
mc_loc: str,
semver: str = "0.1.0",
notes: str = "",
run_id: str | None = None,
select_by: str | None = None,
):
"""Publishes a given experiments results to the model card.
:param mc_loc: Location of metrics in the model card (e.g. performance.performanceMetrics)
:param semver: Semantic version of release to create (if artifacts present), defaults to 0.1.0 or next
:param notes: Notes for release, defaults to ""
:param run_id: Local experiment run ID to be selected, defaults to None
:param select_by: String describing experiment to be selected (e.g. "accuracy MIN|MAX"), defaults to None
..note:: mc_loc is dependent on the model card schema being used
..warning:: User must specify either run_id or select_by, otherwise the code will error
"""
# Check if already published, can only published once
if self.published:
raise BailoException("This experiment has already been published.")
mc = self.model.model_card
if mc is None:
raise BailoException("Model card needs to be populated before publishing an experiment.")
mc = NestedDict(mc)
if len(self.raw) == 0:
raise BailoException(f"This experiment has no runs to publish.")
if (select_by is None) and (run_id is None):
raise BailoException(
"Either select_by (e.g. 'accuracy MIN|MAX') or run_id is required to publish an experiment run."
)
sel_run: dict[Any, Any]
if (select_by is not None) and (run_id is None):
sel_run = self.__select_run(select_by=select_by)
if run_id is not None:
for run in self.raw:
if run["run"] == run_id:
sel_run = run
break
else:
raise NameError(f"Run {run_id} does not exist.")
values = []
for metric in sel_run["metrics"]:
for k, v in metric.items():
values.append({"name": k, "value": v})
# Updating the model card
parsed_values = [{"dataset": sel_run["dataset"], "datasetMetrics": values}]
mc[tuple(mc_loc.split("."))] = parsed_values
self.model.update_model_card(model_card=mc)
# Creating a release and uploading artifacts (if artifacts present)
artifacts = sel_run["artifacts"]
if len(artifacts):
# Create new release
try:
release_latest_version = self.model.get_latest_release().version
release_new_version = release_latest_version.next_minor()
except:
release_new_version = semver
run_id = sel_run["run"]
notes = f"{notes} (Run ID: {run_id})"
release_new = self.model.create_release(version=release_new_version, minor=True, notes=notes)
logger.info(
f"Uploading %d artifacts to version %s of model %s.",
len(artifacts),
str(release_new_version),
self.model.model_id,
)
for artifact in artifacts:
release_new.upload(path=artifact)
self.published = True
if os.path.exists(self.temp_dir) and os.path.isdir(self.temp_dir):
shutil.rmtree(self.temp_dir)
logger.info(
f"Successfully published experiment run %s to model %s.",
str(run_id),
self.model.model_id,
)
def __select_run(self, select_by: str) -> dict:
# Parse target and order from select_by string
select_by_split = select_by.split(" ")
if len(select_by_split) != 2:
raise BailoException("Invalid select_by string. Expected format is 'metric_name MIN|MAX'.")
order_str = select_by_split[1].upper()
order_opt = {"MIN": 0, "MAX": -1}
if order_str not in order_opt:
raise BailoException(f"Runs can only be ordered by MIN or MAX, not {order_str}.")
target_str = select_by_split[0]
# Extract target value for each run
runs = self.raw
for run in runs:
metrics = run["metrics"]
for metric in metrics:
target_value = metric.get(target_str, None)
if target_value is not None:
run["target"] = target_value
else:
raise BailoException(
f"Target '{target_str}' could not be found in at least one experiment run, or is not an integer. Therefore ordering cannot take place."
)
# Sort experiment runs by target value into ascending order, and select first or last depending on order_str
ordered_runs = sorted(runs, key=lambda run: run["target"])
return ordered_runs[order_opt[order_str]]