immutable: refactor download to do only download-and-decode, not decryption

FileDownloader takes a verify cap and produces ciphertext, instead of taking a read cap and producing plaintext.
FileDownloader does all integrity checking including the mandatory ciphertext hash tree and the optional ciphertext flat hash, rather than expecting its target to do some of that checking.
Rename immutable.download.Output to immutable.download.DecryptingOutput. An instance of DecryptingOutput can be passed to FileDownloader to use as the latter's target.  Text pushed to the DecryptingOutput is decrypted and then pushed to *its* target.
DecryptingOutput satisfies the IConsumer interface, and if its target also satisfies IConsumer, then it forwards and pause/unpause signals to its producer (which is the FileDownloader).
This patch also changes some logging code to use the new logging mixin class.
Check integrity of a segment and decrypt the segment one block-sized buffer at a time instead of copying the buffers together into one segment-sized buffer (reduces peak memory usage, I think, and is probably a tad faster/less CPU, depending on your encoding parameters).
Refactor FileDownloader so that processing of segments and of tail-segment share as much code is possible.
FileDownloader and FileNode take caps as instances of URI (Python objects), not as strings.
This commit is contained in:
Zooko O'Whielacronx 2009-01-08 11:53:49 -07:00
parent 9bba578776
commit 600196f571
5 changed files with 103 additions and 161 deletions

View File

@ -361,7 +361,7 @@ class Client(node.Node, pollmixin.PollMixin):
else: else:
key = base32.b2a(u.storage_index) key = base32.b2a(u.storage_index)
cachefile = self.download_cache.get_file(key) cachefile = self.download_cache.get_file(key)
node = FileNode(u.to_string(), self, cachefile) # CHK node = FileNode(u, self, cachefile) # CHK
else: else:
assert IMutableFileURI.providedBy(u), u assert IMutableFileURI.providedBy(u), u
node = MutableFileNode(self).init_from_uri(u) node = MutableFileNode(self).init_from_uri(u)

View File

