1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from typing import Optional
17+
1618from twisted .internet import threads
1719
1820
@@ -61,6 +63,7 @@ class LDAPMode(object):
6163
6264class LdapAuthProvider (object ):
6365 __version__ = "0.1"
66+ _ldap_tls = ldap3 .Tls (validate = ssl .CERT_REQUIRED )
6467
6568 def __init__ (self , config , account_handler ):
6669 self .account_handler = account_handler
@@ -84,6 +87,9 @@ def __init__(self, config, account_handler):
8487 self .ldap_active_directory = config .active_directory
8588 if self .ldap_active_directory :
8689 self .ldap_default_domain = config .default_domain
90+ # Either: the Active Directory root domain (type str); empty string in case
91+ # of error; or None if there was no attempt to fetch root domain yet
92+ self .ldap_root_domain = None # type: Optional[str]
8793
8894 def get_supported_login_types (self ):
8995 return {'m.login.password' : ('password' ,)}
@@ -117,17 +123,14 @@ async def check_auth(self, username, login_type, login_dict):
117123
118124 if self .ldap_active_directory :
119125 try :
120- (login , domain , localpart ) = self ._map_login_to_upn (username )
126+ (login , domain , localpart ) = await self ._map_login_to_upn (username )
121127 uid_value = login + "@" + domain
122128 default_display_name = login
123129 except ActiveDirectoryUPNException :
124130 return False
125131
126132 try :
127- tls = ldap3 .Tls (validate = ssl .CERT_REQUIRED )
128- server = ldap3 .ServerPool (
129- [ldap3 .Server (uri , get_info = None , tls = tls ) for uri in self .ldap_uris ],
130- )
133+ server = self ._get_server ()
131134 logger .debug (
132135 "Attempting LDAP connection with %s" ,
133136 self .ldap_uris
@@ -254,9 +257,7 @@ async def check_3pid_auth(self, medium, address, password):
254257
255258 # Talk to LDAP and check if this email/password combo is correct
256259 try :
257- server = ldap3 .ServerPool (
258- [ldap3 .Server (uri , get_info = None ) for uri in self .ldap_uris ],
259- )
260+ server = self ._get_server ()
260261 logger .debug (
261262 "Attempting LDAP connection with %s" ,
262263 self .ldap_uris
@@ -404,6 +405,78 @@ class _LdapConfig(object):
404405
405406 return ldap_config
406407
408+ def _get_server (self , get_info : Optional [str ] = None ) -> ldap3 .ServerPool :
409+ """Constructs ServerPool from configured LDAP URIs
410+
411+ Args:
412+ get_info (str, optional): specifies if the server schema and server
413+ specific info must be read. Defaults to None.
414+
415+ Returns:
416+ Servers grouped in a ServerPool
417+ """
418+ return ldap3 .ServerPool (
419+ [
420+ ldap3 .Server (
421+ uri ,
422+ get_info = get_info ,
423+ tls = self ._ldap_tls
424+ )
425+ for uri in self .ldap_uris
426+ ],
427+ )
428+
429+ async def _fetch_root_domain (self ) -> str :
430+ """Fetches root domain from LDAP and saves it to ``self.ldap_root_domain``
431+
432+ Returns:
433+ The root domain of Active Directory forest
434+ """
435+ if self .ldap_root_domain is not None :
436+ return self .ldap_root_domain
437+
438+ self .ldap_root_domain = ""
439+
440+ if self .ldap_mode != LDAPMode .SEARCH :
441+ logger .info ("Fetching root domain is supported in search mode only" )
442+ return self .ldap_root_domain
443+
444+ server = self ._get_server (get_info = ldap3 .DSA )
445+ result , conn = await self ._ldap_simple_bind (
446+ server = server ,
447+ bind_dn = self .ldap_bind_dn ,
448+ password = self .ldap_bind_password ,
449+ )
450+
451+ if not result :
452+ logger .warning ("Unable to get root domain due to failed LDAP bind" )
453+ return self .ldap_root_domain
454+
455+ if (
456+ conn .server .info .other
457+ and conn .server .info .other .get ("rootDomainNamingContext" )
458+ ):
459+ # conn.server.info.other["rootDomainNamingContext"][0]
460+ # is of the form DC=example,DC=org
461+ self .ldap_root_domain = "." .join (
462+ [
463+ dc .split ("=" )[1 ] for dc
464+ in conn .server .info .other ["rootDomainNamingContext" ][0 ].split ("," )
465+ if "=" in dc
466+ ]
467+ )
468+ logger .info ('Obtained root domain "%s"' , self .ldap_root_domain )
469+
470+ if not self .ldap_root_domain :
471+ logger .warning (
472+ "No valid `rootDomainNamingContext` attribute was found in the RootDSE. "
473+ "Logging in using short domain name will be unavailable."
474+ )
475+
476+ await threads .deferToThread (conn .unbind )
477+
478+ return self .ldap_root_domain
479+
407480 async def _ldap_simple_bind (self , server , bind_dn , password ):
408481 """ Attempt a simple bind with the credentials
409482 given by the user against the LDAP server.
@@ -556,7 +629,7 @@ async def _ldap_authenticated_search(self, server, password, filters):
556629 logger .warning ("Error during LDAP authentication: %s" , e )
557630 raise
558631
559- def _map_login_to_upn (self , username ):
632+ async def _map_login_to_upn (self , username ):
560633 """Maps user provided login to Active Directory UPN and
561634 local part of Matrix ID.
562635
@@ -577,6 +650,9 @@ def _map_login_to_upn(self, username):
577650
578651 if '\\ ' in username :
579652 (domain , login ) = username .lower ().rsplit ('\\ ' , 1 )
653+ ldap_root_domain = await self ._fetch_root_domain ()
654+ if ldap_root_domain and not domain .endswith (ldap_root_domain ):
655+ domain += "." + ldap_root_domain
580656 elif "/" in username :
581657 (login , domain ) = username .lower ().rsplit ("/" , 1 )
582658 else :
0 commit comments