Flatten outputs for chained boolean operations where possible Add special-case handling to flatten the data structure when chaining multiple AND/OR together; for example, make f_and(f_and(x, y), z) work approximately in the same way as f_and(x, y, z). The goal here is to avoid a stack overflow when using a pattern which can chain a large number of futures. Previously, an overflow could start to happen if the chain was a few hundred futures deep, since each future's completion callback would set the result on the next future which would invoke another completion callback and so on throughout the chain. Fixes #320
diff --git a/more_executors/_impl/futures/bool.py b/more_executors/_impl/futures/bool.py index 422d38b..8d3a8d4 100644 --- a/more_executors/_impl/futures/bool.py +++ b/more_executors/_impl/futures/bool.py
@@ -13,15 +13,77 @@ LOG = LogWrapper(logging.getLogger("more_executors.futures")) +# Our futures here know whether they are an AND or OR future and +# keep a reference back to the underlying boolean oper. +# This allows us to be more efficient in certain scenarios. +class BoolFuture(Future): + def __init__(self, oper): + self.oper = oper + super(BoolFuture, self).__init__() + + +class AndFuture(BoolFuture): + pass + + +class OrFuture(BoolFuture): + pass + + class BoolOperation(object): + FUTURE_CLASS = BoolFuture + def __init__(self, fs): + future_types = {} + remainder = [] + self.fs = {} for f in fs: self.fs[f] = True + if isinstance(f, BoolFuture): + future_types.setdefault(type(f), []).append(f.oper) + else: + remainder.append(f) + + if list(future_types.keys()) == [self.FUTURE_CLASS]: + # Special case: we are being called with input futures + # which are themselves the output of f_and/f_or, and all + # of the same type as our current operation, for example: + # f = f_and(...) + # f = f_and(f, ...) + # f = f_and(f, ...) + # f = f_and(f, ...) + # ...and so on. + # + # In that case: rather than keeping the input futures as + # we normally do, we can look inside of them and pull out + # their constituent futures. This allows us to return a more + # flat data structure, reducing the number of chained futures. + # It's useful to do this because a large number of chained + # futures can lead to a stack overflow as completion callbacks + # are invoked. + self.fs = {} + for oper in future_types[self.FUTURE_CLASS]: + with oper.lock: + if oper.done: + # FIXME: seems like we're still introducing a (hopefully + # very small) chance of stack overflow here. + # We could be in the small window of time where oper.done + # is true but oper.out.set_result hasn't been called. + # If so, then we're unflattening the structure again. + # In theory it could happen for enough futures to cause + # an overflow. Is there a practical way to rule it out + # entirely? + self.fs[oper.out] = True + else: + self.fs.update(oper.fs) + for f in remainder: + self.fs[f] = True + fs = list(self.fs.keys()) self.done = False self.lock = Lock() - self.out = Future() + self.out = self.FUTURE_CLASS(self) for f in fs: chain_cancel(self.out, f) @@ -53,6 +115,8 @@ class OrOperation(BoolOperation): + FUTURE_CLASS = OrFuture + def get_state_update(self, f): set_result = False set_exception = False @@ -111,6 +175,8 @@ class AndOperation(BoolOperation): + FUTURE_CLASS = AndFuture + def get_state_update(self, f): set_result = False set_exception = False
diff --git a/tests/futures/test_and.py b/tests/futures/test_and.py index 44f2c7b..2418000 100644 --- a/tests/futures/test_and.py +++ b/tests/futures/test_and.py
@@ -31,6 +31,11 @@ ] +def delay_then(delay, value): + time.sleep(delay) + return value + + @pytest.mark.parametrize("inputs, expected_result", cases) def test_and(inputs, expected_result, falsey, truthy): inputs = resolve_inputs(inputs, falsey, truthy) @@ -51,10 +56,6 @@ def test_and_order_async(): executor = Executors.thread_pool(max_workers=2) - def delay_then(delay, value): - time.sleep(delay) - return value - f_inputs = [ executor.submit(delay_then, 0.1, 123), executor.submit(delay_then, 0.05, 456), @@ -168,3 +169,23 @@ inputs = [f_return(True) for _ in range(0, 100000)] assert f_and(*inputs).result() is True + + +def test_and_chain_true(): + with Executors.thread_pool() as exec: + f = exec.submit(delay_then, 0.4, 123) + + for _ in range(0, 10000): + f = f_and(f_return(True), f) + + assert f.result() in (123, True) + + +def test_and_chain_false(): + with Executors.thread_pool() as exec: + f = exec.submit(delay_then, 0.4, 123) + + for _ in range(0, 10000): + f = f_and(f_return(False), f) + + assert f.result() == False
diff --git a/tests/futures/test_or.py b/tests/futures/test_or.py index c61894d..ae04870 100644 --- a/tests/futures/test_or.py +++ b/tests/futures/test_or.py
@@ -19,6 +19,11 @@ LOG = logging.getLogger("test_or") +def delay_then(delay, value): + time.sleep(delay) + return value + + cases = [ [(falsey,), falsey], [(truthy,), truthy], @@ -178,3 +183,23 @@ inputs = [f_return(0) for _ in range(0, 100000)] assert f_or(*inputs).result() == 0 + + +def test_or_chain_true(): + with Executors.thread_pool() as exec: + f = exec.submit(delay_then, 0.4, 123) + + for _ in range(0, 10000): + f = f_or(f_return(True), f) + + assert f.result() is True + + +def test_or_chain_false(): + with Executors.thread_pool() as exec: + f = exec.submit(delay_then, 0.4, False) + + for _ in range(0, 10000): + f = f_or(f_return(False), f) + + assert f.result() is False