Skip to content

Commit

Permalink
fix modelcard local testing
Browse files Browse the repository at this point in the history
  • Loading branch information
thorrester committed Jul 27, 2024
1 parent cb1af6e commit e9b6232
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 99 deletions.
6 changes: 2 additions & 4 deletions opsml/model/interfaces/catboost_.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CatBoostModel(ModelInterface):
"""

model: Optional[CatBoost] = None
sample_data: Optional[Union[List[Any], NDArray[Any]]] = None
sample_data: Optional[Union[List[Any], NDArray[Any], DataInterface]] = None
preprocessor: Optional[Any] = None
preprocessor_name: str = CommonKwargs.UNDEFINED.value

Expand All @@ -82,9 +82,7 @@ def _get_sample_data(cls, sample_data: NDArray[Any]) -> Union[List[Any], DataInt
sample_data,
get_class_name(sample_data),
)
assert isinstance(
sample_data, NumpyData
), "Sample data should be a numpy array if using an interface"
assert isinstance(sample_data, NumpyData), "Sample data should be a numpy array if using an interface"

# validate data
assert isinstance(sample_data.data, np.ndarray), "Data should be a numpy array if using an interface"
Expand Down
3 changes: 2 additions & 1 deletion opsml/model/interfaces/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_model_args,
get_processor_name,
)
from opsml.data.interfaces import DataInterface
from opsml.types import CommonKwargs, TrainedModelType
from opsml.types.extra import Suffix

Expand Down Expand Up @@ -46,7 +47,7 @@ class LightGBMModel(ModelInterface):
"""

model: Optional[Union[Booster, LGBMModel]] = None
sample_data: Optional[Union[pd.DataFrame, NDArray[Any]]] = None
sample_data: Optional[Union[pd.DataFrame, NDArray[Any], DataInterface]] = None
preprocessor: Optional[Any] = None
preprocessor_name: str = CommonKwargs.UNDEFINED.value

Expand Down
20 changes: 18 additions & 2 deletions opsml/model/interfaces/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from pydantic import ConfigDict, model_validator

from opsml.helpers.utils import OpsmlImportExceptions, get_class_name
from opsml.model.interfaces.base import ModelInterface, SamplePrediction, get_model_args, get_processor_name, _set_data_args
from opsml.model.interfaces.base import (
ModelInterface,
SamplePrediction,
_set_data_args,
get_model_args,
get_processor_name,
)
from opsml.data.interfaces import DataInterface
from opsml.types import (
CommonKwargs,
ModelReturn,
Expand All @@ -19,6 +26,7 @@

try:
import torch

from opsml.data.interfaces import TorchData

ValidData = Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]
Expand Down Expand Up @@ -51,7 +59,15 @@ class TorchModel(ModelInterface):
"""

model: Optional[torch.nn.Module] = None
sample_data: Optional[Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]] = None
sample_data: Optional[
Union[
torch.Tensor,
Dict[str, torch.Tensor],
List[torch.Tensor],
Tuple[torch.Tensor],
DataInterface,
]
] = None
onnx_args: Optional[TorchOnnxArgs] = None
save_args: TorchSaveArgs = TorchSaveArgs()
preprocessor: Optional[Any] = None
Expand Down
10 changes: 8 additions & 2 deletions opsml/model/interfaces/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from numpy.typing import NDArray
from pydantic import ConfigDict, model_validator

from opsml.model.interfaces.base import ModelInterface, get_model_args, get_processor_name, _set_data_args
from opsml.model.interfaces.base import (
ModelInterface,
_set_data_args,
get_model_args,
get_processor_name,
)
from opsml.data.interfaces import DataInterface
from opsml.types import CommonKwargs, Suffix, TrainedModelType

try:
Expand Down Expand Up @@ -37,7 +43,7 @@ class SklearnModel(ModelInterface):
"""

model: Optional[BaseEstimator] = None
sample_data: Optional[Union[pd.DataFrame, NDArray[Any]]] = None
sample_data: Optional[Union[pd.DataFrame, NDArray[Any], DataInterface]] = None
preprocessor: Optional[Any] = None
preprocessor_name: str = CommonKwargs.UNDEFINED.value

Expand Down
21 changes: 18 additions & 3 deletions opsml/model/interfaces/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import numpy as np
from numpy.typing import NDArray
from pydantic import ConfigDict, model_validator

from opsml.data.interfaces import NumpyData
from opsml.helpers.utils import get_class_name
from opsml.model.interfaces.base import ModelInterface, get_model_args, get_processor_name, _set_data_args
from opsml.model.interfaces.base import (
ModelInterface,
_set_data_args,
get_model_args,
get_processor_name,
)
from opsml.types import CommonKwargs, Suffix, TrainedModelType
from opsml.data.interfaces import DataInterface