@ -42,75 +42,26 @@ class DownloadResults:
self.timings = {} self.timings = {}
self.file_size = None self.file_size = None
class Output: class DecryptingTarget(log.PrefixingLogMixin):
def __init__(self, downloadable, key, total_length, log_parent, implements(IDownloadTarget, IConsumer)
download_status): def __init__(self, downloadable, key, _log_msg_id=None):
precondition(IDownloadTarget.providedBy(downloadable), downloadable)
self.downloadable = downloadable self.downloadable = downloadable
self._decryptor = AES(key) self._decryptor = AES(key)
self._crypttext_hasher = hashutil.crypttext_hasher() prefix = str(downloadable)
self.length = 0 log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix)
self.total_length = total_length def registerProducer(self, producer, streaming):
self._segment_number = 0 if IConsumer.providedBy(self.downloadable):
self._crypttext_hash_tree = None self.downloadable.registerProducer(producer, streaming)
self._opened = False def unregisterProducer(self):
self._log_parent = log_parent if IConsumer.providedBy(self.downloadable):
self._status = download_status self.downloadable.unregisterProducer()
self._status.set_progress(0.0) def write(self, ciphertext):
plaintext = self._decryptor.process(ciphertext)
def log(self, *args, **kwargs):
if "parent" not in kwargs:
kwargs["parent"] = self._log_parent
if "facility" not in kwargs:
kwargs["facility"] = "download.output"
return log.msg(*args, **kwargs)
def got_crypttext_hash_tree(self, crypttext_hash_tree):
self._crypttext_hash_tree = crypttext_hash_tree
def write_segment(self, crypttext):
self.length += len(crypttext)
self._status.set_progress( float(self.length) / self.total_length )
# memory footprint: 'crypttext' is the only segment_size usage
# outstanding. While we decrypt it into 'plaintext', we hit
# 2*segment_size.
self._crypttext_hasher.update(crypttext)
if self._crypttext_hash_tree:
ch = hashutil.crypttext_segment_hasher()
ch.update(crypttext)
crypttext_leaves = {self._segment_number: ch.digest()}
self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
bytes=len(crypttext),
segnum=self._segment_number, hash=base32.b2a(ch.digest()),
level=log.NOISY)
self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
plaintext = self._decryptor.process(crypttext)
del crypttext
# now we're back down to 1*segment_size.
self._segment_number += 1
# We're still at 1*segment_size. The Downloadable is responsible for
# any memory usage beyond this.
if not self._opened:
self._opened = True
self.downloadable.open(self.total_length)
self.downloadable.write(plaintext) self.downloadable.write(plaintext)
def open(self, size):
def fail(self, why): self.downloadable.open(size)
# this is really unusual, and deserves maximum forensics
if why.check(DownloadStopped):
# except DownloadStopped just means the consumer aborted the
# download, not so scary
self.log("download stopped", level=log.UNUSUAL)
else:
self.log("download failed!", failure=why,
level=log.SCARY, umid="lp1vaQ")
self.downloadable.fail(why)
def close(self): def close(self):
self.crypttext_hash = self._crypttext_hasher.digest()
self.log("download finished, closing IDownloadable", level=log.NOISY)
self.downloadable.close() self.downloadable.close()
def finish(self): def finish(self):
return self.downloadable.finish() return self.downloadable.finish()
@ -653,11 +604,14 @@ class DownloadStatus:
self.results = value self.results = value
class FileDownloader(log.PrefixingLogMixin): class FileDownloader(log.PrefixingLogMixin):
""" I download shares, check their integrity, then decode them, check the integrity of the
resulting ciphertext, then and write it to my target. """
implements(IPushProducer) implements(IPushProducer)
_status = None _status = None
def __init__(self, client, u, downloadable): def __init__(self, client, u, downloadable):
precondition(isinstance(u, uri.CHKFileURI), u) precondition(IVerifierURI.providedBy(u), u)
precondition(IDownloadTarget.providedBy(downloadable), downloadable)
prefix=base32.b2a_l(u.get_storage_index()[:8], 60) prefix=base32.b2a_l(u.get_storage_index()[:8], 60)
log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix) log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
@ -691,8 +645,7 @@ class FileDownloader(log.PrefixingLogMixin):
if IConsumer.providedBy(downloadable): if IConsumer.providedBy(downloadable):
downloadable.registerProducer(self, True) downloadable.registerProducer(self, True)
self._downloadable = downloadable self._downloadable = downloadable
self._output = Output(downloadable, u.key, self._uri.size, self._parentmsgid, self._opened = False
self._status)
self.active_buckets = {} # k: shnum, v: bucket self.active_buckets = {} # k: shnum, v: bucket
self._share_buckets = [] # list of (sharenum, bucket) tuples self._share_buckets = [] # list of (sharenum, bucket) tuples
@ -700,8 +653,15 @@ class FileDownloader(log.PrefixingLogMixin):
self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, } self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
self._share_hash_tree = None self._ciphertext_hasher = hashutil.crypttext_hasher()
self._crypttext_hash_tree = None
self._bytes_done = 0
self._status.set_progress(float(self._bytes_done)/self._uri.size)
# _got_uri_extension() will create the following:
# self._crypttext_hash_tree
# self._share_hash_tree
# self._current_segnum = 0
def pauseProducing(self): def pauseProducing(self):
if self._paused: if self._paused:
@ -730,7 +690,6 @@ class FileDownloader(log.PrefixingLogMixin):
self._status.set_active(False) self._status.set_active(False)
def start(self): def start(self):
assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
self.log("starting download") self.log("starting download")
# first step: who should we download from? # first step: who should we download from?
@ -754,7 +713,12 @@ class FileDownloader(log.PrefixingLogMixin):
if self._status: if self._status:
self._status.set_status("Failed") self._status.set_status("Failed")
self._status.set_active(False) self._status.set_active(False)
self._output.fail(why) if why.check(DownloadStopped):
# DownloadStopped just means the consumer aborted the download; not so scary.
self.log("download stopped", level=log.UNUSUAL)
else:
# This is really unusual, and deserves maximum forensics.
self.log("download failed!", failure=why, level=log.SCARY, umid="lp1vaQ")
return why return why
d.addErrback(_failed) d.addErrback(_failed)
d.addCallback(self._done) d.addCallback(self._done)
@ -818,7 +782,6 @@ class FileDownloader(log.PrefixingLogMixin):
del self._share_vbuckets[shnum] del self._share_vbuckets[shnum]
def _got_all_shareholders(self, res): def _got_all_shareholders(self, res):
assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
if self._results: if self._results:
now = time.time() now = time.time()
self._results.timings["peer_selection"] = now - self._started self._results.timings["peer_selection"] = now - self._started
@ -832,7 +795,6 @@ class FileDownloader(log.PrefixingLogMixin):
# "vb is %s but should be a ValidatedReadBucketProxy" % (vb,) # "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
def _obtain_uri_extension(self, ignored): def _obtain_uri_extension(self, ignored):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
# all shareholders are supposed to have a copy of uri_extension, and # all shareholders are supposed to have a copy of uri_extension, and
# all are supposed to be identical. We compute the hash of the data # all are supposed to be identical. We compute the hash of the data
# that comes back, and compare it against the version in our URI. If # that comes back, and compare it against the version in our URI. If
@ -844,7 +806,7 @@ class FileDownloader(log.PrefixingLogMixin):
vups = [] vups = []
for sharenum, bucket in self._share_buckets: for sharenum, bucket in self._share_buckets:
vups.append(ValidatedExtendedURIProxy(bucket, self._uri.get_verify_cap(), self._fetch_failures)) vups.append(ValidatedExtendedURIProxy(bucket, self._uri, self._fetch_failures))
vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid) vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid)
d = vto.start() d = vto.start()
@ -886,7 +848,6 @@ class FileDownloader(log.PrefixingLogMixin):
def _got_crypttext_hash_tree(res): def _got_crypttext_hash_tree(res):
# Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated # Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated
# with hashes. # with hashes.
self._output.got_crypttext_hash_tree(self._crypttext_hash_tree)
if self._results: if self._results:
elapsed = time.time() - _get_crypttext_hash_tree_started elapsed = time.time() - _get_crypttext_hash_tree_started
self._results.timings["hashtrees"] = elapsed self._results.timings["hashtrees"] = elapsed
@ -896,7 +857,6 @@ class FileDownloader(log.PrefixingLogMixin):
def _activate_enough_buckets(self): def _activate_enough_buckets(self):
"""either return a mapping from shnum to a ValidatedReadBucketProxy that can """either return a mapping from shnum to a ValidatedReadBucketProxy that can
provide data for that share, or raise NotEnoughSharesError""" provide data for that share, or raise NotEnoughSharesError"""
assert isinstance(self._uri, uri.CHKFileURI), self._uri
while len(self.active_buckets) < self._uri.needed_shares: while len(self.active_buckets) < self._uri.needed_shares:
# need some more # need some more
@ -934,12 +894,11 @@ class FileDownloader(log.PrefixingLogMixin):
self._started_fetching = time.time() self._started_fetching = time.time()
d = defer.succeed(None) d = defer.succeed(None)
for segnum in range(self._vup.num_segments-1): for segnum in range(self._vup.num_segments):
d.addCallback(self._download_segment, segnum) d.addCallback(self._download_segment, segnum)
# this pause, at the end of write, prevents pre-fetch from # this pause, at the end of write, prevents pre-fetch from
# happening until the consumer is ready for more data. # happening until the consumer is ready for more data.
d.addCallback(self._check_for_pause) d.addCallback(self._check_for_pause)
d.addCallback(self._download_tail_segment, self._vup.num_segments-1)
return d return d
def _check_for_pause(self, res): def _check_for_pause(self, res):
@ -952,7 +911,6 @@ class FileDownloader(log.PrefixingLogMixin):
return res return res
def _download_segment(self, res, segnum): def _download_segment(self, res, segnum):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
if self._status: if self._status:
self._status.set_status("Downloading segment %d of %d" % self._status.set_status("Downloading segment %d of %d" %
(segnum+1, self._vup.num_segments)) (segnum+1, self._vup.num_segments))
@ -979,8 +937,11 @@ class FileDownloader(log.PrefixingLogMixin):
return res return res
if self._results: if self._results:
d.addCallback(_started_decode) d.addCallback(_started_decode)
d.addCallback(lambda (shares, shareids): if segnum + 1 == self._vup.num_segments:
self._codec.decode(shares, shareids)) codec = self._tail_codec
else:
codec = self._codec
d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids))
# once the codec is done, we drop back to 1*segment_size, because # once the codec is done, we drop back to 1*segment_size, because
# 'shares' goes out of scope. The memory usage is all in the # 'shares' goes out of scope. The memory usage is all in the
# plaintext now, spread out into a bunch of tiny buffers. # plaintext now, spread out into a bunch of tiny buffers.
@ -993,91 +954,66 @@ class FileDownloader(log.PrefixingLogMixin):
# pause/check-for-stop just before writing, to honor stopProducing # pause/check-for-stop just before writing, to honor stopProducing
d.addCallback(self._check_for_pause) d.addCallback(self._check_for_pause)
def _done(buffers): d.addCallback(self._got_segment)
# we start by joining all these buffers together into a single
# string. This makes Output.write easier, since it wants to hash
# data one segment at a time anyways, and doesn't impact our
# memory footprint since we're already peaking at 2*segment_size
# inside the codec a moment ago.
segment = "".join(buffers)
del buffers
# we're down to 1*segment_size right now, but write_segment()
# will decrypt a copy of the segment internally, which will push
# us up to 2*segment_size while it runs.
started_decrypt = time.time()
self._output.write_segment(segment)
if self._results:
elapsed = time.time() - started_decrypt
self._results.timings["cumulative_decrypt"] += elapsed
d.addCallback(_done)
return d return d
def _download_tail_segment(self, res, segnum): def _got_segment(self, buffers):
assert isinstance(self._uri, uri.CHKFileURI), self._uri precondition(self._crypttext_hash_tree)
self.log("downloading seg#%d of %d (%d%%)"
% (segnum, self._vup.num_segments,
100.0 * segnum / self._vup.num_segments))
segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares,
self._results)
started = time.time()
d = segmentdler.start()
def _finished_fetching(res):
elapsed = time.time() - started
self._results.timings["cumulative_fetch"] += elapsed
return res
if self._results:
d.addCallback(_finished_fetching)
# pause before using more memory
d.addCallback(self._check_for_pause)
def _started_decode(res):
self._started_decode = time.time()
return res
if self._results:
d.addCallback(_started_decode)
d.addCallback(lambda (shares, shareids):
self._tail_codec.decode(shares, shareids))
def _finished_decode(res):
elapsed = time.time() - self._started_decode
self._results.timings["cumulative_decode"] += elapsed
return res
if self._results:
d.addCallback(_finished_decode)
# pause/check-for-stop just before writing, to honor stopProducing
d.addCallback(self._check_for_pause)
def _done(buffers):
# trim off any padding added by the upload side
segment = "".join(buffers)
del buffers
# we never send empty segments. If the data was an exact multiple
# of the segment size, the last segment will be full.
pad_size = mathutil.pad_size(self._uri.size, self._vup.segment_size)
tail_size = self._vup.segment_size - pad_size
segment = segment[:tail_size]
started_decrypt = time.time() started_decrypt = time.time()
self._output.write_segment(segment) self._status.set_progress(float(self._current_segnum)/self._uri.size)
if self._current_segnum + 1 == self._vup.num_segments:
# This is the last segment.
# Trim off any padding added by the upload side. We never send empty segments. If
# the data was an exact multiple of the segment size, the last segment will be full.
tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._uri.needed_shares)
num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size)
# Remove buffers which don't contain any part of the tail.
del buffers[num_buffers_used:]
# Remove the past-the-tail-part of the last buffer.
tail_in_last_buf = self._vup.tail_data_size % tail_buf_size
if tail_in_last_buf == 0:
tail_in_last_buf = tail_buf_size
buffers[-1] = buffers[-1][:tail_in_last_buf]
# First compute the hash of this segment and check that it fits.
ch = hashutil.crypttext_segment_hasher()
for buffer in buffers:
self._ciphertext_hasher.update(buffer)
ch.update(buffer)
self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()})
# Then write this segment to the target.
if not self._opened:
self._opened = True
self._downloadable.open(self._uri.size)
for buffer in buffers:
self._downloadable.write(buffer)
self._bytes_done += len(buffer)
self._status.set_progress(float(self._bytes_done)/self._uri.size)
self._current_segnum += 1
if self._results: if self._results:
elapsed = time.time() - started_decrypt elapsed = time.time() - started_decrypt
self._results.timings["cumulative_decrypt"] += elapsed self._results.timings["cumulative_decrypt"] += elapsed
d.addCallback(_done)
return d
def _done(self, res): def _done(self, res):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
self.log("download done") self.log("download done")
if self._results: if self._results:
now = time.time() now = time.time()
self._results.timings["total"] = now - self._started self._results.timings["total"] = now - self._started
self._results.timings["segments"] = now - self._started_fetching self._results.timings["segments"] = now - self._started_fetching
self._output.close()
if self._vup.crypttext_hash: if self._vup.crypttext_hash:
_assert(self._vup.crypttext_hash == self._output.crypttext_hash, _assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(),
"bad crypttext_hash: computed=%s, expected=%s" % "bad crypttext_hash: computed=%s, expected=%s" %
(base32.b2a(self._output.crypttext_hash), (base32.b2a(self._ciphertext_hasher.digest()),
base32.b2a(self._vup.crypttext_hash))) base32.b2a(self._vup.crypttext_hash)))
_assert(self._output.length == self._uri.size, _assert(self._bytes_done == self._uri.size, self._bytes_done, self._uri.size)
got=self._output.length, expected=self._uri.size) self._status.set_progress(1)
return self._output.finish() self._downloadable.close()
return self._downloadable.finish()
def get_download_status(self): def get_download_status(self):
return self._status return self._status
@ -1200,7 +1136,9 @@ class Downloader(service.MultiService):
# include LIT files # include LIT files
self.stats_provider.count('downloader.files_downloaded', 1) self.stats_provider.count('downloader.files_downloaded', 1)
self.stats_provider.count('downloader.bytes_downloaded', u.get_size()) self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
dl = FileDownloader(self.parent, u, t)
target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id)
dl = FileDownloader(self.parent, u.get_verify_cap(), target)
self._add_download(dl) self._add_download(dl)
d = dl.start() d = dl.start()
return d return d

