gh-90155: Fix bug in asyncio.Semaphore and strengthen FIFO guarantee (#93222)

The main problem was that an unluckily timed task cancellation could cause
the semaphore to be stuck. There were also doubts about strict FIFO ordering
of tasks allowed to pass.

The Semaphore implementation was rewritten to be more similar to Lock.
Many tests for edge cases (including cancellation) were added.
This commit is contained in:
Cyker Way 2022-09-22 12:34:45 -04:00 committed by GitHub
parent 8fd2c3b75b
commit 24e0379624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 143 additions and 22 deletions

View File

@ -345,9 +345,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
def __init__(self, value=1): def __init__(self, value=1):
if value < 0: if value < 0:
raise ValueError("Semaphore initial value must be >= 0") raise ValueError("Semaphore initial value must be >= 0")
self._waiters = None
self._value = value self._value = value
self._waiters = collections.deque()
self._wakeup_scheduled = False
def __repr__(self): def __repr__(self):
res = super().__repr__() res = super().__repr__()
@ -356,16 +355,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
extra = f'{extra}, waiters:{len(self._waiters)}' extra = f'{extra}, waiters:{len(self._waiters)}'
return f'<{res[1:-1]} [{extra}]>' return f'<{res[1:-1]} [{extra}]>'
def _wake_up_next(self):
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
waiter.set_result(None)
self._wakeup_scheduled = True
return
def locked(self): def locked(self):
"""Returns True if semaphore can not be acquired immediately.""" """Returns True if semaphore counter is zero."""
return self._value == 0 return self._value == 0
async def acquire(self): async def acquire(self):
@ -377,28 +368,57 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
called release() to make it larger than 0, and then return called release() to make it larger than 0, and then return
True. True.
""" """
# _wakeup_scheduled is set if *another* task is scheduled to wakeup if (not self.locked() and (self._waiters is None or
# but its acquire() is not resumed yet all(w.cancelled() for w in self._waiters))):
while self._wakeup_scheduled or self._value <= 0: self._value -= 1
fut = self._get_loop().create_future() return True
self._waiters.append(fut)
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try: try:
await fut await fut
# reset _wakeup_scheduled *after* waiting for a future finally:
self._wakeup_scheduled = False self._waiters.remove(fut)
except exceptions.CancelledError: except exceptions.CancelledError:
self._wake_up_next() if not self.locked():
raise self._wake_up_first()
raise
self._value -= 1 self._value -= 1
if not self.locked():
self._wake_up_first()
return True return True
def release(self): def release(self):
"""Release a semaphore, incrementing the internal counter by one. """Release a semaphore, incrementing the internal counter by one.
When it was zero on entry and another coroutine is waiting for it to When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine. become larger than zero again, wake up that coroutine.
""" """
self._value += 1 self._value += 1
self._wake_up_next() self._wake_up_first()
def _wake_up_first(self):
"""Wake up the first waiter if it isn't done."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return
# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not fut.done():
fut.set_result(True)
class BoundedSemaphore(Semaphore): class BoundedSemaphore(Semaphore):

View File

@ -5,6 +5,7 @@ from unittest import mock
import re import re
import asyncio import asyncio
import collections
STR_RGX_REPR = ( STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)' r'^<(?P<class>.*?) object at (?P<address>.*?)'
@ -774,6 +775,9 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue('waiters' not in repr(sem)) self.assertTrue('waiters' not in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem))) self.assertTrue(RGX_REPR.match(repr(sem)))
if sem._waiters is None:
sem._waiters = collections.deque()
sem._waiters.append(mock.Mock()) sem._waiters.append(mock.Mock())
self.assertTrue('waiters:1' in repr(sem)) self.assertTrue('waiters:1' in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem))) self.assertTrue(RGX_REPR.match(repr(sem)))
@ -842,6 +846,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
sem.release() sem.release()
self.assertEqual(2, sem._value) self.assertEqual(2, sem._value)
await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
self.assertEqual(0, sem._value) self.assertEqual(0, sem._value)
self.assertEqual(3, len(result)) self.assertEqual(3, len(result))
@ -884,6 +889,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
t2.cancel() t2.cancel()
sem.release() sem.release()
await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
num_done = sum(t.done() for t in [t3, t4]) num_done = sum(t.done() for t in [t3, t4])
self.assertEqual(num_done, 1) self.assertEqual(num_done, 1)
@ -904,9 +910,32 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
t1.cancel() t1.cancel()
sem.release() sem.release()
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(sem.locked()) self.assertTrue(sem.locked())
self.assertTrue(t2.done()) self.assertTrue(t2.done())
async def test_acquire_no_hang(self):
sem = asyncio.Semaphore(1)
async def c1():
async with sem:
await asyncio.sleep(0)
t2.cancel()
async def c2():
async with sem:
self.assertFalse(True)
t1 = asyncio.create_task(c1())
t2 = asyncio.create_task(c2())
r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True)
self.assertTrue(r1 is None)
self.assertTrue(isinstance(r2, asyncio.CancelledError))
await asyncio.wait_for(sem.acquire(), timeout=1.0)
def test_release_not_acquired(self): def test_release_not_acquired(self):
sem = asyncio.BoundedSemaphore() sem = asyncio.BoundedSemaphore()
@ -945,6 +974,77 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
result result
) )
async def test_acquire_fifo_order_2(self):
sem = asyncio.Semaphore(1)
result = []
async def c1(result):
await sem.acquire()
result.append(1)
return True
async def c2(result):
await sem.acquire()
result.append(2)
sem.release()
await sem.acquire()
result.append(4)
return True
async def c3(result):
await sem.acquire()
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
sem.release()
sem.release()
tasks = [t1, t2, t3]
await asyncio.gather(*tasks)
self.assertEqual([1, 2, 3, 4], result)
async def test_acquire_fifo_order_3(self):
sem = asyncio.Semaphore(0)
result = []
async def c1(result):
await sem.acquire()
result.append(1)
return True
async def c2(result):
await sem.acquire()
result.append(2)
return True
async def c3(result):
await sem.acquire()
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
t1.cancel()
await asyncio.sleep(0)
sem.release()
sem.release()
tasks = [t1, t2, t3]
await asyncio.gather(*tasks, return_exceptions=True)
self.assertEqual([2, 3], result)
class BarrierTests(unittest.IsolatedAsyncioTestCase): class BarrierTests(unittest.IsolatedAsyncioTestCase):

View File

@ -0,0 +1 @@
Fix broken :class:`asyncio.Semaphore` when acquire is cancelled.