Make cancellation more likely to happen.

This commit is contained in:
Itamar Turner-Trauring 2023-02-27 11:37:18 -05:00
parent e09d19463d
commit 3d0b17bc1c
2 changed files with 72 additions and 44 deletions

View File

@ -47,7 +47,7 @@ from zope.interface import (
) )
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web import http from twisted.web import http
from twisted.internet.task import LoopingCall, deferLater from twisted.internet.task import LoopingCall
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.application import service from twisted.application import service
from twisted.plugin import ( from twisted.plugin import (
@ -935,42 +935,52 @@ class NativeStorageServer(service.MultiService):
self._reconnector.reset() self._reconnector.reset()
async def _pick_a_http_server( def _pick_a_http_server(
reactor, reactor,
nurls: list[DecodedURL], nurls: list[DecodedURL],
request: Callable[[Any, DecodedURL], defer.Deferred[Any]] request: Callable[[Any, DecodedURL], defer.Deferred[Any]]
) -> DecodedURL: ) -> defer.Deferred[Optional[DecodedURL]]:
"""Pick the first server we successfully send a request to.""" """Pick the first server we successfully send a request to.
while True:
result : defer.Deferred[Optional[DecodedURL]] = defer.Deferred()
def succeeded(nurl: DecodedURL, result=result): Fires with ``None`` if no server was found, or with the ``DecodedURL`` of
# Only need the first successful NURL: the first successfully-connected server.
if result.called: """
return
result.callback(nurl)
def failed(failure, failures=[], result=result): to_cancel : list[defer.Deferred] = []
# Logging errors breaks a bunch of tests, and it's not a _bug_ to
# have a failed connection, it's often expected and transient. More
# of a warning, really?
log.msg("Failed to connect to NURL: {}".format(failure))
failures.append(None)
if len(failures) == len(nurls):
# All our potential NURLs failed...
result.callback(None)
for index, nurl in enumerate(nurls): def cancel(result: Optional[defer.Deferred]):
request(reactor, nurl).addCallback( for d in to_cancel:
lambda _, nurl=nurl: nurl).addCallbacks(succeeded, failed) if not d.called:
d.cancel()
if result is not None:
result.errback(defer.CancelledError())
first_nurl = await result result : defer.Deferred[Optional[DecodedURL]] = defer.Deferred(canceller=cancel)
if first_nurl is None:
# Failed to connect to any of the NURLs, try again in a few def succeeded(nurl: DecodedURL, result=result):
# seconds: # Only need the first successful NURL:
await deferLater(reactor, 5, lambda: None) if result.called:
else: return
return first_nurl result.callback(nurl)
# No point in continuing other requests if we're connected:
cancel(None)
def failed(failure, failures=[], result=result):
# Logging errors breaks a bunch of tests, and it's not a _bug_ to
# have a failed connection, it's often expected and transient. More
# of a warning, really?
log.msg("Failed to connect to NURL: {}".format(failure))
failures.append(None)
if len(failures) == len(nurls):
# All our potential NURLs failed...
result.callback(None)
for index, nurl in enumerate(nurls):
d = request(reactor, nurl)
to_cancel.append(d)
d.addCallback(lambda _, nurl=nurl: nurl).addCallbacks(succeeded, failed)
return result
@implementer(IServer) @implementer(IServer)
@ -1117,8 +1127,22 @@ class HTTPNativeStorageServer(service.MultiService):
StorageClient.from_nurl(nurl, reactor) StorageClient.from_nurl(nurl, reactor)
).get_version() ).get_version()
nurl = await _pick_a_http_server(reactor, self._nurls, request) # LoopingCall.stop() doesn't cancel Deferreds, unfortunately:
self._istorage_server = _HTTPStorageServer.from_http_client( # https://github.com/twisted/twisted/issues/11814 Thus we want
# store the Deferred so it gets cancelled.
picking = _pick_a_http_server(reactor, self._nurls, request)
self._connecting_deferred = picking
try:
nurl = await picking
finally:
self._connecting_deferred = None
if nurl is None:
# We failed to find a server to connect to. Perhaps the next
# iteration of the loop will succeed.
return
else:
self._istorage_server = _HTTPStorageServer.from_http_client(
StorageClient.from_nurl(nurl, reactor) StorageClient.from_nurl(nurl, reactor)
) )

