Skip to content

Commit a3c7a9f

Browse files
authored
Active Directory: fetch root domain of forest from LDAP (#105)
This change allows to use short domain name in login field. As part of this PR ssl certificate validation in check_3pid_auth is fixed also. Signed-off-by: Yuri Konotopov <[email protected]>
1 parent aa57347 commit a3c7a9f

File tree

4 files changed

+139
-18
lines changed

4 files changed

+139
-18
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ Let's say you have several domains in the ``example.com`` forest:
116116
bind_dn: "cn=hacker,ou=svcaccts,dc=example,dc=com"
117117
bind_password: "ch33kym0nk3y"
118118
119-
With this configuration the user can log in with either ``main.example.com\someuser``,
120-
``someuser/main.example.com`` or ``someuser``.
119+
With this configuration the user can log in with either ``main\someuser``,
120+
``main.example.com\someuser``, ``someuser/main.example.com`` or ``someuser``.
121121

122122
Users of other domains in the ``example.com`` forest can log in with ``domain\login``
123123
or ``login/domain``.

ldap_auth_provider.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import Optional
17+
1618
from twisted.internet import threads
1719

1820

@@ -61,6 +63,7 @@ class LDAPMode(object):
6163

6264
class LdapAuthProvider(object):
6365
__version__ = "0.1"
66+
_ldap_tls = ldap3.Tls(validate=ssl.CERT_REQUIRED)
6467

6568
def __init__(self, config, account_handler):
6669
self.account_handler = account_handler
@@ -84,6 +87,9 @@ def __init__(self, config, account_handler):
8487
self.ldap_active_directory = config.active_directory
8588
if self.ldap_active_directory:
8689
self.ldap_default_domain = config.default_domain
90+
# Either: the Active Directory root domain (type str); empty string in case
91+
# of error; or None if there was no attempt to fetch root domain yet
92+
self.ldap_root_domain = None # type: Optional[str]
8793

8894
def get_supported_login_types(self):
8995
return {'m.login.password': ('password',)}
@@ -117,17 +123,14 @@ async def check_auth(self, username, login_type, login_dict):
117123

118124
if self.ldap_active_directory:
119125
try:
120-
(login, domain, localpart) = self._map_login_to_upn(username)
126+
(login, domain, localpart) = await self._map_login_to_upn(username)
121127
uid_value = login + "@" + domain
122128
default_display_name = login
123129
except ActiveDirectoryUPNException:
124130
return False
125131

126132
try:
127-
tls = ldap3.Tls(validate=ssl.CERT_REQUIRED)
128-
server = ldap3.ServerPool(
129-
[ldap3.Server(uri, get_info=None, tls=tls) for uri in self.ldap_uris],
130-
)
133+
server = self._get_server()
131134
logger.debug(
132135
"Attempting LDAP connection with %s",
133136
self.ldap_uris
@@ -254,9 +257,7 @@ async def check_3pid_auth(self, medium, address, password):
254257

255258
# Talk to LDAP and check if this email/password combo is correct
256259
try:
257-
server = ldap3.ServerPool(
258-
[ldap3.Server(uri, get_info=None) for uri in self.ldap_uris],
259-
)
260+
server = self._get_server()
260261
logger.debug(
261262
"Attempting LDAP connection with %s",
262263
self.ldap_uris
@@ -404,6 +405,78 @@ class _LdapConfig(object):
404405

405406
return ldap_config
406407

408+
def _get_server(self, get_info: Optional[str] = None) -> ldap3.ServerPool:
409+
"""Constructs ServerPool from configured LDAP URIs
410+
411+
Args:
412+
get_info (str, optional): specifies if the server schema and server
413+
specific info must be read. Defaults to None.
414+
415+
Returns:
416+
Servers grouped in a ServerPool
417+
"""
418+
return ldap3.ServerPool(
419+
[
420+
ldap3.Server(
421+
uri,
422+
get_info=get_info,
423+
tls=self._ldap_tls
424+
)
425+
for uri in self.ldap_uris
426+
],
427+
)
428+
429+
async def _fetch_root_domain(self) -> str:
430+
"""Fetches root domain from LDAP and saves it to ``self.ldap_root_domain``
431+
432+
Returns:
433+
The root domain of Active Directory forest
434+
"""
435+
if self.ldap_root_domain is not None:
436+
return self.ldap_root_domain
437+
438+
self.ldap_root_domain = ""
439+
440+
if self.ldap_mode != LDAPMode.SEARCH:
441+
logger.info("Fetching root domain is supported in search mode only")
442+
return self.ldap_root_domain
443+
444+
server = self._get_server(get_info=ldap3.DSA)
445+
result, conn = await self._ldap_simple_bind(
446+
server=server,
447+
bind_dn=self.ldap_bind_dn,
448+
password=self.ldap_bind_password,
449+
)
450+
451+
if not result:
452+
logger.warning("Unable to get root domain due to failed LDAP bind")
453+
return self.ldap_root_domain
454+
455+
if (
456+
conn.server.info.other
457+
and conn.server.info.other.get("rootDomainNamingContext")
458+
):
459+
# conn.server.info.other["rootDomainNamingContext"][0]
460+
# is of the form DC=example,DC=org
461+
self.ldap_root_domain = ".".join(
462+
[
463+
dc.split("=")[1] for dc
464+
in conn.server.info.other["rootDomainNamingContext"][0].split(",")
465+
if "=" in dc
466+
]
467+
)
468+
logger.info('Obtained root domain "%s"', self.ldap_root_domain)
469+
470+
if not self.ldap_root_domain:
471+
logger.warning(
472+
"No valid `rootDomainNamingContext` attribute was found in the RootDSE. "
473+
"Logging in using short domain name will be unavailable."
474+
)
475+
476+
await threads.deferToThread(conn.unbind)
477+
478+
return self.ldap_root_domain
479+
407480
async def _ldap_simple_bind(self, server, bind_dn, password):
408481
""" Attempt a simple bind with the credentials
409482
given by the user against the LDAP server.
@@ -556,7 +629,7 @@ async def _ldap_authenticated_search(self, server, password, filters):
556629
logger.warning("Error during LDAP authentication: %s", e)
557630
raise
558631

559-
def _map_login_to_upn(self, username):
632+
async def _map_login_to_upn(self, username):
560633
"""Maps user provided login to Active Directory UPN and
561634
local part of Matrix ID.
562635
@@ -577,6 +650,9 @@ def _map_login_to_upn(self, username):
577650

578651
if '\\' in username:
579652
(domain, login) = username.lower().rsplit('\\', 1)
653+
ldap_root_domain = await self._fetch_root_domain()
654+
if ldap_root_domain and not domain.endswith(ldap_root_domain):
655+
domain += "." + ldap_root_domain
580656
elif "/" in username:
581657
(login, domain) = username.lower().rsplit("/", 1)
582658
else:

tests/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from asyncio.futures import Future
2-
from typing import Any, Awaitable
2+
from typing import Any, Awaitable, Type
33

44
from twisted.internet.endpoints import serverFromString
55
from twisted.internet.protocol import ServerFactory
@@ -117,10 +117,9 @@ async def _create_db():
117117

118118

119119
class _LDAPServerFactory(ServerFactory):
120-
protocol = LDAPServer
121-
122-
def __init__(self, root):
120+
def __init__(self, root, ldap_server_type: Type[LDAPServer] = LDAPServer):
123121
self.root = root
122+
self.protocol = ldap_server_type
124123

125124
def buildProtocol(self, addr):
126125
proto = self.protocol()
@@ -158,11 +157,11 @@ def close(self):
158157
)
159158

160159

161-
async def create_ldap_server():
160+
async def create_ldap_server(ldap_server_type: Type[LDAPServer] = LDAPServer):
162161
"Returns a context manager that represents the LDAP server."
163162

164163
db = await _create_db()
165-
factory = _LDAPServerFactory(db)
164+
factory = _LDAPServerFactory(db, ldap_server_type)
166165
factory.debug = True
167166

168167
# We just pick an arbitrary port to listen on.

tests/test_ad.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from ldaptor import interfaces
16+
from ldaptor.protocols import pureldap
17+
from ldaptor.protocols.ldap import ldaperrors
18+
from ldaptor.protocols.ldap.ldapserver import LDAPServer
19+
1520
from twisted.internet.defer import ensureDeferred
1621
from twisted.trial import unittest
1722
from twisted.internet import defer
@@ -29,10 +34,30 @@
2934
logging.basicConfig()
3035

3136

37+
class _ActiveDirectoryLDAPServer(LDAPServer):
38+
"""Extends LDAPServer to return AD-specific attributes
39+
40+
Includes `rootDomainNamingContext` in bind responses.
41+
"""
42+
def getRootDSE(self, request, reply):
43+
root = interfaces.IConnectedLDAPEntry(self.factory)
44+
reply(pureldap.LDAPSearchResultEntry(
45+
objectName='',
46+
attributes=[('supportedLDAPVersion', ['3']),
47+
('namingContexts', [root.dn.getText()]),
48+
('supportedExtension', [
49+
pureldap.LDAPPasswordModifyRequest.oid, ]),
50+
('rootDomainNamingContext', ['DC=example,DC=org']), ], ))
51+
return pureldap.LDAPSearchResultDone(
52+
resultCode=ldaperrors.Success.resultCode)
53+
54+
3255
class AbstractLdapActiveDirectoryTestCase():
3356
@defer.inlineCallbacks
3457
def setUp(self):
35-
self.ldap_server = yield ensureDeferred(create_ldap_server())
58+
self.ldap_server = yield ensureDeferred(
59+
create_ldap_server(_ActiveDirectoryLDAPServer)
60+
)
3661
account_handler = Mock(spec_set=["check_user_exists", "get_qualified_user_id"])
3762
account_handler.check_user_exists.return_value = make_awaitable(True)
3863
account_handler.get_qualified_user_id = get_qualified_user_id
@@ -74,20 +99,41 @@ def test_correct_pwd(self):
7499
))
75100
self.assertEqual(result, "@mainuser/main.example.org:test")
76101

