Fix DelayedCall leak in tests.

This commit is contained in:
Itamar Turner-Trauring 2022-11-18 13:56:54 -05:00
parent 30a9877236
commit 4c0c75a034
2 changed files with 30 additions and 16 deletions

View File

@ -26,7 +26,6 @@ from twisted.internet.interfaces import (
IDelayedCall, IDelayedCall,
) )
from twisted.internet.ssl import CertificateOptions from twisted.internet.ssl import CertificateOptions
from twisted.internet import reactor
from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer from zope.interface import implementer
from hyperlink import DecodedURL from hyperlink import DecodedURL
@ -141,7 +140,9 @@ class _LengthLimitedCollector:
def limited_content( def limited_content(
response, max_length: int = 30 * 1024 * 1024, clock: IReactorTime = reactor response,
clock: IReactorTime,
max_length: int = 30 * 1024 * 1024,
) -> Deferred[BinaryIO]: ) -> Deferred[BinaryIO]:
""" """
Like ``treq.content()``, but limit data read from the response to a set Like ``treq.content()``, but limit data read from the response to a set
@ -168,11 +169,10 @@ def limited_content(
collector.f.seek(0) collector.f.seek(0)
return collector.f return collector.f
d.addCallback(done) return d.addCallback(done)
return d
def _decode_cbor(response, schema: Schema): def _decode_cbor(response, schema: Schema, clock: IReactorTime):
"""Given HTTP response, return decoded CBOR body.""" """Given HTTP response, return decoded CBOR body."""
def got_content(f: BinaryIO): def got_content(f: BinaryIO):
@ -183,7 +183,7 @@ def _decode_cbor(response, schema: Schema):
if response.code > 199 and response.code < 300: if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers) content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE: if content_type == CBOR_MIME_TYPE:
return limited_content(response).addCallback(got_content) return limited_content(response, clock).addCallback(got_content)
else: else:
raise ClientException(-1, "Server didn't send CBOR") raise ClientException(-1, "Server didn't send CBOR")
else: else:
@ -439,7 +439,9 @@ class StorageClientGeneral(object):
""" """
url = self._client.relative_url("/storage/v1/version") url = self._client.relative_url("/storage/v1/version")
response = yield self._client.request("GET", url) response = yield self._client.request("GET", url)
decoded_response = yield _decode_cbor(response, _SCHEMAS["get_version"]) decoded_response = yield _decode_cbor(
response, _SCHEMAS["get_version"], self._client._clock
)
returnValue(decoded_response) returnValue(decoded_response)
@inlineCallbacks @inlineCallbacks
@ -540,7 +542,7 @@ def read_share_chunk(
raise ValueError("Server sent more than we asked for?!") raise ValueError("Server sent more than we asked for?!")
# It might also send less than we asked for. That's (probably) OK, e.g. # It might also send less than we asked for. That's (probably) OK, e.g.
# if we went past the end of the file. # if we went past the end of the file.
body = yield limited_content(response, supposed_length, client._clock) body = yield limited_content(response, client._clock, supposed_length)
body.seek(0, SEEK_END) body.seek(0, SEEK_END)
actual_length = body.tell() actual_length = body.tell()
if actual_length != supposed_length: if actual_length != supposed_length:
@ -627,7 +629,9 @@ class StorageClientImmutables(object):
upload_secret=upload_secret, upload_secret=upload_secret,
message_to_serialize=message, message_to_serialize=message,
) )
decoded_response = yield _decode_cbor(response, _SCHEMAS["allocate_buckets"]) decoded_response = yield _decode_cbor(
response, _SCHEMAS["allocate_buckets"], self._client._clock
)
returnValue( returnValue(
ImmutableCreateResult( ImmutableCreateResult(
already_have=decoded_response["already-have"], already_have=decoded_response["already-have"],
@ -703,7 +707,9 @@ class StorageClientImmutables(object):
raise ClientException( raise ClientException(
response.code, response.code,
) )
body = yield _decode_cbor(response, _SCHEMAS["immutable_write_share_chunk"]) body = yield _decode_cbor(
response, _SCHEMAS["immutable_write_share_chunk"], self._client._clock
)
remaining = RangeMap() remaining = RangeMap()
for chunk in body["required"]: for chunk in body["required"]:
remaining.set(True, chunk["begin"], chunk["end"]) remaining.set(True, chunk["begin"], chunk["end"])
@ -732,7 +738,9 @@ class StorageClientImmutables(object):
url, url,
) )
if response.code == http.OK: if response.code == http.OK:
body = yield _decode_cbor(response, _SCHEMAS["list_shares"]) body = yield _decode_cbor(
response, _SCHEMAS["list_shares"], self._client._clock
)
returnValue(set(body)) returnValue(set(body))
else: else:
raise ClientException(response.code) raise ClientException(response.code)
@ -849,7 +857,9 @@ class StorageClientMutables:
message_to_serialize=message, message_to_serialize=message,
) )
if response.code == http.OK: if response.code == http.OK:
result = await _decode_cbor(response, _SCHEMAS["mutable_read_test_write"]) result = await _decode_cbor(
response, _SCHEMAS["mutable_read_test_write"], self._client._clock
)
return ReadTestWriteResult(success=result["success"], reads=result["data"]) return ReadTestWriteResult(success=result["success"], reads=result["data"])
else: else:
raise ClientException(response.code, (await response.content())) raise ClientException(response.code, (await response.content()))
@ -878,7 +888,9 @@ class StorageClientMutables:
) )
response = await self._client.request("GET", url) response = await self._client.request("GET", url)
if response.code == http.OK: if response.code == http.OK:
return await _decode_cbor(response, _SCHEMAS["mutable_list_shares"]) return await _decode_cbor(
response, _SCHEMAS["mutable_list_shares"], self._client._clock
)
else: else:
raise ClientException(response.code) raise ClientException(response.code)

View File

@ -371,7 +371,9 @@ class CustomHTTPServerTests(SyncTestCase):
) )
self.assertEqual( self.assertEqual(
result_of(limited_content(response, at_least_length)).read(), result_of(
limited_content(response, self._http_server.clock, at_least_length)
).read(),
gen_bytes(length), gen_bytes(length),
) )
@ -390,7 +392,7 @@ class CustomHTTPServerTests(SyncTestCase):
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
result_of(limited_content(response, too_short)) result_of(limited_content(response, self._http_server.clock, too_short))
def test_limited_content_silence_causes_timeout(self): def test_limited_content_silence_causes_timeout(self):
""" """
@ -404,7 +406,7 @@ class CustomHTTPServerTests(SyncTestCase):
) )
) )
body_deferred = limited_content(response, 4, self._http_server.clock) body_deferred = limited_content(response, self._http_server.clock, 4)
result = [] result = []
error = [] error = []
body_deferred.addCallbacks(result.append, error.append) body_deferred.addCallbacks(result.append, error.append)