diff --git a/src/allmydata/download.py b/src/allmydata/download.py index 660d6e354..a3c6ca5ee 100644 --- a/src/allmydata/download.py +++ b/src/allmydata/download.py @@ -7,7 +7,7 @@ from twisted.application import service from allmydata.util import idlib, mathutil, hashutil from allmydata.util.assertutil import _assert -from allmydata import codec, hashtree +from allmydata import codec, hashtree, storageserver from allmydata.Crypto.Cipher import AES from allmydata.uri import unpack_uri, unpack_extension from allmydata.interfaces import IDownloadTarget, IDownloader @@ -109,7 +109,7 @@ class ValidatedBucket: # of the share hash tree to validate it from our share hash up to the # hashroot. if not self._share_hash: - d1 = self.bucket.callRemote('get_share_hashes') + d1 = self.bucket.get_share_hashes() else: d1 = defer.succeed(None) @@ -117,12 +117,12 @@ class ValidatedBucket: # validate the requested block up to the share hash needed = self.block_hash_tree.needed_hashes(blocknum) if needed: - # TODO: get fewer hashes, callRemote('get_block_hashes', needed) - d2 = self.bucket.callRemote('get_block_hashes') + # TODO: get fewer hashes, use get_block_hashes(needed) + d2 = self.bucket.get_block_hashes() else: d2 = defer.succeed([]) - d3 = self.bucket.callRemote('get_block', blocknum) + d3 = self.bucket.get_block(blocknum) d = defer.gatherResults([d1, d2, d3]) d.addCallback(self._got_data, blocknum) @@ -321,8 +321,9 @@ class FileDownloader: def _got_response(self, buckets, connection): _assert(isinstance(buckets, dict), buckets) # soon foolscap will check this for us with its DictOf schema constraint for sharenum, bucket in buckets.iteritems(): - self.add_share_bucket(sharenum, bucket) - self._uri_extension_sources.append(bucket) + b = storageserver.ReadBucketProxy(bucket) + self.add_share_bucket(sharenum, b) + self._uri_extension_sources.append(b) def add_share_bucket(self, sharenum, bucket): # this is split out for the benefit of test_encode.py @@ -379,7 +380,8 @@ class FileDownloader: "%s" % name) bucket = sources[0] sources = sources[1:] - d = bucket.callRemote(methname, *args) + #d = bucket.callRemote(methname, *args) + d = getattr(bucket, methname)(*args) d.addCallback(validatorfunc, bucket) def _bad(f): log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py index e0838f732..feb78992f 100644 --- a/src/allmydata/encode.py +++ b/src/allmydata/encode.py @@ -9,7 +9,7 @@ from allmydata.Crypto.Cipher import AES from allmydata.util import mathutil, hashutil from allmydata.util.assertutil import _assert from allmydata.codec import CRSEncoder -from allmydata.interfaces import IEncoder +from allmydata.interfaces import IEncoder, IStorageBucketWriter """ @@ -158,6 +158,7 @@ class Encoder(object): for k in landlords: # it would be nice to: #assert RIBucketWriter.providedBy(landlords[k]) + assert IStorageBucketWriter(landlords[k]) pass self.landlords = landlords.copy() @@ -307,7 +308,7 @@ class Encoder(object): if shareid not in self.landlords: return defer.succeed(None) sh = self.landlords[shareid] - d = sh.callRemote("put_block", segment_num, subshare) + d = sh.put_block(segment_num, subshare) d.addErrback(self._remove_shareholder, shareid, "segnum=%d" % segment_num) return d @@ -356,7 +357,7 @@ class Encoder(object): if shareid not in self.landlords: return defer.succeed(None) sh = self.landlords[shareid] - d = sh.callRemote("put_plaintext_hashes", all_hashes) + d = sh.put_plaintext_hashes(all_hashes) d.addErrback(self._remove_shareholder, shareid, "put_plaintext_hashes") return d @@ -374,7 +375,7 @@ class Encoder(object): if shareid not in self.landlords: return defer.succeed(None) sh = self.landlords[shareid] - d = sh.callRemote("put_crypttext_hashes", all_hashes) + d = sh.put_crypttext_hashes(all_hashes) d.addErrback(self._remove_shareholder, shareid, "put_crypttext_hashes") return d @@ -397,7 +398,7 @@ class Encoder(object): if shareid not in self.landlords: return defer.succeed(None) sh = self.landlords[shareid] - d = sh.callRemote("put_block_hashes", all_hashes) + d = sh.put_block_hashes(all_hashes) d.addErrback(self._remove_shareholder, shareid, "put_block_hashes") return d @@ -427,7 +428,7 @@ class Encoder(object): if shareid not in self.landlords: return defer.succeed(None) sh = self.landlords[shareid] - d = sh.callRemote("put_share_hashes", needed_hashes) + d = sh.put_share_hashes(needed_hashes) d.addErrback(self._remove_shareholder, shareid, "put_share_hashes") return d @@ -442,7 +443,7 @@ class Encoder(object): def send_uri_extension(self, shareid, uri_extension): sh = self.landlords[shareid] - d = sh.callRemote("put_uri_extension", uri_extension) + d = sh.put_uri_extension(uri_extension) d.addErrback(self._remove_shareholder, shareid, "put_uri_extension") return d @@ -450,7 +451,7 @@ class Encoder(object): log.msg("%s: closing shareholders" % self) dl = [] for shareid in self.landlords: - d = self.landlords[shareid].callRemote("close") + d = self.landlords[shareid].close() d.addErrback(self._remove_shareholder, shareid, "close") dl.append(d) return self._gather_responses(dl) diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py index 2c2acfc15..a4d980e95 100644 --- a/src/allmydata/interfaces.py +++ b/src/allmydata/interfaces.py @@ -119,6 +119,42 @@ class RIStorageServer(RemoteInterface): def get_buckets(storage_index=StorageIndex): return DictOf(int, RIBucketReader, maxKeys=MAX_BUCKETS) + +class IStorageBucketWriter(Interface): + def put_block(segmentnum, data): + pass + + def put_plaintext_hashes(hashes): + pass + def put_crypttext_hashes(hashes): + pass + def put_block_hashes(blockhashes): + pass + def put_share_hashes(sharehashes): + pass + def put_uri_extension(data): + pass + def close(): + pass + +class IStorageBucketReader(Interface): + + def get_block(blocknum): + pass + + def get_plaintext_hashes(): + pass + def get_crypttext_hashes(): + pass + def get_block_hashes(): + pass + def get_share_hashes(): + pass + def get_uri_extension(): + pass + + + # hm, we need a solution for forward references in schemas from foolscap.schema import Any RIMutableDirectoryNode_ = Any() # TODO: how can we avoid this? diff --git a/src/allmydata/storageserver.py b/src/allmydata/storageserver.py index 5c5601f07..66cf33e06 100644 --- a/src/allmydata/storageserver.py +++ b/src/allmydata/storageserver.py @@ -5,7 +5,7 @@ from twisted.application import service from zope.interface import implements from allmydata.interfaces import RIStorageServer, RIBucketWriter, \ - RIBucketReader + RIBucketReader, IStorageBucketWriter, IStorageBucketReader from allmydata import interfaces from allmydata.util import bencode, fileutil, idlib from allmydata.util.assertutil import precondition @@ -203,3 +203,44 @@ class StorageServer(service.MultiService, Referenceable): pass return bucketreaders + +class WriteBucketProxy: + implements(IStorageBucketWriter) + def __init__(self, rref): + self._rref = rref + + def put_block(self, segmentnum, data): + return self._rref.callRemote("put_block", segmentnum, data) + + def put_plaintext_hashes(self, hashes): + return self._rref.callRemote("put_plaintext_hashes", hashes) + def put_crypttext_hashes(self, hashes): + return self._rref.callRemote("put_crypttext_hashes", hashes) + def put_block_hashes(self, blockhashes): + return self._rref.callRemote("put_block_hashes", blockhashes) + def put_share_hashes(self, sharehashes): + return self._rref.callRemote("put_share_hashes", sharehashes) + def put_uri_extension(self, data): + return self._rref.callRemote("put_uri_extension", data) + def close(self): + return self._rref.callRemote("close") + +class ReadBucketProxy: + implements(IStorageBucketReader) + def __init__(self, rref): + self._rref = rref + + def get_block(self, blocknum): + return self._rref.callRemote("get_block", blocknum) + + def get_plaintext_hashes(self): + return self._rref.callRemote("get_plaintext_hashes") + def get_crypttext_hashes(self): + return self._rref.callRemote("get_crypttext_hashes") + def get_block_hashes(self): + return self._rref.callRemote("get_block_hashes") + def get_share_hashes(self): + return self._rref.callRemote("get_share_hashes") + def get_uri_extension(self): + return self._rref.callRemote("get_uri_extension") + diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index dc7a64d59..59857d4cf 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -1,4 +1,5 @@ +from zope.interface import implements from twisted.trial import unittest from twisted.internet import defer from twisted.python.failure import Failure @@ -7,6 +8,7 @@ from allmydata import encode, download, hashtree from allmydata.util import hashutil from allmydata.uri import pack_uri from allmydata.Crypto.Cipher import AES +from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader from cStringIO import StringIO class FakePeer: @@ -48,6 +50,7 @@ def flip_bit(good): # flips the last bit return good[:-1] + chr(ord(good[-1]) ^ 0x01) class FakeBucketWriter: + implements(IStorageBucketWriter, IStorageBucketReader) # these are used for both reading and writing def __init__(self, mode="good"): self.mode = mode @@ -59,90 +62,123 @@ class FakeBucketWriter: self.closed = False def callRemote(self, methname, *args, **kwargs): + # this allows FakeBucketWriter to be used either as an + # IStorageBucketWriter or as the remote reference that it wraps. This + # should be cleaned up eventually when we change RIBucketWriter to + # have just write(offset, data) and close() def _call(): meth = getattr(self, methname) return meth(*args, **kwargs) - return defer.maybeDeferred(_call) + d = eventual.fireEventually() + d.addCallback(lambda res: _call()) + return d def put_block(self, segmentnum, data): - assert not self.closed - assert segmentnum not in self.blocks - if self.mode == "lost" and segmentnum >= 1: - raise LostPeerError("I'm going away now") - self.blocks[segmentnum] = data + def _try(): + assert not self.closed + assert segmentnum not in self.blocks + if self.mode == "lost" and segmentnum >= 1: + raise LostPeerError("I'm going away now") + self.blocks[segmentnum] = data + return defer.maybeDeferred(_try) def put_plaintext_hashes(self, hashes): - assert not self.closed - assert self.plaintext_hashes is None - self.plaintext_hashes = hashes + def _try(): + assert not self.closed + assert self.plaintext_hashes is None + self.plaintext_hashes = hashes + return defer.maybeDeferred(_try) def put_crypttext_hashes(self, hashes): - assert not self.closed - assert self.crypttext_hashes is None - self.crypttext_hashes = hashes + def _try(): + assert not self.closed + assert self.crypttext_hashes is None + self.crypttext_hashes = hashes + return defer.maybeDeferred(_try) def put_block_hashes(self, blockhashes): - assert not self.closed - assert self.block_hashes is None - self.block_hashes = blockhashes + def _try(): + assert not self.closed + assert self.block_hashes is None + self.block_hashes = blockhashes + return defer.maybeDeferred(_try) def put_share_hashes(self, sharehashes): - assert not self.closed - assert self.share_hashes is None - self.share_hashes = sharehashes + def _try(): + assert not self.closed + assert self.share_hashes is None + self.share_hashes = sharehashes + return defer.maybeDeferred(_try) def put_uri_extension(self, uri_extension): - assert not self.closed - self.uri_extension = uri_extension + def _try(): + assert not self.closed + self.uri_extension = uri_extension + return defer.maybeDeferred(_try) def close(self): - assert not self.closed - self.closed = True + def _try(): + assert not self.closed + self.closed = True + return defer.maybeDeferred(_try) def get_block(self, blocknum): - assert isinstance(blocknum, (int, long)) - if self.mode == "bad block": - return flip_bit(self.blocks[blocknum]) - return self.blocks[blocknum] + def _try(): + assert isinstance(blocknum, (int, long)) + if self.mode == "bad block": + return flip_bit(self.blocks[blocknum]) + return self.blocks[blocknum] + return defer.maybeDeferred(_try) def get_plaintext_hashes(self): - hashes = self.plaintext_hashes[:] - if self.mode == "bad plaintext hashroot": - hashes[0] = flip_bit(hashes[0]) - if self.mode == "bad plaintext hash": - hashes[1] = flip_bit(hashes[1]) - return hashes + def _try(): + hashes = self.plaintext_hashes[:] + if self.mode == "bad plaintext hashroot": + hashes[0] = flip_bit(hashes[0]) + if self.mode == "bad plaintext hash": + hashes[1] = flip_bit(hashes[1]) + return hashes + return defer.maybeDeferred(_try) def get_crypttext_hashes(self): - hashes = self.crypttext_hashes[:] - if self.mode == "bad crypttext hashroot": - hashes[0] = flip_bit(hashes[0]) - if self.mode == "bad crypttext hash": - hashes[1] = flip_bit(hashes[1]) - return hashes + def _try(): + hashes = self.crypttext_hashes[:] + if self.mode == "bad crypttext hashroot": + hashes[0] = flip_bit(hashes[0]) + if self.mode == "bad crypttext hash": + hashes[1] = flip_bit(hashes[1]) + return hashes + return defer.maybeDeferred(_try) def get_block_hashes(self): - if self.mode == "bad blockhash": - hashes = self.block_hashes[:] - hashes[1] = flip_bit(hashes[1]) - return hashes - return self.block_hashes + def _try(): + if self.mode == "bad blockhash": + hashes = self.block_hashes[:] + hashes[1] = flip_bit(hashes[1]) + return hashes + return self.block_hashes + return defer.maybeDeferred(_try) + def get_share_hashes(self): - if self.mode == "bad sharehash": - hashes = self.share_hashes[:] - hashes[1] = (hashes[1][0], flip_bit(hashes[1][1])) - return hashes - if self.mode == "missing sharehash": - # one sneaky attack would be to pretend we don't know our own - # sharehash, which could manage to frame someone else. - # download.py is supposed to guard against this case. - return [] - return self.share_hashes + def _try(): + if self.mode == "bad sharehash": + hashes = self.share_hashes[:] + hashes[1] = (hashes[1][0], flip_bit(hashes[1][1])) + return hashes + if self.mode == "missing sharehash": + # one sneaky attack would be to pretend we don't know our own + # sharehash, which could manage to frame someone else. + # download.py is supposed to guard against this case. + return [] + return self.share_hashes + return defer.maybeDeferred(_try) def get_uri_extension(self): - if self.mode == "bad uri_extension": - return flip_bit(self.uri_extension) - return self.uri_extension + def _try(): + if self.mode == "bad uri_extension": + return flip_bit(self.uri_extension) + return self.uri_extension + return defer.maybeDeferred(_try) def make_data(length): diff --git a/src/allmydata/upload.py b/src/allmydata/upload.py index 73e88abcb..6b42ac3ec 100644 --- a/src/allmydata/upload.py +++ b/src/allmydata/upload.py @@ -5,7 +5,7 @@ from twisted.application import service from foolscap import Referenceable from allmydata.util import idlib, hashutil -from allmydata import encode +from allmydata import encode, storageserver from allmydata.uri import pack_uri from allmydata.interfaces import IUploadable, IUploader from allmydata.Crypto.Cipher import AES @@ -53,8 +53,10 @@ class PeerTracker: def _got_reply(self, (alreadygot, buckets)): #log.msg("%s._got_reply(%s)" % (self, (alreadygot, buckets))) - self.buckets.update(buckets) - return (alreadygot, set(buckets.keys())) + b = dict( [ (sharenum, storageserver.WriteBucketProxy(rref)) + for sharenum, rref in buckets.iteritems() ] ) + self.buckets.update(b) + return (alreadygot, set(b.keys())) class FileUploader: