Skip to content

Commit

Permalink
[Rust] Refactor union case patterns (fable-compiler#3932)
Browse files Browse the repository at this point in the history
* [Rust] Refactor union case patterns

* [Rust] Match union case fields
  • Loading branch information
ncave authored Oct 20, 2024
1 parent 37fa7b2 commit 33b501f
Showing 1 changed file with 66 additions and 87 deletions.
153 changes: 66 additions & 87 deletions src/Fable.Transforms/Rust/Fable2Rust.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,12 +1603,14 @@ module Util =
transformLeaveContext com ctx argType arg
)

let prepareRefForPatternMatch (com: IRustCompiler) ctx typ (name: string option) fableExpr =
let makeRefForPatternMatch (com: IRustCompiler) ctx typ (nameOpt: string option) fableExpr =
let expr = com.TransformExpr(ctx, fableExpr)

if isThisArgumentIdentExpr ctx fableExpr then
expr
elif (name.IsSome && isRefScoped ctx name.Value) || (isInRefType com typ) then
elif isInRefType com typ then
expr
elif nameOpt.IsSome && isRefScoped ctx nameOpt.Value then
expr
elif shouldBeRefCountWrapped com ctx typ |> Option.isSome then
expr |> makeAsRef
Expand Down Expand Up @@ -1856,13 +1858,13 @@ module Util =
entName + "::" + unionCase.Name
)

let getUnionCaseFields com ctx name tag (unionCase: Fable.UnionCase) =
unionCase.UnionCaseFields
|> List.mapi (fun i field ->
let fieldName = $"{name}_{tag}_{i}"
let fieldType = FableTransforms.uncurryType field.FieldType
makeTypedIdent fieldType fieldName
)
// let getUnionCaseFields com ctx name caseIndex (unionCase: Fable.UnionCase) =
// unionCase.UnionCaseFields
// |> List.mapi (fun i _field ->
// let fieldName = $"{name}_{caseIndex}_{i}"
// let fieldType = FableTransforms.uncurryType field.FieldType
// makeTypedIdent fieldType fieldName
// )

let makeUnion (com: IRustCompiler) ctx r values tag entRef genArgs =
let ent = com.GetEntity(entRef)
Expand Down Expand Up @@ -1991,7 +1993,7 @@ module Util =

let sourceIsRef =
match e with
| Fable.Get(Fable.IdentExpr ident, _, _, _)
| Fable.Get(Fable.IdentExpr ident, _, _, _) -> isArmScoped ctx ident.Name
| MaybeCasted(Fable.IdentExpr ident) -> isRefScoped ctx ident.Name
| _ -> false

Expand All @@ -2002,7 +2004,6 @@ module Util =
let mustClone =
match e with
| MaybeCasted(Fable.IdentExpr ident) ->
// isArmScoped ctx ident.Name ||
// clone non-mutable idents if used more than once
not (ident.IsMutable) && not (isUsedOnce ctx ident.Name) //&& not (isByRefType com ident.Type)
| Fable.Get(_, Fable.FieldGet _, _, _) -> true // always clone field get exprs
Expand Down Expand Up @@ -2558,7 +2559,7 @@ module Util =

let unionCaseName = getUnionCaseName com ctx info.Entity unionCase
let pat = makeUnionCasePat unionCaseName fields
let expr = fableExpr |> prepareRefForPatternMatch com ctx fableExpr.Type None
let expr = makeRefForPatternMatch com ctx fableExpr.Type None fableExpr
let thenExpr = mkGenericPathExpr [ fieldName ] None |> makeClone

let arms = [ mkArm [] pat None thenExpr ]
Expand Down Expand Up @@ -2774,7 +2775,7 @@ module Util =
| Fable.Test(Fable.IdentExpr ident, Fable.UnionCaseTest _, _) ->
// add scoped ident to ctx for thenBody
let usages = calcIdentUsages [ ident ] [ thenBody ]
getScopedIdentCtx com ctx ident true true false false usages
getScopedIdentCtx com ctx ident true false false false usages
| _ -> ctx

transformLeaveContext com ctx None thenBody
Expand Down Expand Up @@ -2888,32 +2889,60 @@ module Util =
mkLetExpr pat downcastExpr
| _ -> makeLibCall com ctx genArgsOpt "Native" "type_test" [ expr ]

let makeUnionCaseTest (com: IRustCompiler) ctx range tag (fableExpr: Fable.Expr) =
match fableExpr.Type with
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)
assert (ent.IsFSharpUnion)
// let genArgsOpt = transformGenArgs com ctx genArgs // TODO:
let unionCase = ent.UnionCases |> List.item tag
let makeUnionCasePatOpt (com: IRustCompiler) ctx typ nameOpt caseIndex =
match typ with
| Fable.Option(genArg, _) ->
// let genArgsOpt = transformGenArgs com ctx [genArg]
let unionCaseFullName = [ "Some"; "None" ] |> List.item caseIndex |> rawIdent

