From 8e1c6f881daca538913a966f7aa02bdf3959822a Mon Sep 17 00:00:00 2001 From: CRUISE LI Date: Fri, 11 Oct 2024 06:58:48 +0800 Subject: [PATCH] feat: Change Fabric Cog Service Token to Support Billing (#2291) * change fabric cogservice token to support billing * change mwc token * rename --------- Co-authored-by: cruise --- .../ml/services/CognitiveServiceBase.scala | 4 +- .../synapse/ml/services/openai/OpenAI.scala | 14 +--- .../openai/OpenAIChatCompletion.scala | 5 +- .../ml/services/openai/OpenAICompletion.scala | 4 +- .../ml/services/openai/OpenAIEmbedding.scala | 3 +- .../ml/services/openai/OpenAIPrompt.scala | 5 +- .../synapse/ml/fabric/FabricClient.scala | 8 ++- .../ml/fabric/OpenAITokenLibrary.scala | 72 ------------------- .../synapse/ml/fabric/TokenLibrary.scala | 30 ++++++-- 9 files changed, 41 insertions(+), 104 deletions(-) delete mode 100644 core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 31c56dc80c..2d123edf56 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services import com.microsoft.azure.synapse.ml.codegen.Wrappable import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions -import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary} +import com.microsoft.azure.synapse.ml.fabric.FabricClient import com.microsoft.azure.synapse.ml.io.http._ import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails @@ -330,7 +330,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader) if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) { logInfo("Using Default AAD Token On Fabric") - Option(TokenLibrary.getAuthHeader) + Option(FabricClient.getCognitiveMWCTokenAuthHeader) } else { providedCustomAuthHeader } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index b57f4d65da..8d40f49ee6 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -4,7 +4,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.codegen.GenerationUtils -import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary} +import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ @@ -277,18 +277,6 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { } } -trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput { - override protected def getCustomAuthHeader(row: Row): Option[String] = { - val providedCustomHeader = getValueOpt(row, CustomAuthHeader) - if (providedCustomHeader.isEmpty && PlatformDetails.runningOnFabric()) { - logInfo("Using Default OpenAI Token On Fabric") - Option(OpenAITokenLibrary.getAuthHeader) - } else { - providedCustomHeader - } - } -} - abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) with HasOpenAISharedParams with OpenAIFabricSetting { setDefault(timeout -> 360.0) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index 703fc7f471..379d797766 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -5,10 +5,9 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat -import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser +import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser} import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} import org.apache.spark.ml.ComplexParamsReadable -import org.apache.spark.ml.param.Param import org.apache.spark.ml.util._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -20,7 +19,7 @@ import scala.language.existentials object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion] class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput + with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput with HasInternalJsonOutputParser with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala index 953138bc36..219bc34d87 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat -import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser +import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser} import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} import org.apache.spark.ml.ComplexParamsReadable import org.apache.spark.ml.util._ @@ -19,7 +19,7 @@ import scala.language.existentials object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion] class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput + with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput with HasInternalJsonOutputParser with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index 58f5c857d6..342254f8fb 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -7,6 +7,7 @@ import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} import org.apache.spark.ml.ComplexParamsReadable import org.apache.spark.ml.linalg.SQLDataTypes.VectorType @@ -22,7 +23,7 @@ import scala.language.existentials object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding] class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging { + with HasOpenAIEmbeddingParams with HasCognitiveServiceInput with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) def this() = this(Identifiable.randomUID("OpenAIEmbedding")) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 66b42833e1..a43f3ffe3a 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -3,17 +3,16 @@ package com.microsoft.azure.synapse.ml.services.openai -import com.microsoft.azure.synapse.ml.services._ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} import com.microsoft.azure.synapse.ml.param.StringStringMapParam +import com.microsoft.azure.synapse.ml.services._ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} -import org.apache.spark.sql.Row.unapplySeq import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{DataType, StructType} @@ -28,7 +27,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer with HasErrorCol with HasOutputCol with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader - with HasOpenAICognitiveServiceInput + with HasCognitiveServiceInput with ComplexParamsWritable with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala index 6382ce791d..236ffc0565 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala @@ -122,7 +122,7 @@ object FabricClient extends RESTUtils { private def getHeaders: Map[String, String] = { Map( - "Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}", + "Authorization" -> s"${getMLWorkloadAADAuthHeader}", "RequestId" -> UUID.randomUUID().toString, "Content-Type" -> "application/json", "x-ms-workload-resource-moniker" -> UUID.randomUUID().toString @@ -143,4 +143,10 @@ object FabricClient extends RESTUtils { def usagePost(url: String, body: String): JsValue = { usagePost(url, body, getHeaders); } + + def getMLWorkloadAADAuthHeader: String = TokenLibrary.getMLWorkloadAADAuthHeader + + def getCognitiveMWCTokenAuthHeader: String = { + TokenLibrary.getCognitiveMwcTokenAuthHeader(WorkspaceID.getOrElse(""), ArtifactID.getOrElse("")) + } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala deleted file mode 100644 index ab8356601d..0000000000 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.fabric - -import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging -import spray.json.DefaultJsonProtocol.StringJsonFormat - -object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider { - private var MLToken: Option[String] = None; - private var IsMWCTokenEnabled: Boolean = true; - val BackgroundRefreshExpiryCushionInMillis: Long = 5 * 60 * 1000L - val OpenAIFeatureName = "SparkCodeFirst" - - private def buildAuthHeader: String = { - if(IsMWCTokenEnabled) { - "MwcToken " + MLToken.getOrElse("") - } else { - "Bearer " + TokenLibrary.getAccessToken; - } - } - - def getAuthHeader: String = { - if (isTokenExpired(MLToken, BackgroundRefreshExpiryCushionInMillis)) { - val artifactId = FabricClient.ArtifactID - val payload = - s"""{ - |"artifactObjectId": "${artifactId.getOrElse("")}", - |"openAIFeatureName": "$OpenAIFeatureName", - |}""".stripMargin - - val url: String = FabricClient.MLWorkloadEndpointML + "cognitive/openai/generatemwctoken"; - try { - val token = FabricClient.usagePost(url, payload).asJsObject.fields("Token").convertTo[String]; - MLToken = Some(token) - IsMWCTokenEnabled = true - } catch { - case e: Throwable => - IsMWCTokenEnabled = false - } - } - buildAuthHeader - } - - private def getExpiryTime(accessToken: String): Long = { - //Extract expiry time - val parser = new FabricTokenParser(accessToken); - parser.getExpiry - } - - private def isTokenExpired(accessToken: Option[String], expiryCushionInMillis: Long = 0): Boolean = { - accessToken match { - case Some(accessToken) => - try { - val expiry: Long = getExpiryTime(accessToken) - val currentTime: Long = System.currentTimeMillis() - currentTime > expiry - expiryCushionInMillis - } - catch { - case t: Throwable => - logInfo("Error while getting token expiry time", t) - true - } - case None => - true - } - - } - - // scalastyle:off - override val uid: String = "OpenAITokenLibrary"; -} diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/TokenLibrary.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/TokenLibrary.scala index 2ebe09f5a3..1ced3b3d7a 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/TokenLibrary.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/TokenLibrary.scala @@ -6,11 +6,7 @@ package com.microsoft.azure.synapse.ml.fabric import scala.reflect.runtime.currentMirror import scala.reflect.runtime.universe._ -trait AuthHeaderProvider { - def getAuthHeader: String -} - -object TokenLibrary extends AuthHeaderProvider { +object TokenLibrary { def getAccessToken: String = { val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary" val mirror = currentMirror @@ -27,9 +23,29 @@ object TokenLibrary extends AuthHeaderProvider { } }.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type $argType not found")) val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod) - methodMirror("pbi").asInstanceOf[String] + methodMirror("ml").asInstanceOf[String] } + def getSparkMwcToken(workspaceId: String, artifactId: String): String = { + val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary" + val mirror = currentMirror + val module = mirror.staticModule(objectName) + val obj = mirror.reflectModule(module).instance + val objType = mirror.reflect(obj).symbol.toType + val methodName = "getMwcToken" + val methodSymbols = objType.decl(TermName(methodName)).asTerm.alternatives + val argTypes = List(typeOf[String], typeOf[String], typeOf[Integer], typeOf[String]) + val selectedMethodSymbol = methodSymbols.find { m => + m.asMethod.paramLists.flatten.map(_.typeSignature).zip(argTypes).forall { case (a, b) => a =:= b } + }.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type not found")) + val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod) + methodMirror(workspaceId, artifactId, 2, "SparkCore") + .asInstanceOf[String] + } + + + def getMLWorkloadAADAuthHeader: String = "Bearer " + getAccessToken - def getAuthHeader: String = "Bearer " + getAccessToken + def getCognitiveMwcTokenAuthHeader(workspaceId: String, artifactId: String): String = "MwcToken " + + getSparkMwcToken(workspaceId, artifactId) }