Refactor to check HTTP content-type of request body.

This commit is contained in:
Itamar Turner-Trauring 2022-03-14 11:16:09 -04:00
parent fef332754b
commit b6073b11c2
1 changed files with 19 additions and 5 deletions

View File

@ -2,7 +2,7 @@
HTTP server for storage.
"""
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple, Any
from functools import wraps
from base64 import b64decode
@ -23,7 +23,7 @@ from werkzeug.datastructures import ContentRange
from cbor2 import dumps, loads
from .server import StorageServer
from .http_common import swissnum_auth_header, Secrets
from .http_common import swissnum_auth_header, Secrets, get_content_type
from .common import si_a2b
from .immutable import BucketWriter, ConflictingWriteError
from ..util.hashutil import timing_safe_compare
@ -248,7 +248,9 @@ class HTTPServer(object):
return self._app.resource()
def _send_encoded(self, request, data):
"""Return encoded data, by default using CBOR."""
"""
Return encoded data as the HTTP body response, by default using CBOR.
"""
cbor_mime = "application/cbor"
accept_headers = request.requestHeaders.getRawHeaders("accept") or [cbor_mime]
accept = parse_accept_header(accept_headers[0])
@ -262,6 +264,18 @@ class HTTPServer(object):
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE)
def _read_encoded(self, request) -> Any:
"""
Read encoded request body data, decoding it with CBOR by default.
"""
content_type = get_content_type(request.requestHeaders)
if content_type == "application/cbor":
# TODO limit memory usage, client could send arbitrarily large data...
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return loads(request.content.read())
else:
raise _HTTPError(http.UNSUPPORTED_MEDIA_TYPE)
##### Generic APIs #####
@_authorized_route(_app, set(), "/v1/version", methods=["GET"])
@ -280,7 +294,7 @@ class HTTPServer(object):
def allocate_buckets(self, request, authorization, storage_index):
"""Allocate buckets."""
upload_secret = authorization[Secrets.UPLOAD]
info = loads(request.content.read())
info = self._read_encoded(request)
# We do NOT validate the upload secret for existing bucket uploads.
# Another upload may be happening in parallel, with a different upload
@ -480,6 +494,6 @@ class HTTPServer(object):
except KeyError:
raise _HTTPError(http.NOT_FOUND)
info = loads(request.content.read())
info = self._read_encoded(request)
bucket.advise_corrupt_share(info["reason"].encode("utf-8"))
return b""