mutable/servermap: improve test coverage

This commit is contained in:
Brian Warner 2008-04-22 16:47:52 -07:00
parent 81ab1ec785
commit a7632a345e
2 changed files with 116 additions and 24 deletions

View File

@ -484,9 +484,9 @@ class ServermapUpdater:
def _got_results(self, datavs, peerid, readsize, stuff, started): def _got_results(self, datavs, peerid, readsize, stuff, started):
lp = self.log(format="got result from [%(peerid)s], %(numshares)d shares", lp = self.log(format="got result from [%(peerid)s], %(numshares)d shares",
peerid=idlib.shortnodeid_b2a(peerid), peerid=idlib.shortnodeid_b2a(peerid),
numshares=len(datavs), numshares=len(datavs),
level=log.NOISY) level=log.NOISY)
now = time.time() now = time.time()
elapsed = now - started elapsed = now - started
self._queries_outstanding.discard(peerid) self._queries_outstanding.discard(peerid)
@ -508,7 +508,7 @@ class ServermapUpdater:
for shnum,datav in datavs.items(): for shnum,datav in datavs.items():
data = datav[0] data = datav[0]
try: try:
verinfo = self._got_results_one_share(shnum, data, peerid) verinfo = self._got_results_one_share(shnum, data, peerid, lp)
last_verinfo = verinfo last_verinfo = verinfo
last_shnum = shnum last_shnum = shnum
self._node._cache.add(verinfo, shnum, 0, data, now) self._node._cache.add(verinfo, shnum, 0, data, now)
@ -527,6 +527,8 @@ class ServermapUpdater:
if self._need_privkey and last_verinfo: if self._need_privkey and last_verinfo:
# send them a request for the privkey. We send one request per # send them a request for the privkey. We send one request per
# server. # server.
lp2 = self.log("sending privkey request",
parent=lp, level=log.NOISY)
(seqnum, root_hash, IV, segsize, datalength, k, N, prefix, (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
offsets_tuple) = last_verinfo offsets_tuple) = last_verinfo
o = dict(offsets_tuple) o = dict(offsets_tuple)
@ -538,8 +540,8 @@ class ServermapUpdater:
d = self._do_read(ss, peerid, self._storage_index, d = self._do_read(ss, peerid, self._storage_index,
[last_shnum], readv) [last_shnum], readv)
d.addCallback(self._got_privkey_results, peerid, last_shnum, d.addCallback(self._got_privkey_results, peerid, last_shnum,
privkey_started) privkey_started, lp2)
d.addErrback(self._privkey_query_failed, peerid, last_shnum) d.addErrback(self._privkey_query_failed, peerid, last_shnum, lp2)
d.addErrback(log.err) d.addErrback(log.err)
d.addCallback(self._check_for_done) d.addCallback(self._check_for_done)
d.addErrback(self._fatal_error) d.addErrback(self._fatal_error)
@ -547,10 +549,11 @@ class ServermapUpdater:
# all done! # all done!
self.log("_got_results done", parent=lp) self.log("_got_results done", parent=lp)
def _got_results_one_share(self, shnum, data, peerid): def _got_results_one_share(self, shnum, data, peerid, lp):
lp = self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s", self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
shnum=shnum, shnum=shnum,
peerid=idlib.shortnodeid_b2a(peerid)) peerid=idlib.shortnodeid_b2a(peerid),
parent=lp)
# this might raise NeedMoreDataError, if the pubkey and signature # this might raise NeedMoreDataError, if the pubkey and signature
# live at some weird offset. That shouldn't happen, so I'm going to # live at some weird offset. That shouldn't happen, so I'm going to
@ -567,7 +570,7 @@ class ServermapUpdater:
self._node._populate_pubkey(self._deserialize_pubkey(pubkey_s)) self._node._populate_pubkey(self._deserialize_pubkey(pubkey_s))
if self._need_privkey: if self._need_privkey:
self._try_to_extract_privkey(data, peerid, shnum) self._try_to_extract_privkey(data, peerid, shnum, lp)
(ig_version, ig_seqnum, ig_root_hash, ig_IV, ig_k, ig_N, (ig_version, ig_seqnum, ig_root_hash, ig_IV, ig_k, ig_N,
ig_segsize, ig_datalen, offsets) = unpack_header(data) ig_segsize, ig_datalen, offsets) = unpack_header(data)
@ -610,7 +613,7 @@ class ServermapUpdater:
verifier = rsa.create_verifying_key_from_string(pubkey_s) verifier = rsa.create_verifying_key_from_string(pubkey_s)
return verifier return verifier
def _try_to_extract_privkey(self, data, peerid, shnum): def _try_to_extract_privkey(self, data, peerid, shnum, lp):
try: try:
r = unpack_share(data) r = unpack_share(data)
except NeedMoreDataError, e: except NeedMoreDataError, e:
@ -620,7 +623,8 @@ class ServermapUpdater:
self.log("shnum %d on peerid %s: share was too short (%dB) " self.log("shnum %d on peerid %s: share was too short (%dB) "
"to get the encprivkey; [%d:%d] ought to hold it" % "to get the encprivkey; [%d:%d] ought to hold it" %
(shnum, idlib.shortnodeid_b2a(peerid), len(data), (shnum, idlib.shortnodeid_b2a(peerid), len(data),
offset, offset+length)) offset, offset+length),
parent=lp)
# NOTE: if uncoordinated writes are taking place, someone might # NOTE: if uncoordinated writes are taking place, someone might
# change the share (and most probably move the encprivkey) before # change the share (and most probably move the encprivkey) before
# we get a chance to do one of these reads and fetch it. This # we get a chance to do one of these reads and fetch it. This
@ -636,20 +640,22 @@ class ServermapUpdater:
pubkey, signature, share_hash_chain, block_hash_tree, pubkey, signature, share_hash_chain, block_hash_tree,
share_data, enc_privkey) = r share_data, enc_privkey) = r
return self._try_to_validate_privkey(enc_privkey, peerid, shnum) return self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
def _try_to_validate_privkey(self, enc_privkey, peerid, shnum): def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
alleged_privkey_s = self._node._decrypt_privkey(enc_privkey) alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s) alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != self._node.get_writekey(): if alleged_writekey != self._node.get_writekey():
self.log("invalid privkey from %s shnum %d" % self.log("invalid privkey from %s shnum %d" %
(idlib.nodeid_b2a(peerid)[:8], shnum), level=log.WEIRD) (idlib.nodeid_b2a(peerid)[:8], shnum),
parent=lp, level=log.WEIRD)
return return
# it's good # it's good
self.log("got valid privkey from shnum %d on peerid %s" % self.log("got valid privkey from shnum %d on peerid %s" %
(shnum, idlib.shortnodeid_b2a(peerid))) (shnum, idlib.shortnodeid_b2a(peerid)),
parent=lp)
privkey = rsa.create_signing_key_from_string(alleged_privkey_s) privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
self._node._populate_encprivkey(enc_privkey) self._node._populate_encprivkey(enc_privkey)
self._node._populate_privkey(privkey) self._node._populate_privkey(privkey)
@ -669,7 +675,7 @@ class ServermapUpdater:
self._queries_completed += 1 self._queries_completed += 1
self._last_failure = f self._last_failure = f
def _got_privkey_results(self, datavs, peerid, shnum, started): def _got_privkey_results(self, datavs, peerid, shnum, started, lp):
now = time.time() now = time.time()
elapsed = now - started elapsed = now - started
self._status.add_per_server_time(peerid, "privkey", started, elapsed) self._status.add_per_server_time(peerid, "privkey", started, elapsed)
@ -681,12 +687,12 @@ class ServermapUpdater:
return return
datav = datavs[shnum] datav = datavs[shnum]
enc_privkey = datav[0] enc_privkey = datav[0]
self._try_to_validate_privkey(enc_privkey, peerid, shnum) self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
def _privkey_query_failed(self, f, peerid, shnum): def _privkey_query_failed(self, f, peerid, shnum, lp):
self._queries_outstanding.discard(peerid) self._queries_outstanding.discard(peerid)
self.log("error during privkey query: %s %s" % (f, f.value), self.log("error during privkey query: %s %s" % (f, f.value),
level=log.WEIRD) parent=lp, level=log.WEIRD)
if not self._running: if not self._running:
return return
self._queries_outstanding.discard(peerid) self._queries_outstanding.discard(peerid)
@ -702,12 +708,14 @@ class ServermapUpdater:
lp = self.log(format=("_check_for_done, mode is '%(mode)s', " lp = self.log(format=("_check_for_done, mode is '%(mode)s', "
"%(outstanding)d queries outstanding, " "%(outstanding)d queries outstanding, "
"%(extra)d extra peers available, " "%(extra)d extra peers available, "
"%(must)d 'must query' peers left" "%(must)d 'must query' peers left, "
"need_privkey=%(need_privkey)s"
), ),
mode=self.mode, mode=self.mode,
outstanding=len(self._queries_outstanding), outstanding=len(self._queries_outstanding),
extra=len(self.extra_peers), extra=len(self.extra_peers),
must=len(self._must_query), must=len(self._must_query),
need_privkey=self._need_privkey,
level=log.NOISY, level=log.NOISY,
) )

