11"""
22Classes for matching variables and constraints
33"""
4+ from __future__ import annotations
45from itertools import chain
56import numpy as np
6- from typing import Optional , Sequence , Callable , Tuple , Union
7+ from collections .abc import Sequence , Callable
8+ from typing import Optional , Union
79from scipy .optimize import least_squares
810from itertools import repeat
911from at .lattice import Lattice , Refpts , bool_refpts
@@ -32,23 +34,23 @@ class Variable(object):
3234 :py:class:`Variable` initialisation
3335 name: Name of the Variable; Default: ``''``
3436 bounds: Lower and upper bounds of the variable value
35- *args: Positional arguments transmitted to ``setfun`` and
37+ fun_args: Positional arguments transmitted to ``setfun`` and
3638 ``getfun`` functions
3739
3840 Keyword Args:
39- **kwargs : Keyword arguments transmitted to ``setfun``and
41+ **fun_kwargs : Keyword arguments transmitted to ``setfun``and
4042 ``getfun`` functions
4143 """
4244 def __init__ (self , setfun : Callable , getfun : Callable ,
4345 name : str = '' ,
44- bounds : Tuple [float , float ] = (- np .inf , np .inf ),
45- * args , ** kwargs ):
46+ bounds : tuple [float , float ] = (- np .inf , np .inf ),
47+ fun_args : tuple = () , ** fun_kwargs ):
4648 self .setfun = setfun
4749 self .getfun = getfun
4850 self .name = name
4951 self .bounds = bounds
50- self .args = args
51- self .kwargs = kwargs
52+ self .args = fun_args
53+ self .kwargs = fun_kwargs
5254 super (Variable , self ).__init__ ()
5355
5456 def set (self , ring : Lattice , value ):
@@ -90,7 +92,7 @@ class ElementVariable(Variable):
9092 def __init__ (self , refpts : Refpts , attname : str ,
9193 index : Optional [int ] = None ,
9294 name : str = '' ,
93- bounds : Tuple [float , float ] = (- np .inf , np .inf )):
95+ bounds : tuple [float , float ] = (- np .inf , np .inf )):
9496 setf , getf = self ._access (index )
9597
9698 def setfun (ring , value ):
@@ -357,7 +359,8 @@ class LinoptConstraints(ElementConstraints):
357359 transfer matrix. Can be :py:obj:`~.linear.linopt2`,
358360 :py:obj:`~.linear.linopt4`, :py:obj:`~.linear.linopt6`
359361
360- * :py:obj:`~.linear.linopt2`: No longitudinal motion, no H/V coupling,
362+ * :py:obj:`~.linear.linopt2`: No longitudinal motion,
363+ no H/V coupling,
361364 * :py:obj:`~.linear.linopt4`: No longitudinal motion, Sagan/Rubin
362365 4D-analysis of coupled motion,
363366 * :py:obj:`~.linear.linopt6` (default): With or without longitudinal
@@ -369,7 +372,8 @@ class LinoptConstraints(ElementConstraints):
369372
370373 Add a beta x (beta[0]) constraint at location ref_inj:
371374
372- >>> cnstrs.add('beta', 18.0, refpts=ref_inj, name='beta_x_inj', index=0)
375+ >>> cnstrs.add('beta', 18.0, refpts=ref_inj,
376+ name='beta_x_inj', index=0)
373377
374378 Add an horizontal tune (tunes[0]) constraint:
375379
@@ -435,6 +439,7 @@ def add(self, param, target, refpts: Optional[Refpts] = None,
435439 getf = self ._recordaccess (index )
436440 getv = self ._arrayaccess (index )
437441 use_integer = kwargs .pop ('UseInteger' , False )
442+ norm_mu = {'mu' : 1 , 'mun' : 2 * np .pi }
438443
439444 if name is None : # Generate the constraint name
440445 name = param .__name__ if callable (param ) else param
@@ -444,7 +449,6 @@ def add(self, param, target, refpts: Optional[Refpts] = None,
444449 if callable (param ):
445450 def fun (refdata , tune , chrom ):
446451 return getv (param (refdata , tune , chrom ))
447- # self.refpts[:] = True # necessary not to miss 2*pi jumps
448452 self .get_chrom = True # fun may use dispersion or chroma
449453 elif param == 'tunes' :
450454 # noinspection PyUnusedLocal
@@ -457,14 +461,20 @@ def fun(refdata, tune, chrom):
457461 return getv (chrom )
458462 refpts = []
459463 self .get_chrom = True # slower but necessary
464+ elif param == 'mu' or param == 'mun' :
465+ # noinspection PyUnusedLocal
466+ def fun (refdata , tune , chrom ):
467+ if use_integer :
468+ return getf (refdata , 'mu' ) / norm_mu [param ]
469+ else :
470+ return (getf (refdata , 'mu' ) % (2 * np .pi )) / norm_mu [param ]
471+ if use_integer :
472+ self .refpts [:] = True # necessary not to miss 2*pi jumps
473+ else :
474+ target = target % (2 * np .pi / norm_mu [param ])
460475 else :
461476 # noinspection PyUnusedLocal
462477 def fun (refdata , tune , chrom ):
463- if param == 'mu' :
464- return getf (refdata , param ) % (2 * np .pi )
465- elif param == 'mu' and use_integer :
466- # necessary not to miss 2*pi jumps
467- self .refpts [:] = True
468478 return getf (refdata , param )
469479
470480 super (LinoptConstraints , self ).add (fun , target , refpts , name = name ,
0 commit comments