View File

@ -9,6 +9,7 @@ from foolscap.eventual import eventually
from allmydata.interfaces import IFileNode, IFileURI, ICheckable, \ from allmydata.interfaces import IFileNode, IFileURI, ICheckable, \
IDownloadTarget IDownloadTarget
from allmydata.util import log, base32 from allmydata.util import log, base32
from allmydata.util.assertutil import precondition
from allmydata import uri as urimodule from allmydata import uri as urimodule
from allmydata.immutable.checker import Checker from allmydata.immutable.checker import Checker
from allmydata.check_results import CheckAndRepairResults from allmydata.check_results import CheckAndRepairResults
@ -19,6 +20,7 @@ class _ImmutableFileNodeBase(object):
implements(IFileNode, ICheckable) implements(IFileNode, ICheckable)
def __init__(self, uri, client): def __init__(self, uri, client):
precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
self.u = IFileURI(uri) self.u = IFileURI(uri)
self._client = client self._client = client
@ -172,7 +174,7 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
def __init__(self, uri, client, cachefile): def __init__(self, uri, client, cachefile):
_ImmutableFileNodeBase.__init__(self, uri, client) _ImmutableFileNodeBase.__init__(self, uri, client)
self.download_cache = DownloadCache(self, cachefile) self.download_cache = DownloadCache(self, cachefile)
prefix = urimodule.from_string(uri).get_verify_cap().to_string() prefix = uri.get_verify_cap().to_string()
log.PrefixingLogMixin.__init__(self, "allmydata.immutable.filenode", prefix=prefix) log.PrefixingLogMixin.__init__(self, "allmydata.immutable.filenode", prefix=prefix)
self.log("starting", level=log.OPERATIONAL) self.log("starting", level=log.OPERATIONAL)
@ -250,6 +252,7 @@ class LiteralProducer:
class LiteralFileNode(_ImmutableFileNodeBase): class LiteralFileNode(_ImmutableFileNodeBase):
def __init__(self, uri, client): def __init__(self, uri, client):
precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
_ImmutableFileNodeBase.__init__(self, uri, client) _ImmutableFileNodeBase.__init__(self, uri, client)
def get_uri(self): def get_uri(self):

