Add http storage timeouts #1227
|
@ -163,7 +163,9 @@ jobs:
|
|||
matrix:
|
||||
os:
|
||||
- windows-latest
|
||||
- ubuntu-latest
|
||||
# 22.04 has some issue with Tor at the moment:
|
||||
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3943
|
||||
- ubuntu-20.04
|
||||
python-version:
|
||||
- 3.7
|
||||
- 3.9
|
||||
|
@ -175,7 +177,7 @@ jobs:
|
|||
steps:
|
||||
|
||||
- name: Install Tor [Ubuntu]
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
if: ${{ contains(matrix.os, 'ubuntu') }}
|
||||
run: sudo apt install tor
|
||||
|
||||
# TODO: See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3744.
|
||||
|
|
|
@ -20,7 +20,11 @@ from twisted.web.http_headers import Headers
|
|||
from twisted.web import http
|
||||
from twisted.web.iweb import IPolicyForHTTPS
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed
|
||||
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator, IReactorTime
|
||||
from twisted.internet.interfaces import (
|
||||
IOpenSSLClientConnectionCreator,
|
||||
IReactorTime,
|
||||
IDelayedCall,
|
||||
)
|
||||
from twisted.internet.ssl import CertificateOptions
|
||||
from twisted.web.client import Agent, HTTPConnectionPool
|
||||
from zope.interface import implementer
|
||||
|
@ -124,16 +128,22 @@ class _LengthLimitedCollector:
|
|||
"""
|
||||
|
||||
remaining_length: int
|
||||
timeout_on_silence: IDelayedCall
|
||||
f: BytesIO = field(factory=BytesIO)
|
||||
|
||||
def __call__(self, data: bytes):
|
||||
self.timeout_on_silence.reset(60)
|
||||
self.remaining_length -= len(data)
|
||||
if self.remaining_length < 0:
|
||||
raise ValueError("Response length was too long")
|
||||
self.f.write(data)
|
||||
|
||||
|
||||
def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]:
|
||||
def limited_content(
|
||||
response,
|
||||
clock: IReactorTime,
|
||||
max_length: int = 30 * 1024 * 1024,
|
||||
) -> Deferred[BinaryIO]:
|
||||
"""
|
||||
Like ``treq.content()``, but limit data read from the response to a set
|
||||
length. If the response is longer than the max allowed length, the result
|
||||
|
@ -142,39 +152,29 @@ def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[Bi
|
|||
A potentially useful future improvement would be using a temporary file to
|
||||
store the content; since filesystem buffering means that would use memory
|
||||
for small responses and disk for large responses.
|
||||
|
||||
This will time out if no data is received for 60 seconds; so long as a
|
||||
trickle of data continues to arrive, it will continue to run.
|
||||
"""
|
||||
collector = _LengthLimitedCollector(max_length)
|
||||
d = succeed(None)
|
||||
timeout = clock.callLater(60, d.cancel)
|
||||
collector = _LengthLimitedCollector(max_length, timeout)
|
||||
|
||||
# Make really sure everything gets called in Deferred context, treq might
|
||||
# call collector directly...
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: treq.collect(response, collector))
|
||||
|
||||
def done(_):
|
||||
timeout.cancel()
|
||||
collector.f.seek(0)
|
||||
return collector.f
|
||||
|
||||
d.addCallback(done)
|
||||
return d
|
||||
def failed(f):
|
||||
if timeout.active():
|
||||
timeout.cancel()
|
||||
return f
|
||||
|
||||
|
||||
def _decode_cbor(response, schema: Schema):
|
||||
"""Given HTTP response, return decoded CBOR body."""
|
||||
|
||||
def got_content(f: BinaryIO):
|
||||
data = f.read()
|
||||
schema.validate_cbor(data)
|
||||
return loads(data)
|
||||
|
||||
if response.code > 199 and response.code < 300:
|
||||
content_type = get_content_type(response.headers)
|
||||
if content_type == CBOR_MIME_TYPE:
|
||||
return limited_content(response).addCallback(got_content)
|
||||
else:
|
||||
raise ClientException(-1, "Server didn't send CBOR")
|
||||
else:
|
||||
return treq.content(response).addCallback(
|
||||
lambda data: fail(ClientException(response.code, response.phrase, data))
|
||||
)
|
||||
return d.addCallbacks(done, failed)
|
||||
|
||||
|
||||
@define
|
||||
|
@ -362,6 +362,7 @@ class StorageClient(object):
|
|||
write_enabler_secret=None,
|
||||
headers=None,
|
||||
message_to_serialize=None,
|
||||
timeout: float = 60,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -370,6 +371,8 @@ class StorageClient(object):
|
|||
|
||||
If ``message_to_serialize`` is set, it will be serialized (by default
|
||||
with CBOR) and set as the request body.
|
||||
|
||||
Default timeout is 60 seconds.
|
||||
"""
|
||||
|
||||
headers = self._get_headers(headers)
|
||||
|
||||
|
@ -401,7 +404,28 @@ class StorageClient(object):
|
|||
kwargs["data"] = dumps(message_to_serialize)
|
||||
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
|
||||
|
||||
return self._treq.request(method, url, headers=headers, **kwargs)
|
||||
return self._treq.request(
|
||||
method, url, headers=headers, timeout=timeout, **kwargs
|
||||
)
|
||||
|
||||
def decode_cbor(self, response, schema: Schema):
|
||||
"""Given HTTP response, return decoded CBOR body."""
|
||||
|
||||
def got_content(f: BinaryIO):
|
||||
data = f.read()
|
||||
schema.validate_cbor(data)
|
||||
return loads(data)
|
||||
|
||||
if response.code > 199 and response.code < 300:
|
||||
content_type = get_content_type(response.headers)
|
||||
if content_type == CBOR_MIME_TYPE:
|
||||
return limited_content(response, self._clock).addCallback(got_content)
|
||||
else:
|
||||
raise ClientException(-1, "Server didn't send CBOR")
|
||||
else:
|
||||
return treq.content(response).addCallback(
|
||||
lambda data: fail(ClientException(response.code, response.phrase, data))
|
||||
)
|
||||
|
||||
|
||||
@define(hash=True)
|
||||
|
@ -419,7 +443,9 @@ class StorageClientGeneral(object):
|
|||
"""
|
||||
url = self._client.relative_url("/storage/v1/version")
|
||||
response = yield self._client.request("GET", url)
|
||||
decoded_response = yield _decode_cbor(response, _SCHEMAS["get_version"])
|
||||
decoded_response = yield self._client.decode_cbor(
|
||||
response, _SCHEMAS["get_version"]
|
||||
)
|
||||
returnValue(decoded_response)
|
||||
|
||||
@inlineCallbacks
|
||||
|
@ -486,6 +512,9 @@ def read_share_chunk(
|
|||
share_type, _encode_si(storage_index), share_number
|
||||
)
|
||||
)
|
||||
# The default 60 second timeout is for getting the response, so it doesn't
|
||||
# include the time it takes to download the body... so we will will deal
|
||||
# with that later, via limited_content().
|
||||
response = yield client.request(
|
||||
"GET",
|
||||
url,
|
||||
|
@ -494,6 +523,7 @@ def read_share_chunk(
|
|||
# but Range constructor does that the conversion for us.
|
||||
{"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
|
||||
),
|
||||
unbuffered=True, # Don't buffer the response in memory.
|
||||
)
|
||||
|
||||
if response.code == http.NO_CONTENT:
|
||||
|
@ -516,7 +546,7 @@ def read_share_chunk(
|
|||
raise ValueError("Server sent more than we asked for?!")
|
||||
# It might also send less than we asked for. That's (probably) OK, e.g.
|
||||
# if we went past the end of the file.
|
||||
body = yield limited_content(response, supposed_length)
|
||||
body = yield limited_content(response, client._clock, supposed_length)
|
||||
body.seek(0, SEEK_END)
|
||||
actual_length = body.tell()
|
||||
if actual_length != supposed_length:
|
||||
|
@ -603,7 +633,9 @@ class StorageClientImmutables(object):
|
|||
upload_secret=upload_secret,
|
||||
message_to_serialize=message,
|
||||
)
|
||||
decoded_response = yield _decode_cbor(response, _SCHEMAS["allocate_buckets"])
|
||||
decoded_response = yield self._client.decode_cbor(
|
||||
response, _SCHEMAS["allocate_buckets"]
|
||||
)
|
||||
returnValue(
|
||||
ImmutableCreateResult(
|
||||
already_have=decoded_response["already-have"],
|
||||
|
@ -679,7 +711,9 @@ class StorageClientImmutables(object):
|
|||
raise ClientException(
|
||||
response.code,
|
||||
)
|
||||
body = yield _decode_cbor(response, _SCHEMAS["immutable_write_share_chunk"])
|
||||
body = yield self._client.decode_cbor(
|
||||
response, _SCHEMAS["immutable_write_share_chunk"]
|
||||
)
|
||||
remaining = RangeMap()
|
||||
for chunk in body["required"]:
|
||||
remaining.set(True, chunk["begin"], chunk["end"])
|
||||
|
@ -708,7 +742,7 @@ class StorageClientImmutables(object):
|
|||
url,
|
||||
)
|
||||
if response.code == http.OK:
|
||||
body = yield _decode_cbor(response, _SCHEMAS["list_shares"])
|
||||
body = yield self._client.decode_cbor(response, _SCHEMAS["list_shares"])
|
||||
returnValue(set(body))
|
||||
else:
|
||||
raise ClientException(response.code)
|
||||
|
@ -825,7 +859,9 @@ class StorageClientMutables:
|
|||
message_to_serialize=message,
|
||||
)
|
||||
if response.code == http.OK:
|
||||
result = await _decode_cbor(response, _SCHEMAS["mutable_read_test_write"])
|
||||
result = await self._client.decode_cbor(
|
||||
response, _SCHEMAS["mutable_read_test_write"]
|
||||
)
|
||||
return ReadTestWriteResult(success=result["success"], reads=result["data"])
|
||||
else:
|
||||
raise ClientException(response.code, (await response.content()))
|
||||
|
@ -854,7 +890,9 @@ class StorageClientMutables:
|
|||
)
|
||||
response = await self._client.request("GET", url)
|
||||
if response.code == http.OK:
|
||||
return await _decode_cbor(response, _SCHEMAS["mutable_list_shares"])
|
||||
return await self._client.decode_cbor(
|
||||
response, _SCHEMAS["mutable_list_shares"]
|
||||
)
|
||||
else:
|
||||
raise ClientException(response.code)
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from foolscap.api import flushEventualQueue
|
|||
from allmydata import client
|
||||
from allmydata.introducer.server import create_introducer
|
||||
from allmydata.util import fileutil, log, pollmixin
|
||||
from allmydata.util.deferredutil import async_to_deferred
|
||||
from allmydata.storage import http_client
|
||||
from allmydata.storage_client import (
|
||||
NativeStorageServer,
|
||||
|
@ -639,6 +640,40 @@ def _render_section_values(values):
|
|||
))
|
||||
|
||||
|
||||
@async_to_deferred
|
||||
async def spin_until_cleanup_done(value=None, timeout=10):
|
||||
"""
|
||||
At the end of the test, spin until the reactor has no more DelayedCalls
|
||||
and file descriptors (or equivalents) registered. This prevents dirty
|
||||
reactor errors, while also not hard-coding a fixed amount of time, so it
|
||||
can finish faster on faster computers.
|
||||
|
||||
There is also a timeout: if it takes more than 10 seconds (by default) for
|
||||
the remaining reactor state to clean itself up, the presumption is that it
|
||||
will never get cleaned up and the spinning stops.
|
||||
|
||||
Make sure to run as last thing in tearDown.
|
||||
"""
|
||||
def num_fds():
|
||||
if hasattr(reactor, "handles"):
|
||||
# IOCP!
|
||||
return len(reactor.handles)
|
||||
else:
|
||||
# Normal reactor; having internal readers still registered is fine,
|
||||
# that's not our code.
|
||||
return len(
|
||||
set(reactor.getReaders()) - set(reactor._internalReaders)
|
||||
) + len(reactor.getWriters())
|
||||
|
||||
for i in range(timeout * 1000):
|
||||
# There's a single DelayedCall for AsynchronousDeferredRunTest's
|
||||
# timeout...
|
||||
if (len(reactor.getDelayedCalls()) < 2 and num_fds() == 0):
|
||||
break
|
||||
await deferLater(reactor, 0.001)
|
||||
return value
|
||||
|
||||
|
||||
class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
|
||||
|
||||
# If set to True, use Foolscap for storage protocol. If set to False, HTTP
|
||||
|
@ -685,7 +720,7 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
|
|||
d = self.sparent.stopService()
|
||||
d.addBoth(flush_but_dont_ignore)
|
||||
d.addBoth(lambda x: self.close_idle_http_connections().addCallback(lambda _: x))
|
||||
d.addBoth(lambda x: deferLater(reactor, 2, lambda: x))
|
||||
d.addBoth(spin_until_cleanup_done)
|
||||
return d
|
||||
|
||||
def getdir(self, subdir):
|
||||
|
|
|
@ -31,6 +31,8 @@ from klein import Klein
|
|||
from hyperlink import DecodedURL
|
||||
from collections_extended import RangeMap
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.web import http
|
||||
from twisted.web.http_headers import Headers
|
||||
from werkzeug import routing
|
||||
|
@ -245,6 +247,7 @@ def gen_bytes(length: int) -> bytes:
|
|||
class TestApp(object):
|
||||
"""HTTP API for testing purposes."""
|
||||
|
||||
clock: IReactorTime
|
||||
_app = Klein()
|
||||
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
|
||||
|
||||
|
@ -266,6 +269,25 @@ class TestApp(object):
|
|||
"""Return bytes to the given length using ``gen_bytes()``."""
|
||||
return gen_bytes(length)
|
||||
|
||||
@_authorized_route(_app, set(), "/slowly_never_finish_result", methods=["GET"])
|
||||
def slowly_never_finish_result(self, request, authorization):
|
||||
"""
|
||||
Send data immediately, after 59 seconds, after another 59 seconds, and then
|
||||
never again, without finishing the response.
|
||||
"""
|
||||
request.write(b"a")
|
||||
self.clock.callLater(59, request.write, b"b")
|
||||
self.clock.callLater(59 + 59, request.write, b"c")
|
||||
return Deferred()
|
||||
|
||||
@_authorized_route(_app, set(), "/die_unfinished", methods=["GET"])
|
||||
def die(self, request, authorization):
|
||||
"""
|
||||
Dies half-way.
|
||||
"""
|
||||
request.transport.loseConnection()
|
||||
return Deferred()
|
||||
|
||||
|
||||
def result_of(d):
|
||||
"""
|
||||
|
@ -298,12 +320,18 @@ class CustomHTTPServerTests(SyncTestCase):
|
|||
# Could be a fixture, but will only be used in this test class so not
|
||||
# going to bother:
|
||||
self._http_server = TestApp()
|
||||
treq = StubTreq(self._http_server._app.resource())
|
||||
self.client = StorageClient(
|
||||
DecodedURL.from_text("http://127.0.0.1"),
|
||||
SWISSNUM_FOR_TEST,
|
||||
treq=StubTreq(self._http_server._app.resource()),
|
||||
clock=Clock(),
|
||||
treq=treq,
|
||||
# We're using a Treq private API to get the reactor, alas, but only
|
||||
# in a test, so not going to worry about it too much. This would be
|
||||
# fixed if https://github.com/twisted/treq/issues/226 were ever
|
||||
# fixed.
|
||||
clock=treq._agent._memoryReactor,
|
||||
)
|
||||
self._http_server.clock = self.client._clock
|
||||
|
||||
def test_authorization_enforcement(self):
|
||||
"""
|
||||
|
@ -351,7 +379,9 @@ class CustomHTTPServerTests(SyncTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(
|
||||
result_of(limited_content(response, at_least_length)).read(),
|
||||
result_of(
|
||||
limited_content(response, self._http_server.clock, at_least_length)
|
||||
).read(),
|
||||
gen_bytes(length),
|
||||
)
|
||||
|
||||
|
@ -370,7 +400,52 @@ class CustomHTTPServerTests(SyncTestCase):
|
|||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
result_of(limited_content(response, too_short))
|
||||
result_of(limited_content(response, self._http_server.clock, too_short))
|
||||
|
||||
def test_limited_content_silence_causes_timeout(self):
|
||||
"""
|
||||
``http_client.limited_content() times out if it receives no data for 60
|
||||
seconds.
|
||||
"""
|
||||
response = result_of(
|
||||
self.client.request(
|
||||
"GET",
|
||||
"http://127.0.0.1/slowly_never_finish_result",
|
||||
)
|
||||
)
|
||||
|
||||
body_deferred = limited_content(response, self._http_server.clock, 4)
|
||||
result = []
|
||||
error = []
|
||||
body_deferred.addCallbacks(result.append, error.append)
|
||||
|
||||
for i in range(59 + 59 + 60):
|
||||
self.assertEqual((result, error), ([], []))
|
||||
self._http_server.clock.advance(1)
|
||||
# Push data between in-memory client and in-memory server:
|
||||
self.client._treq._agent.flush()
|
||||
|
||||
# After 59 (second write) + 59 (third write) + 60 seconds (quiescent
|
||||
# timeout) the limited_content() response times out.
|
||||
self.assertTrue(error)
|
||||
with self.assertRaises(CancelledError):
|
||||
error[0].raiseException()
|
||||
|
||||
def test_limited_content_cancels_timeout_on_failed_response(self):
|
||||
"""
|
||||
If the response fails somehow, the timeout is still cancelled.
|
||||
"""
|
||||
response = result_of(
|
||||
self.client.request(
|
||||
"GET",
|
||||
"http://127.0.0.1/die",
|
||||
)
|
||||
)
|
||||
|
||||
d = limited_content(response, self._http_server.clock, 4)
|
||||
with self.assertRaises(ValueError):
|
||||
result_of(d)
|
||||
self.assertEqual(len(self._http_server.clock.getDelayedCalls()), 0)
|
||||
|
||||
|
||||
class HttpTestFixture(Fixture):
|
||||
|
|
|
@ -12,7 +12,7 @@ from cryptography import x509
|
|||
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.internet.defer import maybeDeferred
|
||||
from twisted.web.server import Site
|
||||
from twisted.web.static import Data
|
||||
from twisted.web.client import Agent, HTTPConnectionPool, ResponseNeverReceived
|
||||
|
@ -30,6 +30,7 @@ from ..storage.http_common import get_spki_hash
|
|||
from ..storage.http_client import _StorageClientHTTPSPolicy
|
||||
from ..storage.http_server import _TLSEndpointWrapper
|
||||
from ..util.deferredutil import async_to_deferred
|
||||
from .common_system import spin_until_cleanup_done
|
||||
|
||||
|
||||
class HTTPSNurlTests(SyncTestCase):
|
||||
|
@ -87,6 +88,10 @@ class PinningHTTPSValidation(AsyncTestCase):
|
|||
self.addCleanup(self._port_assigner.tearDown)
|
||||
return AsyncTestCase.setUp(self)
|
||||
|
||||
def tearDown(self):
|
||||
d = maybeDeferred(AsyncTestCase.tearDown, self)
|
||||
return d.addCallback(lambda _: spin_until_cleanup_done())
|
||||
|
||||
@asynccontextmanager
|
||||
async def listen(self, private_key_path: FilePath, cert_path: FilePath):
|
||||
"""
|
||||
|
@ -107,9 +112,6 @@ class PinningHTTPSValidation(AsyncTestCase):
|
|||
yield f"https://127.0.0.1:{listening_port.getHost().port}/"
|
||||
finally:
|
||||
await listening_port.stopListening()
|
||||
# Make sure all server connections are closed :( No idea why this
|
||||
# is necessary when it's not for IStorageServer HTTPS tests.
|
||||
await deferLater(reactor, 0.01)
|
||||
|
||||
def request(self, url: str, expected_certificate: x509.Certificate):
|
||||
"""
|
||||
|
@ -144,10 +146,6 @@ class PinningHTTPSValidation(AsyncTestCase):
|
|||
response = await self.request(url, certificate)
|
||||
self.assertEqual(await response.content(), b"YOYODYNE")
|
||||
|
||||
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
|
||||
# to wait for it to finish.
|
||||
await deferLater(reactor, 0.01)
|
||||
|
||||
@async_to_deferred
|
||||
async def test_server_certificate_has_wrong_hash(self):
|
||||
"""
|
||||
|
@ -202,10 +200,6 @@ class PinningHTTPSValidation(AsyncTestCase):
|
|||
response = await self.request(url, certificate)
|
||||
self.assertEqual(await response.content(), b"YOYODYNE")
|
||||
|
||||
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
|
||||
# to wait for it to finish.
|
||||
await deferLater(reactor, 0.001)
|
||||
|
||||
# A potential attack to test is a private key that doesn't match the
|
||||
# certificate... but OpenSSL (quite rightly) won't let you listen with that
|
||||
# so I don't know how to test that! See
|
||||
|
|
Loading…
Reference in New Issue
mypy considers int to be a subclass of float, so
timeout: float = 60
is fine