Add support for rate-limiting and varying cache-time in cache.get().
This commit is contained in:
parent
c99da77155
commit
a46d260782
2 changed files with 57 additions and 45 deletions
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta, datetime
|
||||
from http.client import HTTPResponse
|
||||
from typing import Mapping
|
||||
from typing import Mapping, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from .config import DATA_DIR, CONFIG
|
||||
|
@ -14,11 +15,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Response:
|
||||
def __init__(self, response: HTTPResponse = None, bytes: bytes = None) -> None:
|
||||
if response is not None:
|
||||
self.bytes = response.read()
|
||||
else:
|
||||
self.bytes = bytes
|
||||
def __init__(self, bytes: bytes = None) -> None:
|
||||
self.bytes = bytes
|
||||
self.text = self.bytes.decode() # TODO: Detect encoding
|
||||
|
||||
@property
|
||||
|
@ -30,34 +28,56 @@ connection = sqlite3.connect(DATA_DIR.joinpath("cache.sqlite"))
|
|||
connection.row_factory = sqlite3.Row # allow accessing rows by index and case-insensitively by name
|
||||
connection.text_factory = bytes # do not try to decode bytes as utf-8 strings
|
||||
|
||||
cache_time = timedelta(seconds=CONFIG["web"].getint("cache_time"))
|
||||
logger.info("Requests cache time is %s", cache_time)
|
||||
CACHE_TIME = timedelta(seconds=CONFIG["web"].getint("cache_time"))
|
||||
logger.info("Default cache time is %s", CACHE_TIME)
|
||||
|
||||
connection.executescript(
|
||||
f"""
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
requests (id INTEGER PRIMARY KEY,
|
||||
url TEXT UNIQUE NOT NULL,
|
||||
response BLOB NOT NULL,
|
||||
timestamp INTEGER NOT NULL);
|
||||
|
||||
DELETE FROM requests
|
||||
WHERE timestamp < {(datetime.utcnow() - cache_time).timestamp()};
|
||||
|
||||
VACUUM;
|
||||
expire INTEGER NOT NULL);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get(url: str, params: Mapping = None, *args, **kwargs) -> Response:
|
||||
def clean():
|
||||
connection.execute(
|
||||
"""
|
||||
DELETE FROM requests
|
||||
WHERE expire < :expire;
|
||||
""", {
|
||||
"expire": datetime.utcnow().timestamp(),
|
||||
}
|
||||
)
|
||||
connection.execute("VACUUM;")
|
||||
connection.commit()
|
||||
|
||||
|
||||
last_request = defaultdict(float)
|
||||
|
||||
|
||||
def get(url: str, params: Mapping = None, cache_time: timedelta = CACHE_TIME,
|
||||
ratelimit: Optional[float] = 1, *args, **kwargs) -> Response:
|
||||
"""
|
||||
Sends a GET request, caching the result for cache_time. If 'ratelimit' is supplied, requests are rate limited at the
|
||||
host-level to this number of requests per second.
|
||||
|
||||
We're saving requests' expire instead of the timestamp it was received to allow for varying cache times; if we were
|
||||
saving the timestamp, clean() wouldn't know when to delete unless the cache time was always the same. This, however,
|
||||
also means that the first call determines for how longer subsequent calls will consider a request fresh.
|
||||
"""
|
||||
if params is not None:
|
||||
url += "?" + urllib.parse.urlencode(params)
|
||||
request = Request(url, *args, **kwargs)
|
||||
request.add_header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:65.0) Gecko/20100101 Firefox/65.0")
|
||||
|
||||
#logger.debug("Get %s", url)
|
||||
|
||||
row = connection.execute(
|
||||
"""
|
||||
SELECT response, timestamp
|
||||
SELECT response, expire
|
||||
FROM requests
|
||||
WHERE url = :url;
|
||||
""", {
|
||||
|
@ -65,38 +85,29 @@ def get(url: str, params: Mapping = None, *args, **kwargs) -> Response:
|
|||
}
|
||||
).fetchone()
|
||||
|
||||
# Cache miss
|
||||
if row is None:
|
||||
response = Response(urlopen(request))
|
||||
connection.execute(
|
||||
"""
|
||||
INSERT INTO requests(url, response, timestamp)
|
||||
VALUES (:url, :response, :timestamp);
|
||||
""", {
|
||||
"url": url,
|
||||
"response": response.bytes,
|
||||
"timestamp": datetime.utcnow().timestamp()
|
||||
}
|
||||
)
|
||||
connection.commit()
|
||||
return response
|
||||
if row is not None and datetime.fromtimestamp(row["expire"]) > datetime.utcnow():
|
||||
#logger.debug("Cache hit: %s", url)
|
||||
return Response(row["response"])
|
||||
|
||||
# Cached and fresh
|
||||
if datetime.fromtimestamp(row["timestamp"]) > datetime.utcnow() - cache_time:
|
||||
return Response(bytes=row["response"])
|
||||
#logger.debug("Cache miss: %s", url)
|
||||
if ratelimit is not None:
|
||||
min_interval = 1 / ratelimit
|
||||
elapsed = time.time() - last_request[request.host]
|
||||
wait = min_interval - elapsed
|
||||
if wait > 0:
|
||||
#logger.debug("Rate-limited for %ss", round(wait, 2))
|
||||
time.sleep(wait)
|
||||
|
||||
# Cached but stale
|
||||
response = Response(urlopen(request))
|
||||
response = Response(urlopen(request).read())
|
||||
last_request[request.host] = time.time()
|
||||
connection.execute(
|
||||
"""
|
||||
UPDATE requests
|
||||
SET response = :response,
|
||||
timestamp = :timestamp
|
||||
WHERE url = :url;
|
||||
INSERT OR REPLACE INTO requests(url, response, expire)
|
||||
VALUES (:url, :response, :expire);
|
||||
""", {
|
||||
"url": url,
|
||||
"response": response.bytes,
|
||||
"timestamp": datetime.utcnow().timestamp()
|
||||
"expire": (datetime.utcnow() + cache_time).timestamp()
|
||||
}
|
||||
)
|
||||
connection.commit()
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Set
|
||||
|
||||
from . import util, reddit, predbs, parsing
|
||||
from . import cache, util, reddit, predbs, parsing
|
||||
from .config import CONFIG, DATA_DIR
|
||||
from .parsing import Releases, Release, ReleaseType
|
||||
|
||||
|
@ -144,6 +144,7 @@ def generate(post=False, pm_recipients=None) -> None:
|
|||
for recipient in pm_recipients:
|
||||
reddit.send_pm(recipient, title, msg)
|
||||
|
||||
cache.clean()
|
||||
logger.info("Execution took %s seconds", int(time.time() - start_time))
|
||||
logger.info("-------------------------------------------------------------------------------------------------")
|
||||
|
||||
|
|
Reference in a new issue