try:
import tensorflow as tf
Expand Down Expand Up @@ -39,7 +46,15 @@ class TensorFlowModel(ModelInterface):
"""

model: Optional[tf.keras.Model] = None # pylint: disable=no-member
sample_data: Optional[Union[ArrayType, Dict[str, ArrayType], List[ArrayType], Tuple[ArrayType]]] = None
sample_data: Optional[
Union[
ArrayType,
Dict[str, ArrayType],
List[ArrayType],
Tuple[ArrayType],
DataInterface,
]
] = None
preprocessor: Optional[Any] = None
preprocessor_name: str = CommonKwargs.UNDEFINED.value

Expand Down
6 changes: 2 additions & 4 deletions opsml/model/interfaces/vowpal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from pydantic import ConfigDict, model_validator

from opsml.helpers.utils import get_class_name
from opsml.model.interfaces.base import ModelInterface
from opsml.model.interfaces.base import ModelInterface, _set_data_args
from opsml.types import CommonKwargs, ModelReturn, Suffix, TrainedModelType

try:
Expand Down Expand Up @@ -59,8 +58,7 @@ def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
assert model is not None, "Model must not be None"

sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
model_args = _set_data_args(sample_data, model_args)
model_args[CommonKwargs.VOWPAL_ARGS.value] = model.get_arguments()

return model_args
Expand Down
26 changes: 18 additions & 8 deletions opsml/model/interfaces/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from pydantic import ConfigDict, model_validator

from opsml.helpers.logging import ArtifactLogger
from opsml.helpers.utils import get_class_name
from opsml.model import ModelInterface
from opsml.model.interfaces.base import get_model_args, get_processor_name
from opsml.model.interfaces.base import (
_set_data_args,
get_model_args,
get_processor_name,
)
from opsml.data.interfaces import DataInterface
from opsml.types import CommonKwargs, ModelReturn, Suffix, TrainedModelType

logger = ArtifactLogger.get_logger()
Expand Down Expand Up @@ -42,7 +46,14 @@ class XGBoostModel(ModelInterface):
"""

model: Optional[Union[Booster, XGBModel]] = None
sample_data: Optional[Union[pd.DataFrame, NDArray[Any], DMatrix]] = None
sample_data: Optional[
Union[
pd.DataFrame,
NDArray[Any],
DMatrix,
DataInterface,
]
] = None
preprocessor: Optional[Any] = None
preprocessor_name: str = CommonKwargs.UNDEFINED.value

Expand All @@ -55,7 +66,7 @@ def model_class(self) -> str:
return TrainedModelType.SKLEARN_ESTIMATOR.value

@classmethod
def _get_sample_data(cls, sample_data: Any) -> Union[pd.DataFrame, NDArray[Any], DMatrix]:
def _get_sample_data(cls, sample_data: Any) -> Any:
"""Check sample data and returns one record to be used
during type inference and ONNX conversion/validation.
Expand Down Expand Up @@ -88,8 +99,7 @@ def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
model_args[CommonKwargs.MODEL_TYPE.value] = "subclass"

sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value])
model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data
model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data)
model_args = _set_data_args(sample_data, model_args)
model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name(
model_args.get(CommonKwargs.PREPROCESSOR.value),
)
Expand Down Expand Up @@ -173,7 +183,7 @@ def save_sample_data(self, path: Path) -> None:
self.sample_data.save_binary(path)

else:
joblib.dump(self.sample_data, path)
super().save_sample_data(path)

def load_sample_data(self, path: Path) -> None:
"""Serialized and save sample data to path.
Expand All @@ -185,7 +195,7 @@ def load_sample_data(self, path: Path) -> None:
if self.model_class == TrainedModelType.XGB_BOOSTER.value:
self.sample_data = DMatrix(path)
else:
self.sample_data = joblib.load(path)
super().load_sample_data(path)

