# Copyright 2020 The Merlin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
from urllib.parse import urlparse
from google.cloud import storage
from os.path import basename, exists, dirname
from os import makedirs
[docs]def guess_mlp_ui_url(mlp_api_url: str) -> str:
    url = mlp_api_url.replace("/api", "")
    if not url.startswith("http://"):
        return f"http://{url}"
    return url 
[docs]def autostr(cls):
    def __str__(self):
        return '%s(%s)' % (
            type(self).__name__,
            ', '.join('%s=%s' % item for item in vars(self).items())
        )
    cls.__str__ = __str__
    cls.__repr__ = __str__
    return cls 
[docs]def valid_name_check(input_name: str) -> bool:
    """
    Checks if inputted name for project and model is url-friendly
        - has to be lower case
        - can only contain character (a-z) number (0-9) and some limited symbols
    """
    # allowed characters to be included in pattern after backslash
    pattern = r'[-a-z0-9]+'
    matching_group = None
    if re.search(pattern, input_name):
        match = re.search(pattern, input_name)
        if match is None:
            return False
        matching_group = match.group(0)
    return matching_group == input_name 
[docs]def get_bucket_name(gcs_uri: str) -> str:
    parsed_result = urlparse(gcs_uri)
    return parsed_result.netloc 
[docs]def get_gcs_path(gcs_uri: str) -> str:
    parsed_result = urlparse(gcs_uri)
    return parsed_result.path.strip("/") 
[docs]def download_files_from_gcs(gcs_uri: str, destination_path: str):
    makedirs(destination_path, exist_ok=True)
    client = storage.Client()
    bucket_name = get_bucket_name(gcs_uri)
    path = get_gcs_path(gcs_uri)
    bucket = client.get_bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=path)
    for blob in blobs:
        # Get only the path after .../artifacts/model
        # E.g.
        # Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb
        # we only want to extract 1/saved_model.pb
        artifact_path = os.path.join(*blob.name.split("/")[5:])
        dir = os.path.join(destination_path, dirname(artifact_path))
        makedirs(dir, exist_ok=True)
        blob.download_to_filename(os.path.join(destination_path, artifact_path))