diff --git a/CHANGELOG.md b/CHANGELOG.md index a78eabf786..edaccf29c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## Added - [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search +- Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729) ## Added - [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found diff --git a/dash/dash-renderer/src/actions/api.js b/dash/dash-renderer/src/actions/api.js index 12fa3e84bb..507f33c6f0 100644 --- a/dash/dash-renderer/src/actions/api.js +++ b/dash/dash-renderer/src/actions/api.js @@ -7,22 +7,22 @@ import {JWT_EXPIRED_MESSAGE, STATUS} from '../constants/constants'; /* eslint-disable-next-line no-console */ const logWarningOnce = once(console.warn); -function GET(path, fetchConfig) { +function GET(path, fetchConfig, _body, config) { return fetch( path, mergeDeepRight(fetchConfig, { method: 'GET', - headers: getCSRFHeader() + headers: getCSRFHeader(config) }) ); } -function POST(path, fetchConfig, body = {}) { +function POST(path, fetchConfig, body = {}, config) { return fetch( path, mergeDeepRight(fetchConfig, { method: 'POST', - headers: getCSRFHeader(), + headers: getCSRFHeader(config), body: body ? JSON.stringify(body) : null }) ); @@ -55,7 +55,12 @@ export default function apiThunk(endpoint, method, store, id, body) { let res; for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { - res = await request[method](url, config.fetch, body); + res = await request[method]( + url, + config.fetch, + body, + config + ); } catch (e) { // fetch rejection - this means the request didn't return, // we don't get here from 400/500 errors, only network diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 37aab3f194..206839cd07 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -483,7 +483,7 @@ function handleServerside( } const fetchCallback = () => { - const headers = getCSRFHeader() as any; + const headers = getCSRFHeader(config) as any; let url = `${urlBase(config)}_dash-update-component`; let newBody = body; diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index 7c92d17afc..6169c4f65e 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -61,11 +61,16 @@ export function hydrateInitialOutputs() { /* eslint-disable-next-line no-console */ const logWarningOnce = once(console.warn); -export function getCSRFHeader() { +export function getCSRFHeader(config) { try { - return { - 'X-CSRFToken': cookie.parse(document.cookie)._csrf_token - }; + const tokenName = (config && config.csrf_token_name) || '_csrf_token'; + const headerName = (config && config.csrf_header_name) || 'X-CSRFToken'; + const cookies = cookie.parse(document.cookie); + const token = cookies[tokenName]; + if (!token) { + return {}; + } + return {[headerName]: token}; } catch (e) { logWarningOnce(e); return {}; diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index ac18678364..d7f16beda8 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -22,6 +22,8 @@ export type DashConfig = { serve_locally?: boolean; plotlyjs_url?: string; validate_callbacks: boolean; + csrf_token_name?: string; + csrf_header_name?: string; }; export default function getConfigFromDOM(): DashConfig { diff --git a/dash/dash.py b/dash/dash.py index 122cf54dd6..340c112569 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -421,6 +421,15 @@ class Dash(ObsoleteChecker): :param health_endpoint: Path for the health check endpoint. Set to None to disable the health endpoint. Default is None. :type health_endpoint: string or None + + :param csrf_token_name: Name of the cookie to read the CSRF token from. + Default ``'_csrf_token'``. Set this to match the CSRF cookie name + used by your server framework (e.g. ``'csrftoken'`` for Django). + :type csrf_token_name: string + + :param csrf_header_name: Name of the HTTP header to send the CSRF token in. + Default ``'X-CSRFToken'``. + :type csrf_header_name: string """ _plotlyjs_url: str @@ -472,6 +481,8 @@ def __init__( # pylint: disable=too-many-statements on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, + csrf_token_name: str = "_csrf_token", + csrf_header_name: str = "X-CSRFToken", **obsolete, ): @@ -492,6 +503,11 @@ def __init__( # pylint: disable=too-many-statements _validate.check_obsolete(obsolete) + if not csrf_token_name or not csrf_token_name.strip(): + raise ValueError("csrf_token_name must be a non-empty string") + if not csrf_header_name or not csrf_header_name.strip(): + raise ValueError("csrf_header_name must be a non-empty string") + caller_name: str = name if name is not None else get_caller_name() # We have 3 cases: server is either True (we create the server), False @@ -545,6 +561,8 @@ def __init__( # pylint: disable=too-many-statements description=description, health_endpoint=health_endpoint, hide_all_callbacks=False, + csrf_token_name=csrf_token_name, + csrf_header_name=csrf_header_name, ) self.config.set_read_only( [ @@ -555,6 +573,8 @@ def __init__( # pylint: disable=too-many-statements "serve_locally", "compress", "pages_folder", + "csrf_token_name", + "csrf_header_name", ], "Read-only: can only be set in the Dash constructor", ) @@ -938,6 +958,8 @@ def _config(self): "ddk_version": ddk_version, "plotly_version": plotly_version, "validate_callbacks": self._dev_tools.validate_callbacks, + "csrf_token_name": self.config.csrf_token_name, + "csrf_header_name": self.config.csrf_header_name, } if self._plotly_cloud is None: if os.getenv("DASH_ENTERPRISE_ENV") == "WORKSPACE": diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index ca026d2211..415eb61540 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -480,3 +480,47 @@ def test_debug_mode_enable_dev_tools(empty_environ, debug_env, debug, expected): def test_missing_flask_compress_raises(): with pytest.raises(ImportError): Dash(compress=True) + + +def test_csrf_config_defaults(): + app = Dash() + assert app.config.csrf_token_name == "_csrf_token" + assert app.config.csrf_header_name == "X-CSRFToken" + + config = app._config() + assert config["csrf_token_name"] == "_csrf_token" + assert config["csrf_header_name"] == "X-CSRFToken" + + +def test_csrf_config_custom(): + app = Dash(csrf_token_name="csrftoken", csrf_header_name="X-CSRF-Token") + assert app.config.csrf_token_name == "csrftoken" + assert app.config.csrf_header_name == "X-CSRF-Token" + + config = app._config() + assert config["csrf_token_name"] == "csrftoken" + assert config["csrf_header_name"] == "X-CSRF-Token" + + +def test_csrf_config_in_index(): + app = Dash(csrf_token_name="csrftoken") + config_html = app._generate_config_html() + assert '"csrf_token_name":"csrftoken"' in config_html + assert '"csrf_header_name":"X-CSRFToken"' in config_html + + +@pytest.mark.parametrize( + "token_name, header_name", + [("", "X-CSRFToken"), ("csrftoken", ""), (" ", "X-CSRFToken")], +) +def test_csrf_config_validation(token_name, header_name): + with pytest.raises(ValueError): + Dash(csrf_token_name=token_name, csrf_header_name=header_name) + + +def test_csrf_config_read_only(): + app = Dash() + with pytest.raises(AttributeError): + app.config.csrf_token_name = "something_else" + with pytest.raises(AttributeError): + app.config.csrf_header_name = "something_else"