From 94670f7f95201415f480950bc8a631b74a4f52fe Mon Sep 17 00:00:00 2001 From: echel0n Date: Tue, 22 Apr 2014 23:24:08 -0700 Subject: [PATCH] Updated our cache code. Updated rsstorrents to not bother using requests sessions. --- lib/cachecontrol/__init__.py | 12 +- lib/cachecontrol/adapter.py | 79 ++++++------- lib/cachecontrol/caches/__init__.py | 4 +- lib/cachecontrol/caches/file_cache.py | 70 +++++++---- lib/cachecontrol/caches/redis_cache.py | 12 +- lib/cachecontrol/compat.py | 14 +++ lib/cachecontrol/controller.py | 154 ++++++++++++------------- lib/cachecontrol/patch_requests.py | 1 + lib/cachecontrol/serialize.py | 97 ++++++++++++++++ lib/cachecontrol/wrapper.py | 15 ++- sickbeard/providers/rsstorrent.py | 6 +- sickbeard/providers/thepiratebay.py | 2 - 12 files changed, 292 insertions(+), 174 deletions(-) create mode 100644 lib/cachecontrol/serialize.py diff --git a/lib/cachecontrol/__init__.py b/lib/cachecontrol/__init__.py index 693e11f1..c18e70c0 100644 --- a/lib/cachecontrol/__init__.py +++ b/lib/cachecontrol/__init__.py @@ -2,12 +2,6 @@ Make it easy to import from cachecontrol without long namespaces. """ - -# patch our requests.models.Response to make them pickleable in older -# versions of requests. - -import cachecontrol.patch_requests - -from cachecontrol.wrapper import CacheControl -from cachecontrol.adapter import CacheControlAdapter -from cachecontrol.controller import CacheController +from .wrapper import CacheControl +from .adapter import CacheControlAdapter +from .controller import CacheController diff --git a/lib/cachecontrol/adapter.py b/lib/cachecontrol/adapter.py index 2e818c7b..d2ca7e87 100644 --- a/lib/cachecontrol/adapter.py +++ b/lib/cachecontrol/adapter.py @@ -1,43 +1,65 @@ from requests.adapters import HTTPAdapter -from cachecontrol.controller import CacheController -from cachecontrol.cache import DictCache +from .controller import CacheController +from .cache import DictCache class CacheControlAdapter(HTTPAdapter): invalidating_methods = set(['PUT', 'DELETE']) - def __init__(self, cache=None, cache_etags=True, *args, **kw): + def __init__(self, cache=None, cache_etags=True, controller_class=None, + serializer=None, *args, **kw): super(CacheControlAdapter, self).__init__(*args, **kw) self.cache = cache or DictCache() - self.controller = CacheController(self.cache, cache_etags=cache_etags) + + controller_factory = controller_class or CacheController + self.controller = controller_factory( + self.cache, + cache_etags=cache_etags, + serializer=serializer, + ) def send(self, request, **kw): - """Send a request. Use the request information to see if it - exists in the cache. + """ + Send a request. Use the request information to see if it + exists in the cache and cache the response if we need to and can. """ if request.method == 'GET': - cached_response = self.controller.cached_request( - request.url, request.headers - ) + cached_response = self.controller.cached_request(request) if cached_response: - # Cached responses should not have a raw field since - # they *cannot* be created from some stream. - cached_response.raw = None - return cached_response + return self.build_response(request, cached_response, from_cache=True) # check for etags and add headers if appropriate - headers = self.controller.add_headers(request.url) - request.headers.update(headers) + request.headers.update(self.controller.conditional_headers(request)) resp = super(CacheControlAdapter, self).send(request, **kw) + return resp - def build_response(self, request, response): - """Build a response by making a request or using the cache. + def build_response(self, request, response, from_cache=False): + """ + Build a response by making a request or using the cache. This will end up calling send and returning a potentially cached response """ + if not from_cache and request.method == 'GET': + if response.status == 304: + # We must have sent an ETag request. This could mean + # that we've been expired already or that we simply + # have an etag. In either case, we want to try and + # update the cache if that is the case. + cached_response = self.controller.update_cached_response( + request, response + ) + + if cached_response is not response: + from_cache = True + + response = cached_response + else: + # try to cache the response + self.controller.cache_response(request, response) + resp = super(CacheControlAdapter, self).build_response( request, response ) @@ -47,28 +69,7 @@ class CacheControlAdapter(HTTPAdapter): cache_url = self.controller.cache_url(request.url) self.cache.delete(cache_url) - # Try to store the response if it is a GET - elif request.method == 'GET': - if response.status == 304: - # We must have sent an ETag request. This could mean - # that we've been expired already or that we simply - # have an etag. In either case, we want to try and - # update the cache if that is the case. - resp = self.controller.update_cached_response( - request, response - ) - # Fix possible exception when using missing `raw` field in - # requests - # TODO: remove when requests will be bump to 2.2.2 or 2.3 - # version - resp.raw = None - else: - # try to cache the response - self.controller.cache_response(request, resp) - # Give the request a from_cache attr to let people use it - # rather than testing for hasattr. - if not hasattr(resp, 'from_cache'): - resp.from_cache = False + resp.from_cache = from_cache return resp diff --git a/lib/cachecontrol/caches/__init__.py b/lib/cachecontrol/caches/__init__.py index 5e851b03..f9e66a1f 100644 --- a/lib/cachecontrol/caches/__init__.py +++ b/lib/cachecontrol/caches/__init__.py @@ -1,7 +1,7 @@ from textwrap import dedent try: - from cachecontrol.caches.file_cache import FileCache + from .file_cache import FileCache except ImportError: notice = dedent(''' NOTE: In order to use the FileCache you must have @@ -13,6 +13,6 @@ except ImportError: try: import redis - from cachecontrol.caches.redis_cache import RedisCache + from .redis_cache import RedisCache except ImportError: pass diff --git a/lib/cachecontrol/caches/file_cache.py b/lib/cachecontrol/caches/file_cache.py index a75f700e..711687ca 100644 --- a/lib/cachecontrol/caches/file_cache.py +++ b/lib/cachecontrol/caches/file_cache.py @@ -1,26 +1,62 @@ +import hashlib import os -import sys -from hashlib import md5 - -try: - from pickle import load, dump, HIGHEST_PROTOCOL -except ImportError: - from cPickle import load, dump, HIGHEST_PROTOCOL from lockfile import FileLock +def _secure_open_write(filename, fmode): + # We only want to write to this file, so open it in write only mode + flags = os.O_WRONLY + + # os.O_CREAT | os.O_EXCL will fail if the file already exists, so we only + # will open *new* files. + # We specify this because we want to ensure that the mode we pass is the + # mode of the file. + flags |= os.O_CREAT | os.O_EXCL + + # Do not follow symlinks to prevent someone from making a symlink that + # we follow and insecurely open a cache file. + if hasattr(os, "O_NOFOLLOW"): + flags |= os.O_NOFOLLOW + + # On Windows we'll mark this file as binary + if hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + + # Before we open our file, we want to delete any existing file that is + # there + try: + os.remove(filename) + except (IOError, OSError): + # The file must not exist already, so we can just skip ahead to opening + pass + + # Open our file, the use of os.O_CREAT | os.O_EXCL will ensure that if a + # race condition happens between the os.remove and this line, that an + # error will be raised. Because we utilize a lockfile this should only + # happen if someone is attempting to attack us. + fd = os.open(filename, flags, fmode) + try: + return os.fdopen(fd, "wb") + except: + # An error occurred wrapping our FD in a file object + os.close(fd) + raise + + class FileCache(object): - def __init__(self, directory, forever=False): + def __init__(self, directory, forever=False, filemode=0o0600, + dirmode=0o0700): self.directory = directory self.forever = forever + self.filemode = filemode if not os.path.isdir(self.directory): - os.mkdir(self.directory) + os.makedirs(self.directory, dirmode) @staticmethod def encode(x): - return md5(x.encode()).hexdigest() + return hashlib.sha224(x.encode()).hexdigest() def _fn(self, name): return os.path.join(self.directory, self.encode(name)) @@ -31,21 +67,15 @@ class FileCache(object): return None with open(name, 'rb') as fh: - try: - if sys.version < '3': - return load(fh) - else: - return load(fh, encoding='latin1') - except ValueError: - return None + return fh.read() def set(self, key, value): name = self._fn(key) with FileLock(name) as lock: - with open(lock.path, 'wb') as fh: - dump(value, fh, HIGHEST_PROTOCOL) + with _secure_open_write(lock.path, self.filemode) as fh: + fh.write(value) def delete(self, key): name = self._fn(key) if not self.forever: - os.remove(name) \ No newline at end of file + os.remove(name) diff --git a/lib/cachecontrol/caches/redis_cache.py b/lib/cachecontrol/caches/redis_cache.py index d3814ebc..72b8ca31 100644 --- a/lib/cachecontrol/caches/redis_cache.py +++ b/lib/cachecontrol/caches/redis_cache.py @@ -2,11 +2,6 @@ from __future__ import division from datetime import datetime -try: - from cPickle import loads, dumps -except ImportError: # Python 3.x - from pickle import loads, dumps - def total_seconds(td): """Python 2.6 compatability""" @@ -24,14 +19,11 @@ class RedisCache(object): self.conn = conn def get(self, key): - val = self.conn.get(key) - if val: - return loads(val) - return None + return self.conn.get(key) def set(self, key, value, expires=None): if not expires: - self.conn.set(key, dumps(value)) + self.conn.set(key, value) else: expires = expires - datetime.now() self.conn.setex(key, total_seconds(expires), value) diff --git a/lib/cachecontrol/compat.py b/lib/cachecontrol/compat.py index 1b6e596e..cb6e1b0b 100644 --- a/lib/cachecontrol/compat.py +++ b/lib/cachecontrol/compat.py @@ -10,3 +10,17 @@ try: except ImportError: import email.Utils parsedate_tz = email.Utils.parsedate_tz + + +try: + import cPickle as pickle +except ImportError: + import pickle + + +# Handle the case where the requests has been patched to not have urllib3 +# bundled as part of it's source. +try: + from requests.packages.urllib3.response import HTTPResponse +except ImportError: + from urllib3.response import HTTPResponse diff --git a/lib/cachecontrol/controller.py b/lib/cachecontrol/controller.py index e0b2bf54..9bf9186c 100644 --- a/lib/cachecontrol/controller.py +++ b/lib/cachecontrol/controller.py @@ -6,9 +6,11 @@ import calendar import time import datetime -from cachecontrol.cache import DictCache -from cachecontrol.compat import parsedate_tz -from cachecontrol.session import CacheControlSession +from requests.structures import CaseInsensitiveDict + +from .cache import DictCache +from .compat import parsedate_tz +from .serialize import Serializer URI = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") @@ -25,9 +27,10 @@ def parse_uri(uri): class CacheController(object): """An interface to see if request should cached or not. """ - def __init__(self, cache=None, sess=None, cache_etags=True): + def __init__(self, cache=None, cache_etags=True, serializer=None): self.cache = cache or DictCache() self.cache_etags = cache_etags + self.serializer = serializer or Serializer() def _urlnorm(self, uri): """Normalize the URL to create a safe key for the cache""" @@ -71,62 +74,47 @@ class CacheController(object): retval = dict(parts_with_args + parts_wo_args) return retval - def cached_request(self, url, headers): - cache_url = self.cache_url(url) - cc = self.parse_cache_control(headers) + def cached_request(self, request): + cache_url = self.cache_url(request.url) + cc = self.parse_cache_control(request.headers) # non-caching states no_cache = True if 'no-cache' in cc else False if 'max-age' in cc and cc['max-age'] == 0: no_cache = True - # see if it is in the cache anyways - in_cache = self.cache.get(cache_url) - if no_cache or not in_cache: + # Bail out if no-cache was set + if no_cache: return False # It is in the cache, so lets see if it is going to be # fresh enough - resp = self.cache.get(cache_url) + resp = self.serializer.loads(request, self.cache.get(cache_url)) - # Check our Vary header to make sure our request headers match - # up. We don't delete it from the though, we just don't return - # our cached value. - # - # NOTE: Because httplib2 stores raw content, it denotes - # headers that were sent in the original response by - # adding -varied-$name. We don't have to do that b/c we - # are storing the object which has a reference to the - # original request. If that changes, then I'd propose - # using the varied headers in the cache key to avoid the - # situation all together. - if 'vary' in resp.headers: - varied_headers = resp.headers['vary'].replace(' ', '').split(',') - original_headers = resp.request.headers - for header in varied_headers: - # If our headers don't match for the headers listed in - # the vary header, then don't use the cached response - if headers.get(header, None) != original_headers.get(header): - return False + # Check to see if we have a cached object + if not resp: + return False + + headers = CaseInsensitiveDict(resp.headers) now = time.time() date = calendar.timegm( - parsedate_tz(resp.headers['date']) + parsedate_tz(headers['date']) ) current_age = max(0, now - date) # TODO: There is an assumption that the result will be a - # requests response object. This may not be best since we + # urllib3 response object. This may not be best since we # could probably avoid instantiating or constructing the # response until we know we need it. - resp_cc = self.parse_cache_control(resp.headers) + resp_cc = self.parse_cache_control(headers) # determine freshness freshness_lifetime = 0 if 'max-age' in resp_cc and resp_cc['max-age'].isdigit(): freshness_lifetime = int(resp_cc['max-age']) - elif 'expires' in resp.headers: - expires = parsedate_tz(resp.headers['expires']) + elif 'expires' in headers: + expires = parsedate_tz(headers['expires']) if expires is not None: expire_time = calendar.timegm(expires) - date freshness_lifetime = max(0, expire_time) @@ -150,30 +138,32 @@ class CacheController(object): fresh = (freshness_lifetime > current_age) if fresh: - # make sure we set the from_cache to true - resp.from_cache = True return resp # we're not fresh. If we don't have an Etag, clear it out - if 'etag' not in resp.headers: + if 'etag' not in headers: self.cache.delete(cache_url) - if 'etag' in resp.headers: - headers['If-None-Match'] = resp.headers['ETag'] - - if 'last-modified' in resp.headers: - headers['If-Modified-Since'] = resp.headers['Last-Modified'] - # return the original handler return False - def add_headers(self, url): - resp = self.cache.get(url) - if resp and 'etag' in resp.headers: - return {'If-None-Match': resp.headers['etag']} - return {} + def conditional_headers(self, request): + cache_url = self.cache_url(request.url) + resp = self.serializer.loads(request, self.cache.get(cache_url)) + new_headers = {} - def cache_response(self, request, resp): + if resp: + headers = CaseInsensitiveDict(resp.headers) + + if 'etag' in headers: + new_headers['If-None-Match'] = headers['ETag'] + + if 'last-modified' in headers: + new_headers['If-Modified-Since'] = headers['Last-Modified'] + + return new_headers + + def cache_response(self, request, response): """ Algorithm for caching requests. @@ -181,7 +171,7 @@ class CacheController(object): """ # From httplib2: Don't cache 206's since we aren't going to # handle byte range requests - if resp.status_code not in [200, 203]: + if response.status not in [200, 203]: return # Cache Session Params @@ -189,13 +179,15 @@ class CacheController(object): cache_urls = getattr(request, 'cache_urls', []) cache_max_age = getattr(request, 'cache_max_age', None) + response_headers = CaseInsensitiveDict(response.headers) + # Check if we are wanting to cache responses from specific urls only cache_url = self.cache_url(request.url) if len(cache_urls) > 0 and not any(s in cache_url for s in cache_urls): return cc_req = self.parse_cache_control(request.headers) - cc = self.parse_cache_control(resp.headers) + cc = self.parse_cache_control(response_headers) # Delete it from the cache if we happen to have it stored there no_store = cc.get('no-store') or cc_req.get('no-store') @@ -203,40 +195,43 @@ class CacheController(object): self.cache.delete(cache_url) # If we've been given an etag, then keep the response - if self.cache_etags and 'etag' in resp.headers: - self.cache.set(cache_url, resp) + if self.cache_etags and 'etag' in response_headers: + self.cache.set(cache_url, self.serializer.dumps(request, response)) # If we want to cache sites not setup with cache headers then add the proper headers and keep the response - elif cache_auto and not cc and resp.headers: + elif cache_auto and not cc and response_headers: headers = {'Cache-Control': 'public,max-age=%d' % int(cache_max_age or 900)} - resp.headers.update(headers) + response.headers.update(headers) - if 'expires' not in resp.headers: - if getattr(resp.headers, 'expires', None) is None: - expires = datetime.datetime.utcnow() + datetime.timedelta(days=(1)) + if 'expires' not in response_headers: + if getattr(response_headers, 'expires', None) is None: + expires = datetime.datetime.utcnow() + datetime.timedelta(days=1) expires = expires.strftime("%a, %d %b %Y %H:%M:%S GMT") headers = {'Expires': expires} - resp.headers.update(headers) + response.headers.update(headers) - self.cache.set(cache_url, resp) + self.cache.set(cache_url, self.serializer.dumps(request, response)) # Add to the cache if the response headers demand it. If there # is no date header then we can't do anything about expiring # the cache. - elif 'date' in resp.headers: + elif 'date' in response_headers: # cache when there is a max-age > 0 if cc and cc.get('max-age'): if int(cc['max-age']) > 0: - if isinstance(cache_max_age, (int)): + if isinstance(cache_max_age, int): cc['max-age'] = int(cache_max_age) - resp.headers['cache-control'] = ''.join(['%s=%s' % (key, value) for (key, value) in cc.items()]) - self.cache.set(cache_url, resp) + response.headers['cache-control'] = ''.join(['%s=%s' % (key, value) for (key, value) in cc.items()]) + self.cache.set(cache_url, self.serializer.dumps(request, response)) # If the request can expire, it means we should cache it # in the meantime. - elif 'expires' in resp.headers: - if getattr(resp.headers, 'expires', None) is not None: - self.cache.set(cache_url, resp) + elif 'expires' in response_headers: + if response_headers['expires']: + self.cache.set( + cache_url, + self.serializer.dumps(request, response), + ) def update_cached_response(self, request, response): """On a 304 we will get a new set of headers that we want to @@ -247,27 +242,22 @@ class CacheController(object): """ cache_url = self.cache_url(request.url) - resp = self.cache.get(cache_url) + cached_response = self.serializer.loads(request, self.cache.get(cache_url)) - if not resp: + if not cached_response: # we didn't have a cached response return response # did so lets update our headers - resp.headers.update(resp.headers) + cached_response.headers.update(response.headers) # we want a 200 b/c we have content via the cache - request.status_code = 200 - - # update the request as it has the if-none-match header + any - # other headers that the server might have updated (ie Date, - # Cache-Control, Expires, etc.) - resp.request = request + cached_response.status = 200 # update our cache - self.cache.set(cache_url, resp) + self.cache.set( + cache_url, + self.serializer.dumps(request, cached_response), + ) - # Let everyone know this was from the cache. - resp.from_cache = True - - return resp + return cached_response diff --git a/lib/cachecontrol/patch_requests.py b/lib/cachecontrol/patch_requests.py index a5563531..3399223a 100644 --- a/lib/cachecontrol/patch_requests.py +++ b/lib/cachecontrol/patch_requests.py @@ -52,4 +52,5 @@ def make_responses_pickleable(): raise pass + make_responses_pickleable() \ No newline at end of file diff --git a/lib/cachecontrol/serialize.py b/lib/cachecontrol/serialize.py new file mode 100644 index 00000000..5316fa1c --- /dev/null +++ b/lib/cachecontrol/serialize.py @@ -0,0 +1,97 @@ +import io + +from requests.structures import CaseInsensitiveDict + +from .compat import HTTPResponse, pickle + + +class Serializer(object): + def dumps(self, request, response, body=None): + response_headers = CaseInsensitiveDict(response.headers) + + if body is None: + # TODO: Figure out a way to handle this which doesn't break + # streaming + body = response.read(decode_content=False) + response._fp = io.BytesIO(body) + + data = { + "response": { + "body": body, + "headers": response.headers, + "status": response.status, + "version": response.version, + "reason": response.reason, + "strict": response.strict, + "decode_content": response.decode_content, + }, + } + + # Construct our vary headers + data["vary"] = {} + if "vary" in response_headers: + varied_headers = response_headers['vary'].split(',') + for header in varied_headers: + header = header.strip() + data["vary"][header] = request.headers.get(header, None) + + return b"cc=1," + pickle.dumps(data, pickle.HIGHEST_PROTOCOL) + + def loads(self, request, data): + # Short circuit if we've been given an empty set of data + if not data: + return + + # Determine what version of the serializer the data was serialized + # with + try: + ver, data = data.split(b",", 1) + except ValueError: + ver = b"cc=0" + + # Make sure that our "ver" is actually a version and isn't a false + # positive from a , being in the data stream. + if ver[:3] != b"cc=": + data = ver + data + ver = b"cc=0" + + # Get the version number out of the cc=N + ver = ver.split(b"=", 1)[-1].decode("ascii") + + # Dispatch to the actual load method for the given version + try: + return getattr(self, "_loads_v{0}".format(ver))(request, data) + except AttributeError: + # This is a version we don't have a loads function for, so we'll + # just treat it as a miss and return None + return + + def _loads_v0(self, request, data): + # The original legacy cache data. This doesn't contain enough + # information to construct everything we need, so we'll treat this as + # a miss. + return + + def _loads_v1(self, request, data): + try: + cached = pickle.loads(data) + except ValueError: + return + + # Special case the '*' Vary value as it means we cannot actually + # determine if the cached response is suitable for this request. + if "*" in cached.get("vary", {}): + return + + # Ensure that the Vary headers for the cached response match our + # request + for header, value in cached.get("vary", {}).items(): + if request.headers.get(header, None) != value: + return + + body = io.BytesIO(cached["response"].pop("body")) + return HTTPResponse( + body=body, + preload_content=False, + **cached["response"] + ) diff --git a/lib/cachecontrol/wrapper.py b/lib/cachecontrol/wrapper.py index 88dc2c97..0dc608a0 100644 --- a/lib/cachecontrol/wrapper.py +++ b/lib/cachecontrol/wrapper.py @@ -1,11 +1,16 @@ -from cachecontrol.adapter import CacheControlAdapter -from cachecontrol.cache import DictCache -from cachecontrol.session import CacheControlSession +from .adapter import CacheControlAdapter +from .cache import DictCache +from .session import CacheControlSession -def CacheControl(sess=None, cache=None, cache_etags=True): +def CacheControl(sess=None, cache=None, cache_etags=True, serializer=None): sess = sess or CacheControlSession() cache = cache or DictCache() - adapter = CacheControlAdapter(cache, cache_etags=cache_etags) + adapter = CacheControlAdapter( + cache, + cache_etags=cache_etags, + serializer=serializer, + ) sess.mount('http://', adapter) + sess.mount('https://', adapter) return sess diff --git a/sickbeard/providers/rsstorrent.py b/sickbeard/providers/rsstorrent.py index edfb267f..89ce9ea1 100644 --- a/sickbeard/providers/rsstorrent.py +++ b/sickbeard/providers/rsstorrent.py @@ -45,7 +45,6 @@ class TorrentRssProvider(generic.TorrentProvider): self.url = re.sub('\/$', '', url) self.enabled = True self.supportsBacklog = False - self.session = None def configStr(self): return self.name + '|' + self.url + '|' + str(int(self.enabled)) @@ -127,12 +126,9 @@ class TorrentRssProvider(generic.TorrentProvider): def getURL(self, url, post_data=None, headers=None): - if not self.session: - self.session = requests.Session() - try: url = urljoin(url, urlparse(url).path.replace('//', '/')) - response = self.session.get(url, verify=False) + response = requests.get(url, verify=False) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError), e: logger.log(u"Error loading " + self.name + " URL: " + ex(e), logger.ERROR) return None diff --git a/sickbeard/providers/thepiratebay.py b/sickbeard/providers/thepiratebay.py index 3f6c1548..5c540e40 100644 --- a/sickbeard/providers/thepiratebay.py +++ b/sickbeard/providers/thepiratebay.py @@ -307,8 +307,6 @@ class ThePirateBayProvider(generic.TorrentProvider): if self.proxy.isEnabled(): headers.update({'referer': self.proxy.getProxyURL()}) - result = None - try: r = requests.get(url, headers=headers) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError), e: