diff --git a/bookwyrm/tests/test_utils.py b/bookwyrm/tests/test_utils.py index 60f3185e6..61ed2262c 100644 --- a/bookwyrm/tests/test_utils.py +++ b/bookwyrm/tests/test_utils.py @@ -22,8 +22,8 @@ class TestUtils(TestCase): def test_invalid_url_domain(self): """Check with an invalid URL""" - self.assertEqual( - validate_url_domain("https://up-to-no-good.tld/bad-actor.exe"), "/" + self.assertIsNone( + validate_url_domain("https://up-to-no-good.tld/bad-actor.exe") ) def test_default_url_domain(self): diff --git a/bookwyrm/utils/validate.py b/bookwyrm/utils/validate.py index 89aee4782..b91add3ad 100644 --- a/bookwyrm/utils/validate.py +++ b/bookwyrm/utils/validate.py @@ -2,12 +2,12 @@ from bookwyrm.settings import DOMAIN, USE_HTTPS -def validate_url_domain(url, default="/"): +def validate_url_domain(url): """Basic check that the URL starts with the instance domain name""" if not url: - return default + return None - if url in ("/", default): + if url == "/": return url protocol = "https://" if USE_HTTPS else "http://" @@ -16,4 +16,4 @@ def validate_url_domain(url, default="/"): if url.startswith(origin): return url - return default + return None diff --git a/bookwyrm/views/helpers.py b/bookwyrm/views/helpers.py index f89ea0dfe..4f5e00e41 100644 --- a/bookwyrm/views/helpers.py +++ b/bookwyrm/views/helpers.py @@ -16,6 +16,7 @@ from bookwyrm import activitypub, models, settings from bookwyrm.connectors import ConnectorException, get_data from bookwyrm.status import create_generated_note from bookwyrm.utils import regex +from bookwyrm.utils.validate import validate_url_domain # pylint: disable=unnecessary-pass @@ -219,3 +220,15 @@ def maybe_redirect_local_path(request, model): new_path = f"{model.local_path}?{request.GET.urlencode()}" return redirect(new_path, permanent=True) + + +def redirect_to_referer(request, *args): + """Redirect to the referrer, if it's in our domain, with get params""" + # make sure the refer is part of this instance + validated = validate_url_domain(request.META.get("HTTP_REFERER")) + + if validated: + return redirect(validated) + + # if not, use the args passed you'd normally pass to redirect() + return redirect(*args or "/")