View File

@ -493,7 +493,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
client = FakeClient() client = FakeClient()
if not target: if not target:
target = download.Data() target = download.Data()
fd = download.FileDownloader(client, u, target) target = download.DecryptingTarget(target, u.key)
fd = download.FileDownloader(client, u.get_verify_cap(), target)
# we manually cycle the FileDownloader through a number of steps that # we manually cycle the FileDownloader through a number of steps that
# would normally be sequenced by a Deferred chain in # would normally be sequenced by a Deferred chain in

View File

@ -27,8 +27,8 @@ class Node(unittest.TestCase):
size=1000) size=1000)
c = FakeClient() c = FakeClient()
cf = cachedir.CacheFile("none") cf = cachedir.CacheFile("none")
fn1 = filenode.FileNode(u.to_string(), c, cf) fn1 = filenode.FileNode(u, c, cf)
fn2 = filenode.FileNode(u.to_string(), c, cf) fn2 = filenode.FileNode(u, c, cf)
self.failUnlessEqual(fn1, fn2) self.failUnlessEqual(fn1, fn2)
self.failIfEqual(fn1, "I am not a filenode") self.failIfEqual(fn1, "I am not a filenode")
self.failIfEqual(fn1, NotANode()) self.failIfEqual(fn1, NotANode())
@ -49,7 +49,7 @@ class Node(unittest.TestCase):
u = uri.LiteralFileURI(data=DATA) u = uri.LiteralFileURI(data=DATA)
c = None c = None
fn1 = filenode.LiteralFileNode(u, c) fn1 = filenode.LiteralFileNode(u, c)
fn2 = filenode.LiteralFileNode(u.to_string(), c) fn2 = filenode.LiteralFileNode(u, c)
self.failUnlessEqual(fn1, fn2) self.failUnlessEqual(fn1, fn2)
self.failIfEqual(fn1, "I am not a filenode") self.failIfEqual(fn1, "I am not a filenode")
self.failIfEqual(fn1, NotANode()) self.failIfEqual(fn1, NotANode())