immutable: refactor download to do only download-and-decode, not decryption

FileDownloader takes a verify cap and produces ciphertext, instead of taking a read cap and producing plaintext.
FileDownloader does all integrity checking including the mandatory ciphertext hash tree and the optional ciphertext flat hash, rather than expecting its target to do some of that checking.
Rename immutable.download.Output to immutable.download.DecryptingOutput. An instance of DecryptingOutput can be passed to FileDownloader to use as the latter's target.  Text pushed to the DecryptingOutput is decrypted and then pushed to *its* target.
DecryptingOutput satisfies the IConsumer interface, and if its target also satisfies IConsumer, then it forwards and pause/unpause signals to its producer (which is the FileDownloader).
This patch also changes some logging code to use the new logging mixin class.
Check integrity of a segment and decrypt the segment one block-sized buffer at a time instead of copying the buffers together into one segment-sized buffer (reduces peak memory usage, I think, and is probably a tad faster/less CPU, depending on your encoding parameters).
Refactor FileDownloader so that processing of segments and of tail-segment share as much code is possible.
FileDownloader and FileNode take caps as instances of URI (Python objects), not as strings.
This commit is contained in:
Zooko O'Whielacronx 2009-01-08 11:53:49 -07:00
parent 9bba578776
commit 600196f571
5 changed files with 103 additions and 161 deletions

View File

@ -361,7 +361,7 @@ class Client(node.Node, pollmixin.PollMixin):
else:
key = base32.b2a(u.storage_index)
cachefile = self.download_cache.get_file(key)
node = FileNode(u.to_string(), self, cachefile) # CHK
node = FileNode(u, self, cachefile) # CHK
else:
assert IMutableFileURI.providedBy(u), u
node = MutableFileNode(self).init_from_uri(u)

View File

@ -42,75 +42,26 @@ class DownloadResults:
self.timings = {}
self.file_size = None
class Output:
def __init__(self, downloadable, key, total_length, log_parent,
download_status):
class DecryptingTarget(log.PrefixingLogMixin):
implements(IDownloadTarget, IConsumer)
def __init__(self, downloadable, key, _log_msg_id=None):
precondition(IDownloadTarget.providedBy(downloadable), downloadable)
self.downloadable = downloadable
self._decryptor = AES(key)
self._crypttext_hasher = hashutil.crypttext_hasher()
self.length = 0
self.total_length = total_length
self._segment_number = 0
self._crypttext_hash_tree = None
self._opened = False
self._log_parent = log_parent
self._status = download_status
self._status.set_progress(0.0)
def log(self, *args, **kwargs):
if "parent" not in kwargs:
kwargs["parent"] = self._log_parent
if "facility" not in kwargs:
kwargs["facility"] = "download.output"
return log.msg(*args, **kwargs)
def got_crypttext_hash_tree(self, crypttext_hash_tree):
self._crypttext_hash_tree = crypttext_hash_tree
def write_segment(self, crypttext):
self.length += len(crypttext)
self._status.set_progress( float(self.length) / self.total_length )
# memory footprint: 'crypttext' is the only segment_size usage
# outstanding. While we decrypt it into 'plaintext', we hit
# 2*segment_size.
self._crypttext_hasher.update(crypttext)
if self._crypttext_hash_tree:
ch = hashutil.crypttext_segment_hasher()
ch.update(crypttext)
crypttext_leaves = {self._segment_number: ch.digest()}
self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
bytes=len(crypttext),
segnum=self._segment_number, hash=base32.b2a(ch.digest()),
level=log.NOISY)
self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
plaintext = self._decryptor.process(crypttext)
del crypttext
# now we're back down to 1*segment_size.
self._segment_number += 1
# We're still at 1*segment_size. The Downloadable is responsible for
# any memory usage beyond this.
if not self._opened:
self._opened = True
self.downloadable.open(self.total_length)
prefix = str(downloadable)
log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix)
def registerProducer(self, producer, streaming):
if IConsumer.providedBy(self.downloadable):
self.downloadable.registerProducer(producer, streaming)
def unregisterProducer(self):
if IConsumer.providedBy(self.downloadable):
self.downloadable.unregisterProducer()
def write(self, ciphertext):
plaintext = self._decryptor.process(ciphertext)
self.downloadable.write(plaintext)
def fail(self, why):
# this is really unusual, and deserves maximum forensics
if why.check(DownloadStopped):
# except DownloadStopped just means the consumer aborted the
# download, not so scary
self.log("download stopped", level=log.UNUSUAL)
else:
self.log("download failed!", failure=why,
level=log.SCARY, umid="lp1vaQ")
self.downloadable.fail(why)
def open(self, size):
self.downloadable.open(size)
def close(self):
self.crypttext_hash = self._crypttext_hasher.digest()
self.log("download finished, closing IDownloadable", level=log.NOISY)
self.downloadable.close()
def finish(self):
return self.downloadable.finish()
@ -653,11 +604,14 @@ class DownloadStatus:
self.results = value
class FileDownloader(log.PrefixingLogMixin):
""" I download shares, check their integrity, then decode them, check the integrity of the
resulting ciphertext, then and write it to my target. """
implements(IPushProducer)
_status = None
def __init__(self, client, u, downloadable):
precondition(isinstance(u, uri.CHKFileURI), u)
precondition(IVerifierURI.providedBy(u), u)
precondition(IDownloadTarget.providedBy(downloadable), downloadable)
prefix=base32.b2a_l(u.get_storage_index()[:8], 60)
log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
@ -691,8 +645,7 @@ class FileDownloader(log.PrefixingLogMixin):
if IConsumer.providedBy(downloadable):
downloadable.registerProducer(self, True)
self._downloadable = downloadable
self._output = Output(downloadable, u.key, self._uri.size, self._parentmsgid,
self._status)
self._opened = False
self.active_buckets = {} # k: shnum, v: bucket
self._share_buckets = [] # list of (sharenum, bucket) tuples
@ -700,8 +653,15 @@ class FileDownloader(log.PrefixingLogMixin):
self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
self._share_hash_tree = None
self._crypttext_hash_tree = None
self._ciphertext_hasher = hashutil.crypttext_hasher()
self._bytes_done = 0
self._status.set_progress(float(self._bytes_done)/self._uri.size)
# _got_uri_extension() will create the following:
# self._crypttext_hash_tree
# self._share_hash_tree
# self._current_segnum = 0
def pauseProducing(self):
if self._paused:
@ -730,7 +690,6 @@ class FileDownloader(log.PrefixingLogMixin):
self._status.set_active(False)
def start(self):
assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
self.log("starting download")
# first step: who should we download from?
@ -754,7 +713,12 @@ class FileDownloader(log.PrefixingLogMixin):
if self._status:
self._status.set_status("Failed")
self._status.set_active(False)
self._output.fail(why)
if why.check(DownloadStopped):
# DownloadStopped just means the consumer aborted the download; not so scary.
self.log("download stopped", level=log.UNUSUAL)
else:
# This is really unusual, and deserves maximum forensics.
self.log("download failed!", failure=why, level=log.SCARY, umid="lp1vaQ")
return why
d.addErrback(_failed)
d.addCallback(self._done)
@ -818,7 +782,6 @@ class FileDownloader(log.PrefixingLogMixin):
del self._share_vbuckets[shnum]
def _got_all_shareholders(self, res):
assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
if self._results:
now = time.time()
self._results.timings["peer_selection"] = now - self._started
@ -832,7 +795,6 @@ class FileDownloader(log.PrefixingLogMixin):
# "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
def _obtain_uri_extension(self, ignored):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
# all shareholders are supposed to have a copy of uri_extension, and
# all are supposed to be identical. We compute the hash of the data
# that comes back, and compare it against the version in our URI. If
@ -844,7 +806,7 @@ class FileDownloader(log.PrefixingLogMixin):
vups = []
for sharenum, bucket in self._share_buckets:
vups.append(ValidatedExtendedURIProxy(bucket, self._uri.get_verify_cap(), self._fetch_failures))
vups.append(ValidatedExtendedURIProxy(bucket, self._uri, self._fetch_failures))
vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid)
d = vto.start()
@ -886,7 +848,6 @@ class FileDownloader(log.PrefixingLogMixin):
def _got_crypttext_hash_tree(res):
# Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated
# with hashes.
self._output.got_crypttext_hash_tree(self._crypttext_hash_tree)
if self._results:
elapsed = time.time() - _get_crypttext_hash_tree_started
self._results.timings["hashtrees"] = elapsed
@ -896,7 +857,6 @@ class FileDownloader(log.PrefixingLogMixin):
def _activate_enough_buckets(self):
"""either return a mapping from shnum to a ValidatedReadBucketProxy that can
provide data for that share, or raise NotEnoughSharesError"""
assert isinstance(self._uri, uri.CHKFileURI), self._uri
while len(self.active_buckets) < self._uri.needed_shares:
# need some more
@ -934,12 +894,11 @@ class FileDownloader(log.PrefixingLogMixin):
self._started_fetching = time.time()
d = defer.succeed(None)
for segnum in range(self._vup.num_segments-1):
for segnum in range(self._vup.num_segments):
d.addCallback(self._download_segment, segnum)
# this pause, at the end of write, prevents pre-fetch from
# happening until the consumer is ready for more data.
d.addCallback(self._check_for_pause)
d.addCallback(self._download_tail_segment, self._vup.num_segments-1)
return d
def _check_for_pause(self, res):
@ -952,7 +911,6 @@ class FileDownloader(log.PrefixingLogMixin):
return res
def _download_segment(self, res, segnum):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
if self._status:
self._status.set_status("Downloading segment %d of %d" %
(segnum+1, self._vup.num_segments))
@ -979,8 +937,11 @@ class FileDownloader(log.PrefixingLogMixin):
return res
if self._results:
d.addCallback(_started_decode)
d.addCallback(lambda (shares, shareids):
self._codec.decode(shares, shareids))
if segnum + 1 == self._vup.num_segments:
codec = self._tail_codec
else:
codec = self._codec
d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids))
# once the codec is done, we drop back to 1*segment_size, because
# 'shares' goes out of scope. The memory usage is all in the
# plaintext now, spread out into a bunch of tiny buffers.
@ -993,91 +954,66 @@ class FileDownloader(log.PrefixingLogMixin):
# pause/check-for-stop just before writing, to honor stopProducing
d.addCallback(self._check_for_pause)
def _done(buffers):
# we start by joining all these buffers together into a single
# string. This makes Output.write easier, since it wants to hash
# data one segment at a time anyways, and doesn't impact our
# memory footprint since we're already peaking at 2*segment_size
# inside the codec a moment ago.
segment = "".join(buffers)
del buffers
# we're down to 1*segment_size right now, but write_segment()
# will decrypt a copy of the segment internally, which will push
# us up to 2*segment_size while it runs.
started_decrypt = time.time()
self._output.write_segment(segment)
if self._results:
elapsed = time.time() - started_decrypt
self._results.timings["cumulative_decrypt"] += elapsed
d.addCallback(_done)
d.addCallback(self._got_segment)
return d
def _download_tail_segment(self, res, segnum):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
self.log("downloading seg#%d of %d (%d%%)"
% (segnum, self._vup.num_segments,
100.0 * segnum / self._vup.num_segments))
segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares,
self._results)
started = time.time()
d = segmentdler.start()
def _finished_fetching(res):
elapsed = time.time() - started
self._results.timings["cumulative_fetch"] += elapsed
return res
if self._results:
d.addCallback(_finished_fetching)
# pause before using more memory
d.addCallback(self._check_for_pause)
def _started_decode(res):
self._started_decode = time.time()
return res
if self._results:
d.addCallback(_started_decode)
d.addCallback(lambda (shares, shareids):
self._tail_codec.decode(shares, shareids))
def _finished_decode(res):
elapsed = time.time() - self._started_decode
self._results.timings["cumulative_decode"] += elapsed
return res
if self._results:
d.addCallback(_finished_decode)
# pause/check-for-stop just before writing, to honor stopProducing
d.addCallback(self._check_for_pause)
def _done(buffers):
# trim off any padding added by the upload side
segment = "".join(buffers)
del buffers
# we never send empty segments. If the data was an exact multiple
# of the segment size, the last segment will be full.
pad_size = mathutil.pad_size(self._uri.size, self._vup.segment_size)
tail_size = self._vup.segment_size - pad_size
segment = segment[:tail_size]
def _got_segment(self, buffers):
precondition(self._crypttext_hash_tree)
started_decrypt = time.time()
self._output.write_segment(segment)
self._status.set_progress(float(self._current_segnum)/self._uri.size)
if self._current_segnum + 1 == self._vup.num_segments:
# This is the last segment.
# Trim off any padding added by the upload side. We never send empty segments. If
# the data was an exact multiple of the segment size, the last segment will be full.
tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._uri.needed_shares)
num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size)
# Remove buffers which don't contain any part of the tail.
del buffers[num_buffers_used:]
# Remove the past-the-tail-part of the last buffer.
tail_in_last_buf = self._vup.tail_data_size % tail_buf_size
if tail_in_last_buf == 0:
tail_in_last_buf = tail_buf_size
buffers[-1] = buffers[-1][:tail_in_last_buf]
# First compute the hash of this segment and check that it fits.
ch = hashutil.crypttext_segment_hasher()
for buffer in buffers:
self._ciphertext_hasher.update(buffer)
ch.update(buffer)
self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()})
# Then write this segment to the target.
if not self._opened:
self._opened = True
self._downloadable.open(self._uri.size)
for buffer in buffers:
self._downloadable.write(buffer)
self._bytes_done += len(buffer)
self._status.set_progress(float(self._bytes_done)/self._uri.size)
self._current_segnum += 1
if self._results:
elapsed = time.time() - started_decrypt
self._results.timings["cumulative_decrypt"] += elapsed
d.addCallback(_done)
return d
def _done(self, res):
assert isinstance(self._uri, uri.CHKFileURI), self._uri
self.log("download done")
if self._results:
now = time.time()
self._results.timings["total"] = now - self._started
self._results.timings["segments"] = now - self._started_fetching
self._output.close()
if self._vup.crypttext_hash:
_assert(self._vup.crypttext_hash == self._output.crypttext_hash,
_assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(),
"bad crypttext_hash: computed=%s, expected=%s" %
(base32.b2a(self._output.crypttext_hash),
(base32.b2a(self._ciphertext_hasher.digest()),
base32.b2a(self._vup.crypttext_hash)))
_assert(self._output.length == self._uri.size,
got=self._output.length, expected=self._uri.size)
return self._output.finish()
_assert(self._bytes_done == self._uri.size, self._bytes_done, self._uri.size)
self._status.set_progress(1)
self._downloadable.close()
return self._downloadable.finish()
def get_download_status(self):
return self._status
@ -1200,7 +1136,9 @@ class Downloader(service.MultiService):
# include LIT files
self.stats_provider.count('downloader.files_downloaded', 1)
self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
dl = FileDownloader(self.parent, u, t)
target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id)
dl = FileDownloader(self.parent, u.get_verify_cap(), target)
self._add_download(dl)
d = dl.start()
return d

