Refactor to check HTTP content-type of request body.

This commit is contained in:
Itamar Turner-Trauring 2022-03-14 11:16:09 -04:00
parent fef332754b
commit b6073b11c2
1 changed files with 19 additions and 5 deletions

View File

@ -2,7 +2,7 @@
HTTP server for storage. HTTP server for storage.
""" """
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple, Any
from functools import wraps from functools import wraps
from base64 import b64decode from base64 import b64decode
@ -23,7 +23,7 @@ from werkzeug.datastructures import ContentRange
from cbor2 import dumps, loads from cbor2 import dumps, loads
from .server import StorageServer from .server import StorageServer
from .http_common import swissnum_auth_header, Secrets from .http_common import swissnum_auth_header, Secrets, get_content_type
from .common import si_a2b from .common import si_a2b
from .immutable import BucketWriter, ConflictingWriteError from .immutable import BucketWriter, ConflictingWriteError
from ..util.hashutil import timing_safe_compare from ..util.hashutil import timing_safe_compare
@ -248,7 +248,9 @@ class HTTPServer(object):
return self._app.resource() return self._app.resource()
def _send_encoded(self, request, data): def _send_encoded(self, request, data):
"""Return encoded data, by default using CBOR.""" """
Return encoded data as the HTTP body response, by default using CBOR.
"""
cbor_mime = "application/cbor" cbor_mime = "application/cbor"
accept_headers = request.requestHeaders.getRawHeaders("accept") or [cbor_mime] accept_headers = request.requestHeaders.getRawHeaders("accept") or [cbor_mime]
accept = parse_accept_header(accept_headers[0]) accept = parse_accept_header(accept_headers[0])
@ -262,6 +264,18 @@ class HTTPServer(object):
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE) raise _HTTPError(http.NOT_ACCEPTABLE)
def _read_encoded(self, request) -> Any:
"""
Read encoded request body data, decoding it with CBOR by default.
"""
content_type = get_content_type(request.requestHeaders)
if content_type == "application/cbor":
# TODO limit memory usage, client could send arbitrarily large data...
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return loads(request.content.read())
else:
raise _HTTPError(http.UNSUPPORTED_MEDIA_TYPE)
##### Generic APIs ##### ##### Generic APIs #####
@_authorized_route(_app, set(), "/v1/version", methods=["GET"]) @_authorized_route(_app, set(), "/v1/version", methods=["GET"])
@ -280,7 +294,7 @@ class HTTPServer(object):
def allocate_buckets(self, request, authorization, storage_index): def allocate_buckets(self, request, authorization, storage_index):
"""Allocate buckets.""" """Allocate buckets."""
upload_secret = authorization[Secrets.UPLOAD] upload_secret = authorization[Secrets.UPLOAD]
info = loads(request.content.read()) info = self._read_encoded(request)
# We do NOT validate the upload secret for existing bucket uploads. # We do NOT validate the upload secret for existing bucket uploads.
# Another upload may be happening in parallel, with a different upload # Another upload may be happening in parallel, with a different upload
@ -480,6 +494,6 @@ class HTTPServer(object):
except KeyError: except KeyError:
raise _HTTPError(http.NOT_FOUND) raise _HTTPError(http.NOT_FOUND)
info = loads(request.content.read()) info = self._read_encoded(request)
bucket.advise_corrupt_share(info["reason"].encode("utf-8")) bucket.advise_corrupt_share(info["reason"].encode("utf-8"))
return b"" return b""