@property
def model_suffix(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion opsml/model/metadata_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_input_schema(self) -> Dict[str, Feature]:
try:
model_data = get_model_data(
data_type=self.interface.data_type,
input_data=self.interface.sample_data,
input_data=self.interface._prediction_data,
)

return model_data.feature_dict
Expand Down
2 changes: 1 addition & 1 deletion opsml/model/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def convert_model(self) -> ModelReturn:

model_data = get_model_data(
data_type=self.interface.data_type,
input_data=self.interface.sample_data,
input_data=self.interface._prediction_data,
)

onnx_model_return = _OnnxConverterHelper.convert_model(model_interface=self.interface, data_helper=model_data)
Expand Down
23 changes: 12 additions & 11 deletions opsml/model/onnx/torch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,21 @@ def __init__(self, model_interface: TorchModel):

def _get_additional_model_args(self) -> TorchOnnxArgs:
"""Passes or creates TorchOnnxArgs needed for Onnx model conversion"""
prediction_data = self.interface._prediction_data

if self.interface.onnx_args is None:
assert self.interface.sample_data is not None, "Sample data must be provided"
return _PytorchArgBuilder(input_data=self.interface.sample_data).get_args()
assert prediction_data is not None, "Sample data must be provided"
return _PytorchArgBuilder(input_data=prediction_data).get_args()
return self.interface.onnx_args

def _coerce_data_for_onnx(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
assert self.interface.sample_data is not None, "Sample data must not be None"
assert self.interface._prediction_data is not None, "Sample data must not be None"

if isinstance(self.interface.sample_data, dict):
return tuple(self.interface.sample_data.values())
if isinstance(self.interface.sample_data, torch.Tensor):
return self.interface.sample_data
return tuple(self.interface.sample_data)
if isinstance(self.interface._prediction_data, dict):
return tuple(self.interface._prediction_data.values())
if isinstance(self.interface._prediction_data, torch.Tensor):
return self.interface._prediction_data
return tuple(self.interface._prediction_data)

def _load_onnx_model(self, path: Path) -> rt.InferenceSession:
return rt.InferenceSession(
Expand Down Expand Up @@ -105,8 +106,8 @@ def _get_additional_model_args(self) -> TorchOnnxArgs:
"""Passes or creates TorchOnnxArgs needed for Onnx model conversion"""

if self.interface.onnx_args is None:
assert self.interface.sample_data is not None, "No sample data provided"
return _PytorchArgBuilder(input_data=cast(ValidData, self.interface.sample_data)).get_args()
assert self.interface._prediction_data is not None, "No sample data provided"
return _PytorchArgBuilder(input_data=cast(ValidData, self.interface._prediction_data)).get_args()
return self.interface.onnx_args

def _load_onnx_model(self, path: Path) -> rt.InferenceSession:
Expand All @@ -127,7 +128,7 @@ def convert_to_onnx(self, path: Path) -> OnnxModel:

self.interface.model.model.to_onnx(
path.as_posix(),
self.interface.sample_data,
self.interface._prediction_data,
**onnx_args.model_dump(exclude={"options"}),
)

Expand Down
29 changes: 18 additions & 11 deletions opsml/storage/card_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@
from opsml.model.interfaces.huggingface import HuggingFaceModel
from opsml.settings.config import config
from opsml.storage import client
from opsml.types import (
AllowedDataType,
CardType,
RegistryTableNames,
RegistryType,
SaveName,
Suffix,
)
from opsml.types import AllowedDataType, CardType, RegistryTableNames, RegistryType, SaveName, Suffix, CommonKwargs
from opsml.types.model import ModelMetadata, OnnxModel

logger = ArtifactLogger.get_logger()
Expand Down Expand Up @@ -400,8 +393,20 @@ def onnx_suffix(self) -> str:
return Suffix.ONNX.value
return ""

def _load_data_interface(self) -> str:
# load sample data interface if it exists
if self.card.interface.sample_data_interface_type != CommonKwargs.UNDEFINED.value:
interface: DataInterface = get_interface(RegistryType.DATA, self.card.interface.sample_data_interface_type)() # type: ignore
interface.feature_map = self.card.interface.feature_map
self.card.interface.sample_data = interface

return interface.data_suffix

return self.card.interface.data_suffix

def _load_sample_data(self, lpath: Path, rpath: Path) -> None:
"""Load sample data for model interface. Sample data is always saved via joblib
"""Load sample data for model interface. Sample data is is either saved in
a DataInterface or joblib
Args:
lpath:
Expand All @@ -413,11 +418,13 @@ def _load_sample_data(self, lpath: Path, rpath: Path) -> None:
logger.info("Sample data already loaded")
return None

load_rpath = Path(self.card.uri, SaveName.SAMPLE_MODEL_DATA.value).with_suffix(self.card.interface.data_suffix)
data_suffix = self._load_data_interface()

load_rpath = Path(self.card.uri, SaveName.SAMPLE_MODEL_DATA.value).with_suffix(data_suffix)
if not self.storage_client.exists(load_rpath):
return None

lpath = self.download(lpath, rpath, SaveName.SAMPLE_MODEL_DATA.value, self.card.interface.data_suffix)
lpath = self.download(lpath, rpath, SaveName.SAMPLE_MODEL_DATA.value, data_suffix)

return self.card.interface.load_sample_data(lpath)

Expand Down
4 changes: 1 addition & 3 deletions opsml/storage/card_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,7 @@ def save_card_artifacts(card: Card) -> None:
"""

card_saver = next(
card_saver for card_saver in CardSaver.__subclasses__() if card_saver.validate(card_type=card.card_type)
)
card_saver = next(card_saver for card_saver in CardSaver.__subclasses__() if card_saver.validate(card_type=card.card_type))

saver = card_saver(card=card)

Expand Down
2 changes: 2 additions & 0 deletions opsml/storage/schemas/modelcard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ keys:
- modelcard_uid
- save_args
- arguments
- sample_data_interface_type
- feature_map

Loading

0 comments on commit e9b6232

Please sign in to comment.