Add http storage timeouts #1227

Merged
itamarst merged 21 commits from 3940-http-timeouts into master 2022-11-28 16:03:50 +00:00
6 changed files with 196 additions and 52 deletions

View File

@ -163,7 +163,9 @@ jobs:
matrix: matrix:
os: os:
- windows-latest - windows-latest
- ubuntu-latest # 22.04 has some issue with Tor at the moment:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3943
- ubuntu-20.04
python-version: python-version:
- 3.7 - 3.7
- 3.9 - 3.9
@ -175,7 +177,7 @@ jobs:
steps: steps:
- name: Install Tor [Ubuntu] - name: Install Tor [Ubuntu]
if: matrix.os == 'ubuntu-latest' if: ${{ contains(matrix.os, 'ubuntu') }}
run: sudo apt install tor run: sudo apt install tor
# TODO: See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3744. # TODO: See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3744.

0
newsfragments/3940.minor Normal file
View File

View File

@ -20,7 +20,11 @@ from twisted.web.http_headers import Headers
from twisted.web import http from twisted.web import http
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator, IReactorTime from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IReactorTime,
IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions from twisted.internet.ssl import CertificateOptions
from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer from zope.interface import implementer
@ -124,16 +128,22 @@ class _LengthLimitedCollector:
""" """
remaining_length: int remaining_length: int
timeout_on_silence: IDelayedCall
f: BytesIO = field(factory=BytesIO) f: BytesIO = field(factory=BytesIO)
def __call__(self, data: bytes): def __call__(self, data: bytes):
self.timeout_on_silence.reset(60)
self.remaining_length -= len(data) self.remaining_length -= len(data)
if self.remaining_length < 0: if self.remaining_length < 0:
raise ValueError("Response length was too long") raise ValueError("Response length was too long")
self.f.write(data) self.f.write(data)
def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]: def limited_content(
response,
clock: IReactorTime,
max_length: int = 30 * 1024 * 1024,
) -> Deferred[BinaryIO]:
""" """
Like ``treq.content()``, but limit data read from the response to a set Like ``treq.content()``, but limit data read from the response to a set
length. If the response is longer than the max allowed length, the result length. If the response is longer than the max allowed length, the result
@ -142,39 +152,29 @@ def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[Bi
A potentially useful future improvement would be using a temporary file to A potentially useful future improvement would be using a temporary file to
store the content; since filesystem buffering means that would use memory store the content; since filesystem buffering means that would use memory
for small responses and disk for large responses. for small responses and disk for large responses.
This will time out if no data is received for 60 seconds; so long as a
trickle of data continues to arrive, it will continue to run.
""" """
collector = _LengthLimitedCollector(max_length) d = succeed(None)
timeout = clock.callLater(60, d.cancel)
collector = _LengthLimitedCollector(max_length, timeout)
# Make really sure everything gets called in Deferred context, treq might # Make really sure everything gets called in Deferred context, treq might
# call collector directly... # call collector directly...
d = succeed(None)
d.addCallback(lambda _: treq.collect(response, collector)) d.addCallback(lambda _: treq.collect(response, collector))
def done(_): def done(_):
timeout.cancel()
collector.f.seek(0) collector.f.seek(0)
return collector.f return collector.f
d.addCallback(done) def failed(f):
return d if timeout.active():
timeout.cancel()
return f
return d.addCallbacks(done, failed)
def _decode_cbor(response, schema: Schema):
"""Given HTTP response, return decoded CBOR body."""
def got_content(f: BinaryIO):
data = f.read()
schema.validate_cbor(data)
return loads(data)
if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE:
return limited_content(response).addCallback(got_content)
else:
raise ClientException(-1, "Server didn't send CBOR")
else:
return treq.content(response).addCallback(
lambda data: fail(ClientException(response.code, response.phrase, data))
)
@define @define
@ -362,6 +362,7 @@ class StorageClient(object):
write_enabler_secret=None, write_enabler_secret=None,
headers=None, headers=None,
message_to_serialize=None, message_to_serialize=None,
timeout: float = 60,
**kwargs, **kwargs,
): ):
""" """
@ -370,6 +371,8 @@ class StorageClient(object):
If ``message_to_serialize`` is set, it will be serialized (by default If ``message_to_serialize`` is set, it will be serialized (by default
with CBOR) and set as the request body. with CBOR) and set as the request body.
Default timeout is 60 seconds.
""" """
exarkun commented 2022-11-23 15:43:42 +00:00 (Migrated from github.com)
Review

