# Copyright 2020 The Merlin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os import listdir
from os.path import isdir
[docs]def validate_model_dir(input_model_type, target_model_type, model_dir):
"""
Validates user-provided model directory based on file structure.
For tensorflow models, checking is only done on the subdirectory
with the largest version number.
:param input_model_type: type of given model
:param target_model_type: type of supposed model, dependent on log_<model type>(...)
:param model_dir: directory containing serialised model file
"""
from merlin.model import ModelType
if target_model_type == None and input_model_type == ModelType.TENSORFLOW:
path_isdir = [isdir(f'{model_dir}/{path}') for path in listdir(model_dir)]
if len(listdir(model_dir)) > 0 and all(path_isdir):
model_dir = f'{model_dir}/{sorted(listdir(model_dir))[-1]}'
if input_model_type != ModelType.PYFUNC and input_model_type != ModelType.PYFUNC_V2:
file_structure_reqs_map = {
ModelType.XGBOOST: ['model.bst'],
ModelType.TENSORFLOW: ['saved_model.pb', 'variables'],
ModelType.SKLEARN: ['model.joblib'],
ModelType.PYTORCH: ['model.pt'],
ModelType.ONNX: ['model.onnx']
}
input_structure = listdir(model_dir)
file_structure_req = file_structure_reqs_map[input_model_type]
if not all([req in input_structure for req in file_structure_req]):
raise Exception(
f"Provided {input_model_type.name} model directory should contain all of the following: {file_structure_req}")