Client-side schema validation.

This commit is contained in:
Itamar Turner-Trauring 2022-04-11 14:03:48 -04:00
parent dfad50b1c2
commit 4b20b67ce6
2 changed files with 78 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import attr
# TODO Make sure to import Python version? # TODO Make sure to import Python version?
from cbor2 import loads, dumps from cbor2 import loads, dumps
from pycddl import Schema
from collections_extended import RangeMap from collections_extended import RangeMap
from werkzeug.datastructures import Range, ContentRange from werkzeug.datastructures import Range, ContentRange
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
@ -36,14 +37,62 @@ class ClientException(Exception):
self.code = code self.code = code
def _decode_cbor(response): # Schemas for server responses.
#
# TODO usage of sets is inconsistent. Either use everywhere (and document in
# spec document) or use nowhere.
_SCHEMAS = {
"get_version": Schema(
"""
message = {'http://allmydata.org/tahoe/protocols/storage/v1' => {
'maximum-immutable-share-size' => uint
'maximum-mutable-share-size' => uint
'available-space' => uint
'tolerates-immutable-read-overrun' => bool
'delete-mutable-shares-with-zero-length-writev' => bool
'fills-holes-with-zero-bytes' => bool
'prevents-read-past-end-of-share-data' => bool
}
'application-version' => bstr
}
"""
),
"allocate_buckets": Schema(
"""
message = {
already-have: #6.258([* uint])
allocated: #6.258([* uint])
}
"""
),
"immutable_write_share_chunk": Schema(
"""
message = {
required: [* {begin: uint, end: uint}]
}
"""
),
"list_shares": Schema(
"""
message = [* uint]
"""
),
}
def _decode_cbor(response, schema: Schema):
"""Given HTTP response, return decoded CBOR body.""" """Given HTTP response, return decoded CBOR body."""
def got_content(data):
schema.validate_cbor(data)
return loads(data)
if response.code > 199 and response.code < 300: if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers) content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE: if content_type == CBOR_MIME_TYPE:
# TODO limit memory usage # TODO limit memory usage
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return treq.content(response).addCallback(loads) return treq.content(response).addCallback(got_content)
else: else:
raise ClientException(-1, "Server didn't send CBOR") raise ClientException(-1, "Server didn't send CBOR")
else: else:
@ -151,7 +200,7 @@ class StorageClientGeneral(object):
""" """
url = self._client.relative_url("/v1/version") url = self._client.relative_url("/v1/version")
response = yield self._client.request("GET", url) response = yield self._client.request("GET", url)
decoded_response = yield _decode_cbor(response) decoded_response = yield _decode_cbor(response, _SCHEMAS["get_version"])
returnValue(decoded_response) returnValue(decoded_response)
@ -209,7 +258,7 @@ class StorageClientImmutables(object):
upload_secret=upload_secret, upload_secret=upload_secret,
message_to_serialize=message, message_to_serialize=message,
) )
decoded_response = yield _decode_cbor(response) decoded_response = yield _decode_cbor(response, _SCHEMAS["allocate_buckets"])
returnValue( returnValue(
ImmutableCreateResult( ImmutableCreateResult(
already_have=decoded_response["already-have"], already_have=decoded_response["already-have"],
@ -281,7 +330,7 @@ class StorageClientImmutables(object):
raise ClientException( raise ClientException(
response.code, response.code,
) )
body = yield _decode_cbor(response) body = yield _decode_cbor(response, _SCHEMAS["immutable_write_share_chunk"])
remaining = RangeMap() remaining = RangeMap()
for chunk in body["required"]: for chunk in body["required"]:
remaining.set(True, chunk["begin"], chunk["end"]) remaining.set(True, chunk["begin"], chunk["end"])
@ -334,7 +383,7 @@ class StorageClientImmutables(object):
url, url,
) )
if response.code == http.OK: if response.code == http.OK:
body = yield _decode_cbor(response) body = yield _decode_cbor(response, _SCHEMAS["list_shares"])
returnValue(set(body)) returnValue(set(body))
else: else:
raise ClientException(response.code) raise ClientException(response.code)

View File

@ -18,6 +18,8 @@ from base64 import b64encode
from contextlib import contextmanager from contextlib import contextmanager
from os import urandom from os import urandom
from cbor2 import dumps
from pycddl import ValidationError as CDDLValidationError
from hypothesis import assume, given, strategies as st from hypothesis import assume, given, strategies as st
from fixtures import Fixture, TempDir from fixtures import Fixture, TempDir
from treq.testing import StubTreq from treq.testing import StubTreq
@ -49,7 +51,7 @@ from ..storage.http_client import (
StorageClientGeneral, StorageClientGeneral,
_encode_si, _encode_si,
) )
from ..storage.http_common import get_content_type from ..storage.http_common import get_content_type, CBOR_MIME_TYPE
from ..storage.common import si_b2a from ..storage.common import si_b2a
@ -239,6 +241,12 @@ class TestApp(object):
else: else:
return "BAD: {}".format(authorization) return "BAD: {}".format(authorization)
@_authorized_route(_app, set(), "/v1/version", methods=["GET"])
def bad_version(self, request, authorization):
"""Return version result that violates the expected schema."""
request.setHeader("content-type", CBOR_MIME_TYPE)
return dumps({"garbage": 123})
def result_of(d): def result_of(d):
""" """
@ -257,15 +265,15 @@ def result_of(d):
) )
class RoutingTests(SyncTestCase): class CustomHTTPServerTests(SyncTestCase):
""" """
Tests for the HTTP routing infrastructure. Tests that use a custom HTTP server.
""" """
def setUp(self): def setUp(self):
if PY2: if PY2:
self.skipTest("Not going to bother supporting Python 2") self.skipTest("Not going to bother supporting Python 2")
super(RoutingTests, self).setUp() super(CustomHTTPServerTests, self).setUp()
# Could be a fixture, but will only be used in this test class so not # Could be a fixture, but will only be used in this test class so not
# going to bother: # going to bother:
self._http_server = TestApp() self._http_server = TestApp()
@ -277,8 +285,8 @@ class RoutingTests(SyncTestCase):
def test_authorization_enforcement(self): def test_authorization_enforcement(self):
""" """
The requirement for secrets is enforced; if they are not given, a 400 The requirement for secrets is enforced by the ``_authorized_route``
response code is returned. decorator; if they are not given, a 400 response code is returned.
""" """
# Without secret, get a 400 error. # Without secret, get a 400 error.
response = result_of( response = result_of(
@ -298,6 +306,14 @@ class RoutingTests(SyncTestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
self.assertEqual(result_of(response.content()), b"GOOD SECRET") self.assertEqual(result_of(response.content()), b"GOOD SECRET")
def test_client_side_schema_validation(self):
"""
The client validates returned CBOR message against a schema.
"""
client = StorageClientGeneral(self.client)
with self.assertRaises(CDDLValidationError):
result_of(client.get_version())
class HttpTestFixture(Fixture): class HttpTestFixture(Fixture):
""" """
@ -413,7 +429,7 @@ class GenericHTTPAPITests(SyncTestCase):
) )
self.assertEqual(version, expected_version) self.assertEqual(version, expected_version)
def test_schema_validation(self): def test_server_side_schema_validation(self):
""" """
Ensure that schema validation is happening: invalid CBOR should result Ensure that schema validation is happening: invalid CBOR should result
in bad request response code (error 400). in bad request response code (error 400).