Source code for merlin.util

# Copyright 2020 The Merlin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
from urllib.parse import urlparse
from google.cloud import storage
from os.path import basename, exists, dirname
from os import makedirs


[docs]def guess_mlp_ui_url(mlp_api_url: str) -> str: url = mlp_api_url.replace("/api", "") if not url.startswith("http://"): return f"http://{url}" return url
[docs]def autostr(cls): def __str__(self): return '%s(%s)' % ( type(self).__name__, ', '.join('%s=%s' % item for item in vars(self).items()) ) cls.__str__ = __str__ cls.__repr__ = __str__ return cls
[docs]def valid_name_check(input_name: str) -> bool: """ Checks if inputted name for project and model is url-friendly - has to be lower case - can only contain character (a-z) number (0-9) and some limited symbols """ # allowed characters to be included in pattern after backslash pattern = r'[-a-z0-9]+' matching_group = None if re.search(pattern, input_name): match = re.search(pattern, input_name) if match is None: return False matching_group = match.group(0) return matching_group == input_name
[docs]def get_bucket_name(gcs_uri: str) -> str: parsed_result = urlparse(gcs_uri) return parsed_result.netloc
[docs]def get_gcs_path(gcs_uri: str) -> str: parsed_result = urlparse(gcs_uri) return parsed_result.path.strip("/")
[docs]def download_files_from_gcs(gcs_uri: str, destination_path: str): makedirs(destination_path, exist_ok=True) client = storage.Client() bucket_name = get_bucket_name(gcs_uri) path = get_gcs_path(gcs_uri) bucket = client.get_bucket(bucket_name) blobs = bucket.list_blobs(prefix=path) for blob in blobs: # Get only the path after .../artifacts/model # E.g. # Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb # we only want to extract 1/saved_model.pb artifact_path = os.path.join(*blob.name.split("/")[5:]) dir = os.path.join(destination_path, dirname(artifact_path)) makedirs(dir, exist_ok=True) blob.download_to_filename(os.path.join(destination_path, artifact_path))