View File

@ -83,7 +83,6 @@ from allmydata.webish import (
WebishServer, WebishServer,
) )
from allmydata.util import base32, yamlutil from allmydata.util import base32, yamlutil
from allmydata.util.deferredutil import async_to_deferred
from allmydata.storage_client import ( from allmydata.storage_client import (
IFoolscapStorageServer, IFoolscapStorageServer,
NativeStorageServer, NativeStorageServer,
@ -741,7 +740,7 @@ storage:
class PickHTTPServerTests(unittest.SynchronousTestCase): class PickHTTPServerTests(unittest.SynchronousTestCase):
"""Tests for ``_pick_a_http_server``.""" """Tests for ``_pick_a_http_server``."""
def loop_until_result(self, url_to_results: dict[DecodedURL, list[tuple[float, Union[Exception, Any]]]]) -> Deferred[DecodedURL]: def loop_until_result(self, url_to_results: dict[DecodedURL, list[tuple[float, Union[Exception, Any]]]]) -> tuple[int, DecodedURL]:
""" """
Given mapping of URLs to list of (delay, result), return the URL of the Given mapping of URLs to list of (delay, result), return the URL of the
first selected server. first selected server.
@ -759,12 +758,15 @@ class PickHTTPServerTests(unittest.SynchronousTestCase):
reactor.callLater(delay, add_result_value) reactor.callLater(delay, add_result_value)
return result return result
d = async_to_deferred(_pick_a_http_server)( iterations = 0
clock, list(url_to_results.keys()), request while True:
) iterations += 1
for i in range(1000): d = _pick_a_http_server(clock, list(url_to_results.keys()), request)
clock.advance(0.1) for i in range(100):
return d clock.advance(0.1)
result = self.successResultOf(d)
if result is not None:
return iterations, result
def test_first_successful_connect_is_picked(self): def test_first_successful_connect_is_picked(self):
""" """
@ -772,11 +774,12 @@ class PickHTTPServerTests(unittest.SynchronousTestCase):
""" """
earliest_url = DecodedURL.from_text("http://a") earliest_url = DecodedURL.from_text("http://a")
latest_url = DecodedURL.from_text("http://b") latest_url = DecodedURL.from_text("http://b")
d = self.loop_until_result({ iterations, result = self.loop_until_result({
latest_url: [(2, None)], latest_url: [(2, None)],
earliest_url: [(1, None)] earliest_url: [(1, None)]
}) })
self.assertEqual(self.successResultOf(d), earliest_url) self.assertEqual(iterations, 1)
self.assertEqual(result, earliest_url)
def test_failures_are_retried(self): def test_failures_are_retried(self):
""" """
@ -785,10 +788,11 @@ class PickHTTPServerTests(unittest.SynchronousTestCase):
""" """
eventually_good_url = DecodedURL.from_text("http://good") eventually_good_url = DecodedURL.from_text("http://good")
bad_url = DecodedURL.from_text("http://bad") bad_url = DecodedURL.from_text("http://bad")
d = self.loop_until_result({ iterations, result = self.loop_until_result({
eventually_good_url: [ eventually_good_url: [
(1, ZeroDivisionError()), (0.1, ZeroDivisionError()), (1, None) (1, ZeroDivisionError()), (0.1, ZeroDivisionError()), (1, None)
], ],
bad_url: [(0.1, RuntimeError()), (0.1, RuntimeError()), (0.1, RuntimeError())] bad_url: [(0.1, RuntimeError()), (0.1, RuntimeError()), (0.1, RuntimeError())]
}) })
self.assertEqual(self.successResultOf(d), eventually_good_url) self.assertEqual(iterations, 3)
self.assertEqual(result, eventually_good_url)