mutable/servermap: improve test coverage

This commit is contained in:
Brian Warner 2008-04-22 16:47:52 -07:00
parent 81ab1ec785
commit a7632a345e
2 changed files with 116 additions and 24 deletions

View File

@ -508,7 +508,7 @@ class ServermapUpdater:
for shnum,datav in datavs.items():
data = datav[0]
try:
verinfo = self._got_results_one_share(shnum, data, peerid)
verinfo = self._got_results_one_share(shnum, data, peerid, lp)
last_verinfo = verinfo
last_shnum = shnum
self._node._cache.add(verinfo, shnum, 0, data, now)
@ -527,6 +527,8 @@ class ServermapUpdater:
if self._need_privkey and last_verinfo:
# send them a request for the privkey. We send one request per
# server.
lp2 = self.log("sending privkey request",
parent=lp, level=log.NOISY)
(seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
offsets_tuple) = last_verinfo
o = dict(offsets_tuple)
@ -538,8 +540,8 @@ class ServermapUpdater:
d = self._do_read(ss, peerid, self._storage_index,
[last_shnum], readv)
d.addCallback(self._got_privkey_results, peerid, last_shnum,
privkey_started)
d.addErrback(self._privkey_query_failed, peerid, last_shnum)
privkey_started, lp2)
d.addErrback(self._privkey_query_failed, peerid, last_shnum, lp2)
d.addErrback(log.err)
d.addCallback(self._check_for_done)
d.addErrback(self._fatal_error)
@ -547,10 +549,11 @@ class ServermapUpdater:
# all done!
self.log("_got_results done", parent=lp)
def _got_results_one_share(self, shnum, data, peerid):
lp = self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
def _got_results_one_share(self, shnum, data, peerid, lp):
self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
shnum=shnum,
peerid=idlib.shortnodeid_b2a(peerid))
peerid=idlib.shortnodeid_b2a(peerid),
parent=lp)
# this might raise NeedMoreDataError, if the pubkey and signature
# live at some weird offset. That shouldn't happen, so I'm going to
@ -567,7 +570,7 @@ class ServermapUpdater:
self._node._populate_pubkey(self._deserialize_pubkey(pubkey_s))
if self._need_privkey:
self._try_to_extract_privkey(data, peerid, shnum)
self._try_to_extract_privkey(data, peerid, shnum, lp)
(ig_version, ig_seqnum, ig_root_hash, ig_IV, ig_k, ig_N,
ig_segsize, ig_datalen, offsets) = unpack_header(data)
@ -610,7 +613,7 @@ class ServermapUpdater:
verifier = rsa.create_verifying_key_from_string(pubkey_s)
return verifier
def _try_to_extract_privkey(self, data, peerid, shnum):
def _try_to_extract_privkey(self, data, peerid, shnum, lp):
try:
r = unpack_share(data)
except NeedMoreDataError, e:
@ -620,7 +623,8 @@ class ServermapUpdater:
self.log("shnum %d on peerid %s: share was too short (%dB) "
"to get the encprivkey; [%d:%d] ought to hold it" %
(shnum, idlib.shortnodeid_b2a(peerid), len(data),
offset, offset+length))
offset, offset+length),
parent=lp)
# NOTE: if uncoordinated writes are taking place, someone might
# change the share (and most probably move the encprivkey) before
# we get a chance to do one of these reads and fetch it. This
@ -636,20 +640,22 @@ class ServermapUpdater:
pubkey, signature, share_hash_chain, block_hash_tree,
share_data, enc_privkey) = r
return self._try_to_validate_privkey(enc_privkey, peerid, shnum)
return self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
def _try_to_validate_privkey(self, enc_privkey, peerid, shnum):
def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != self._node.get_writekey():
self.log("invalid privkey from %s shnum %d" %
(idlib.nodeid_b2a(peerid)[:8], shnum), level=log.WEIRD)
(idlib.nodeid_b2a(peerid)[:8], shnum),
parent=lp, level=log.WEIRD)
return
# it's good
self.log("got valid privkey from shnum %d on peerid %s" %
(shnum, idlib.shortnodeid_b2a(peerid)))
(shnum, idlib.shortnodeid_b2a(peerid)),
parent=lp)
privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
self._node._populate_encprivkey(enc_privkey)
self._node._populate_privkey(privkey)
@ -669,7 +675,7 @@ class ServermapUpdater:
self._queries_completed += 1
self._last_failure = f
def _got_privkey_results(self, datavs, peerid, shnum, started):
def _got_privkey_results(self, datavs, peerid, shnum, started, lp):
now = time.time()
elapsed = now - started
self._status.add_per_server_time(peerid, "privkey", started, elapsed)
@ -681,12 +687,12 @@ class ServermapUpdater:
return
datav = datavs[shnum]
enc_privkey = datav[0]
self._try_to_validate_privkey(enc_privkey, peerid, shnum)
self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
def _privkey_query_failed(self, f, peerid, shnum):
def _privkey_query_failed(self, f, peerid, shnum, lp):
self._queries_outstanding.discard(peerid)
self.log("error during privkey query: %s %s" % (f, f.value),
level=log.WEIRD)
parent=lp, level=log.WEIRD)
if not self._running:
return
self._queries_outstanding.discard(peerid)
@ -702,12 +708,14 @@ class ServermapUpdater:
lp = self.log(format=("_check_for_done, mode is '%(mode)s', "
"%(outstanding)d queries outstanding, "
"%(extra)d extra peers available, "
"%(must)d 'must query' peers left"
"%(must)d 'must query' peers left, "
"need_privkey=%(need_privkey)s"
),
mode=self.mode,
outstanding=len(self._queries_outstanding),
extra=len(self.extra_peers),
must=len(self._must_query),
need_privkey=self._need_privkey,
level=log.NOISY,
)

