Skip to content

Commit 2e9d781

Browse files
reivilibreclokep
andauthored
Convert to async/await (#101)
Signed-off-by: Olivier Wilkinson (reivilibre) <[email protected]> Co-authored-by: Patrick Cloke <[email protected]>
1 parent 5a9b52c commit 2e9d781

File tree

4 files changed

+135
-118
lines changed

4 files changed

+135
-118
lines changed

ldap_auth_provider.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from twisted.internet import defer, threads
16+
from twisted.internet import threads
1717

1818

1919
import ldap3
@@ -88,8 +88,7 @@ def __init__(self, config, account_handler):
8888
def get_supported_login_types(self):
8989
return {'m.login.password': ('password',)}
9090

91-
@defer.inlineCallbacks
92-
def check_auth(self, username, login_type, login_dict):
91+
async def check_auth(self, username, login_type, login_dict):
9392
""" Attempt to authenticate a user against an LDAP Server
9493
and register an account if none exists.
9594
@@ -103,7 +102,7 @@ def check_auth(self, username, login_type, login_dict):
103102
# an anonymous authorization state and not suitable for user
104103
# authentication.
105104
if not password:
106-
defer.returnValue(False)
105+
return False
107106

108107
if username.startswith("@") and ":" in username:
109108
# username is of the form @foo:bar.com
@@ -122,7 +121,7 @@ def check_auth(self, username, login_type, login_dict):
122121
uid_value = login + "@" + domain
123122
default_display_name = login
124123
except ActiveDirectoryUPNException:
125-
defer.returnValue(False)
124+
return False
126125

127126
try:
128127
tls = ldap3.Tls(validate=ssl.CERT_REQUIRED)
@@ -140,7 +139,7 @@ def check_auth(self, username, login_type, login_dict):
140139
value=uid_value,
141140
base=self.ldap_base
142141
)
143-
result, conn = yield self._ldap_simple_bind(
142+
result, conn = await self._ldap_simple_bind(
144143
server=server, bind_dn=bind_dn, password=password
145144
)
146145
logger.debug(
@@ -150,10 +149,10 @@ def check_auth(self, username, login_type, login_dict):
150149
conn
151150
)
152151
if not result:
153-
defer.returnValue(False)
152+
return False
154153
elif self.ldap_mode == LDAPMode.SEARCH:
155154
filters = [(self.ldap_attributes["uid"], uid_value)]
156-
result, conn, _ = yield self._ldap_authenticated_search(
155+
result, conn, _ = await self._ldap_authenticated_search(
157156
server=server, password=password, filters=filters
158157
)
159158
logger.debug(
@@ -163,7 +162,7 @@ def check_auth(self, username, login_type, login_dict):
163162
conn
164163
)
165164
if not result:
166-
defer.returnValue(False)
165+
return False
167166
else: # pragma: no cover
168167
raise RuntimeError(
169168
'Invalid LDAP mode specified: {mode}'.format(
@@ -181,17 +180,17 @@ def check_auth(self, username, login_type, login_dict):
181180
"Authentication method yielded no LDAP connection, "
182181
"aborting!"
183182
)
184-
defer.returnValue(False)
183+
return False
185184

186185
# Get full user id from localpart
187186
user_id = self.account_handler.get_qualified_user_id(localpart)
188187

189188
# check if user with user_id exists
190-
if (yield self.account_handler.check_user_exists(user_id)):
189+
if await self.account_handler.check_user_exists(user_id):
191190
# exists, authentication complete
192191
if hasattr(conn, "unbind"):
193-
yield threads.deferToThread(conn.unbind)
194-
defer.returnValue(user_id)
192+
await threads.deferToThread(conn.unbind)
193+
return user_id
195194

196195
else:
197196
# does not exist, register
@@ -200,7 +199,7 @@ def check_auth(self, username, login_type, login_dict):
200199
# existing ldap connection
201200
filters = [(self.ldap_attributes['uid'], uid_value)]
202201

203-
result, conn, response = yield self._ldap_authenticated_search(
202+
result, conn, response = await self._ldap_authenticated_search(
204203
server=server, password=password, filters=filters,
205204
)
206205

@@ -222,18 +221,17 @@ def check_auth(self, username, login_type, login_dict):
222221
mail = None
223222

224223
# Register the user
225-
user_id = yield self.register_user(localpart, display_name, mail)
224+
user_id = await self.register_user(localpart, display_name, mail)
226225

227-
defer.returnValue(user_id)
226+
return user_id
228227

229-
defer.returnValue(False)
228+
return False
230229

231230
except ldap3.core.exceptions.LDAPException as e:
232231
logger.warning("Error during ldap authentication: %s", e)
233-
defer.returnValue(False)
232+
return False
234233

235-
@defer.inlineCallbacks
236-
def check_3pid_auth(self, medium, address, password):
234+
async def check_3pid_auth(self, medium, address, password):
237235
""" Handle authentication against thirdparty login types, such as email
238236
239237
Args:
@@ -248,11 +246,11 @@ def check_3pid_auth(self, medium, address, password):
248246
if self.ldap_mode != LDAPMode.SEARCH:
249247
logger.debug("3PID LDAP login/register attempted but LDAP search mode "
250248
"not enabled. Bailing.")
251-
defer.returnValue(None)
249+
return None
252250

253251
# We currently only support email
254252
if medium != "email":
255-
defer.returnValue(None)
253+
return None
256254

257255
# Talk to LDAP and check if this email/password combo is correct
258256
try:
@@ -265,7 +263,7 @@ def check_3pid_auth(self, medium, address, password):
265263
)
266264

