diff --git a/dailyreleases/cache.py b/dailyreleases/cache.py index 5785924..103eb28 100644 --- a/dailyreleases/cache.py +++ b/dailyreleases/cache.py @@ -1,11 +1,12 @@ import json import logging import sqlite3 +import time import urllib.parse import urllib.request +from collections import defaultdict from datetime import timedelta, datetime -from http.client import HTTPResponse -from typing import Mapping +from typing import Mapping, Optional from urllib.request import Request, urlopen from .config import DATA_DIR, CONFIG @@ -14,11 +15,8 @@ logger = logging.getLogger(__name__) class Response: - def __init__(self, response: HTTPResponse = None, bytes: bytes = None) -> None: - if response is not None: - self.bytes = response.read() - else: - self.bytes = bytes + def __init__(self, bytes: bytes = None) -> None: + self.bytes = bytes self.text = self.bytes.decode() # TODO: Detect encoding @property @@ -30,34 +28,56 @@ connection = sqlite3.connect(DATA_DIR.joinpath("cache.sqlite")) connection.row_factory = sqlite3.Row # allow accessing rows by index and case-insensitively by name connection.text_factory = bytes # do not try to decode bytes as utf-8 strings -cache_time = timedelta(seconds=CONFIG["web"].getint("cache_time")) -logger.info("Requests cache time is %s", cache_time) +CACHE_TIME = timedelta(seconds=CONFIG["web"].getint("cache_time")) +logger.info("Default cache time is %s", CACHE_TIME) -connection.executescript( - f""" +connection.execute( + """ CREATE TABLE IF NOT EXISTS requests (id INTEGER PRIMARY KEY, url TEXT UNIQUE NOT NULL, response BLOB NOT NULL, - timestamp INTEGER NOT NULL); - - DELETE FROM requests - WHERE timestamp < {(datetime.utcnow() - cache_time).timestamp()}; - - VACUUM; + expire INTEGER NOT NULL); """ ) -def get(url: str, params: Mapping = None, *args, **kwargs) -> Response: +def clean(): + connection.execute( + """ + DELETE FROM requests + WHERE expire < :expire; + """, { + "expire": datetime.utcnow().timestamp(), + } + ) + connection.execute("VACUUM;") + connection.commit() + + +last_request = defaultdict(float) + + +def get(url: str, params: Mapping = None, cache_time: timedelta = CACHE_TIME, + ratelimit: Optional[float] = 1, *args, **kwargs) -> Response: + """ + Sends a GET request, caching the result for cache_time. If 'ratelimit' is supplied, requests are rate limited at the + host-level to this number of requests per second. + + We're saving requests' expire instead of the timestamp it was received to allow for varying cache times; if we were + saving the timestamp, clean() wouldn't know when to delete unless the cache time was always the same. This, however, + also means that the first call determines for how longer subsequent calls will consider a request fresh. + """ if params is not None: url += "?" + urllib.parse.urlencode(params) request = Request(url, *args, **kwargs) request.add_header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:65.0) Gecko/20100101 Firefox/65.0") + #logger.debug("Get %s", url) + row = connection.execute( """ - SELECT response, timestamp + SELECT response, expire FROM requests WHERE url = :url; """, { @@ -65,38 +85,29 @@ def get(url: str, params: Mapping = None, *args, **kwargs) -> Response: } ).fetchone() - # Cache miss - if row is None: - response = Response(urlopen(request)) - connection.execute( - """ - INSERT INTO requests(url, response, timestamp) - VALUES (:url, :response, :timestamp); - """, { - "url": url, - "response": response.bytes, - "timestamp": datetime.utcnow().timestamp() - } - ) - connection.commit() - return response + if row is not None and datetime.fromtimestamp(row["expire"]) > datetime.utcnow(): + #logger.debug("Cache hit: %s", url) + return Response(row["response"]) - # Cached and fresh - if datetime.fromtimestamp(row["timestamp"]) > datetime.utcnow() - cache_time: - return Response(bytes=row["response"]) + #logger.debug("Cache miss: %s", url) + if ratelimit is not None: + min_interval = 1 / ratelimit + elapsed = time.time() - last_request[request.host] + wait = min_interval - elapsed + if wait > 0: + #logger.debug("Rate-limited for %ss", round(wait, 2)) + time.sleep(wait) - # Cached but stale - response = Response(urlopen(request)) + response = Response(urlopen(request).read()) + last_request[request.host] = time.time() connection.execute( """ - UPDATE requests - SET response = :response, - timestamp = :timestamp - WHERE url = :url; + INSERT OR REPLACE INTO requests(url, response, expire) + VALUES (:url, :response, :expire); """, { "url": url, "response": response.bytes, - "timestamp": datetime.utcnow().timestamp() + "expire": (datetime.utcnow() + cache_time).timestamp() } ) connection.commit() diff --git a/dailyreleases/generation.py b/dailyreleases/generation.py index e0a4dbf..29d4081 100644 --- a/dailyreleases/generation.py +++ b/dailyreleases/generation.py @@ -7,7 +7,7 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Set -from . import util, reddit, predbs, parsing +from . import cache, util, reddit, predbs, parsing from .config import CONFIG, DATA_DIR from .parsing import Releases, Release, ReleaseType @@ -144,6 +144,7 @@ def generate(post=False, pm_recipients=None) -> None: for recipient in pm_recipients: reddit.send_pm(recipient, title, msg) + cache.clean() logger.info("Execution took %s seconds", int(time.time() - start_time)) logger.info("-------------------------------------------------------------------------------------------------")