View File

@ -9,6 +9,7 @@ from foolscap.eventual import eventually
from allmydata.interfaces import IFileNode, IFileURI, ICheckable, \
IDownloadTarget
from allmydata.util import log, base32
from allmydata.util.assertutil import precondition
from allmydata import uri as urimodule
from allmydata.immutable.checker import Checker
from allmydata.check_results import CheckAndRepairResults
@ -19,6 +20,7 @@ class _ImmutableFileNodeBase(object):
implements(IFileNode, ICheckable)
def __init__(self, uri, client):
precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
self.u = IFileURI(uri)
self._client = client
@ -172,7 +174,7 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
def __init__(self, uri, client, cachefile):
_ImmutableFileNodeBase.__init__(self, uri, client)
self.download_cache = DownloadCache(self, cachefile)
prefix = urimodule.from_string(uri).get_verify_cap().to_string()
prefix = uri.get_verify_cap().to_string()
log.PrefixingLogMixin.__init__(self, "allmydata.immutable.filenode", prefix=prefix)
self.log("starting", level=log.OPERATIONAL)
@ -250,6 +252,7 @@ class LiteralProducer:
class LiteralFileNode(_ImmutableFileNodeBase):
def __init__(self, uri, client):
precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
_ImmutableFileNodeBase.__init__(self, uri, client)
def get_uri(self):