mypy considers int to be a subclass of float, so timeout: float = 60 is fine

mypy considers int to be a subclass of float, so `timeout: float = 60` is fine
headers = self._get_headers(headers) headers = self._get_headers(headers)
@ -401,7 +404,28 @@ class StorageClient(object):
kwargs["data"] = dumps(message_to_serialize) kwargs["data"] = dumps(message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE) headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
return self._treq.request(method, url, headers=headers, **kwargs) return self._treq.request(
method, url, headers=headers, timeout=timeout, **kwargs
)
def decode_cbor(self, response, schema: Schema):
"""Given HTTP response, return decoded CBOR body."""
def got_content(f: BinaryIO):
data = f.read()
schema.validate_cbor(data)
return loads(data)
if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE:
return limited_content(response, self._clock).addCallback(got_content)
else:
raise ClientException(-1, "Server didn't send CBOR")
else:
return treq.content(response).addCallback(
lambda data: fail(ClientException(response.code, response.phrase, data))
)
@define(hash=True) @define(hash=True)
@ -419,7 +443,9 @@ class StorageClientGeneral(object):
""" """
url = self._client.relative_url("/storage/v1/version") url = self._client.relative_url("/storage/v1/version")
response = yield self._client.request("GET", url) response = yield self._client.request("GET", url)
decoded_response = yield _decode_cbor(response, _SCHEMAS["get_version"]) decoded_response = yield self._client.decode_cbor(
response, _SCHEMAS["get_version"]
)
returnValue(decoded_response) returnValue(decoded_response)
@inlineCallbacks @inlineCallbacks
@ -486,6 +512,9 @@ def read_share_chunk(
share_type, _encode_si(storage_index), share_number share_type, _encode_si(storage_index), share_number
) )
) )
# The default 60 second timeout is for getting the response, so it doesn't
# include the time it takes to download the body... so we will will deal
# with that later, via limited_content().
response = yield client.request( response = yield client.request(
"GET", "GET",
url, url,
@ -494,6 +523,7 @@ def read_share_chunk(
# but Range constructor does that the conversion for us. # but Range constructor does that the conversion for us.
{"range": [Range("bytes", [(offset, offset + length)]).to_header()]} {"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
), ),
unbuffered=True, # Don't buffer the response in memory.
) )
if response.code == http.NO_CONTENT: if response.code == http.NO_CONTENT:
@ -516,7 +546,7 @@ def read_share_chunk(
raise ValueError("Server sent more than we asked for?!") raise ValueError("Server sent more than we asked for?!")
# It might also send less than we asked for. That's (probably) OK, e.g. # It might also send less than we asked for. That's (probably) OK, e.g.
# if we went past the end of the file. # if we went past the end of the file.
body = yield limited_content(response, supposed_length) body = yield limited_content(response, client._clock, supposed_length)
body.seek(0, SEEK_END) body.seek(0, SEEK_END)
actual_length = body.tell() actual_length = body.tell()
if actual_length != supposed_length: if actual_length != supposed_length:
@ -603,7 +633,9 @@ class StorageClientImmutables(object):
upload_secret=upload_secret, upload_secret=upload_secret,
message_to_serialize=message, message_to_serialize=message,
) )
decoded_response = yield _decode_cbor(response, _SCHEMAS["allocate_buckets"]) decoded_response = yield self._client.decode_cbor(
response, _SCHEMAS["allocate_buckets"]
)
returnValue( returnValue(
ImmutableCreateResult( ImmutableCreateResult(
already_have=decoded_response["already-have"], already_have=decoded_response["already-have"],
@ -679,7 +711,9 @@ class StorageClientImmutables(object):
raise ClientException( raise ClientException(
response.code, response.code,
) )
body = yield _decode_cbor(response, _SCHEMAS["immutable_write_share_chunk"]) body = yield self._client.decode_cbor(
response, _SCHEMAS["immutable_write_share_chunk"]
)
remaining = RangeMap() remaining = RangeMap()
for chunk in body["required"]: for chunk in body["required"]:
remaining.set(True, chunk["begin"], chunk["end"]) remaining.set(True, chunk["begin"], chunk["end"])
@ -708,7 +742,7 @@ class StorageClientImmutables(object):
url, url,
) )
if response.code == http.OK: if response.code == http.OK:
body = yield _decode_cbor(response, _SCHEMAS["list_shares"]) body = yield self._client.decode_cbor(response, _SCHEMAS["list_shares"])
returnValue(set(body)) returnValue(set(body))
else: else:
raise ClientException(response.code) raise ClientException(response.code)
@ -825,7 +859,9 @@ class StorageClientMutables:
message_to_serialize=message, message_to_serialize=message,
) )
if response.code == http.OK: if response.code == http.OK:
result = await _decode_cbor(response, _SCHEMAS["mutable_read_test_write"]) result = await self._client.decode_cbor(
response, _SCHEMAS["mutable_read_test_write"]
)
return ReadTestWriteResult(success=result["success"], reads=result["data"]) return ReadTestWriteResult(success=result["success"], reads=result["data"])
else: else:
raise ClientException(response.code, (await response.content())) raise ClientException(response.code, (await response.content()))
@ -854,7 +890,9 @@ class StorageClientMutables:
) )
response = await self._client.request("GET", url) response = await self._client.request("GET", url)
if response.code == http.OK: if response.code == http.OK:
return await _decode_cbor(response, _SCHEMAS["mutable_list_shares"]) return await self._client.decode_cbor(
response, _SCHEMAS["mutable_list_shares"]
)
else: else:
raise ClientException(response.code) raise ClientException(response.code)