View File

@ -3,6 +3,7 @@ import os, struct
from cStringIO import StringIO from cStringIO import StringIO
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.python import failure
from allmydata import uri, download, storage from allmydata import uri, download, storage
from allmydata.util import base32, testutil, idlib from allmydata.util import base32, testutil, idlib
from allmydata.util.idlib import shortnodeid_b2a from allmydata.util.idlib import shortnodeid_b2a
@ -54,18 +55,29 @@ class FakeStorage:
# order). # order).
self._sequence = None self._sequence = None
self._pending = {} self._pending = {}
self._pending_timer = None
self._special_answers = {}
def read(self, peerid, storage_index): def read(self, peerid, storage_index):
shares = self._peers.get(peerid, {}) shares = self._peers.get(peerid, {})
if self._special_answers.get(peerid, []):
mode = self._special_answers[peerid].pop(0)
if mode == "fail":
shares = failure.Failure(IntentionalError())
elif mode == "none":
shares = {}
elif mode == "normal":
pass
if self._sequence is None: if self._sequence is None:
return defer.succeed(shares) return defer.succeed(shares)
d = defer.Deferred() d = defer.Deferred()
if not self._pending: if not self._pending:
reactor.callLater(1.0, self._fire_readers) self._pending_timer = reactor.callLater(1.0, self._fire_readers)
self._pending[peerid] = (d, shares) self._pending[peerid] = (d, shares)
return d return d
def _fire_readers(self): def _fire_readers(self):
self._pending_timer = None
pending = self._pending pending = self._pending
self._pending = {} self._pending = {}
extra = [] extra = []
@ -654,7 +666,7 @@ class Servermap(unittest.TestCase):
d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10)) d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
# create a new file, which is large enough to knock the privkey out # create a new file, which is large enough to knock the privkey out
# of the early part of the fil # of the early part of the file
LARGE = "These are Larger contents" * 200 # about 5KB LARGE = "These are Larger contents" * 200 # about 5KB
d.addCallback(lambda res: self._client.create_mutable_file(LARGE)) d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
def _created(large_fn): def _created(large_fn):
@ -1342,6 +1354,7 @@ class LocalWrapper:
def __init__(self, original): def __init__(self, original):
self.original = original self.original = original
self.broken = False self.broken = False
self.post_call_notifier = None
def callRemote(self, methname, *args, **kwargs): def callRemote(self, methname, *args, **kwargs):
def _call(): def _call():
if self.broken: if self.broken:
@ -1350,6 +1363,8 @@ class LocalWrapper:
return meth(*args, **kwargs) return meth(*args, **kwargs)
d = fireEventually() d = fireEventually()
d.addCallback(lambda res: _call()) d.addCallback(lambda res: _call())
if self.post_call_notifier:
d.addCallback(self.post_call_notifier, methname)
return d return d
class LessFakeClient(FakeClient): class LessFakeClient(FakeClient):
@ -1469,3 +1484,72 @@ class Problems(unittest.TestCase, testutil.ShouldFailMixin):
d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2")) d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
return d return d
def test_privkey_query_error(self):
# when a servermap is updated with MODE_WRITE, it tries to get the
# privkey. Something might go wrong during this query attempt.
self.client = FakeClient(20)
# we need some contents that are large enough to push the privkey out
# of the early part of the file
LARGE = "These are Larger contents" * 200 # about 5KB
d = self.client.create_mutable_file(LARGE)
def _created(n):
self.uri = n.get_uri()
self.n2 = self.client.create_node_from_uri(self.uri)
# we start by doing a map update to figure out which is the first
# server.
return n.get_servermap(MODE_WRITE)
d.addCallback(_created)
d.addCallback(lambda res: fireEventually(res))
def _got_smap1(smap):
peer0 = list(smap.make_sharemap()[0])[0]
# we tell the server to respond to this peer first, so that it
# will be asked for the privkey first
self.client._storage._sequence = [peer0]
# now we make the peer fail their second query
self.client._storage._special_answers[peer0] = ["normal", "fail"]
d.addCallback(_got_smap1)
# now we update a servermap from a new node (which doesn't have the
# privkey yet, forcing it to use a separate privkey query). Each
# query response will trigger a privkey query, and since we're using
# _sequence to make the peer0 response come back first, we'll send it
# a privkey query first, and _sequence will again ensure that the
# peer0 query will also come back before the others, and then
# _special_answers will make sure that the query raises an exception.
# The whole point of these hijinks is to exercise the code in
# _privkey_query_failed. Note that the map-update will succeed, since
# we'll just get a copy from one of the other shares.
d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
# Using FakeStorage._sequence means there will be read requests still
# floating around.. wait for them to retire
def _cancel_timer(res):
if self.client._storage._pending_timer:
self.client._storage._pending_timer.cancel()
return res
d.addBoth(_cancel_timer)
return d
def test_privkey_query_missing(self):
# like test_privkey_query_error, but the shares are deleted by the
# second query, instead of raising an exception.
self.client = FakeClient(20)
LARGE = "These are Larger contents" * 200 # about 5KB
d = self.client.create_mutable_file(LARGE)
def _created(n):
self.uri = n.get_uri()
self.n2 = self.client.create_node_from_uri(self.uri)
return n.get_servermap(MODE_WRITE)
d.addCallback(_created)
d.addCallback(lambda res: fireEventually(res))
def _got_smap1(smap):
peer0 = list(smap.make_sharemap()[0])[0]
self.client._storage._sequence = [peer0]
self.client._storage._special_answers[peer0] = ["normal", "none"]
d.addCallback(_got_smap1)
d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
def _cancel_timer(res):
if self.client._storage._pending_timer:
self.client._storage._pending_timer.cancel()
return res
d.addBoth(_cancel_timer)
return d