102+
result = yield ensureDeferred(self.auth_provider.check_auth(
103+
"main\\mainuser",
104+
'm.login.password',
105+
{"password": "abracadabra"}
106+
))
107+
self.assertEqual(result, "@mainuser/main.example.org:test")
108+
77109
result = yield ensureDeferred(self.auth_provider.check_auth(
78110
"subsidiary.example.org\\nonmainuser",
79111
'm.login.password',
80112
{"password": "simsalabim"}
81113
))
82114
self.assertEqual(result, "@nonmainuser/subsidiary.example.org:test")
83115

116+
result = yield ensureDeferred(self.auth_provider.check_auth(
117+
"subsidiary\\nonmainuser",
118+
'm.login.password',
119+
{"password": "simsalabim"}
120+
))
121+
self.assertEqual(result, "@nonmainuser/subsidiary.example.org:test")
122+
84123
result = yield ensureDeferred(self.auth_provider.check_auth(
85124
"subsidiary.example.org\\mainuser",
86125
'm.login.password',
87126
{"password": "changeit"}
88127
))
89128
self.assertEqual(result, "@mainuser/subsidiary.example.org:test")
90129

130+
result = yield ensureDeferred(self.auth_provider.check_auth(
131+
"subsidiary\\mainuser",
132+
'm.login.password',
133+
{"password": "changeit"}
134+
))
135+
self.assertEqual(result, "@mainuser/subsidiary.example.org:test")
136+
91137
@defer.inlineCallbacks
92138
def test_single_email(self):
93139
result = yield ensureDeferred(self.auth_provider.check_3pid_auth(

0 commit comments

Comments
 (0)