267265
search_filter = [(self.ldap_attributes["mail"], address)]
268-
result, conn, response = yield self._ldap_authenticated_search(
266+
result, conn, response = await self._ldap_authenticated_search(
269267
server=server, password=password, filters=search_filter,
270268
)
271269

@@ -279,10 +277,10 @@ def check_3pid_auth(self, medium, address, password):
279277

280278
# Close connection
281279
if hasattr(conn, "unbind"):
282-
yield threads.deferToThread(conn.unbind)
280+
await threads.deferToThread(conn.unbind)
283281

284282
if not result:
285-
defer.returnValue(None)
283+
return None
286284

287285
# Extract the username from the search response from the LDAP server
288286
localpart = response["attributes"].get(
@@ -306,16 +304,15 @@ def check_3pid_auth(self, medium, address, password):
306304
givenName = givenName[0] if len(givenName) == 1 else localpart
307305

308306
# Register the user
309-
user_id = yield self.register_user(localpart, givenName, address)
307+
user_id = await self.register_user(localpart, givenName, address)
310308

311-
defer.returnValue(user_id)
309+
return user_id
312310

313311
except ldap3.core.exceptions.LDAPException as e:
314312
logger.warning("Error during ldap authentication: %s", e)
315313
raise
316314

317-
@defer.inlineCallbacks
318-
def register_user(self, localpart, name, email_address):
315+
async def register_user(self, localpart, name, email_address):
319316
"""Register a Synapse user, first checking if they exist.
320317
321318
Args:
@@ -329,9 +326,9 @@ def register_user(self, localpart, name, email_address):
329326
# Get full user id from localpart
330327
user_id = self.account_handler.get_qualified_user_id(localpart)
331328

332-
if (yield self.account_handler.check_user_exists(user_id)):
329+
if await self.account_handler.check_user_exists(user_id):
333330
# exists, authentication complete
334-
defer.returnValue(user_id)
331+
return user_id
335332

336333
# register an email address if one exists
337334
emails = [email_address] if email_address is not None else []
@@ -341,14 +338,14 @@ def register_user(self, localpart, name, email_address):
341338
# from password providers
342339
if parse_version(synapse.__version__) <= parse_version("0.99.3"):
343340
user_id, access_token = (
344-
yield self.account_handler.register(
341+
await self.account_handler.register(
345342
localpart=localpart, displayname=name,
346343
)
347344
)
348345
else:
349346
# If Synapse has support, bind emails
350347
user_id, access_token = (
351-
yield self.account_handler.register(
348+
await self.account_handler.register(
352349
localpart=localpart, displayname=name, emails=emails,
353350
)
354351
)
@@ -358,7 +355,7 @@ def register_user(self, localpart, name, email_address):
358355
user_id,
359356
)
360357

361-
defer.returnValue(user_id)
358+
return user_id
362359

363360
@staticmethod
364361
def parse_config(config):
@@ -407,8 +404,7 @@ class _LdapConfig(object):
407404

408405
return ldap_config
409406

410-
@defer.inlineCallbacks
411-
def _ldap_simple_bind(self, server, bind_dn, password):
407+
async def _ldap_simple_bind(self, server, bind_dn, password):
412408
""" Attempt a simple bind with the credentials
413409
given by the user against the LDAP server.
414410
@@ -420,7 +416,7 @@ def _ldap_simple_bind(self, server, bind_dn, password):
420416

421417
try:
422418
# bind with the the local user's ldap credentials
423-
conn = yield threads.deferToThread(
419+
conn = await threads.deferToThread(
424420
ldap3.Connection,
425421
server, bind_dn, password,
426422
authentication=LDAP_AUTH_SIMPLE,
@@ -432,33 +428,32 @@ def _ldap_simple_bind(self, server, bind_dn, password):
432428
)
433429

434430
if self.ldap_start_tls:
435-
yield threads.deferToThread(conn.open)
436-
yield threads.deferToThread(conn.start_tls)
431+
await threads.deferToThread(conn.open)
432+
await threads.deferToThread(conn.start_tls)
437433
logger.debug(
438434
"Upgraded LDAP connection in simple bind mode through "
439435
"StartTLS: %s",
440436
conn
441437
)
442438

443-
if (yield threads.deferToThread(conn.bind)):
439+
if await threads.deferToThread(conn.bind):
444440
# GOOD: bind okay
445441
logger.debug("LDAP Bind successful in simple bind mode.")
446-
defer.returnValue((True, conn))
442+
return (True, conn)
447443

448444
# BAD: bind failed
449445
logger.info(
450446
"Binding against LDAP failed for '%s' failed: %s",
451447
bind_dn, conn.result['description']
452448
)
453-
yield threads.deferToThread(conn.unbind)
454-
defer.returnValue((False, None))
449+
await threads.deferToThread(conn.unbind)
450+
return (False, None)
455451

456452
except ldap3.core.exceptions.LDAPException as e:
457453
logger.warning("Error during LDAP authentication: %s", e)
458454
raise
459455

460-
@defer.inlineCallbacks
461-
def _ldap_authenticated_search(self, server, password, filters):
456+
async def _ldap_authenticated_search(self, server, password, filters):
462457
"""Attempt to login with the preconfigured bind_dn and then continue
463458
searching and filtering within the base_dn.
464459
@@ -480,7 +475,7 @@ def _ldap_authenticated_search(self, server, password, filters):
480475
"""
481476

482477
try:
483-
conn = yield threads.deferToThread(
478+
conn = await threads.deferToThread(
484479
ldap3.Connection,
485480
server,
486481
self.ldap_bind_dn,
@@ -493,21 +488,21 @@ def _ldap_authenticated_search(self, server, password, filters):
493488
)
494489

495490
if self.ldap_start_tls:
496-
yield threads.deferToThread(conn.open)
497-
yield threads.deferToThread(conn.start_tls)
491+
await threads.deferToThread(conn.open)
492+
await threads.deferToThread(conn.start_tls)
498493
logger.debug(
499494
"Upgraded LDAP connection in search mode through "
500495
"StartTLS: %s",
501496
conn
502497
)
503498

504-
if not (yield threads.deferToThread(conn.bind)):
499+
if not await threads.deferToThread(conn.bind):
505500
logger.warning(
506501
"Binding against LDAP with `bind_dn` failed: %s",
507502
conn.result['description']
508503
)
509-
yield threads.deferToThread(conn.unbind)
510-
defer.returnValue((False, None, None))
504+
await threads.deferToThread(conn.unbind)
505+
return (False, None, None)
511506

512507
# Construct search filter
513508
query = ""
@@ -529,7 +524,7 @@ def _ldap_authenticated_search(self, server, password, filters):
529524
"LDAP search filter: %s",
530525
query
531526
)
532-
yield threads.deferToThread(
527+
await threads.deferToThread(
533528
conn.search,
534529
search_base=self.ldap_base,
535530
search_filter=query,
@@ -555,12 +550,12 @@ def _ldap_authenticated_search(self, server, password, filters):
555550
# unbind and simple bind with user_dn to verify the password
556551
# Note: do not use rebind(), for some reason it did not verify
557552
# the password for me!
558-
yield threads.deferToThread(conn.unbind)
559-
result, conn = yield self._ldap_simple_bind(
553+
await threads.deferToThread(conn.unbind)
554+
result, conn = await self._ldap_simple_bind(
560555
server=server, bind_dn=user_dn, password=password
561556
)
562557

563-
defer.returnValue((result, conn, responses[0]))
558+
return (result, conn, responses[0])
564559
else:
565560
# BAD: found 0 or > 1 results, abort!
566561
if len(responses) == 0:
@@ -573,9 +568,9 @@ def _ldap_authenticated_search(self, server, password, filters):
573568
"LDAP search returned too many (%s) results for '%s'",
574569
len(responses), filters
575570
)
576-
yield threads.deferToThread(conn.unbind)
571+
await threads.deferToThread(conn.unbind)
577572

578-
defer.returnValue((False, None, None))
573+
return (False, None, None)
579574

580575
except ldap3.core.exceptions.LDAPException as e:
581576
logger.warning("Error during LDAP authentication: %s", e)

0 commit comments

Comments
 (0)