View File

@ -20,6 +20,7 @@ from foolscap.api import flushEventualQueue
from allmydata import client from allmydata import client
from allmydata.introducer.server import create_introducer from allmydata.introducer.server import create_introducer
from allmydata.util import fileutil, log, pollmixin from allmydata.util import fileutil, log, pollmixin
from allmydata.util.deferredutil import async_to_deferred
from allmydata.storage import http_client from allmydata.storage import http_client
from allmydata.storage_client import ( from allmydata.storage_client import (
NativeStorageServer, NativeStorageServer,
@ -639,6 +640,40 @@ def _render_section_values(values):
)) ))
@async_to_deferred
async def spin_until_cleanup_done(value=None, timeout=10):
"""
At the end of the test, spin until the reactor has no more DelayedCalls
and file descriptors (or equivalents) registered. This prevents dirty
reactor errors, while also not hard-coding a fixed amount of time, so it
can finish faster on faster computers.
There is also a timeout: if it takes more than 10 seconds (by default) for
the remaining reactor state to clean itself up, the presumption is that it
will never get cleaned up and the spinning stops.
Make sure to run as last thing in tearDown.
"""
def num_fds():
if hasattr(reactor, "handles"):
# IOCP!
return len(reactor.handles)
else:
# Normal reactor; having internal readers still registered is fine,
# that's not our code.
return len(
set(reactor.getReaders()) - set(reactor._internalReaders)
) + len(reactor.getWriters())
for i in range(timeout * 1000):
# There's a single DelayedCall for AsynchronousDeferredRunTest's
# timeout...
if (len(reactor.getDelayedCalls()) < 2 and num_fds() == 0):
break
await deferLater(reactor, 0.001)
return value
class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin): class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
# If set to True, use Foolscap for storage protocol. If set to False, HTTP # If set to True, use Foolscap for storage protocol. If set to False, HTTP
@ -685,7 +720,7 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
d = self.sparent.stopService() d = self.sparent.stopService()
d.addBoth(flush_but_dont_ignore) d.addBoth(flush_but_dont_ignore)
d.addBoth(lambda x: self.close_idle_http_connections().addCallback(lambda _: x)) d.addBoth(lambda x: self.close_idle_http_connections().addCallback(lambda _: x))
d.addBoth(lambda x: deferLater(reactor, 2, lambda: x)) d.addBoth(spin_until_cleanup_done)
return d return d
def getdir(self, subdir): def getdir(self, subdir):

