diff --git a/.gitignore b/.gitignore index ab1ce312..6159c8fa 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ staticfiles/ !.elasticbeanstalk/*.global.yml **/**/dist/ +scripts/reset-db.sh diff --git a/common/cloudflare.py b/common/cloudflare.py index 33c2a7f9..6b661130 100644 --- a/common/cloudflare.py +++ b/common/cloudflare.py @@ -6,8 +6,13 @@ from django.conf import settings from django.core.cache import cache +from django.core.exceptions import DisallowedHost from django.contrib.sites.models import Site +from django.http import HttpRequest from django.urls import reverse +from django.utils.cache import get_cache_key + +from visualizer.middleware import get_current_request logger = logging.getLogger(__name__) @@ -51,13 +56,78 @@ def purge_vis_cache(cls, slug): ] cls.purge_paths_cache(paths) + @classmethod + def _make_cache_request(cls, path: str, domain: str) -> HttpRequest: + """ Build a synthetic request matching how Django's cache middleware + would have seen the original request for this path and domain. + Sets HTTP_HOST so build_absolute_uri() matches the original + request's cache key (SERVER_NAME alone appends the port). """ + request = HttpRequest() + request.method = 'GET' + if '?' in path: + request.path, request.META['QUERY_STRING'] = path.split('?', 1) + else: + request.path = path + request.META['QUERY_STRING'] = '' + request.META['HTTP_HOST'] = domain + request.META['wsgi.url_scheme'] = 'https' + request.META['SERVER_NAME'] = domain + request.META['SERVER_PORT'] = '443' + return request + + @classmethod + def _get_purge_domains(cls) -> list[dict]: + """ Return (host, scheme, port) dicts covering the current request's + host (if available) plus the Site domain for production. + The current request host handles dev servers (e.g. localhost:8000) + where the Site domain wouldn't match the cached keys. """ + domains = [] + + # Use the real request's host if available (set by CurrentRequestMiddleware). + # This matches exactly how the cache middleware keyed the response. + current_request = get_current_request() + if current_request: + host = current_request.get_host() # includes port if non-standard + scheme = current_request.scheme + port = current_request.META.get('SERVER_PORT', '443' if scheme == 'https' else '80') + domains.append({'host': host, 'scheme': scheme, 'port': port}) + + # Always also try the Site domain (production). + site_domain = Site.objects.get_current().domain + if not any(d['host'] == site_domain for d in domains): + domains.append({'host': site_domain, 'scheme': 'https', 'port': '443'}) + if not site_domain.startswith("www."): + www_domain = f"www.{site_domain}" + if not any(d['host'] == www_domain for d in domains): + domains.append({'host': www_domain, 'scheme': 'https', 'port': '443'}) + + return domains + + @classmethod + def _purge_django_cache(cls, paths: list[str]) -> None: + """ Purge matching entries from Django's file-based cache. + Uses the current request's host (via thread-local) plus the + Site domain and www. variant to cover dev and production. """ + domains = cls._get_purge_domains() + + for path in paths: + for d in domains: + request = cls._make_cache_request(path, d['host']) + request.META['wsgi.url_scheme'] = d['scheme'] + request.META['SERVER_PORT'] = d['port'] + try: + cache_key = get_cache_key(request) + except DisallowedHost: + # Host isn't in ALLOWED_HOSTS, so there can't be + # cached responses for it — skip. + continue + if cache_key: + cache.delete(cache_key) + @classmethod def purge_paths_cache(cls, paths): - """ Purges the URLs (paths, not URLs) """ - # We also want to purge the file-based cache, but unfortunately - # we don't have a way of doing this per-URL. - # It's overkill, but here we purge everything. - cache.clear() + """ Purges the given paths from both Django's file cache and Cloudflare CDN. """ + cls._purge_django_cache(paths) # If we're on local/dev/staging/etc, we're done. if not cls._is_api_enabled(): diff --git a/rcvis/settings.py b/rcvis/settings.py index 8c9ae643..6275da49 100644 --- a/rcvis/settings.py +++ b/rcvis/settings.py @@ -81,6 +81,9 @@ 'django.contrib.sessions.middleware.SessionMiddleware', + # Store current request in thread-local for cache key construction in model.save() + 'visualizer.middleware.CurrentRequestMiddleware', + # Order of the next 3 is important 'visualizer.middleware.UpdateCacheWithoutMaxAgeMiddleware', 'django.middleware.common.CommonMiddleware', diff --git a/visualizer/middleware.py b/visualizer/middleware.py index 4a650d72..e5191508 100644 --- a/visualizer/middleware.py +++ b/visualizer/middleware.py @@ -9,10 +9,37 @@ the browser not to revalidate for 10 minutes). This subclass calls super() to get the server-side caching, then strips max-age from any response that already has no-cache, so the browser always revalidates via If-Modified-Since. + +CurrentRequestMiddleware: stores the current request in a thread-local so +that model.save() can access the real request host for cache key construction +(needed by _purge_django_cache to match the keys created by the cache +middleware, which vary by host — e.g. localhost:8000 vs example.com). """ +import threading + from django.middleware.cache import UpdateCacheMiddleware from django.utils.cache import cc_delim_re +from django.utils.deprecation import MiddlewareMixin + +# Thread-local storage for the current request. +_thread_locals = threading.local() + + +def get_current_request(): + """Return the current request, or None if called outside a request cycle.""" + return getattr(_thread_locals, 'request', None) + + +class CurrentRequestMiddleware(MiddlewareMixin): + """Store the current request in a thread-local for access by model code.""" + + def process_request(self, request): + _thread_locals.request = request + + def process_response(self, request, response): + _thread_locals.request = None + return response class UpdateCacheWithoutMaxAgeMiddleware(UpdateCacheMiddleware): diff --git a/visualizer/tests/testSimple.py b/visualizer/tests/testSimple.py index 054e54d9..f6a69447 100644 --- a/visualizer/tests/testSimple.py +++ b/visualizer/tests/testSimple.py @@ -7,11 +7,13 @@ import json from mock import patch +from django.core.cache import cache from django.core.files import File from django.core.management import call_command from django.test import TestCase from django.test.client import RequestFactory from django.urls import reverse +from django.utils.cache import get_cache_key from django.utils.http import http_date, parse_http_date from rcvformats.schemas.universaltabulator import SchemaV0 as UTSchema @@ -629,6 +631,66 @@ def test_save_purge_only_on_update(self): config.save() mockPurge.assert_called_once_with(config.slug) + def test_purge_vis_cache_clears_all_cached_urls(self): + """ + purge_vis_cache should delete all cached entries for a slug, + including query-string variants like ?vistype=sankey. + NOTE: purge_vis_cache has a hardcoded list of known URL patterns. + If new vistypes or URL patterns are added, that list must be updated + or those entries won't be purged (this is a known fragility of the + surgical approach vs cache.clear()). + """ + with open(filenames.ONE_ROUND, 'r', encoding='utf-8') as f: + self.client.post('/upload.html', {'jsonFile': f}) + config = TestHelpers.get_latest_upload() + slug = config.slug + + # Representative sample of the URLs purge_vis_cache should clear: + # base view, embedded view, and two query-string variants. + paths = [ + reverse('visualize', args=(slug,)), + reverse('visualizeEmbedded', args=(slug,)), + reverse('visualizeEmbedded', args=(slug,)) + '?vistype=sankey', + reverse('visualizeEmbedded', args=(slug,)) + '?vistype=barchart-interactive', + ] + + # Use example.com (the Site domain) so cache keys match what + # _get_purge_domains will try to purge. + with self.settings(ALLOWED_HOSTS=['example.com', 'www.example.com', 'testserver']): + # Populate the cache for each path and collect cache keys + cache_keys = [] + for path in paths: + response = self.client.get(path, SERVER_NAME='example.com') + self.assertEqual(response.status_code, 200) + cache_key = get_cache_key(response.wsgi_request) + self.assertIsNotNone(cache_key) + self.assertIsNotNone(cache.get(cache_key)) + cache_keys.append(cache_key) + + # Purge everything for this slug + CloudflareAPI.purge_vis_cache(slug) + + # Verify all entries are gone + for i, cache_key in enumerate(cache_keys): + self.assertIsNone(cache.get(cache_key), + f"Cache entry for {paths[i]} was not purged") + + def test_purge_django_cache_tries_www_variant(self): + """ + _purge_django_cache should also try the www. variant of the Site + domain, in case the cache was populated with that host. + """ + with open(filenames.ONE_ROUND, 'r', encoding='utf-8') as f: + self.client.post('/upload.html', {'jsonFile': f}) + config = TestHelpers.get_latest_upload() + path = reverse('visualize', args=(config.slug,)) + + # _get_purge_domains should include both example.com and www.example.com + domains = CloudflareAPI._get_purge_domains() + hosts = [d['host'] for d in domains] + self.assertIn('example.com', hosts) + self.assertIn('www.example.com', hosts) + def test_homepage_real_world_examples(self): """ Tests the "real-world examples" section on the homepage.