test.mutable : refactored roundtrip and servermap tests

Signed-off-by: fenn-cs <fenn25.fn@gmail.com>
This commit is contained in:
fenn-cs 2021-09-10 00:59:55 +01:00
parent bbbc8592f0
commit 61b9f15fd1
2 changed files with 52 additions and 49 deletions

View File

@ -11,7 +11,8 @@ if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from six.moves import cStringIO as StringIO from six.moves import cStringIO as StringIO
from twisted.trial import unittest from ..common import AsyncTestCase
from testtools.matchers import Equals, HasLength, Contains
from twisted.internet import defer from twisted.internet import defer
from allmydata.util import base32, consumer from allmydata.util import base32, consumer
@ -23,8 +24,9 @@ from allmydata.mutable.retrieve import Retrieve
from .util import PublishMixin, make_storagebroker, corrupt from .util import PublishMixin, make_storagebroker, corrupt
from .. import common_util as testutil from .. import common_util as testutil
class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin): class Roundtrip(AsyncTestCase, testutil.ShouldFailMixin, PublishMixin):
def setUp(self): def setUp(self):
super(Roundtrip, self).setUp()
return self.publish_one() return self.publish_one()
def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None): def make_servermap(self, mode=MODE_READ, oldmap=None, sb=None):
@ -73,11 +75,11 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
def _do_retrieve(servermap): def _do_retrieve(servermap):
self._smap = servermap self._smap = servermap
#self.dump_servermap(servermap) #self.dump_servermap(servermap)
self.failUnlessEqual(len(servermap.recoverable_versions()), 1) self.assertThat(servermap.recoverable_versions(), HasLength(1))
return self.do_download(servermap) return self.do_download(servermap)
d.addCallback(_do_retrieve) d.addCallback(_do_retrieve)
def _retrieved(new_contents): def _retrieved(new_contents):
self.failUnlessEqual(new_contents, self.CONTENTS) self.assertThat(new_contents, Equals(self.CONTENTS))
d.addCallback(_retrieved) d.addCallback(_retrieved)
# we should be able to re-use the same servermap, both with and # we should be able to re-use the same servermap, both with and
# without updating it. # without updating it.
@ -132,10 +134,10 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
# back empty # back empty
d = self.make_servermap(sb=sb2) d = self.make_servermap(sb=sb2)
def _check_servermap(servermap): def _check_servermap(servermap):
self.failUnlessEqual(servermap.best_recoverable_version(), None) self.assertThat(servermap.best_recoverable_version(), Equals(None))
self.failIf(servermap.recoverable_versions()) self.assertFalse(servermap.recoverable_versions())
self.failIf(servermap.unrecoverable_versions()) self.assertFalse(servermap.unrecoverable_versions())
self.failIf(servermap.all_servers()) self.assertFalse(servermap.all_servers())
d.addCallback(_check_servermap) d.addCallback(_check_servermap)
return d return d
@ -154,7 +156,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
self._fn._storage_broker = self._storage_broker self._fn._storage_broker = self._storage_broker
return self._fn.download_best_version() return self._fn.download_best_version()
def _retrieved(new_contents): def _retrieved(new_contents):
self.failUnlessEqual(new_contents, self.CONTENTS) self.assertThat(new_contents, Equals(self.CONTENTS))
d.addCallback(_restore) d.addCallback(_restore)
d.addCallback(_retrieved) d.addCallback(_retrieved)
return d return d
@ -178,13 +180,13 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
# should be noted in the servermap's list of problems. # should be noted in the servermap's list of problems.
if substring: if substring:
allproblems = [str(f) for f in servermap.get_problems()] allproblems = [str(f) for f in servermap.get_problems()]
self.failUnlessIn(substring, "".join(allproblems)) self.assertThat("".join(allproblems), Contains(substring))
return servermap return servermap
if should_succeed: if should_succeed:
d1 = self._fn.download_version(servermap, ver, d1 = self._fn.download_version(servermap, ver,
fetch_privkey) fetch_privkey)
d1.addCallback(lambda new_contents: d1.addCallback(lambda new_contents:
self.failUnlessEqual(new_contents, self.CONTENTS)) self.assertThat(new_contents, Equals(self.CONTENTS)))
else: else:
d1 = self.shouldFail(NotEnoughSharesError, d1 = self.shouldFail(NotEnoughSharesError,
"_corrupt_all(offset=%s)" % (offset,), "_corrupt_all(offset=%s)" % (offset,),
@ -207,7 +209,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
# and the dump should mention the problems # and the dump should mention the problems
s = StringIO() s = StringIO()
dump = servermap.dump(s).getvalue() dump = servermap.dump(s).getvalue()
self.failUnless("30 PROBLEMS" in dump, dump) self.assertTrue("30 PROBLEMS" in dump, msg=dump)
d.addCallback(_check_servermap) d.addCallback(_check_servermap)
return d return d
@ -299,8 +301,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
# in NotEnoughSharesError, since each share will look invalid # in NotEnoughSharesError, since each share will look invalid
def _check(res): def _check(res):
f = res[0] f = res[0]
self.failUnless(f.check(NotEnoughSharesError)) self.assertThat(f.check(NotEnoughSharesError), HasLength(1))
self.failUnless("uncoordinated write" in str(f)) self.assertThat("uncoordinated write" in str(f), Equals(True))
return self._test_corrupt_all(1, "ran out of servers", return self._test_corrupt_all(1, "ran out of servers",
corrupt_early=False, corrupt_early=False,
failure_checker=_check) failure_checker=_check)
@ -309,7 +311,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
def test_corrupt_all_block_late(self): def test_corrupt_all_block_late(self):
def _check(res): def _check(res):
f = res[0] f = res[0]
self.failUnless(f.check(NotEnoughSharesError)) self.assertTrue(f.check(NotEnoughSharesError))
return self._test_corrupt_all("share_data", "block hash tree failure", return self._test_corrupt_all("share_data", "block hash tree failure",
corrupt_early=False, corrupt_early=False,
failure_checker=_check) failure_checker=_check)
@ -330,9 +332,9 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
shnums_to_corrupt=list(range(0, N-k))) shnums_to_corrupt=list(range(0, N-k)))
d.addCallback(lambda res: self.make_servermap()) d.addCallback(lambda res: self.make_servermap())
def _do_retrieve(servermap): def _do_retrieve(servermap):
self.failUnless(servermap.get_problems()) self.assertTrue(servermap.get_problems())
self.failUnless("pubkey doesn't match fingerprint" self.assertThat("pubkey doesn't match fingerprint"
in str(servermap.get_problems()[0])) in str(servermap.get_problems()[0]), Equals(True))
ver = servermap.best_recoverable_version() ver = servermap.best_recoverable_version()
r = Retrieve(self._fn, self._storage_broker, servermap, ver) r = Retrieve(self._fn, self._storage_broker, servermap, ver)
c = consumer.MemoryConsumer() c = consumer.MemoryConsumer()
@ -340,7 +342,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
d.addCallback(_do_retrieve) d.addCallback(_do_retrieve)
d.addCallback(lambda mc: b"".join(mc.chunks)) d.addCallback(lambda mc: b"".join(mc.chunks))
d.addCallback(lambda new_contents: d.addCallback(lambda new_contents:
self.failUnlessEqual(new_contents, self.CONTENTS)) self.assertThat(new_contents, Equals(self.CONTENTS)))
return d return d
@ -355,11 +357,11 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
self.make_servermap()) self.make_servermap())
def _do_retrieve(servermap): def _do_retrieve(servermap):
ver = servermap.best_recoverable_version() ver = servermap.best_recoverable_version()
self.failUnless(ver) self.assertTrue(ver)
return self._fn.download_best_version() return self._fn.download_best_version()
d.addCallback(_do_retrieve) d.addCallback(_do_retrieve)
d.addCallback(lambda new_contents: d.addCallback(lambda new_contents:
self.failUnlessEqual(new_contents, self.CONTENTS)) self.assertThat(new_contents, Equals(self.CONTENTS)))
return d return d

