summaryrefslogtreecommitdiff
path: root/Lib/test/test_asyncgen.py
diff options
context:
space:
mode:
authorSteve Dower <steve.dower@microsoft.com>2017-02-04 15:05:40 -0800
committerSteve Dower <steve.dower@microsoft.com>2017-02-04 15:05:40 -0800
commitb2fa705fd3887c326e811c418469c784353027f4 (patch)
treeb3428f73de91453edbfd4df1a5d4a212d182eb44 /Lib/test/test_asyncgen.py
parent134e58fd3aaa2e91390041e143f3f0a21a60142b (diff)
parentb53654b6dbfce8318a7d4d1cdaddca7a7fec194b (diff)
downloadcpython-b2fa705fd3887c326e811c418469c784353027f4.tar.gz
Issue #29392: Prevent crash when passing invalid arguments into msvcrt module.
Diffstat (limited to 'Lib/test/test_asyncgen.py')
-rw-r--r--Lib/test/test_asyncgen.py1044
1 files changed, 1044 insertions, 0 deletions
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
new file mode 100644
index 0000000000..34ab8a04ee
--- /dev/null
+++ b/Lib/test/test_asyncgen.py
@@ -0,0 +1,1044 @@
+import inspect
+import sys
+import types
+import unittest
+
+from unittest import mock
+
+from test.support import import_module
+asyncio = import_module("asyncio")
+
+
+class AwaitException(Exception):
+ pass
+
+
+@types.coroutine
+def awaitable(*, throw=False):
+ if throw:
+ yield ('throw',)
+ else:
+ yield ('result',)
+
+
+def run_until_complete(coro):
+ exc = False
+ while True:
+ try:
+ if exc:
+ exc = False
+ fut = coro.throw(AwaitException)
+ else:
+ fut = coro.send(None)
+ except StopIteration as ex:
+ return ex.args[0]
+
+ if fut == ('throw',):
+ exc = True
+
+
+def to_list(gen):
+ async def iterate():
+ res = []
+ async for i in gen:
+ res.append(i)
+ return res
+
+ return run_until_complete(iterate())
+
+
+class AsyncGenSyntaxTest(unittest.TestCase):
+
+ def test_async_gen_syntax_01(self):
+ code = '''async def foo():
+ await abc
+ yield from 123
+ '''
+
+ with self.assertRaisesRegex(SyntaxError, 'yield from.*inside async'):
+ exec(code, {}, {})
+
+ def test_async_gen_syntax_02(self):
+ code = '''async def foo():
+ yield from 123
+ '''
+
+ with self.assertRaisesRegex(SyntaxError, 'yield from.*inside async'):
+ exec(code, {}, {})
+
+ def test_async_gen_syntax_03(self):
+ code = '''async def foo():
+ await abc
+ yield
+ return 123
+ '''
+
+ with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
+ exec(code, {}, {})
+
+ def test_async_gen_syntax_04(self):
+ code = '''async def foo():
+ yield
+ return 123
+ '''
+
+ with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
+ exec(code, {}, {})
+
+ def test_async_gen_syntax_05(self):
+ code = '''async def foo():
+ if 0:
+ yield
+ return 12
+ '''
+
+ with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
+ exec(code, {}, {})
+
+
+class AsyncGenTest(unittest.TestCase):
+
+ def compare_generators(self, sync_gen, async_gen):
+ def sync_iterate(g):
+ res = []
+ while True:
+ try:
+ res.append(g.__next__())
+ except StopIteration:
+ res.append('STOP')
+ break
+ except Exception as ex:
+ res.append(str(type(ex)))
+ return res
+
+ def async_iterate(g):
+ res = []
+ while True:
+ try:
+ g.__anext__().__next__()
+ except StopAsyncIteration:
+ res.append('STOP')
+ break
+ except StopIteration as ex:
+ if ex.args:
+ res.append(ex.args[0])
+ else:
+ res.append('EMPTY StopIteration')
+ break
+ except Exception as ex:
+ res.append(str(type(ex)))
+ return res
+
+ sync_gen_result = sync_iterate(sync_gen)
+ async_gen_result = async_iterate(async_gen)
+ self.assertEqual(sync_gen_result, async_gen_result)
+ return async_gen_result
+
+ def test_async_gen_iteration_01(self):
+ async def gen():
+ await awaitable()
+ a = yield 123
+ self.assertIs(a, None)
+ await awaitable()
+ yield 456
+ await awaitable()
+ yield 789
+
+ self.assertEqual(to_list(gen()), [123, 456, 789])
+
+ def test_async_gen_iteration_02(self):
+ async def gen():
+ await awaitable()
+ yield 123
+ await awaitable()
+
+ g = gen()
+ ai = g.__aiter__()
+ self.assertEqual(ai.__anext__().__next__(), ('result',))
+
+ try:
+ ai.__anext__().__next__()
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 123)
+ else:
+ self.fail('StopIteration was not raised')
+
+ self.assertEqual(ai.__anext__().__next__(), ('result',))
+
+ try:
+ ai.__anext__().__next__()
+ except StopAsyncIteration as ex:
+ self.assertFalse(ex.args)
+ else:
+ self.fail('StopAsyncIteration was not raised')
+
+ def test_async_gen_exception_03(self):
+ async def gen():
+ await awaitable()
+ yield 123
+ await awaitable(throw=True)
+ yield 456
+
+ with self.assertRaises(AwaitException):
+ to_list(gen())
+
+ def test_async_gen_exception_04(self):
+ async def gen():
+ await awaitable()
+ yield 123
+ 1 / 0
+
+ g = gen()
+ ai = g.__aiter__()
+ self.assertEqual(ai.__anext__().__next__(), ('result',))
+
+ try:
+ ai.__anext__().__next__()
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 123)
+ else:
+ self.fail('StopIteration was not raised')
+
+ with self.assertRaises(ZeroDivisionError):
+ ai.__anext__().__next__()
+
+ def test_async_gen_exception_05(self):
+ async def gen():
+ yield 123
+ raise StopAsyncIteration
+
+ with self.assertRaisesRegex(RuntimeError,
+ 'async generator.*StopAsyncIteration'):
+ to_list(gen())
+
+ def test_async_gen_exception_06(self):
+ async def gen():
+ yield 123
+ raise StopIteration
+
+ with self.assertRaisesRegex(RuntimeError,
+ 'async generator.*StopIteration'):
+ to_list(gen())
+
+ def test_async_gen_exception_07(self):
+ def sync_gen():
+ try:
+ yield 1
+ 1 / 0
+ finally:
+ yield 2
+ yield 3
+
+ yield 100
+
+ async def async_gen():
+ try:
+ yield 1
+ 1 / 0
+ finally:
+ yield 2
+ yield 3
+
+ yield 100
+
+ self.compare_generators(sync_gen(), async_gen())
+
+ def test_async_gen_exception_08(self):
+ def sync_gen():
+ try:
+ yield 1
+ finally:
+ yield 2
+ 1 / 0
+ yield 3
+
+ yield 100
+
+ async def async_gen():
+ try:
+ yield 1
+ await awaitable()
+ finally:
+ await awaitable()
+ yield 2
+ 1 / 0
+ yield 3
+
+ yield 100
+
+ self.compare_generators(sync_gen(), async_gen())
+
+ def test_async_gen_exception_09(self):
+ def sync_gen():
+ try:
+ yield 1
+ 1 / 0
+ finally:
+ yield 2
+ yield 3
+
+ yield 100
+
+ async def async_gen():
+ try:
+ await awaitable()
+ yield 1
+ 1 / 0
+ finally:
+ yield 2
+ await awaitable()
+ yield 3
+
+ yield 100
+
+ self.compare_generators(sync_gen(), async_gen())
+
+ def test_async_gen_exception_10(self):
+ async def gen():
+ yield 123
+ with self.assertRaisesRegex(TypeError,
+ "non-None value .* async generator"):
+ gen().__anext__().send(100)
+
+ def test_async_gen_api_01(self):
+ async def gen():
+ yield 123
+
+ g = gen()
+
+ self.assertEqual(g.__name__, 'gen')
+ g.__name__ = '123'
+ self.assertEqual(g.__name__, '123')
+
+ self.assertIn('.gen', g.__qualname__)
+ g.__qualname__ = '123'
+ self.assertEqual(g.__qualname__, '123')
+
+ self.assertIsNone(g.ag_await)
+ self.assertIsInstance(g.ag_frame, types.FrameType)
+ self.assertFalse(g.ag_running)
+ self.assertIsInstance(g.ag_code, types.CodeType)
+
+ self.assertTrue(inspect.isawaitable(g.aclose()))
+
+
+class AsyncGenAsyncioTest(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ async def to_list(self, gen):
+ res = []
+ async for i in gen:
+ res.append(i)
+ return res
+
+ def test_async_gen_asyncio_01(self):
+ async def gen():
+ yield 1
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield 2
+ await asyncio.sleep(0.01, loop=self.loop)
+ return
+ yield 3
+
+ res = self.loop.run_until_complete(self.to_list(gen()))
+ self.assertEqual(res, [1, 2])
+
+ def test_async_gen_asyncio_02(self):
+ async def gen():
+ yield 1
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield 2
+ 1 / 0
+ yield 3
+
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(self.to_list(gen()))
+
+ def test_async_gen_asyncio_03(self):
+ loop = self.loop
+
+ class Gen:
+ async def __aiter__(self):
+ yield 1
+ await asyncio.sleep(0.01, loop=loop)
+ yield 2
+
+ res = loop.run_until_complete(self.to_list(Gen()))
+ self.assertEqual(res, [1, 2])
+
+ def test_async_gen_asyncio_anext_04(self):
+ async def foo():
+ yield 1
+ await asyncio.sleep(0.01, loop=self.loop)
+ try:
+ yield 2
+ yield 3
+ except ZeroDivisionError:
+ yield 1000
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield 4
+
+ async def run1():
+ it = foo().__aiter__()
+
+ self.assertEqual(await it.__anext__(), 1)
+ self.assertEqual(await it.__anext__(), 2)
+ self.assertEqual(await it.__anext__(), 3)
+ self.assertEqual(await it.__anext__(), 4)
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+
+ async def run2():
+ it = foo().__aiter__()
+
+ self.assertEqual(await it.__anext__(), 1)
+ self.assertEqual(await it.__anext__(), 2)
+ try:
+ it.__anext__().throw(ZeroDivisionError)
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 1000)
+ else:
+ self.fail('StopIteration was not raised')
+ self.assertEqual(await it.__anext__(), 4)
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+
+ self.loop.run_until_complete(run1())
+ self.loop.run_until_complete(run2())
+
+ def test_async_gen_asyncio_anext_05(self):
+ async def foo():
+ v = yield 1
+ v = yield v
+ yield v * 100
+
+ async def run():
+ it = foo().__aiter__()
+
+ try:
+ it.__anext__().send(None)
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 1)
+ else:
+ self.fail('StopIteration was not raised')
+
+ try:
+ it.__anext__().send(10)
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 10)
+ else:
+ self.fail('StopIteration was not raised')
+
+ try:
+ it.__anext__().send(12)
+ except StopIteration as ex:
+ self.assertEqual(ex.args[0], 1200)
+ else:
+ self.fail('StopIteration was not raised')
+
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_anext_06(self):
+ DONE = 0
+
+ # test synchronous generators
+ def foo():
+ try:
+ yield
+ except:
+ pass
+ g = foo()
+ g.send(None)
+ with self.assertRaises(StopIteration):
+ g.send(None)
+
+ # now with asynchronous generators
+
+ async def gen():
+ nonlocal DONE
+ try:
+ yield
+ except:
+ pass
+ DONE = 1
+
+ async def run():
+ nonlocal DONE
+ g = gen()
+ await g.asend(None)
+ with self.assertRaises(StopAsyncIteration):
+ await g.asend(None)
+ DONE += 10
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 11)
+
+ def test_async_gen_asyncio_anext_tuple(self):
+ async def foo():
+ try:
+ yield (1,)
+ except ZeroDivisionError:
+ yield (2,)
+
+ async def run():
+ it = foo().__aiter__()
+
+ self.assertEqual(await it.__anext__(), (1,))
+ with self.assertRaises(StopIteration) as cm:
+ it.__anext__().throw(ZeroDivisionError)
+ self.assertEqual(cm.exception.args[0], (2,))
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_anext_stopiteration(self):
+ async def foo():
+ try:
+ yield StopIteration(1)
+ except ZeroDivisionError:
+ yield StopIteration(3)
+
+ async def run():
+ it = foo().__aiter__()
+
+ v = await it.__anext__()
+ self.assertIsInstance(v, StopIteration)
+ self.assertEqual(v.value, 1)
+ with self.assertRaises(StopIteration) as cm:
+ it.__anext__().throw(ZeroDivisionError)
+ v = cm.exception.args[0]
+ self.assertIsInstance(v, StopIteration)
+ self.assertEqual(v.value, 3)
+ with self.assertRaises(StopAsyncIteration):
+ await it.__anext__()
+
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_aclose_06(self):
+ async def foo():
+ try:
+ yield 1
+ 1 / 0
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield 12
+
+ async def run():
+ gen = foo()
+ it = gen.__aiter__()
+ await it.__anext__()
+ await gen.aclose()
+
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "async generator ignored GeneratorExit"):
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_aclose_07(self):
+ DONE = 0
+
+ async def foo():
+ nonlocal DONE
+ try:
+ yield 1
+ 1 / 0
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE += 1
+ DONE += 1000
+
+ async def run():
+ gen = foo()
+ it = gen.__aiter__()
+ await it.__anext__()
+ await gen.aclose()
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_aclose_08(self):
+ DONE = 0
+
+ fut = asyncio.Future(loop=self.loop)
+
+ async def foo():
+ nonlocal DONE
+ try:
+ yield 1
+ await fut
+ DONE += 1000
+ yield 2
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE += 1
+ DONE += 1000
+
+ async def run():
+ gen = foo()
+ it = gen.__aiter__()
+ self.assertEqual(await it.__anext__(), 1)
+ t = self.loop.create_task(it.__anext__())
+ await asyncio.sleep(0.01, loop=self.loop)
+ await gen.aclose()
+ return t
+
+ t = self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ # Silence ResourceWarnings
+ fut.cancel()
+ t.cancel()
+ self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))
+
+ def test_async_gen_asyncio_gc_aclose_09(self):
+ DONE = 0
+
+ async def gen():
+ nonlocal DONE
+ try:
+ while True:
+ yield 1
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+ await g.__anext__()
+ await g.__anext__()
+ del g
+
+ await asyncio.sleep(0.1, loop=self.loop)
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_aclose_10(self):
+ DONE = 0
+
+ # test synchronous generators
+ def foo():
+ try:
+ yield
+ except:
+ pass
+ g = foo()
+ g.send(None)
+ g.close()
+
+ # now with asynchronous generators
+
+ async def gen():
+ nonlocal DONE
+ try:
+ yield
+ except:
+ pass
+ DONE = 1
+
+ async def run():
+ nonlocal DONE
+ g = gen()
+ await g.asend(None)
+ await g.aclose()
+ DONE += 10
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 11)
+
+ def test_async_gen_asyncio_aclose_11(self):
+ DONE = 0
+
+ # test synchronous generators
+ def foo():
+ try:
+ yield
+ except:
+ pass
+ yield
+ g = foo()
+ g.send(None)
+ with self.assertRaisesRegex(RuntimeError, 'ignored GeneratorExit'):
+ g.close()
+
+ # now with asynchronous generators
+
+ async def gen():
+ nonlocal DONE
+ try:
+ yield
+ except:
+ pass
+ yield
+ DONE += 1
+
+ async def run():
+ nonlocal DONE
+ g = gen()
+ await g.asend(None)
+ with self.assertRaisesRegex(RuntimeError, 'ignored GeneratorExit'):
+ await g.aclose()
+ DONE += 10
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 10)
+
+ def test_async_gen_asyncio_asend_01(self):
+ DONE = 0
+
+ # Sanity check:
+ def sgen():
+ v = yield 1
+ yield v * 2
+ sg = sgen()
+ v = sg.send(None)
+ self.assertEqual(v, 1)
+ v = sg.send(100)
+ self.assertEqual(v, 200)
+
+ async def gen():
+ nonlocal DONE
+ try:
+ await asyncio.sleep(0.01, loop=self.loop)
+ v = yield 1
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield v * 2
+ await asyncio.sleep(0.01, loop=self.loop)
+ return
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+
+ v = await g.asend(100)
+ self.assertEqual(v, 200)
+
+ with self.assertRaises(StopAsyncIteration):
+ await g.asend(None)
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_asend_02(self):
+ DONE = 0
+
+ async def sleep_n_crash(delay):
+ await asyncio.sleep(delay, loop=self.loop)
+ 1 / 0
+
+ async def gen():
+ nonlocal DONE
+ try:
+ await asyncio.sleep(0.01, loop=self.loop)
+ v = yield 1
+ await sleep_n_crash(0.01)
+ DONE += 1000
+ yield v * 2
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+
+ await g.asend(100)
+
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_asend_03(self):
+ DONE = 0
+
+ async def sleep_n_crash(delay):
+ fut = asyncio.ensure_future(asyncio.sleep(delay, loop=self.loop),
+ loop=self.loop)
+ self.loop.call_later(delay / 2, lambda: fut.cancel())
+ return await fut
+
+ async def gen():
+ nonlocal DONE
+ try:
+ await asyncio.sleep(0.01, loop=self.loop)
+ v = yield 1
+ await sleep_n_crash(0.01)
+ DONE += 1000
+ yield v * 2
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+
+ await g.asend(100)
+
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_athrow_01(self):
+ DONE = 0
+
+ class FooEr(Exception):
+ pass
+
+ # Sanity check:
+ def sgen():
+ try:
+ v = yield 1
+ except FooEr:
+ v = 1000
+ yield v * 2
+ sg = sgen()
+ v = sg.send(None)
+ self.assertEqual(v, 1)
+ v = sg.throw(FooEr)
+ self.assertEqual(v, 2000)
+ with self.assertRaises(StopIteration):
+ sg.send(None)
+
+ async def gen():
+ nonlocal DONE
+ try:
+ await asyncio.sleep(0.01, loop=self.loop)
+ try:
+ v = yield 1
+ except FooEr:
+ v = 1000
+ await asyncio.sleep(0.01, loop=self.loop)
+ yield v * 2
+ await asyncio.sleep(0.01, loop=self.loop)
+ # return
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+
+ v = await g.athrow(FooEr)
+ self.assertEqual(v, 2000)
+
+ with self.assertRaises(StopAsyncIteration):
+ await g.asend(None)
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_athrow_02(self):
+ DONE = 0
+
+ class FooEr(Exception):
+ pass
+
+ async def sleep_n_crash(delay):
+ fut = asyncio.ensure_future(asyncio.sleep(delay, loop=self.loop),
+ loop=self.loop)
+ self.loop.call_later(delay / 2, lambda: fut.cancel())
+ return await fut
+
+ async def gen():
+ nonlocal DONE
+ try:
+ await asyncio.sleep(0.01, loop=self.loop)
+ try:
+ v = yield 1
+ except FooEr:
+ await sleep_n_crash(0.01)
+ yield v * 2
+ await asyncio.sleep(0.01, loop=self.loop)
+ # return
+ finally:
+ await asyncio.sleep(0.01, loop=self.loop)
+ await asyncio.sleep(0.01, loop=self.loop)
+ DONE = 1
+
+ async def run():
+ g = gen()
+
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+
+ try:
+ await g.athrow(FooEr)
+ except asyncio.CancelledError:
+ self.assertEqual(DONE, 1)
+ raise
+ else:
+ self.fail('CancelledError was not raised')
+
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 1)
+
+ def test_async_gen_asyncio_athrow_03(self):
+ DONE = 0
+
+ # test synchronous generators
+ def foo():
+ try:
+ yield
+ except:
+ pass
+ g = foo()
+ g.send(None)
+ with self.assertRaises(StopIteration):
+ g.throw(ValueError)
+
+ # now with asynchronous generators
+
+ async def gen():
+ nonlocal DONE
+ try:
+ yield
+ except:
+ pass
+ DONE = 1
+
+ async def run():
+ nonlocal DONE
+ g = gen()
+ await g.asend(None)
+ with self.assertRaises(StopAsyncIteration):
+ await g.athrow(ValueError)
+ DONE += 10
+
+ self.loop.run_until_complete(run())
+ self.assertEqual(DONE, 11)
+
+ def test_async_gen_asyncio_athrow_tuple(self):
+ async def gen():
+ try:
+ yield 1
+ except ZeroDivisionError:
+ yield (2,)
+
+ async def run():
+ g = gen()
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+ v = await g.athrow(ZeroDivisionError)
+ self.assertEqual(v, (2,))
+ with self.assertRaises(StopAsyncIteration):
+ await g.asend(None)
+
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_athrow_stopiteration(self):
+ async def gen():
+ try:
+ yield 1
+ except ZeroDivisionError:
+ yield StopIteration(2)
+
+ async def run():
+ g = gen()
+ v = await g.asend(None)
+ self.assertEqual(v, 1)
+ v = await g.athrow(ZeroDivisionError)
+ self.assertIsInstance(v, StopIteration)
+ self.assertEqual(v.value, 2)
+ with self.assertRaises(StopAsyncIteration):
+ await g.asend(None)
+
+ self.loop.run_until_complete(run())
+
+ def test_async_gen_asyncio_shutdown_01(self):
+ finalized = 0
+
+ async def waiter(timeout):
+ nonlocal finalized
+ try:
+ await asyncio.sleep(timeout, loop=self.loop)
+ yield 1
+ finally:
+ await asyncio.sleep(0, loop=self.loop)
+ finalized += 1
+
+ async def wait():
+ async for _ in waiter(1):
+ pass
+
+ t1 = self.loop.create_task(wait())
+ t2 = self.loop.create_task(wait())
+
+ self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
+
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.assertEqual(finalized, 2)
+
+ # Silence warnings
+ t1.cancel()
+ t2.cancel()
+ self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
+
+ def test_async_gen_asyncio_shutdown_02(self):
+ logged = 0
+
+ def logger(loop, context):
+ nonlocal logged
+ self.assertIn('asyncgen', context)
+ expected = 'an error occurred during closing of asynchronous'
+ if expected in context['message']:
+ logged += 1
+
+ async def waiter(timeout):
+ try:
+ await asyncio.sleep(timeout, loop=self.loop)
+ yield 1
+ finally:
+ 1 / 0
+
+ async def wait():
+ async for _ in waiter(1):
+ pass
+
+ t = self.loop.create_task(wait())
+ self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
+
+ self.loop.set_exception_handler(logger)
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+
+ self.assertEqual(logged, 1)
+
+ # Silence warnings
+ t.cancel()
+ self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
+
+if __name__ == "__main__":
+ unittest.main()