View File

@ -31,6 +31,8 @@ from klein import Klein
from hyperlink import DecodedURL from hyperlink import DecodedURL
from collections_extended import RangeMap from collections_extended import RangeMap
from twisted.internet.task import Clock, Cooperator from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IReactorTime
from twisted.internet.defer import CancelledError, Deferred
from twisted.web import http from twisted.web import http
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from werkzeug import routing from werkzeug import routing
@ -245,6 +247,7 @@ def gen_bytes(length: int) -> bytes:
class TestApp(object): class TestApp(object):
"""HTTP API for testing purposes.""" """HTTP API for testing purposes."""
clock: IReactorTime
_app = Klein() _app = Klein()
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using _swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
@ -266,6 +269,25 @@ class TestApp(object):
"""Return bytes to the given length using ``gen_bytes()``.""" """Return bytes to the given length using ``gen_bytes()``."""
return gen_bytes(length) return gen_bytes(length)
@_authorized_route(_app, set(), "/slowly_never_finish_result", methods=["GET"])
def slowly_never_finish_result(self, request, authorization):
"""
Send data immediately, after 59 seconds, after another 59 seconds, and then
never again, without finishing the response.
"""
request.write(b"a")
self.clock.callLater(59, request.write, b"b")
self.clock.callLater(59 + 59, request.write, b"c")
return Deferred()
@_authorized_route(_app, set(), "/die_unfinished", methods=["GET"])
def die(self, request, authorization):
"""
Dies half-way.
"""
request.transport.loseConnection()
return Deferred()
def result_of(d): def result_of(d):
""" """
@ -298,12 +320,18 @@ class CustomHTTPServerTests(SyncTestCase):
# Could be a fixture, but will only be used in this test class so not # Could be a fixture, but will only be used in this test class so not
# going to bother: # going to bother:
self._http_server = TestApp() self._http_server = TestApp()
treq = StubTreq(self._http_server._app.resource())
self.client = StorageClient( self.client = StorageClient(
DecodedURL.from_text("http://127.0.0.1"), DecodedURL.from_text("http://127.0.0.1"),
SWISSNUM_FOR_TEST, SWISSNUM_FOR_TEST,
treq=StubTreq(self._http_server._app.resource()), treq=treq,
clock=Clock(), # We're using a Treq private API to get the reactor, alas, but only
# in a test, so not going to worry about it too much. This would be
# fixed if https://github.com/twisted/treq/issues/226 were ever
# fixed.
clock=treq._agent._memoryReactor,
) )
self._http_server.clock = self.client._clock
def test_authorization_enforcement(self): def test_authorization_enforcement(self):
""" """
@ -351,7 +379,9 @@ class CustomHTTPServerTests(SyncTestCase):
) )
self.assertEqual( self.assertEqual(
result_of(limited_content(response, at_least_length)).read(), result_of(
limited_content(response, self._http_server.clock, at_least_length)
).read(),
gen_bytes(length), gen_bytes(length),
) )
@ -370,7 +400,52 @@ class CustomHTTPServerTests(SyncTestCase):
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
result_of(limited_content(response, too_short)) result_of(limited_content(response, self._http_server.clock, too_short))
def test_limited_content_silence_causes_timeout(self):
"""
``http_client.limited_content() times out if it receives no data for 60
seconds.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/slowly_never_finish_result",
)
)
body_deferred = limited_content(response, self._http_server.clock, 4)
result = []
error = []
body_deferred.addCallbacks(result.append, error.append)
for i in range(59 + 59 + 60):
self.assertEqual((result, error), ([], []))
self._http_server.clock.advance(1)
# Push data between in-memory client and in-memory server:
self.client._treq._agent.flush()
# After 59 (second write) + 59 (third write) + 60 seconds (quiescent
# timeout) the limited_content() response times out.
self.assertTrue(error)
with self.assertRaises(CancelledError):
error[0].raiseException()
def test_limited_content_cancels_timeout_on_failed_response(self):
"""
If the response fails somehow, the timeout is still cancelled.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/die",
)
)
d = limited_content(response, self._http_server.clock, 4)
with self.assertRaises(ValueError):
result_of(d)
self.assertEqual(len(self._http_server.clock.getDelayedCalls()), 0)
class HttpTestFixture(Fixture): class HttpTestFixture(Fixture):