View File

@ -11,7 +11,8 @@ from future.utils import PY2
if PY2: if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from twisted.trial import unittest from ..common import AsyncTestCase
from testtools.matchers import Equals, NotEquals, HasLength
from twisted.internet import defer from twisted.internet import defer
from allmydata.monitor import Monitor from allmydata.monitor import Monitor
from allmydata.mutable.common import \ from allmydata.mutable.common import \
@ -20,8 +21,9 @@ from allmydata.mutable.publish import MutableData
from allmydata.mutable.servermap import ServerMap, ServermapUpdater from allmydata.mutable.servermap import ServerMap, ServermapUpdater
from .util import PublishMixin from .util import PublishMixin
class Servermap(unittest.TestCase, PublishMixin): class Servermap(AsyncTestCase, PublishMixin):
def setUp(self): def setUp(self):
super(Servermap, self).setUp()
return self.publish_one() return self.publish_one()
def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None, def make_servermap(self, mode=MODE_CHECK, fn=None, sb=None,
@ -42,17 +44,17 @@ class Servermap(unittest.TestCase, PublishMixin):
return d return d
def failUnlessOneRecoverable(self, sm, num_shares): def failUnlessOneRecoverable(self, sm, num_shares):
self.failUnlessEqual(len(sm.recoverable_versions()), 1) self.assertThat(sm.recoverable_versions(), HasLength(1))
self.failUnlessEqual(len(sm.unrecoverable_versions()), 0) self.assertThat(sm.unrecoverable_versions(), HasLength(0))
best = sm.best_recoverable_version() best = sm.best_recoverable_version()
self.failIfEqual(best, None) self.assertThat(best, NotEquals(None))
self.failUnlessEqual(sm.recoverable_versions(), set([best])) self.assertThat(sm.recoverable_versions(), Equals(set([best])))
self.failUnlessEqual(len(sm.shares_available()), 1) self.assertThat(sm.shares_available(), HasLength(1))
self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3, 10)) self.assertThat(sm.shares_available()[best], Equals((num_shares, 3, 10)))
shnum, servers = list(sm.make_sharemap().items())[0] shnum, servers = list(sm.make_sharemap().items())[0]
server = list(servers)[0] server = list(servers)[0]
self.failUnlessEqual(sm.version_on_server(server, shnum), best) self.assertThat(sm.version_on_server(server, shnum), Equals(best))
self.failUnlessEqual(sm.version_on_server(server, 666), None) self.assertThat(sm.version_on_server(server, 666), Equals(None))
return sm return sm
def test_basic(self): def test_basic(self):
@ -117,7 +119,7 @@ class Servermap(unittest.TestCase, PublishMixin):
v = sm.best_recoverable_version() v = sm.best_recoverable_version()
vm = sm.make_versionmap() vm = sm.make_versionmap()
shares = list(vm[v]) shares = list(vm[v])
self.failUnlessEqual(len(shares), 6) self.assertThat(shares, HasLength(6))
self._corrupted = set() self._corrupted = set()
# mark the first 5 shares as corrupt, then update the servermap. # mark the first 5 shares as corrupt, then update the servermap.
# The map should not have the marked shares it in any more, and # The map should not have the marked shares it in any more, and
@ -135,18 +137,17 @@ class Servermap(unittest.TestCase, PublishMixin):
shares = list(vm[v]) shares = list(vm[v])
for (server, shnum) in self._corrupted: for (server, shnum) in self._corrupted:
server_shares = sm.debug_shares_on_server(server) server_shares = sm.debug_shares_on_server(server)
self.failIf(shnum in server_shares, self.assertFalse(shnum in server_shares, "%d was in %s" % (shnum, server_shares))
"%d was in %s" % (shnum, server_shares)) self.assertThat(shares, HasLength(5))
self.failUnlessEqual(len(shares), 5)
d.addCallback(_check_map) d.addCallback(_check_map)
return d return d
def failUnlessNoneRecoverable(self, sm): def failUnlessNoneRecoverable(self, sm):
self.failUnlessEqual(len(sm.recoverable_versions()), 0) self.assertThat(sm.recoverable_versions(), HasLength(0))
self.failUnlessEqual(len(sm.unrecoverable_versions()), 0) self.assertThat(sm.unrecoverable_versions(), HasLength(0))
best = sm.best_recoverable_version() best = sm.best_recoverable_version()
self.failUnlessEqual(best, None) self.assertThat(best, Equals(None))
self.failUnlessEqual(len(sm.shares_available()), 0) self.assertThat(sm.shares_available(), HasLength(0))
def test_no_shares(self): def test_no_shares(self):
self._storage._peers = {} # delete all shares self._storage._peers = {} # delete all shares
@ -168,12 +169,12 @@ class Servermap(unittest.TestCase, PublishMixin):
return d return d
def failUnlessNotQuiteEnough(self, sm): def failUnlessNotQuiteEnough(self, sm):
self.failUnlessEqual(len(sm.recoverable_versions()), 0) self.assertThat(sm.recoverable_versions(), HasLength(0))
self.failUnlessEqual(len(sm.unrecoverable_versions()), 1) self.assertThat(sm.unrecoverable_versions(), HasLength(1))
best = sm.best_recoverable_version() best = sm.best_recoverable_version()
self.failUnlessEqual(best, None) self.assertThat(best, Equals(None))
self.failUnlessEqual(len(sm.shares_available()), 1) self.assertThat(sm.shares_available(), HasLength(1))
self.failUnlessEqual(list(sm.shares_available().values())[0], (2,3,10) ) self.assertThat(list(sm.shares_available().values())[0], Equals((2,3,10)))
return sm return sm
def test_not_quite_enough_shares(self): def test_not_quite_enough_shares(self):
@ -193,7 +194,7 @@ class Servermap(unittest.TestCase, PublishMixin):
d.addCallback(lambda res: ms(mode=MODE_CHECK)) d.addCallback(lambda res: ms(mode=MODE_CHECK))
d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
d.addCallback(lambda sm: d.addCallback(lambda sm:
self.failUnlessEqual(len(sm.make_sharemap()), 2)) self.assertThat(sm.make_sharemap(), HasLength(2)))
d.addCallback(lambda res: ms(mode=MODE_ANYTHING)) d.addCallback(lambda res: ms(mode=MODE_ANYTHING))
d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm))
d.addCallback(lambda res: ms(mode=MODE_WRITE)) d.addCallback(lambda res: ms(mode=MODE_WRITE))
@ -216,7 +217,7 @@ class Servermap(unittest.TestCase, PublishMixin):
# Calling make_servermap also updates the servermap in the mode # Calling make_servermap also updates the servermap in the mode
# that we specify, so we just need to see what it says. # that we specify, so we just need to see what it says.
def _check_servermap(sm): def _check_servermap(sm):
self.failUnlessEqual(len(sm.recoverable_versions()), 1) self.assertThat(sm.recoverable_versions(), HasLength(1))
d.addCallback(_check_servermap) d.addCallback(_check_servermap)
return d return d
@ -229,10 +230,10 @@ class Servermap(unittest.TestCase, PublishMixin):
self.make_servermap(mode=MODE_WRITE, update_range=(1, 2))) self.make_servermap(mode=MODE_WRITE, update_range=(1, 2)))
def _check_servermap(sm): def _check_servermap(sm):
# 10 shares # 10 shares
self.failUnlessEqual(len(sm.update_data), 10) self.assertThat(sm.update_data, HasLength(10))
# one version # one version
for data in sm.update_data.values(): for data in sm.update_data.values():
self.failUnlessEqual(len(data), 1) self.assertThat(data, HasLength(1))
d.addCallback(_check_servermap) d.addCallback(_check_servermap)
return d return d
@ -244,5 +245,5 @@ class Servermap(unittest.TestCase, PublishMixin):
d.addCallback(lambda ignored: d.addCallback(lambda ignored:
self.make_servermap(mode=MODE_CHECK)) self.make_servermap(mode=MODE_CHECK))
d.addCallback(lambda servermap: d.addCallback(lambda servermap:
self.failUnlessEqual(len(servermap.recoverable_versions()), 1)) self.assertThat(servermap.recoverable_versions(), HasLength(1)))
return d return d