Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Optional unit rewriting #1252

Closed
wants to merge 11 commits into from
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ include("systems/pde/pdesystem.jl")

include("systems/discrete_system/discrete_system.jl")
include("systems/validation.jl")
include("systems/unitconversion.jl")
include("systems/dependency_graphs.jl")
include("systems/systemstructure.jl")
using .SystemStructures
Expand Down Expand Up @@ -174,7 +175,7 @@ export Equation, ConstrainedEquation
export Term, Sym
export SymScope, LocalScope, ParentScope, GlobalScope
export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure
export structural_simplify
export structural_simplify, rewrite_units
export DiscreteSystem, DiscreteProblem

export calculate_jacobian, generate_jacobian, generate_function
Expand Down
180 changes: 180 additions & 0 deletions src/systems/unitconversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"Wrapper for Unitful.convfact that returns a Constant & throws ValidationError instead of DimensionError."
function unitfactor(u, t)
try
cf = Unitful.convfact(u, t)
return cf == 1 ? 1 : Constant(cf*u/t)
catch err
throw(ValidationError("Unable to convert [$t] to [$u]"))
end
end

"Turn an expression into a Julia function w/ correct units behavior." # mostly for testing
function functionize(pt)
syms = Symbolics.get_variables(pt)
eval(build_function(constructunit(pt), syms, expression = Val{false}))
end

"Represent a constant as a Symbolic (esp. for lifting units to metadata level)."
struct Constant{T, M} <: SymbolicUtils.Symbolic{T}
val::T
metadata::M
end

Constant(x) = Constant(x, Dict(VariableUnit => Unitful.unit(x)))
Base.:*(x::Num, y::Unitful.Quantity) = value(x) * y
Base.:*(x::Unitful.Quantity, y::Num) = x * value(y)
Base.show(io::IO, v::Constant) = Base.show(io, v.val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably show unit if it's there. Good way to know it's not just a number.


Unitless = Union{typeof.([exp, log, sinh, asinh, asin,
cosh, acosh, acos,
tanh, atanh, atan,
coth, acoth, acot,
sech, asech, asec,
csch, acsch, acsc])...}
isunitless(f::Unitless) = true

#Should run this at the end of @variables and @parameters
set_unitless(x::Vector) = [_has_unit(y) ? y : SymbolicUtils.setmetadata(y,VariableUnit,unitless) for y in x]

"Convert symbolic expression `x` to have units `u` if possible."
function unitcoerce(u::Unitful.Unitlike, x::Symbolic)
st = _has_unit(x) ? x : constructunit(x)
tu = _get_unit(st)
output = unitfactor(u, tu) * st
return SymbolicUtils.setmetadata(output, VariableUnit, u)
end

"Convert a set of expressions to a common unit, defined by the first dimensional quantity encountered."
function uniformize(subterms)
newterms = Vector{Any}(undef, size(subterms))
firstunit = nothing
for (idx, st) in enumerate(subterms)
if !isequal(st, 0)
st = constructunit(st)
tu = _get_unit(st)
if firstunit === nothing
firstunit = tu
end
newterms[idx] = unitfactor(firstunit, tu) * st
else
newterms[idx] = 0
end
end
return newterms
end

constructunit(x::Num) = constructunit(value(x))
function constructunit(x::Unitful.Quantity)
return Constant(x.val, Dict(VariableUnit => Unitful.unit(x)))
end

function constructunit(x) #This is where it all starts
maybeunit = safe_get_unit(x,"")
if maybeunit !== nothing
return SymbolicUtils.setmetadata(x, VariableUnit, maybeunit)
else # Something needs to be rewritten
op = operation(x)
args = arguments(x)
constructunit(op, args)
end
end

function constructunit(op, args) # Fallback
if isunitless(op)
try
args = unitcoerce.(unitless, args)
return SymbolicUtils.setmetadata(op(args...), VariableUnit, unitless)
catch err
if err isa Unitful.DimensionError
argunits = get_unit.(args)
throw(ValidationError("Unable to coerce $args to dimensionless from $argunits for function $op."))
else
rethrow(err)
end
end
else
throw(ValidationError("Unknown function $op supplied with $args with units $argunits"))
end
end

function constructunit(op::typeof(+), subterms)
newterms = uniformize(subterms)
output = +(newterms...)
return SymbolicUtils.setmetadata(output, VariableUnit, _get_unit(newterms[1]))
end

function constructunit(op::Conditional, subterms)
newterms = Vector{Any}(undef, 3)
firstunit = nothing
newterms[1] = constructunit(subterms[1])
newterms[2:3] = uniformize(subterms[2:3])
output = op(newterms...)
return SymbolicUtils.setmetadata(output, VariableUnit, _get_unit(newterms[2]))
end

function constructunit(op::Union{Differential,Difference}, subterms)
numerator = constructunit(only(subterms))
nu = _get_unit(numerator)
denominator = op isa Differential ? constructunit(op.x) : constructunit(op.t) #TODO: make consistent!
du = _get_unit(denominator)
output = op isa Differential ? Differential(denominator)(numerator) : Difference(denominator)(numerator)
return SymbolicUtils.setmetadata(output, VariableUnit, nu/du)
end

function constructunit(op::typeof(^), subterms)
base, exponent = subterms
base = constructunit(base)
bu = _get_unit(base)
exponent = constructunit(exponent)
exponent = unitfactor(unitless, _get_unit(exponent)) * exponent
output = base^exponent
output_unit = bu == unitless ? unitless : (exponent isa Real ? bu^exponent : (1*bu)^exponent)
return SymbolicUtils.setmetadata(output, VariableUnit, output_unit)
end

Root = Union{typeof(sqrt),typeof(cbrt)}
function constructunit(op::Root,args)
arg = constructunit(only(args))
argunit = _get_unit(arg)
return SymbolicUtils.setmetadata(op(arg), VariableUnit, op(argunit))
end

function constructunit(op::Comparison, subterms)
newterms = uniformize(subterms)
output = op(newterms...)
return SymbolicUtils.setmetadata(output, VariableUnit, unitless)
end

function constructunit(op::typeof(*), subterms)
newterms = Vector{Any}(undef, size(subterms))
pu = unitless
for (idx, st) in enumerate(subterms)
st = constructunit(st)
pu *= _get_unit(st)
newterms[idx] = st
end
output = op(newterms...)
return SymbolicUtils.setmetadata(output, VariableUnit, pu)
end

function constructunit(eq::ModelingToolkit.Equation)
newterms = uniformize([eq.lhs, eq.rhs])
return ~(newterms...)
#return SymbolicUtils.setmetadata(output,VariableUnit,firstunit) #Fix this once Symbolics.jl Equations accept units
end

"Rewrite a set of equations by inserting appropriate unit conversion factors."
function rewrite_units(eqs::Vector{Equation}; debug = false)
output = similar(eqs)
allgood = true
for (idx, eq) in enumerate(eqs)
try
output[idx] = constructunit(eq)
catch err
allgood = false
err isa ValidationError && !debug ? @warn("in eq [$idx], "*err.message) : rethrow(err)
end
end
allgood || throw(ValidationError("Some equations had invalid units. See warnings for details."))
return output
end
19 changes: 16 additions & 3 deletions src/systems/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ end

"Throw exception on invalid unit types, otherwise return argument."
function screen_unit(result)
result isa Symbolic && return result #For cases like P^γ where base is unitful, exponent is symbolic but dimensionless
result isa Unitful.Unitlike || throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
result isa Unitful.ScalarUnits || throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
result == u"°" && throw(ValidationError("Degrees are not supported. Use radians instead."))
Expand All @@ -33,6 +34,16 @@ Literal = Union{Sym,Symbolics.ArrayOp,Symbolics.Arr,Symbolics.CallWithMetadata}
Conditional = Union{typeof(ifelse),typeof(IfElse.ifelse)}
Comparison = Union{typeof.([==, !=, ≠, <, <=, ≤, >, >=, ≥])...}

#Underscore methods are 'dumb': they only look at the outermost object to see if it has units, they don't traverse the expression tree.
#_has_unit(x::Equation) = getmetadata(x,VariableUnit) Doesn't work yet, equations don't have metadata.
_has_unit(x::Real) = true
_has_unit(x::Num) = _has_unit(value(x))
_has_unit(x::Symbolic) = hasmetadata(x,VariableUnit)

_get_unit(x::Real) = unitless
_get_unit(x::Num) = _get_unit(value(x))
_get_unit(x::Symbolic) = screen_unit(getmetadata(x,VariableUnit,unitless))

"Find the unit of a symbolic item."
get_unit(x::Real) = unitless
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
Expand All @@ -42,6 +53,7 @@ get_unit(x::Literal) = screen_unit(getmetadata(x,VariableUnit, unitless))
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t)
get_unit(op::typeof(getindex),args) = get_unit(args[1])