View File

@ -12,7 +12,7 @@ from cryptography import x509
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.task import deferLater from twisted.internet.defer import maybeDeferred
from twisted.web.server import Site from twisted.web.server import Site
from twisted.web.static import Data from twisted.web.static import Data
from twisted.web.client import Agent, HTTPConnectionPool, ResponseNeverReceived from twisted.web.client import Agent, HTTPConnectionPool, ResponseNeverReceived
@ -30,6 +30,7 @@ from ..storage.http_common import get_spki_hash
from ..storage.http_client import _StorageClientHTTPSPolicy from ..storage.http_client import _StorageClientHTTPSPolicy
from ..storage.http_server import _TLSEndpointWrapper from ..storage.http_server import _TLSEndpointWrapper
from ..util.deferredutil import async_to_deferred from ..util.deferredutil import async_to_deferred
from .common_system import spin_until_cleanup_done
class HTTPSNurlTests(SyncTestCase): class HTTPSNurlTests(SyncTestCase):
@ -87,6 +88,10 @@ class PinningHTTPSValidation(AsyncTestCase):
self.addCleanup(self._port_assigner.tearDown) self.addCleanup(self._port_assigner.tearDown)
return AsyncTestCase.setUp(self) return AsyncTestCase.setUp(self)
def tearDown(self):
d = maybeDeferred(AsyncTestCase.tearDown, self)
return d.addCallback(lambda _: spin_until_cleanup_done())
@asynccontextmanager @asynccontextmanager
async def listen(self, private_key_path: FilePath, cert_path: FilePath): async def listen(self, private_key_path: FilePath, cert_path: FilePath):
""" """
@ -107,9 +112,6 @@ class PinningHTTPSValidation(AsyncTestCase):
yield f"https://127.0.0.1:{listening_port.getHost().port}/" yield f"https://127.0.0.1:{listening_port.getHost().port}/"
finally: finally:
await listening_port.stopListening() await listening_port.stopListening()
# Make sure all server connections are closed :( No idea why this
# is necessary when it's not for IStorageServer HTTPS tests.
await deferLater(reactor, 0.01)
def request(self, url: str, expected_certificate: x509.Certificate): def request(self, url: str, expected_certificate: x509.Certificate):
""" """
@ -144,10 +146,6 @@ class PinningHTTPSValidation(AsyncTestCase):
response = await self.request(url, certificate) response = await self.request(url, certificate)
self.assertEqual(await response.content(), b"YOYODYNE") self.assertEqual(await response.content(), b"YOYODYNE")
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
# to wait for it to finish.
await deferLater(reactor, 0.01)
@async_to_deferred @async_to_deferred
async def test_server_certificate_has_wrong_hash(self): async def test_server_certificate_has_wrong_hash(self):
""" """
@ -202,10 +200,6 @@ class PinningHTTPSValidation(AsyncTestCase):
response = await self.request(url, certificate) response = await self.request(url, certificate)
self.assertEqual(await response.content(), b"YOYODYNE") self.assertEqual(await response.content(), b"YOYODYNE")
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
# to wait for it to finish.
await deferLater(reactor, 0.001)
# A potential attack to test is a private key that doesn't match the # A potential attack to test is a private key that doesn't match the
# certificate... but OpenSSL (quite rightly) won't let you listen with that # certificate... but OpenSSL (quite rightly) won't let you listen with that
# so I don't know how to test that! See # so I don't know how to test that! See