Skip to content

Commit

Permalink
feat: Change Fabric Cog Service Token to Support Billing (#2291)
Browse files Browse the repository at this point in the history
* change fabric cogservice token to support billing

* change mwc token

* rename

---------

Co-authored-by: cruise <[email protected]>
  • Loading branch information
lhrotk and mslhrotk authored Oct 10, 2024
1 parent 6854b5f commit 8e1c6f8
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
Expand Down Expand Up @@ -330,7 +330,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
Option(FabricClient.getCognitiveMWCTokenAuthHeader)
} else {
providedCustomAuthHeader
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary}
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
Expand Down Expand Up @@ -277,18 +277,6 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
override protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader.isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default OpenAI Token On Fabric")
Option(OpenAITokenLibrary.getAuthHeader)
} else {
providedCustomHeader
}
}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
setDefault(timeout -> 360.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
Expand All @@ -20,7 +19,7 @@ import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
Expand All @@ -19,7 +19,7 @@ import scala.language.existentials
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand All @@ -22,7 +23,7 @@ import scala.language.existentials
object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
with HasOpenAIEmbeddingParams with HasCognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.Row.unapplySeq
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -28,7 +27,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with HasOpenAICognitiveServiceInput
with HasCognitiveServiceInput
with ComplexParamsWritable with SynapseMLLogging {

logClass(FeatureNames.AiServices.OpenAI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ object FabricClient extends RESTUtils {

private def getHeaders: Map[String, String] = {
Map(
"Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}",
"Authorization" -> s"${getMLWorkloadAADAuthHeader}",
"RequestId" -> UUID.randomUUID().toString,
"Content-Type" -> "application/json",
"x-ms-workload-resource-moniker" -> UUID.randomUUID().toString
Expand All @@ -143,4 +143,10 @@ object FabricClient extends RESTUtils {
def usagePost(url: String, body: String): JsValue = {
usagePost(url, body, getHeaders);
}

def getMLWorkloadAADAuthHeader: String = TokenLibrary.getMLWorkloadAADAuthHeader

def getCognitiveMWCTokenAuthHeader: String = {
TokenLibrary.getCognitiveMwcTokenAuthHeader(WorkspaceID.getOrElse(""), ArtifactID.getOrElse(""))
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ package com.microsoft.azure.synapse.ml.fabric
import scala.reflect.runtime.currentMirror
import scala.reflect.runtime.universe._

trait AuthHeaderProvider {
def getAuthHeader: String
}

object TokenLibrary extends AuthHeaderProvider {
object TokenLibrary {
def getAccessToken: String = {
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
val mirror = currentMirror
Expand All @@ -27,9 +23,29 @@ object TokenLibrary extends AuthHeaderProvider {
}
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type $argType not found"))
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
methodMirror("pbi").asInstanceOf[String]
methodMirror("ml").asInstanceOf[String]
}

def getSparkMwcToken(workspaceId: String, artifactId: String): String = {
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
val mirror = currentMirror
val module = mirror.staticModule(objectName)
val obj = mirror.reflectModule(module).instance
val objType = mirror.reflect(obj).symbol.toType
val methodName = "getMwcToken"
val methodSymbols = objType.decl(TermName(methodName)).asTerm.alternatives
val argTypes = List(typeOf[String], typeOf[String], typeOf[Integer], typeOf[String])
val selectedMethodSymbol = methodSymbols.find { m =>
m.asMethod.paramLists.flatten.map(_.typeSignature).zip(argTypes).forall { case (a, b) => a =:= b }
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type not found"))
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
methodMirror(workspaceId, artifactId, 2, "SparkCore")
.asInstanceOf[String]
}


def getMLWorkloadAADAuthHeader: String = "Bearer " + getAccessToken

def getAuthHeader: String = "Bearer " + getAccessToken
def getCognitiveMwcTokenAuthHeader(workspaceId: String, artifactId: String): String = "MwcToken " +
getSparkMwcToken(workspaceId, artifactId)
}

0 comments on commit 8e1c6f8

Please sign in to comment.