View File

@ -493,7 +493,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
client = FakeClient()
if not target:
target = download.Data()
fd = download.FileDownloader(client, u, target)
target = download.DecryptingTarget(target, u.key)
fd = download.FileDownloader(client, u.get_verify_cap(), target)
# we manually cycle the FileDownloader through a number of steps that
# would normally be sequenced by a Deferred chain in

View File

@ -27,8 +27,8 @@ class Node(unittest.TestCase):
size=1000)
c = FakeClient()
cf = cachedir.CacheFile("none")
fn1 = filenode.FileNode(u.to_string(), c, cf)
fn2 = filenode.FileNode(u.to_string(), c, cf)
fn1 = filenode.FileNode(u, c, cf)
fn2 = filenode.FileNode(u, c, cf)
self.failUnlessEqual(fn1, fn2)
self.failIfEqual(fn1, "I am not a filenode")
self.failIfEqual(fn1, NotANode())
@ -49,7 +49,7 @@ class Node(unittest.TestCase):
u = uri.LiteralFileURI(data=DATA)
c = None
fn1 = filenode.LiteralFileNode(u, c)
fn2 = filenode.LiteralFileNode(u.to_string(), c)
fn2 = filenode.LiteralFileNode(u, c)
self.failUnlessEqual(fn1, fn2)
self.failIfEqual(fn1, "I am not a filenode")
self.failIfEqual(fn1, NotANode())