Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This project adheres to [Semantic Versioning](https://semver.org/).

## Added
- [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search
- Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729)

## Added
- [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found
Expand Down
15 changes: 10 additions & 5 deletions dash/dash-renderer/src/actions/api.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ import {JWT_EXPIRED_MESSAGE, STATUS} from '../constants/constants';
/* eslint-disable-next-line no-console */
const logWarningOnce = once(console.warn);

function GET(path, fetchConfig) {
function GET(path, fetchConfig, _body, config) {
return fetch(
path,
mergeDeepRight(fetchConfig, {
method: 'GET',
headers: getCSRFHeader()
headers: getCSRFHeader(config)
})
);
}

function POST(path, fetchConfig, body = {}) {
function POST(path, fetchConfig, body = {}, config) {
return fetch(
path,
mergeDeepRight(fetchConfig, {
method: 'POST',
headers: getCSRFHeader(),
headers: getCSRFHeader(config),
body: body ? JSON.stringify(body) : null
})
);
Expand Down Expand Up @@ -55,7 +55,12 @@ export default function apiThunk(endpoint, method, store, id, body) {
let res;
for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) {
try {
res = await request[method](url, config.fetch, body);
res = await request[method](
url,
config.fetch,
body,
config
);
} catch (e) {
// fetch rejection - this means the request didn't return,
// we don't get here from 400/500 errors, only network
Expand Down
2 changes: 1 addition & 1 deletion dash/dash-renderer/src/actions/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ function handleServerside(
}

const fetchCallback = () => {
const headers = getCSRFHeader() as any;
const headers = getCSRFHeader(config) as any;
let url = `${urlBase(config)}_dash-update-component`;
let newBody = body;

Expand Down
13 changes: 9 additions & 4 deletions dash/dash-renderer/src/actions/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ export function hydrateInitialOutputs() {
/* eslint-disable-next-line no-console */
const logWarningOnce = once(console.warn);

export function getCSRFHeader() {
export function getCSRFHeader(config) {
try {
return {
'X-CSRFToken': cookie.parse(document.cookie)._csrf_token
};
const tokenName = (config && config.csrf_token_name) || '_csrf_token';
const headerName = (config && config.csrf_header_name) || 'X-CSRFToken';
const cookies = cookie.parse(document.cookie);
const token = cookies[tokenName];
if (!token) {
return {};
}
return {[headerName]: token};
} catch (e) {
logWarningOnce(e);
return {};
Expand Down
2 changes: 2 additions & 0 deletions dash/dash-renderer/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ export type DashConfig = {
serve_locally?: boolean;
plotlyjs_url?: string;
validate_callbacks: boolean;
csrf_token_name?: string;
csrf_header_name?: string;
};

export default function getConfigFromDOM(): DashConfig {
Expand Down
22 changes: 22 additions & 0 deletions dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ class Dash(ObsoleteChecker):
:param health_endpoint: Path for the health check endpoint. Set to None to
disable the health endpoint. Default is None.
:type health_endpoint: string or None

:param csrf_token_name: Name of the cookie to read the CSRF token from.
Default ``'_csrf_token'``. Set this to match the CSRF cookie name
used by your server framework (e.g. ``'csrftoken'`` for Django).
:type csrf_token_name: string

:param csrf_header_name: Name of the HTTP header to send the CSRF token in.
Default ``'X-CSRFToken'``.
:type csrf_header_name: string
"""

_plotlyjs_url: str
Expand Down Expand Up @@ -472,6 +481,8 @@ def __init__( # pylint: disable=too-many-statements
on_error: Optional[Callable[[Exception], Any]] = None,
use_async: Optional[bool] = None,
health_endpoint: Optional[str] = None,
csrf_token_name: str = "_csrf_token",
csrf_header_name: str = "X-CSRFToken",
**obsolete,
):

Expand All @@ -492,6 +503,11 @@ def __init__( # pylint: disable=too-many-statements

_validate.check_obsolete(obsolete)

if not csrf_token_name or not csrf_token_name.strip():
raise ValueError("csrf_token_name must be a non-empty string")
if not csrf_header_name or not csrf_header_name.strip():
raise ValueError("csrf_header_name must be a non-empty string")

caller_name: str = name if name is not None else get_caller_name()

# We have 3 cases: server is either True (we create the server), False
Expand Down Expand Up @@ -545,6 +561,8 @@ def __init__( # pylint: disable=too-many-statements
description=description,
health_endpoint=health_endpoint,
hide_all_callbacks=False,
csrf_token_name=csrf_token_name,
csrf_header_name=csrf_header_name,
)
self.config.set_read_only(
[
Expand All @@ -555,6 +573,8 @@ def __init__( # pylint: disable=too-many-statements
"serve_locally",
"compress",
"pages_folder",
"csrf_token_name",
"csrf_header_name",
],
"Read-only: can only be set in the Dash constructor",
)
Expand Down Expand Up @@ -938,6 +958,8 @@ def _config(self):
"ddk_version": ddk_version,
"plotly_version": plotly_version,
"validate_callbacks": self._dev_tools.validate_callbacks,
"csrf_token_name": self.config.csrf_token_name,
"csrf_header_name": self.config.csrf_header_name,
}
if self._plotly_cloud is None:
if os.getenv("DASH_ENTERPRISE_ENV") == "WORKSPACE":
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,47 @@ def test_debug_mode_enable_dev_tools(empty_environ, debug_env, debug, expected):
def test_missing_flask_compress_raises():
with pytest.raises(ImportError):
Dash(compress=True)


def test_csrf_config_defaults():
app = Dash()
assert app.config.csrf_token_name == "_csrf_token"
assert app.config.csrf_header_name == "X-CSRFToken"

config = app._config()
assert config["csrf_token_name"] == "_csrf_token"
assert config["csrf_header_name"] == "X-CSRFToken"


def test_csrf_config_custom():
app = Dash(csrf_token_name="csrftoken", csrf_header_name="X-CSRF-Token")
assert app.config.csrf_token_name == "csrftoken"
assert app.config.csrf_header_name == "X-CSRF-Token"

config = app._config()
assert config["csrf_token_name"] == "csrftoken"
assert config["csrf_header_name"] == "X-CSRF-Token"


def test_csrf_config_in_index():
app = Dash(csrf_token_name="csrftoken")
config_html = app._generate_config_html()
assert '"csrf_token_name":"csrftoken"' in config_html
assert '"csrf_header_name":"X-CSRFToken"' in config_html


@pytest.mark.parametrize(
"token_name, header_name",
[("", "X-CSRFToken"), ("csrftoken", ""), (" ", "X-CSRFToken")],
)
def test_csrf_config_validation(token_name, header_name):
with pytest.raises(ValueError):
Dash(csrf_token_name=token_name, csrf_header_name=header_name)


def test_csrf_config_read_only():
app = Dash()
with pytest.raises(AttributeError):
app.config.csrf_token_name = "something_else"
with pytest.raises(AttributeError):
app.config.csrf_header_name = "something_else"