let fields =
match fableExpr with
| Fable.IdentExpr ident ->
let fieldIdents = getUnionCaseFields com ctx ident.Name tag unionCase
fieldIdents |> List.map (fun fi -> makeFullNameIdentPat fi.Name)
| _ ->
if List.isEmpty unionCase.UnionCaseFields then
[]
else
[ WILD_PAT ]
match caseIndex with
| 0 ->
match nameOpt with
| Some identName ->
let fieldName = $"{identName}_{caseIndex}_{0}"
[ makeFullNameIdentPat fieldName ]
| _ -> [ WILD_PAT ]
| _ -> []

let unionCaseName =
tryUseKnownUnionCaseNames unionCaseFullName
|> Option.defaultValue unionCaseFullName

let unionCaseName = getUnionCaseName com ctx entRef unionCase
let pat = makeUnionCasePat unionCaseName fields
Some(pat)
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)

let expr =
fableExpr
|> prepareRefForPatternMatch com ctx fableExpr.Type (tryGetIdentName fableExpr)
if ent.IsFSharpUnion then
// let genArgsOpt = transformGenArgs com ctx genArgs // TODO:
let unionCase = ent.UnionCases |> List.item caseIndex

let fields =
match nameOpt with
| Some identName ->
unionCase.UnionCaseFields
|> List.mapi (fun i _field ->
let fieldName = $"{identName}_{caseIndex}_{i}"
makeFullNameIdentPat fieldName
)
| _ -> unionCase.UnionCaseFields |> List.map (fun _field -> WILD_PAT)

let unionCaseName = getUnionCaseName com ctx entRef unionCase
let pat = makeUnionCasePat unionCaseName fields
Some(pat)
else
None
| _ -> None

let makeUnionCaseTest (com: IRustCompiler) ctx range tag (fableExpr: Fable.Expr) =
let typ = fableExpr.Type
let nameOpt = tryGetIdentName fableExpr
let patOpt = makeUnionCasePatOpt com ctx typ nameOpt tag

match patOpt with
| Some pat ->
let expr = makeRefForPatternMatch com ctx typ nameOpt fableExpr
let letExpr = mkLetExpr pat expr
letExpr
| _ -> failwith "unreachable"
Expand Down Expand Up @@ -2997,55 +3026,6 @@ module Util =

mkArm attrs pat guard body

let makeUnionCasePatOpt evalType evalName caseIndex =
match evalType with
| Fable.Option(genArg, _) ->
// let genArgsOpt = transformGenArgs com ctx [genArg]
let unionCaseFullName = [ "Some"; "None" ] |> List.item caseIndex |> rawIdent

let fields =
match evalName with
| Some idName ->
match caseIndex with
| 0 ->
let fieldName = $"{idName}_{caseIndex}_{0}"
[ makeFullNameIdentPat fieldName ]
| _ -> []
| _ -> [ WILD_PAT ]

let unionCaseName =
tryUseKnownUnionCaseNames unionCaseFullName
|> Option.defaultValue unionCaseFullName

Some(makeUnionCasePat unionCaseName fields)
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)

if ent.IsFSharpUnion then
// let genArgsOpt = transformGenArgs com ctx genArgs
let unionCase = ent.UnionCases |> List.item caseIndex

let fields =
match evalName with
| Some idName ->
unionCase.UnionCaseFields
|> List.mapi (fun i _field ->
let fieldName = $"{idName}_{caseIndex}_{i}"
makeFullNameIdentPat fieldName
)
| _ ->
if List.isEmpty unionCase.UnionCaseFields then
[]
else
[ WILD_PAT ]

let unionCaseName = getUnionCaseName com ctx entRef unionCase

Some(makeUnionCasePat unionCaseName fields)
else
None
| _ -> None

let evalType, evalName =
match evalExpr with
| Fable.Get(Fable.IdentExpr ident, Fable.UnionTag, _, _) -> ident.Type, Some ident.Name
Expand All @@ -3057,7 +3037,7 @@ module Util =
let patOpt =
match caseExpr with
| Fable.Value(Fable.NumberConstant(Fable.NumberValue.Int32 tag, Fable.NumberInfo.Empty), r) ->
makeUnionCasePatOpt evalType evalName tag
makeUnionCasePatOpt com ctx evalType evalName tag
| _ -> None

let pat =
Expand All @@ -3082,11 +3062,11 @@ module Util =
| Fable.Get(Fable.IdentExpr ident, Fable.OptionValue, _, _) when
Some ident.Name = evalName && ident.Type = evalType
->
makeUnionCasePatOpt evalType evalName 0
makeUnionCasePatOpt com ctx evalType evalName 0
| Fable.Get(Fable.IdentExpr ident, Fable.UnionField info, _, _) when
Some ident.Name = evalName && ident.Type = evalType
->
makeUnionCasePatOpt evalType evalName info.CaseIndex
makeUnionCasePatOpt com ctx evalType evalName info.CaseIndex
| _ ->
//need to recurse or this only works for trivial expressions
let subExprs = getSubExpressions expr
Expand All @@ -3098,8 +3078,7 @@ module Util =
let extraVals = namesForIndex evalType evalName targetIndex
makeArm pat targetIndex boundValues extraVals

let expr = evalExpr |> prepareRefForPatternMatch com ctx evalType evalName

let expr = makeRefForPatternMatch com ctx evalType evalName evalExpr
mkMatchExpr expr (arms @ [ defaultArm ])

let matchTargetIdentAndValues idents values =
Expand Down

0 comments on commit 33b501f

Please sign in to comment.