View File

@ -3,6 +3,7 @@ import os, struct
from cStringIO import StringIO
from twisted.trial import unittest
from twisted.internet import defer, reactor
from twisted.python import failure
from allmydata import uri, download, storage
from allmydata.util import base32, testutil, idlib
from allmydata.util.idlib import shortnodeid_b2a
@ -54,18 +55,29 @@ class FakeStorage:
# order).
self._sequence = None
self._pending = {}
self._pending_timer = None
self._special_answers = {}
def read(self, peerid, storage_index):
shares = self._peers.get(peerid, {})
if self._special_answers.get(peerid, []):
mode = self._special_answers[peerid].pop(0)
if mode == "fail":
shares = failure.Failure(IntentionalError())
elif mode == "none":
shares = {}
elif mode == "normal":
pass
if self._sequence is None:
return defer.succeed(shares)
d = defer.Deferred()
if not self._pending:
reactor.callLater(1.0, self._fire_readers)
self._pending_timer = reactor.callLater(1.0, self._fire_readers)
self._pending[peerid] = (d, shares)
return d
def _fire_readers(self):
self._pending_timer = None
pending = self._pending
self._pending = {}
extra = []
@ -654,7 +666,7 @@ class Servermap(unittest.TestCase):
d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
# create a new file, which is large enough to knock the privkey out
# of the early part of the fil
# of the early part of the file
LARGE = "These are Larger contents" * 200 # about 5KB
d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
def _created(large_fn):
@ -1342,6 +1354,7 @@ class LocalWrapper:
def __init__(self, original):
self.original = original
self.broken = False
self.post_call_notifier = None
def callRemote(self, methname, *args, **kwargs):
def _call():
if self.broken:
@ -1350,6 +1363,8 @@ class LocalWrapper:
return meth(*args, **kwargs)
d = fireEventually()
d.addCallback(lambda res: _call())
if self.post_call_notifier:
d.addCallback(self.post_call_notifier, methname)
return d
class LessFakeClient(FakeClient):
@ -1469,3 +1484,72 @@ class Problems(unittest.TestCase, testutil.ShouldFailMixin):
d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
return d
def test_privkey_query_error(self):
# when a servermap is updated with MODE_WRITE, it tries to get the
# privkey. Something might go wrong during this query attempt.
self.client = FakeClient(20)
# we need some contents that are large enough to push the privkey out
# of the early part of the file
LARGE = "These are Larger contents" * 200 # about 5KB
d = self.client.create_mutable_file(LARGE)
def _created(n):
self.uri = n.get_uri()
self.n2 = self.client.create_node_from_uri(self.uri)
# we start by doing a map update to figure out which is the first
# server.
return n.get_servermap(MODE_WRITE)
d.addCallback(_created)
d.addCallback(lambda res: fireEventually(res))
def _got_smap1(smap):
peer0 = list(smap.make_sharemap()[0])[0]
# we tell the server to respond to this peer first, so that it
# will be asked for the privkey first
self.client._storage._sequence = [peer0]
# now we make the peer fail their second query
self.client._storage._special_answers[peer0] = ["normal", "fail"]
d.addCallback(_got_smap1)
# now we update a servermap from a new node (which doesn't have the
# privkey yet, forcing it to use a separate privkey query). Each
# query response will trigger a privkey query, and since we're using
# _sequence to make the peer0 response come back first, we'll send it
# a privkey query first, and _sequence will again ensure that the
# peer0 query will also come back before the others, and then
# _special_answers will make sure that the query raises an exception.
# The whole point of these hijinks is to exercise the code in
# _privkey_query_failed. Note that the map-update will succeed, since
# we'll just get a copy from one of the other shares.
d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
# Using FakeStorage._sequence means there will be read requests still
# floating around.. wait for them to retire
def _cancel_timer(res):
if self.client._storage._pending_timer:
self.client._storage._pending_timer.cancel()
return res
d.addBoth(_cancel_timer)
return d
def test_privkey_query_missing(self):
# like test_privkey_query_error, but the shares are deleted by the
# second query, instead of raising an exception.
self.client = FakeClient(20)
LARGE = "These are Larger contents" * 200 # about 5KB
d = self.client.create_mutable_file(LARGE)
def _created(n):
self.uri = n.get_uri()
self.n2 = self.client.create_node_from_uri(self.uri)
return n.get_servermap(MODE_WRITE)
d.addCallback(_created)
d.addCallback(lambda res: fireEventually(res))
def _got_smap1(smap):
peer0 = list(smap.make_sharemap()[0])[0]
self.client._storage._sequence = [peer0]
self.client._storage._special_answers[peer0] = ["normal", "none"]
d.addCallback(_got_smap1)
d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
def _cancel_timer(res):
if self.client._storage._pending_timer:
self.client._storage._pending_timer.cancel()
return res
d.addBoth(_cancel_timer)
return d