import random
from itertools import chain
from sympy import sympify, Expr, Function, sinh, cosh, tanh, exp, log, Derivative, symbols, simplify, Eq, Symbol
from sympy.core.relational import Relational
from sympy.parsing.latex import parse_latex
from sympy import sin, cos, tan, cot, sinh, cosh, tanh, coth, exp, log
from multipledispatch import dispatch
from sympy import GreaterThan, Basic
from sympy.printing.latex import latex
[docs]def is_substricted(symb, tag=None):
"""
Check if symbol is substricted with tag.
:param symb: Union[datamodel.Parameter, datamodel.Phase]
:param tag: Union[string, None]
:return: bool
"""
# simple heuristics
if not tag:
return True if "_{" in symb.__str__() else False
else:
ret = True if f"_{{{tag}}}" in symb.__str__() else False
return ret
[docs]def remove_subscript(symb):
"""
Firstly check if symbol is substricted, then if true remove !ANY! substricted.
:param symb: Union[datamodel.Parameter, datamodel.Phase]
:return: Union[datamodel.Parameter, datamodel.Phase] -- unsubscripted.
"""
if is_substricted(symb):
if symb.args:
# case if function
return symbols(symb.name.split('_')[0], cls=Function)(*symb.args)
return symbols(symb.name.split('_')[0], cls=Symbol)
return symb
[docs]def add_subscript(symb, tag):
"""
Add subscript `tag` to symbol
:param symb: Union[datamodel.Parameter, datamodel.Phase] - untagged
:param tag: str
:return: Union[datamodel.Parameter, datamodel.Phase] -- tagged
"""
if not is_substricted(symb):
if symb.args:
# case if function
return symbols(f'{symb.name}_{{{tag}}}', cls=Function)(*symb.args)
return symbols(f'{symb.name}_{{{tag}}}', cls=Symbol)
return symb
[docs]def latexify(exprs: list, to_str=False):
"""
Provide latex codesnippets to expressions. Can be returned as joined string or as List[str]
:param exprs: List[Expr]
:param to_str: bool
:return: Union[List[Expr], str]
"""
ret = [latex(e) for e in exprs]
if not to_str:
return ret
return ',~'.join(ret)
[docs]def KKT_mask(dual: dict):
"""
Returns Dual Feasibility and Complementary Slackness for dual dict
:param dual: Dict[Union[datamodel.Parameter, datamodel.Phase] --> Expr]
:return: List[Expr] = [K*V]
"""
return [*chain(*[(GreaterThan(v, 0), Eq(k * v, 0)) for k, v in dual.items()])]
[docs]def euler_mask(L, x, t):
"""
Provide Euler-Lagrange equations for passed Lagrangian and variables [phase, time]
:param L: Expr -- Lagrangian
:param x: datamodel.Phase -- Phase variables
:param t: datamodel.Parameter -- Time variable
:return: List[Expr]
"""
if x.args and t in x.args:
x = x.func(*x.args)
else:
raise TypeError('Wrong time :)')
x_prime = x.diff(t)
L_x_prime = L.diff(x_prime)
L_x = L.diff(x)
return Eq(simplify(Derivative(L_x_prime, t) - L_x), 0)
[docs]def transversality_mask(L, x, t, l, t0, t1):
"""
Transveraslity condition for agent optimal control problem
:param L: Expr -- Lagrangian
:param x: datamodel.Phase -- phase variable
:param t: datamodel.Parameter -- time variable
:param l: Expr -- termination lagrangian
:param t0: datamodel.Parameter -- time horizon[0]
:param t1: datamodel.Parameter -- time horizon[1]
:return: List[Eq]
"""
if x.args and t in x.args:
x = x.func(*x.args)
else:
raise TypeError('Wrong time :)')
# lhs
x_prime = x.diff(t)
L_x_prime = L.diff(x_prime)
# rhs
l_x_t_0 = l.diff(x.subs({t: t0}))
l_x_t_1 = -l.diff(x.subs({t: t1}))
return Eq(L_x_prime.simplify().subs({t: t0}), l_x_t_0), Eq(L_x_prime.simplify().subs({t: t1}), l_x_t_1)
[docs]def generate_symbols(tag, count, cls):
"""
Generate symbols function. Used for dual variables and parameters generation in Agent.process
:param tag: str -- Alias for generated symbols
:param count: int -- count of symbols
:param cls: Class[Union[datamodel.Parameter, datamodel.Phase]]
:return: Iter[Symbol]
"""
query = [tag + "_" + str(i) + " " for i in range(count)]
query = ''.join(query)[:-1]
return symbols(query, cls=cls) if count != 1 else [symbols(query, cls=cls)]
[docs]def spec_funcs():
"""
Set of special analytical functions in sympy. Used to dimension check of its args.
:return: set
"""
return {sin, cos, tan, cot, sinh, cosh, tanh, coth, exp, log}
[docs]def is_spec_function(func):
"""
Bool version of spec_funcs. Return true if func is listed in spec_funcs, otherwise else.
:param func: function to be checked
:return: bool
"""
spec = {sin, cos, tan, cot, sinh, cosh, tanh, coth, exp, log}
return func.__class__ in spec
[docs]def eq2func(eq):
"""
Return Expr:= lhs - rhs from Eq
:param eq: Eq to be transformed
:return: Expr
"""
return eq.args[0] - eq.args[1]
[docs]def deriv_degree(bc):
"""
Returns maximum degree of differential operator inside `bc` expression.
:param bc: Expr
:return: int
"""
deg = 0
if eq2func(bc).class_key()[-1] == 'Derivative': ## catch derivative of x eqs 0
deg_ = sum(i[-1] for i in eq2func(bc).variable_count)
return deg_
for c in eq2func(bc).args:
# print(c.class_key()[-1])
deg_ = 0
if c.class_key()[-1] == 'Derivative':
deg_ = sum([i[-1] for i in c.variable_count])
if c.class_key()[-1] == 'Mul':
for m in c.args:
if m.class_key()[-1] == 'Derivative':
deg_ = sum(i[-1] for i in m.variable_count)
if deg_ > deg:
deg = deg_
return deg
#@dispatch(dict)
[docs]def span_dict(d: dict):
"""
Return linear span of KV-storage. ret = sum_0^N[K[i] * V[i]]
:param d: Dict[Expr -> Expr]
:return: Expr
"""
from sympy.core.numbers import Zero
ret = Zero()
for k, v in d.items():
ret += k * v
return ret
#@dispatch(set, set)
#@dispatch(list, list)
[docs]def span(coefs, variables):
"""
Return linear span of KV-storage. ret = sum_0^N[L_1[i] * L_2[i]]
:param coefs: List[Expr]
:param variables: List[Expr]
:return: Expr
"""
if len(coefs) != len(variables):
raise TypeError(f'Not equal number of arguments: {len(coefs)} != {len(variables)}')
ret = 0
for i, j in zip(coefs, variables):
ret += i * j
return ret
[docs]def gradient(expr, varss):
if issubclass(type(expr), Relational):
expr = eq2func(expr)
return [expr.diff(i) for i in varss]
[docs]def pi_theorem(vars, eq):
"""
Deprecated method for dimension checking in equations.
:param vars:
:param eq:
:return:
"""
ret = True # bool return
def _dim_subs(vars, eq):
for k, v in vars.items():
eq_ = eq.replace(sympify(k), sympify(v))
if 0 not in eq_.args:
eq = eq_
else:
coef = random.random()
eq = eq.replace(sympify(k), sympify(v + '*' + str(coef)))
# print(eq)
# eq = eq.subs(vars).simplify()
if eq.class_key()[-1] == 'Equality':
lhs = eq.args[0]
rhs = eq.args[1]
if (lhs * rhs ** (-1)).class_key()[-1] == 'Number':
return True
else:
return False
else:
if eq.simplify().class_key()[-1] == 'Number':
return True
else:
return False
# var correctance
vars_ = {k: Expr(v) for k, v in vars.items()}
eq = parse_latex(eq)
expr = eq2func(eq)
# parse functions
funcs = [*expr.atoms(Function)]
# check correctance
actual_vars = [str(v) for v in expr.free_symbols]
actual_vars.extend([f.class_key()[-1] for f in funcs])
# if set(actual_vars) != set(vars.keys()):
# raise Warning('check var_list isnt full', set(actual_vars)- set(vars.keys()))
# parse special functions
special_tags = [sinh, cosh, tanh, exp, log]
occ_funcs = []
occurencies = []
for tag in special_tags:
occ_funcs.extend([*expr.find(tag)])
occurencies.extend([e.args[0] for e in expr.find(tag)])
for expr_ in occurencies:
ret = _dim_subs(vars, expr_)
if ret is False:
return False
# swap special functions
for of in occ_funcs:
eq = eq.replace(of, 1)
# parse derivatives
derivatives = [*expr.find(Derivative)]
# swap derivas
for d in derivatives:
var, order = d.args
buf = var / order[0] ** order[1]
eq = eq.replace(d, buf)
# swap functions
for f in funcs:
eq = eq.replace(f, symbols(f.class_key()[-1]))
ret = _dim_subs(vars, eq)
if ret is False:
return False
return ret