function get_unit(op,args) # Fallback
result = op(1 .* get_unit.(args)...)
try
Expand Down Expand Up @@ -86,8 +98,8 @@ end

function get_unit(op::Conditional, args)
terms = get_unit.(args)
terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless."))
equivalent(terms[2], terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match."))
terms[1] == unitless || throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless."))
equivalent(terms[2], terms[3]) || throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match."))
return terms[2]
end

Expand All @@ -106,6 +118,7 @@ function get_unit(op::Comparison, args)
end

function get_unit(x::Symbolic)
_has_unit(x) && return _get_unit(x) #Easy out, if the tree has already been traversed by constructunit
if SymbolicUtils.istree(x)
op = operation(x)
if op isa Sym || (op isa Term && operation(op) isa Term) # Dependent variables, not function calls
Expand All @@ -116,7 +129,7 @@ function get_unit(x::Symbolic)
end # Actual function calls:
args = arguments(x)
return get_unit(op, args)
else # This function should only be reached by Terms, for which `istree` is true
else # This method should only be reached by Terms, for which `istree` is true, so this branch should never happen:
throw(ArgumentError("Unsupported value $x."))
end
end
Expand Down
20 changes: 20 additions & 0 deletions test/units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,23 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])

# Rewriting
@variables t [unit = u"ms"] P(t) [unit = u"MW"] E(t) [unit = u"J"]
@parameters τ [unit = u"ms"] γ
D = Differential(t)
eqs = [D(E) ~ P - E/τ]
@test_throws MT.ValidationError MT.get_unit(eqs[1].rhs)
neweqs = MT.rewrite_units(eqs)
@named sys = ODESystem(neweqs)
equations(sys)

@test MT.get_unit(t/τ) == MT._get_unit(MT.constructunit(t/τ))
@test MT.get_unit(2^(t/τ)) == MT._get_unit(MT.constructunit(2^(t/τ)))
@test MT.equivalent(MT.get_unit(t^γ), MT._get_unit(MT.constructunit(t^γ)))
@test MT.get_unit(sin(γ)) == MT._get_unit(MT.sin(γ))
@test MT.get_unit(sqrt(E)) == MT._get_unit(MT.constructunit(sqrt(E)))
@test MT.get_unit(exp(γ)) == MT._get_unit(MT.exp(γ))

@variables E(t) [unit = u"kJ"]
@test MT.get_unit(IfElse.ifelse(t<τ,E/τ,P)) == MT._get_unit(MT.constructunit(IfElse.ifelse(t<τ,E/τ,P)))
@test_throws MT.ValidationError MT.constructunit(E+τ)