diff --git a/src/_nebari/provider/cloud/google_cloud.py b/src/_nebari/provider/cloud/google_cloud.py index 6b54e40e9..81f86d8b1 100644 --- a/src/_nebari/provider/cloud/google_cloud.py +++ b/src/_nebari/provider/cloud/google_cloud.py @@ -57,13 +57,23 @@ def kubernetes_versions(region: str) -> List[str]: credentials, project_id = load_credentials() client = container_v1.ClusterManagerClient(credentials=credentials) response = client.get_server_config( - name=f"projects/{project_id}/locations/{region}" + name=f"projects/{project_id}/locations/{region}", timeout=300 ) supported_kubernetes_versions = response.valid_master_versions return filter_by_highest_supported_k8s_version(supported_kubernetes_versions) +def get_patch_version(full_version: str) -> str: + return full_version.split("-")[0] + + +def get_minor_version(full_version: str) -> str: + patch_version = get_patch_version(full_version) + parts = patch_version.split(".") + return f"{parts[0]}.{parts[1]}" + + def cluster_exists(cluster_name: str, region: str) -> bool: """Check if a GKE cluster exists.""" credentials, project_id = load_credentials() diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index b716bbd5d..41e5d8241 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -8,7 +8,7 @@ import tempfile from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union -from pydantic import Field, field_validator, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from _nebari import constants from _nebari.provider import terraform @@ -359,6 +359,9 @@ class GCPNodeGroup(schema.Base): class GoogleCloudPlatformProvider(schema.Base): + # If you pass a major and minor version without a patch version + # yaml will pass it as a float, so we need to coerce it to a string + model_config = ConfigDict(coerce_numbers_to_str=True) region: str project: str kubernetes_version: str @@ -373,6 +376,12 @@ class GoogleCloudPlatformProvider(schema.Base): master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None + @field_validator("kubernetes_version", mode="before") + @classmethod + def transform_version_to_str(cls, value) -> str: + """Transforms the version to a string if it is not already.""" + return str(value) + @model_validator(mode="before") @classmethod def _check_input(cls, data: Any) -> Any: @@ -383,8 +392,10 @@ def _check_input(cls, data: Any) -> Any: ) available_kubernetes_versions = google_cloud.kubernetes_versions(data["region"]) - print(available_kubernetes_versions) - if data["kubernetes_version"] not in available_kubernetes_versions: + if not any( + v.startswith(str(data["kubernetes_version"])) + for v in available_kubernetes_versions + ): raise ValueError( f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." ) diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index 743d30cb4..0b5f36bc7 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -409,13 +409,15 @@ def check_cloud_provider_kubernetes_version( versions = google_cloud.kubernetes_versions(region) if not kubernetes_version or kubernetes_version == LATEST: - kubernetes_version = get_latest_kubernetes_version(versions) + kubernetes_version = google_cloud.get_patch_version( + get_latest_kubernetes_version(versions) + ) rich.print( DEFAULT_KUBERNETES_VERSION_MSG.format( kubernetes_version=kubernetes_version ) ) - if kubernetes_version not in versions: + if not any(v.startswith(kubernetes_version) for v in versions): raise ValueError( f"Invalid Kubernetes version `{kubernetes_version}`. Please refer to the GCP docs for a list of valid versions: {versions}" )