mutable/retrieve.py: clean up control flow to avoid dropping errors

* replace DeferredList with gatherResults, simplify result handling
* use BadShareError to signal recoverable problems in either fetch or
  validate, catch after _validate_block
* _validate_block is thus not responsible for noticing fetch problems
* rename _validation_or_decoding_failed() to _handle_bad_share()
* _get_needed_hashes() returns two Deferreds, instead of a hard-to-unpack
  DeferredList
This commit is contained in:
Brian Warner 2012-01-07 18:12:51 -08:00
parent c56839478e
commit 893eea849b
1 changed files with 34 additions and 39 deletions

View File

@ -5,7 +5,8 @@ from zope.interface import implements
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.internet.interfaces import IPushProducer, IConsumer from twisted.internet.interfaces import IPushProducer, IConsumer
from foolscap.api import eventually, fireEventually from foolscap.api import eventually, fireEventually, DeadReferenceError, \
RemoteException
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
DownloadStopped, MDMF_VERSION, SDMF_VERSION DownloadStopped, MDMF_VERSION, SDMF_VERSION
from allmydata.util import hashutil, log, mathutil, deferredutil from allmydata.util import hashutil, log, mathutil, deferredutil
@ -15,7 +16,8 @@ from allmydata.storage.server import si_b2a
from pycryptopp.cipher.aes import AES from pycryptopp.cipher.aes import AES
from pycryptopp.publickey import rsa from pycryptopp.publickey import rsa
from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError from allmydata.mutable.common import CorruptShareError, BadShareError, \
UncoordinatedWriteError
from allmydata.mutable.layout import MDMFSlotReadProxy from allmydata.mutable.layout import MDMFSlotReadProxy
class RetrieveStatus: class RetrieveStatus:
@ -591,7 +593,7 @@ class Retrieve:
self._bad_shares.add((server, shnum, f)) self._bad_shares.add((server, shnum, f))
self._status.add_problem(server, f) self._status.add_problem(server, f)
self._last_failure = f self._last_failure = f
if f.check(CorruptShareError): if f.check(BadShareError):
self.notify_server_corruption(server, shnum, str(f.value)) self.notify_server_corruption(server, shnum, str(f.value))
@ -631,16 +633,15 @@ class Retrieve:
ds = [] ds = []
for reader in self._active_readers: for reader in self._active_readers:
started = time.time() started = time.time()
d = reader.get_block_and_salt(segnum) d1 = reader.get_block_and_salt(segnum)
d2 = self._get_needed_hashes(reader, segnum) d2,d3 = self._get_needed_hashes(reader, segnum)
dl = defer.DeferredList([d, d2], consumeErrors=True) d = deferredutil.gatherResults([d1,d2,d3])
dl.addCallback(self._validate_block, segnum, reader, reader.server, started) d.addCallback(self._validate_block, segnum, reader, reader.server, started)
dl.addErrback(self._validation_or_decoding_failed, [reader]) # _handle_bad_share takes care of recoverable errors (by dropping
ds.append(dl) # that share and returning None). Any other errors (i.e. code
# _validation_or_decoding_failed is supposed to eat any recoverable # bugs) are passed through and cause the retrieve to fail.
# errors (like corrupt shares), returning a None when that happens. d.addErrback(self._handle_bad_share, [reader])
# If it raises an exception itself, or if it can't handle the error, ds.append(d)
# the download should fail. So we can use gatherResults() here.
dl = deferredutil.gatherResults(ds) dl = deferredutil.gatherResults(ds)
if self._verify: if self._verify:
dl.addCallback(lambda ignored: "") dl.addCallback(lambda ignored: "")
@ -672,8 +673,6 @@ class Retrieve:
self.log("everything looks ok, building segment %d" % segnum) self.log("everything looks ok, building segment %d" % segnum)
d = self._decode_blocks(results, segnum) d = self._decode_blocks(results, segnum)
d.addCallback(self._decrypt_segment) d.addCallback(self._decrypt_segment)
d.addErrback(self._validation_or_decoding_failed,
self._active_readers)
# check to see whether we've been paused before writing # check to see whether we've been paused before writing
# anything. # anything.
d.addCallback(self._check_for_paused) d.addCallback(self._check_for_paused)
@ -724,13 +723,25 @@ class Retrieve:
self._current_segment += 1 self._current_segment += 1
def _validation_or_decoding_failed(self, f, readers): def _handle_bad_share(self, f, readers):
""" """
I am called when a block or a salt fails to correctly validate, or when I am called when a block or a salt fails to correctly validate, or when
the decryption or decoding operation fails for some reason. I react to the decryption or decoding operation fails for some reason. I react to
this failure by notifying the remote server of corruption, and then this failure by notifying the remote server of corruption, and then
removing the remote server from further activity. removing the remote server from further activity.
""" """
# these are the errors we can tolerate: by giving up on this share
# and finding others to replace it. Any other errors (i.e. coding
# bugs) are re-raised, causing the download to fail.
f.trap(DeadReferenceError, RemoteException, BadShareError)
# DeadReferenceError happens when we try to fetch data from a server
# that has gone away. RemoteException happens if the server had an
# internal error. BadShareError encompasses: (UnknownVersionError,
# LayoutInvalid, struct.error) which happen when we get obviously
# wrong data, and CorruptShareError which happens later, when we
# perform integrity checks on the data.
assert isinstance(readers, list) assert isinstance(readers, list)
bad_shnums = [reader.shnum for reader in readers] bad_shnums = [reader.shnum for reader in readers]
@ -739,7 +750,7 @@ class Retrieve:
(bad_shnums, readers, self._current_segment, str(f))) (bad_shnums, readers, self._current_segment, str(f)))
for reader in readers: for reader in readers:
self._mark_bad_share(reader.server, reader.shnum, reader, f) self._mark_bad_share(reader.server, reader.shnum, reader, f)
return return None
def _validate_block(self, results, segnum, reader, server, started): def _validate_block(self, results, segnum, reader, server, started):
@ -753,30 +764,15 @@ class Retrieve:
elapsed = time.time() - started elapsed = time.time() - started
self._status.add_fetch_timing(server, elapsed) self._status.add_fetch_timing(server, elapsed)
self._set_current_status("validating blocks") self._set_current_status("validating blocks")
# Did we fail to fetch either of the things that we were
# supposed to? Fail if so.
if not results[0][0] and results[1][0]:
# handled by the errback handler.
# These all get batched into one query, so the resulting block_and_salt, blockhashes, sharehashes = results
# failure should be the same for all of them, so we can just block, salt = block_and_salt
# use the first one.
assert isinstance(results[0][1], failure.Failure)
f = results[0][1] blockhashes = dict(enumerate(blockhashes))
raise CorruptShareError(server,
reader.shnum,
"Connection error: %s" % str(f))
block_and_salt, block_and_sharehashes = results
block, salt = block_and_salt[1]
blockhashes, sharehashes = block_and_sharehashes[1]
blockhashes = dict(enumerate(blockhashes[1]))
self.log("the reader gave me the following blockhashes: %s" % \ self.log("the reader gave me the following blockhashes: %s" % \
blockhashes.keys()) blockhashes.keys())
self.log("the reader gave me the following sharehashes: %s" % \ self.log("the reader gave me the following sharehashes: %s" % \
sharehashes[1].keys()) sharehashes.keys())
bht = self._block_hash_trees[reader.shnum] bht = self._block_hash_trees[reader.shnum]
if bht.needed_hashes(segnum, include_leaf=True): if bht.needed_hashes(segnum, include_leaf=True):
@ -815,7 +811,7 @@ class Retrieve:
include_leaf=True) or \ include_leaf=True) or \
self._verify: self._verify:
try: try:
self.share_hash_tree.set_hashes(hashes=sharehashes[1], self.share_hash_tree.set_hashes(hashes=sharehashes,
leaves={reader.shnum: bht[0]}) leaves={reader.shnum: bht[0]})
except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \ except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
IndexError), e: IndexError), e:
@ -855,8 +851,7 @@ class Retrieve:
else: else:
d2 = defer.succeed({}) # the logic in the next method d2 = defer.succeed({}) # the logic in the next method
# expects a dict # expects a dict
dl = defer.DeferredList([d1, d2], consumeErrors=True) return d1,d2
return dl
def _decode_blocks(self, results, segnum): def _decode_blocks(self, results, segnum):