diff --git a/@plotly/dash-websocket-worker/README.md b/@plotly/dash-websocket-worker/README.md new file mode 100644 index 0000000000..64e37a1987 --- /dev/null +++ b/@plotly/dash-websocket-worker/README.md @@ -0,0 +1,3 @@ +# Dash websocket worker + +Worker for websocket based callbacks. diff --git a/@plotly/dash-websocket-worker/package.json b/@plotly/dash-websocket-worker/package.json new file mode 100644 index 0000000000..619a842380 --- /dev/null +++ b/@plotly/dash-websocket-worker/package.json @@ -0,0 +1,29 @@ +{ + "name": "@plotly/dash-websocket-worker", + "version": "1.0.0", + "description": "SharedWorker for WebSocket-based Dash callbacks", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "build": "webpack --mode production", + "build:dev": "webpack --mode development", + "watch": "webpack --mode development --watch", + "clean": "rm -rf dist" + }, + "files": [ + "dist" + ], + "keywords": [ + "dash", + "websocket", + "sharedworker" + ], + "author": "Plotly", + "license": "MIT", + "devDependencies": { + "typescript": "^5.0.0", + "webpack": "^5.0.0", + "webpack-cli": "^5.0.0", + "ts-loader": "^9.0.0" + } +} diff --git a/@plotly/dash-websocket-worker/src/MessageRouter.ts b/@plotly/dash-websocket-worker/src/MessageRouter.ts new file mode 100644 index 0000000000..68a9f4bfc2 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/MessageRouter.ts @@ -0,0 +1,207 @@ +import { + WorkerMessageType, + WorkerMessage, + CallbackRequestMessage, + GetPropsResponseMessage, + SetPropsMessage, + GetPropsRequestMessage, + CallbackResponseMessage +} from './types'; + +/** + * Routes messages between renderers (via MessagePorts) and the WebSocket server. + */ +export class MessageRouter { + /** Map of renderer IDs to their MessagePorts */ + private renderers: Map = new Map(); + + /** Callback to send messages to the WebSocket server */ + public sendToServer: ((message: unknown) => void) | null = null; + + /** + * Register a renderer with its MessagePort. + * @param rendererId Unique identifier for the renderer + * @param port The MessagePort for communication + */ + public registerRenderer(rendererId: string, port: MessagePort): void { + this.renderers.set(rendererId, port); + } + + /** + * Unregister a renderer. + * @param rendererId The renderer to unregister + */ + public unregisterRenderer(rendererId: string): void { + this.renderers.delete(rendererId); + } + + /** + * Get the number of connected renderers. + */ + public get rendererCount(): number { + return this.renderers.size; + } + + /** + * Handle a message from a renderer. + * @param rendererId The ID of the renderer that sent the message + * @param message The message from the renderer + */ + public handleRendererMessage(rendererId: string, message: WorkerMessage): void { + switch (message.type) { + case WorkerMessageType.CALLBACK_REQUEST: + this.forwardCallbackRequest(rendererId, message as CallbackRequestMessage); + break; + + case WorkerMessageType.GET_PROPS_RESPONSE: + this.forwardGetPropsResponse(rendererId, message as GetPropsResponseMessage); + break; + + default: + console.warn(`Unknown message type from renderer: ${message.type}`); + } + } + + /** + * Handle a message from the WebSocket server. + * @param message The message from the server + */ + public handleServerMessage(message: unknown): void { + const msg = message as WorkerMessage; + const rendererId = msg.rendererId; + + switch (msg.type) { + case WorkerMessageType.CALLBACK_RESPONSE: + this.forwardToRenderer(rendererId, msg as CallbackResponseMessage); + break; + + case WorkerMessageType.SET_PROPS: + this.forwardSetProps(rendererId, msg as SetPropsMessage); + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + this.forwardGetPropsRequest(rendererId, msg as GetPropsRequestMessage); + break; + + case WorkerMessageType.ERROR: + this.forwardToRenderer(rendererId, msg); + break; + + default: + console.warn(`Unknown message type from server: ${msg.type}`); + } + } + + /** + * Send a message to all connected renderers. + * @param message The message to broadcast + */ + public broadcastToRenderers(message: WorkerMessage): void { + for (const [rendererId, port] of this.renderers) { + try { + port.postMessage(message); + } catch (error) { + // Port may be closed if tab was closed + console.warn(`Failed to send to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a connected notification to a specific renderer. + * @param rendererId The renderer to notify + */ + public notifyConnected(rendererId: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.CONNECTED, + rendererId + }); + } catch (error) { + console.warn(`Failed to notify renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a disconnected notification to all renderers. + * @param reason Optional reason for disconnection + */ + public notifyDisconnected(reason?: string): void { + this.broadcastToRenderers({ + type: WorkerMessageType.DISCONNECTED, + rendererId: '', + payload: { reason } + }); + } + + /** + * Send an error notification to a specific renderer. + * @param rendererId The renderer to notify + * @param message Error message + * @param code Optional error code + */ + public notifyError(rendererId: string, message: string, code?: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.ERROR, + rendererId, + payload: { message, code } + }); + } catch (error) { + console.warn(`Failed to send error to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + private forwardCallbackRequest(rendererId: string, message: CallbackRequestMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardGetPropsResponse(rendererId: string, message: GetPropsResponseMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardToRenderer(rendererId: string, message: WorkerMessage): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage(message); + } catch (error) { + console.warn(`Failed to forward to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } else { + console.warn(`Renderer ${rendererId} not found for message`); + } + } + + private forwardSetProps(rendererId: string, message: SetPropsMessage): void { + this.forwardToRenderer(rendererId, message); + } + + private forwardGetPropsRequest(rendererId: string, message: GetPropsRequestMessage): void { + this.forwardToRenderer(rendererId, message); + } +} diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts new file mode 100644 index 0000000000..5f11086945 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -0,0 +1,299 @@ +/** + * Configuration options for WebSocket connection. + */ +interface WebSocketConfig { + /** Maximum number of reconnection attempts */ + maxRetries: number; + /** Initial delay between reconnection attempts (ms) */ + initialRetryDelay: number; + /** Maximum delay between reconnection attempts (ms) */ + maxRetryDelay: number; + /** Heartbeat interval (ms) */ + heartbeatInterval: number; + /** Heartbeat timeout (ms) */ + heartbeatTimeout: number; + /** Inactivity timeout (ms) - 0 to disable */ + inactivityTimeout: number; +} + +const DEFAULT_CONFIG: WebSocketConfig = { + maxRetries: 10, + initialRetryDelay: 1000, + maxRetryDelay: 30000, + heartbeatInterval: 30000, + heartbeatTimeout: 10000, + inactivityTimeout: 300000 // 5 minutes default +}; + +/** + * Manages WebSocket connection with automatic reconnection and heartbeat. + */ +export class WebSocketManager { + private ws: WebSocket | null = null; + private serverUrl: string | null = null; + private config: WebSocketConfig; + private retryCount = 0; + private retryTimeout: ReturnType | null = null; + private heartbeatInterval: ReturnType | null = null; + private heartbeatTimeout: ReturnType | null = null; + private lastActivityTime: number = Date.now(); + private messageQueue: string[] = []; + private isConnecting = false; + + /** Callback when connection is established */ + public onOpen: (() => void) | null = null; + /** Callback when connection is closed */ + public onClose: ((reason?: string) => void) | null = null; + /** Callback when a message is received */ + public onMessage: ((data: unknown) => void) | null = null; + /** Callback when an error occurs */ + public onError: ((error: Error) => void) | null = null; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + /** + * Update configuration options. + * Only updates the provided options, keeping others unchanged. + * @param config Partial configuration to merge + */ + public setConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + } + + /** + * Connect to the WebSocket server. + * @param serverUrl The WebSocket server URL + */ + public connect(serverUrl: string): void { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + // Already connected + return; + } + + if (this.isConnecting) { + // Connection in progress + return; + } + + this.serverUrl = serverUrl; + this.isConnecting = true; + this.createConnection(); + } + + /** + * Disconnect from the WebSocket server. + */ + public disconnect(): void { + this.cleanup(); + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(1000, 'Client disconnect'); + } + this.ws = null; + this.serverUrl = null; + this.retryCount = 0; + } + + /** + * Send a message through the WebSocket connection. + * If not connected, queues the message and triggers reconnection. + * @param message The message to send + */ + public send(message: unknown): void { + const data = JSON.stringify(message); + + // Track activity for non-heartbeat messages + const msgObj = message as { type?: string }; + if (msgObj.type !== 'heartbeat') { + this.lastActivityTime = Date.now(); + } + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(data); + } else { + // Queue message for when connection is established + this.messageQueue.push(data); + + // Trigger reconnect if we have a server URL but aren't connected/connecting + if (this.serverUrl && !this.isConnecting) { + this.isConnecting = true; + this.createConnection(); + } + } + } + + /** + * Check if the WebSocket is currently connected. + */ + public get isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + private createConnection(): void { + if (!this.serverUrl) { + return; + } + + try { + this.ws = new WebSocket(this.serverUrl); + this.ws.onopen = this.handleOpen.bind(this); + this.ws.onclose = this.handleClose.bind(this); + this.ws.onmessage = this.handleMessage.bind(this); + this.ws.onerror = this.handleError.bind(this); + } catch (error) { + this.isConnecting = false; + this.scheduleReconnect(); + } + } + + private handleOpen(): void { + this.isConnecting = false; + this.retryCount = 0; + this.lastActivityTime = Date.now(); + + // Flush queued messages + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message && this.ws) { + this.ws.send(message); + } + } + + // Start heartbeat (also handles inactivity check) + this.startHeartbeat(); + + if (this.onOpen) { + this.onOpen(); + } + } + + private handleClose(event: CloseEvent): void { + this.isConnecting = false; + this.cleanup(); + + const reason = event.reason || 'Connection closed'; + + if (this.onClose) { + this.onClose(reason); + } + + // Only reconnect if: + // - We haven't explicitly disconnected (code 1000) + // - It's not an inactivity timeout (code 4001) + if (this.serverUrl && event.code !== 1000 && event.code !== 4001) { + this.scheduleReconnect(); + } + } + + private handleMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + + // Handle heartbeat acknowledgment - does NOT count as activity + if (data.type === 'heartbeat_ack') { + this.clearHeartbeatTimeout(); + return; + } + + // Track activity for actual callback messages + this.lastActivityTime = Date.now(); + + if (this.onMessage) { + this.onMessage(data); + } + } catch (error) { + if (this.onError) { + this.onError(new Error('Failed to parse message')); + } + } + } + + private handleError(): void { + this.isConnecting = false; + // WebSocket error events don't contain useful information + // The close event will follow with more details + } + + private scheduleReconnect(): void { + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + } + + if (this.retryCount >= this.config.maxRetries) { + if (this.onError) { + this.onError(new Error('Max reconnection attempts reached')); + } + return; + } + + // Exponential backoff with jitter + const delay = Math.min( + this.config.initialRetryDelay * Math.pow(2, this.retryCount) + + Math.random() * 1000, + this.config.maxRetryDelay + ); + + this.retryCount++; + + this.retryTimeout = setTimeout(() => { + this.createConnection(); + }, delay); + } + + private startHeartbeat(): void { + this.stopHeartbeat(); + + this.heartbeatInterval = setInterval(() => { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return; + } + + // Check for inactivity timeout + if (this.config.inactivityTimeout > 0) { + const timeSinceActivity = Date.now() - this.lastActivityTime; + if (timeSinceActivity >= this.config.inactivityTimeout) { + this.ws.close(4001, 'Inactivity timeout'); + return; + } + } + + this.ws.send(JSON.stringify({ type: 'heartbeat' })); + this.setHeartbeatTimeout(); + }, this.config.heartbeatInterval); + } + + private stopHeartbeat(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + this.clearHeartbeatTimeout(); + } + + private setHeartbeatTimeout(): void { + this.clearHeartbeatTimeout(); + + this.heartbeatTimeout = setTimeout(() => { + // Heartbeat timeout - connection may be dead + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(4000, 'Heartbeat timeout'); + } + }, this.config.heartbeatTimeout); + } + + private clearHeartbeatTimeout(): void { + if (this.heartbeatTimeout) { + clearTimeout(this.heartbeatTimeout); + this.heartbeatTimeout = null; + } + } + + private cleanup(): void { + this.stopHeartbeat(); + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + this.retryTimeout = null; + } + } +} diff --git a/@plotly/dash-websocket-worker/src/index.ts b/@plotly/dash-websocket-worker/src/index.ts new file mode 100644 index 0000000000..e21b382d41 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/index.ts @@ -0,0 +1,18 @@ +/** + * Dash WebSocket Worker Package + * + * Provides a SharedWorker for WebSocket-based Dash callbacks. + */ + +export * from './types'; + +/** + * Get the URL for the WebSocket worker script. + * This should be used to instantiate the SharedWorker. + * + * @param baseUrl Base URL where the worker script is served + * @returns Full URL to the worker script + */ +export function getWorkerUrl(baseUrl: string): string { + return `${baseUrl}/dash-ws-worker.js`; +} diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts new file mode 100644 index 0000000000..fac282b5e1 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -0,0 +1,151 @@ +/** + * Message types for communication between renderer and worker. + */ +export enum WorkerMessageType { + // Renderer -> Worker + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + + // Worker -> Renderer + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** + * Base message structure for worker communication. + */ +export interface WorkerMessage { + type: WorkerMessageType; + rendererId: string; + requestId?: string; + payload?: unknown; +} + +/** + * Message from renderer to worker requesting connection. + */ +export interface ConnectMessage extends WorkerMessage { + type: WorkerMessageType.CONNECT; + payload: { + serverUrl: string; + inactivityTimeout?: number; + }; +} + +/** + * Message from renderer to worker requesting disconnect. + */ +export interface DisconnectMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECT; +} + +/** + * Callback request payload structure. + */ +export interface CallbackPayload { + output: string; + outputs: unknown[]; + inputs: unknown[]; + state?: unknown[]; + changedPropIds: string[]; + parsedChangedPropsIds?: string[]; +} + +/** + * Message from renderer to worker with callback request. + */ +export interface CallbackRequestMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_REQUEST; + payload: CallbackPayload; +} + +/** + * Message from worker to renderer with callback response. + */ +export interface CallbackResponseMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_RESPONSE; + payload: { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; + }; +} + +/** + * Message from worker to renderer to set component props. + */ +export interface SetPropsMessage extends WorkerMessage { + type: WorkerMessageType.SET_PROPS; + payload: { + componentId: string; + props: Record; + }; +} + +/** + * Message from worker to renderer requesting prop values. + */ +export interface GetPropsRequestMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_REQUEST; + payload: { + componentId: string; + properties: string[]; + }; +} + +/** + * Message from renderer to worker with prop values. + */ +export interface GetPropsResponseMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_RESPONSE; + payload: Record; +} + +/** + * Error message from worker to renderer. + */ +export interface ErrorMessage extends WorkerMessage { + type: WorkerMessageType.ERROR; + payload: { + message: string; + code?: string; + }; +} + +/** + * Connected confirmation message from worker to renderer. + */ +export interface ConnectedMessage extends WorkerMessage { + type: WorkerMessageType.CONNECTED; +} + +/** + * Disconnected notification message from worker to renderer. + */ +export interface DisconnectedMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECTED; + payload?: { + reason?: string; + }; +} + +/** + * Union type of all possible worker messages. + */ +export type AnyWorkerMessage = + | ConnectMessage + | DisconnectMessage + | CallbackRequestMessage + | CallbackResponseMessage + | SetPropsMessage + | GetPropsRequestMessage + | GetPropsResponseMessage + | ErrorMessage + | ConnectedMessage + | DisconnectedMessage; diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts new file mode 100644 index 0000000000..0e68f0b09a --- /dev/null +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -0,0 +1,135 @@ +/** + * Dash WebSocket Worker + * + * A SharedWorker that maintains a single WebSocket connection to the Dash server + * and routes messages between multiple renderer instances (browser tabs). + */ + +import { WebSocketManager } from './WebSocketManager'; +import { MessageRouter } from './MessageRouter'; +import { + WorkerMessageType, + WorkerMessage, + ConnectMessage +} from './types'; + +// SharedWorker global scope +declare const self: SharedWorkerGlobalScope; + +/** WebSocket connection manager */ +const wsManager = new WebSocketManager(); + +/** Message router for renderers */ +const router = new MessageRouter(); + +/** Current server URL */ +let serverUrl: string | null = null; + +/** + * Set up WebSocket manager callbacks. + */ +wsManager.onOpen = () => { + console.log('[DashWSWorker] WebSocket connected'); + // Notify all renderers that connection is established + for (const rendererId of getRendererIds()) { + router.notifyConnected(rendererId); + } +}; + +wsManager.onClose = (reason?: string) => { + console.log(`[DashWSWorker] WebSocket closed: ${reason}`); + router.notifyDisconnected(reason); +}; + +wsManager.onMessage = (data: unknown) => { + router.handleServerMessage(data); +}; + +wsManager.onError = (error: Error) => { + console.error('[DashWSWorker] WebSocket error:', error.message); +}; + +/** + * Set up router to send messages to WebSocket. + */ +router.sendToServer = (message: unknown) => { + wsManager.send(message); +}; + +// Track renderer IDs separately for iteration +const rendererIds = new Set(); + +/** + * Get all registered renderer IDs. + */ +function getRendererIds(): string[] { + return Array.from(rendererIds); +} + +/** + * Handle new connection from a renderer (browser tab). + */ +self.onconnect = (event: MessageEvent) => { + const port = event.ports[0]; + + port.onmessage = (e: MessageEvent) => { + const message = e.data as WorkerMessage; + + switch (message.type) { + case WorkerMessageType.CONNECT: { + const connectMsg = message as ConnectMessage; + const rendererId = connectMsg.rendererId; + const newServerUrl = connectMsg.payload.serverUrl; + const inactivityTimeout = connectMsg.payload.inactivityTimeout; + + // Register the renderer + router.registerRenderer(rendererId, port); + rendererIds.add(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}`); + + // Update inactivity timeout if provided + if (typeof inactivityTimeout === 'number') { + wsManager.setConfig({ inactivityTimeout }); + } + + // Connect to server if not already connected + if (!wsManager.isConnected) { + if (serverUrl !== newServerUrl) { + serverUrl = newServerUrl; + } + wsManager.connect(serverUrl); + } else { + // Already connected, notify the renderer + router.notifyConnected(rendererId); + } + break; + } + + case WorkerMessageType.DISCONNECT: { + const rendererId = message.rendererId; + router.unregisterRenderer(rendererId); + rendererIds.delete(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} disconnected`); + + // If no more renderers, disconnect from server + if (router.rendererCount === 0) { + wsManager.disconnect(); + serverUrl = null; + console.log('[DashWSWorker] All renderers disconnected, closing WebSocket'); + } + break; + } + + default: + // Forward other messages through the router + router.handleRendererMessage(message.rendererId, message); + } + }; + + port.start(); +}; + +// Log worker startup +console.log('[DashWSWorker] SharedWorker initialized'); diff --git a/@plotly/dash-websocket-worker/tsconfig.json b/@plotly/dash-websocket-worker/tsconfig.json new file mode 100644 index 0000000000..0254db7f91 --- /dev/null +++ b/@plotly/dash-websocket-worker/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "ESNext", + "lib": ["ES2020", "WebWorker"], + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "moduleResolution": "node", + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] +} diff --git a/@plotly/dash-websocket-worker/webpack.config.js b/@plotly/dash-websocket-worker/webpack.config.js new file mode 100644 index 0000000000..efe7b59e89 --- /dev/null +++ b/@plotly/dash-websocket-worker/webpack.config.js @@ -0,0 +1,25 @@ +const path = require('path'); + +// This config is for standalone development/testing of the worker. +// The production build is handled by dash-renderer's webpack config. +module.exports = { + entry: './src/worker.ts', + output: { + filename: 'dash-ws-worker.js', + path: path.resolve(__dirname, 'dist'), + clean: true + }, + resolve: { + extensions: ['.ts', '.js'] + }, + module: { + rules: [ + { + test: /\.ts$/, + use: 'ts-loader', + exclude: /node_modules/ + } + ] + }, + target: 'webworker' +}; diff --git a/dash/_callback.py b/dash/_callback.py index 37a53d7ec5..b27e9a18b9 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -77,6 +77,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + websocket: Optional[bool] = False, **_kwargs, ) -> Callable[..., Any]: """ @@ -228,6 +229,7 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + websocket=websocket, ) @@ -275,6 +277,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, + websocket=False, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -300,6 +303,7 @@ def insert_callback( "no_output": no_output, "optional": optional, "hidden": hidden, + "websocket": websocket, } if running: callback_spec["running"] = running @@ -652,6 +656,7 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + websocket=_kwargs.get("websocket", False), ) # pylint: disable=too-many-locals diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 646db990ab..4f296bde66 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,9 +1,12 @@ +import asyncio import functools import warnings import json import contextvars import typing +from dash.backends.base_server import DashWebsocketCallback + from . import exceptions from ._get_app import get_app from ._utils import AttributeDict, stringify_id @@ -323,6 +326,32 @@ def custom_data(self): """ return _get_from_context("custom_data", {}) + @property + @has_context + def get_websocket(self) -> typing.Optional[DashWebsocketCallback]: + """Get WebSocket interface if running in WebSocket context. + + Returns the DashWebsocketCallback instance if the callback is being + executed via WebSocket, otherwise returns None. + + Raises: + RuntimeError: If websocket_callbacks is requested but the backend + doesn't support WebSocket. + """ + ws = _get_from_context("dash_websocket", None) + if ws is None: + app = get_app() + if ( + hasattr(app, "_websocket_callbacks") + and app._websocket_callbacks # pylint: disable=protected-access + and not app.backend.websocket_capability + ): + raise RuntimeError( + f"WebSocket callbacks requested but backend " + f"'{app.backend.server_type}' doesn't support them." + ) + return ws + callback_context = CallbackContext() @@ -330,5 +359,26 @@ def custom_data(self): def set_props(component_id: typing.Union[str, dict], props: dict): """ Set the props for a component not included in the callback outputs. + + If running in a WebSocket context, props are streamed immediately to the + client. Otherwise, props are batched and sent with the callback response. """ - callback_context.set_props(component_id, props) + ws = _get_from_context("dash_websocket", None) + if ws is not None: + # Stream immediately via WebSocket + _id = stringify_id(component_id) + + async def _send_props(): + for prop_name, value in props.items(): + await ws.set_prop(_id, prop_name, value) + + # If we're in an async context, schedule the coroutine + try: + asyncio.get_running_loop() + asyncio.ensure_future(_send_props()) + except RuntimeError: + # No running event loop - run synchronously + asyncio.run(_send_props()) + else: + # Batch for response (existing behavior) + callback_context.set_props(component_id, props) diff --git a/dash/_dash_renderer.py b/dash/_dash_renderer.py index ee507ddb71..5574131d10 100644 --- a/dash/_dash_renderer.py +++ b/dash/_dash_renderer.py @@ -1,7 +1,7 @@ import os from typing import Any, List, Dict -__version__ = "3.0.0" +__version__ = "3.1.0" _available_react_versions = {"18.3.1", "18.2.0", "16.14.0"} _available_reactdom_versions = {"18.3.1", "18.2.0", "16.14.0"} @@ -65,7 +65,7 @@ def _set_react_version(v_react, v_reactdom=None): { "relative_package_path": "dash-renderer/build/dash_renderer.min.js", "dev_package_path": "dash-renderer/build/dash_renderer.dev.js", - "external_url": "https://unpkg.com/dash-renderer@3.0.0" + "external_url": "https://unpkg.com/dash-renderer@3.1.0" "/build/dash_renderer.min.js", "namespace": "dash", }, @@ -75,4 +75,9 @@ def _set_react_version(v_react, v_reactdom=None): "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/_hooks.py b/dash/_hooks.py index 1631b40ddc..f260b1fcb0 100644 --- a/dash/_hooks.py +++ b/dash/_hooks.py @@ -49,6 +49,8 @@ def __init__(self) -> None: "index": [], "custom_data": [], "dev_tools": [], + "websocket_connect": [], + "websocket_message": [], } self._js_dist: _t.List[_t.Any] = [] self._css_dist: _t.List[_t.Any] = [] @@ -244,6 +246,60 @@ def devtool( } ) + def websocket_connect(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket connection validation hook. + + The hook receives the WebSocket object and should return: + - True (or any truthy value): Allow the connection + - False: Reject with default code (4001) and reason + - tuple (code, reason): Reject with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_connect() + async def validate_session(websocket): + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_connect", func, priority=priority, final=final) + return func + + return decorator + + def websocket_message(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket message validation hook. + + The hook receives the WebSocket object and message dict, and should return: + - True (or any truthy value): Allow the message + - False: Disconnect with default code (4001) and reason + - tuple (code, reason): Disconnect with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_message() + async def validate_session(websocket, message): + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_message", func, priority=priority, final=final) + return func + + return decorator + hooks = _Hooks() diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4e5bdad621..1e2186c384 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,7 +1,9 @@ from __future__ import annotations from contextvars import copy_context, ContextVar +import asyncio import json +import uuid from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -13,6 +15,7 @@ import subprocess import threading import traceback +from urllib.parse import urlparse try: from fastapi import FastAPI, Request, Response, Body @@ -21,6 +24,7 @@ from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Scope, Receive, Send + from starlette.websockets import WebSocket, WebSocketDisconnect import uvicorn except ImportError as _err: raise ImportError( @@ -30,7 +34,12 @@ from dash.fingerprint import check_fingerprint from dash import _validate, get_app from dash.exceptions import PreventUpdate -from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, + DashWebsocketCallback, +) from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -73,6 +82,86 @@ def set_response(self, **kwargs): return resp +class FastAPIWebsocketCallback(DashWebsocketCallback): + """WebSocket callback implementation for FastAPI backend. + + Provides real-time bidirectional communication for callback execution. + """ + + def __init__( + self, websocket: WebSocket, pending_get_props: Dict[str, asyncio.Future] + ): + """Initialize the WebSocket callback interface. + + Args: + websocket: The WebSocket connection + pending_get_props: Dict to track pending get_props requests + """ + self._websocket = websocket + self._pending_get_props = pending_get_props + + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + await self._websocket.send_json( + { + "type": "set_props", + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + ) + + async def get_prop(self, component_id: str, prop_name: str) -> Any: + """Request current prop value from the client. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + + Returns: + The current value of the property from the client's state + """ + request_id = str(uuid.uuid4()) + + # Create a future to wait for the response + future: asyncio.Future = asyncio.get_event_loop().create_future() + self._pending_get_props[request_id] = future + + # Send the request + await self._websocket.send_json( + { + "type": "get_props_request", + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + ) + + # Wait for the response with timeout + try: + result = await asyncio.wait_for(future, timeout=30.0) + if result and prop_name in result: + return result[prop_name] + return None + except asyncio.TimeoutError as exc: + self._pending_get_props.pop(request_id, None) + raise TimeoutError( + f"Timeout waiting for get_prop response for {component_id}.{prop_name}" + ) from exc + + async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: + """Close the WebSocket connection. + + Args: + code: WebSocket close code (default 1000 for normal closure) + reason: Human-readable reason for closing + """ + await self._websocket.close(code=code, reason=reason) + + _current_request_var = ContextVar("dash_current_request", default=None) @@ -224,6 +313,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class FastAPIDashServer(BaseDashServer[FastAPI]): + websocket_capability: bool = True + def __init__(self, server: FastAPI): super().__init__(server) self.server_type = "fastapi" @@ -609,6 +700,224 @@ async def timing_headers_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response + async def _run_ws_hooks( + self, hooks, websocket: "WebSocket", *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + websocket: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(websocket, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + + # Get allowed origins from dash app config + allowed_origins = getattr( + dash_app, "_allowed_websocket_origins", [] + ) # pylint: disable=protected-access + + def validate_origin(origin: str | None, host: str | None) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + async def websocket_handler(websocket: WebSocket): + # Validate Origin header to prevent Cross-Site WebSocket Hijacking + origin = websocket.headers.get("origin") + host = websocket.headers.get("host") + error = validate_origin(origin, host) + if error: + await websocket.close(code=4003, reason=error) + return + + # Call websocket_connect hooks (before accept) + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + websocket, + default_reason="Connection rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + await websocket.accept() + + # Track pending get_props requests + pending_get_props: Dict[str, asyncio.Future] = {} + + try: + while True: + message = await websocket.receive_json() + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + websocket, + message, + default_reason="Message rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + msg_type = message.get("type") + renderer_id = message.get("rendererId") + + if msg_type == "callback_request": + response = await self._execute_ws_callback( + dash_app, websocket, message, pending_get_props + ) + await websocket.send_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": message.get("requestId"), + "payload": response, + } + ) + + elif msg_type == "get_props_response": + # Handle response for pending get_props request + request_id = message.get("requestId") + if request_id in pending_get_props: + future = pending_get_props.pop(request_id) + if not future.done(): + future.set_result(message.get("payload")) + + elif msg_type == "heartbeat": + await websocket.send_json({"type": "heartbeat_ack"}) + + except WebSocketDisconnect: + pass # Clean disconnect + finally: + # Cancel any pending futures + for future in pending_get_props.values(): + if not future.done(): + future.cancel() + + self.server.add_api_websocket_route(ws_path, websocket_handler) + + async def _execute_ws_callback( + self, + dash_app: "Dash", + websocket: WebSocket, + message: dict, + pending_get_props: Dict[str, asyncio.Future], + ) -> dict: + """Execute callback from WebSocket message. + + Args: + dash_app: The Dash application instance + websocket: The WebSocket connection + message: The callback request message + pending_get_props: Dict to track pending get_props requests + + Returns: + Response dict with status and data + """ + payload = message.get("payload", {}) + + # Create WebSocket callback context + cb_ctx = self._create_ws_context( + dash_app, websocket, payload, pending_get_props + ) + + try: + # Reuse existing callback machinery + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + # pylint: enable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + def _create_ws_context( + self, + _dash_app: "Dash", # pylint: disable=unused-argument + websocket: WebSocket, + payload: dict, + pending_get_props: Dict[str, asyncio.Future], + ): + """Create callback context from WebSocket message. + + Args: + _dash_app: The Dash application instance (unused, kept for API consistency) + websocket: The WebSocket connection + payload: The callback payload + pending_get_props: Dict to track pending get_props requests + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = FastAPIResponseAdapter() + g.updated_props = {} + + # Add WebSocket callback interface + g.dash_websocket = FastAPIWebsocketCallback(websocket, pending_get_props) + + return g + class FastAPIRequestAdapter(RequestAdapter): def __init__(self): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index f7211f44a6..283a3414e2 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -169,6 +169,7 @@ class BaseDashServer(ABC, Generic[ServerType]): config: Dict[str, Any] request_adapter: Type[RequestAdapter] response_adapter: Type[ResponseAdapter] + websocket_capability: bool = False def __init__(self, server: ServerType) -> None: """Initialize the server wrapper. @@ -372,3 +373,54 @@ def setup_backend(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ + + def serve_websocket_callback(self, dash_app: "dash.Dash"): + """Set up the WebSocket endpoint for callback handling. + + Override this method in backends that support WebSocket callbacks. + + Args: + dash_app: The Dash application instance + """ + + +class DashWebsocketCallback(ABC): + """Abstract interface for WebSocket-based callback communication. + + Provides methods for real-time bidirectional communication between + the server and renderer during callback execution. + """ + + @abstractmethod + async def get_prop(self, component_id: str, prop_name: str) -> Any: + """Request current prop value from the client. + + Args: + component_id: The component ID (string or stringified dict for pattern matching) + prop_name: The property name to retrieve + + Returns: + The current value of the property from the client's state + """ + + @abstractmethod + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Args: + component_id: The component ID (string or stringified dict for pattern matching) + prop_name: The property name to update + value: The new value to set + """ + + @abstractmethod + async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: + """Close the WebSocket connection. + + Allows developers to forcibly disconnect a client, e.g., on suspicious + activity, session revocation, or policy violation. + + Args: + code: WebSocket close code (default 1000 for normal closure) + reason: Human-readable reason for closing + """ diff --git a/dash/dash-renderer/init.template b/dash/dash-renderer/init.template index 463cfa02aa..a6b84d3d70 100644 --- a/dash/dash-renderer/init.template +++ b/dash/dash-renderer/init.template @@ -75,4 +75,9 @@ _js_dist = [ "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index f92d22cfc5..a404fa2425 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -13,7 +13,7 @@ "build:dev": "webpack", "build:local": "renderer build local", "build": "renderer build && npm run prepublishOnly", - "postbuild": "es-check es2015 ../deps/*.js build/*.js", + "postbuild": "es-check es2015 ../deps/*.js build/dash_renderer.*.js", "test": "karma start karma.conf.js --single-run", "format": "run-s private::format.*", "lint": "run-s private::lint.* --continue-on-error" diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 343789ca43..2a6b95240c 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -1,9 +1,14 @@ import PropTypes from 'prop-types'; -import React, {useState} from 'react'; +import React, {useState, useEffect} from 'react'; import {Provider} from 'react-redux'; import Store from './store'; import AppContainer from './AppContainer.react'; +import getConfigFromDOM from './config'; +import { + initializeWebSocket, + disconnectWebSocket +} from './observers/websocketObserver'; const AppProvider = ({ hooks = { @@ -16,6 +21,31 @@ const AppProvider = ({ } }: any) => { const [{store}] = useState(() => new Store()); + + // Initialize WebSocket connection if enabled + useEffect(() => { + const config = getConfigFromDOM(); + if (config.websocket?.enabled) { + // Add fetch config for consistency + const fullConfig = { + ...config, + fetch: { + credentials: 'same-origin', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json' + } + } + }; + initializeWebSocket(store, fullConfig); + } + + // Cleanup on unmount + return () => { + disconnectWebSocket(); + }; + }, [store]); + return ( diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 37aab3f194..d2e4a18d1c 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -52,6 +52,11 @@ import {parsePMCId} from './patternMatching'; import {replacePMC} from './patternMatching'; import {loaded, loading} from './loading'; import {getComponentLayout} from '../wrapper/wrapping'; +import { + getWorkerClient, + isWebSocketEnabled, + isWebSocketAvailable +} from '../utils/workerClient'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -685,6 +690,140 @@ function handleServerside( }); } +/** + * Handle serverside callback via WebSocket connection. + * + * Uses the SharedWorker to send the callback request through the persistent + * WebSocket connection instead of HTTP POST. + */ +async function handleWebsocketCallback( + dispatch: any, + hooks: any, + config: any, + payload: ICallbackPayload, + running: any +): Promise { + if (hooks.request_pre) { + hooks.request_pre(payload); + } + + const requestTime = Date.now(); + let runningOff: any; + + if (running) { + dispatch(sideUpdate(running.running, payload)); + runningOff = running.runningOff; + } + + const workerClient = getWorkerClient(); + + try { + // Ensure WebSocket connection is established + await workerClient.ensureConnected(config); + + const response = await workerClient.sendCallback(payload); + + // Handle running off state + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (response.status === 'prevent_update') { + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.PREVENT_UPDATE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + return {}; + } + + if (response.status === 'error') { + throw new Error(response.message || 'Callback error'); + } + + // Extract the callback data - structure is {multi: boolean, response: {...}} + const callbackData = response.data as CallbackResponseData; + + // Handle sideUpdate if present + if (callbackData?.sideUpdate) { + dispatch(sideUpdate(callbackData.sideUpdate, payload)); + } + + // Extract the actual outputs from the response + // Format is similar to HTTP path's finishLine function + let result: CallbackResponse; + const {multi, response: callbackResponse} = callbackData || {}; + + if (hooks.request_post) { + hooks.request_post(payload, callbackResponse); + } + + if (multi) { + result = callbackResponse as CallbackResponse; + } else { + // Single output - convert to the expected format + const {output} = payload; + const id = output.substr(0, output.lastIndexOf('.')); + result = {[id]: (callbackResponse as CallbackResponse)?.props}; + } + + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.OK, + result: result || {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + return result || {}; + } catch (error) { + // Handle running off state on error + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (config.ui) { + dispatch( + updateResourceUsage({ + id: payload.output, + status: STATUS.NO_RESPONSE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + throw error; + } +} + function inputsToDict(inputs_list: any) { // Ported directly from _utils.py, inputs_to_dict // takes an array of inputs (some inputs may be an array) @@ -890,18 +1029,44 @@ export function executeCallback( } ); + // Use WebSocket for callbacks when: + // 1. Global WebSocket is enabled, OR + // 2. Per-callback websocket flag is set (and WebSocket is available) + // (but never for background callbacks) + const useWebSocket = + !background && + (isWebSocketEnabled(config) || + (cb.callback.websocket && + isWebSocketAvailable(config))); + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { - let data = await handleServerside( - dispatch, - hooks, - newConfig, - payload, - background, - additionalArgs.length ? additionalArgs : undefined, - getState, - cb.callback.running - ); + let data: CallbackResponse; + + if (useWebSocket) { + // Use WebSocket path for real-time callbacks + data = await handleWebsocketCallback( + dispatch, + hooks, + newConfig, + payload, + cb.callback.running + ); + } else { + // Use traditional HTTP path + data = await handleServerside( + dispatch, + hooks, + newConfig, + payload, + background, + additionalArgs.length + ? additionalArgs + : undefined, + getState, + cb.callback.running + ); + } if (newHeaders) { dispatch(addHttpHeaders(newHeaders)); diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index ac18678364..b9e68eae03 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -22,6 +22,12 @@ export type DashConfig = { serve_locally?: boolean; plotlyjs_url?: string; validate_callbacks: boolean; + websocket?: { + enabled: boolean; + url: string; + worker_url: string; + inactivity_timeout?: number; + }; }; export default function getConfigFromDOM(): DashConfig { diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts new file mode 100644 index 0000000000..ff5b9d099b --- /dev/null +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -0,0 +1,159 @@ +/** + * Observer for handling incoming WebSocket messages (SET_PROPS, GET_PROPS_REQUEST). + */ + +/* eslint-disable no-console */ + +import {Store} from 'redux'; +import {path} from 'ramda'; + +import {IStoreState} from '../store'; +import {updateProps, notifyObservers} from '../actions'; +import {getPath} from '../actions/paths'; +import { + getWorkerClient, + SetPropsPayload, + GetPropsRequestPayload +} from '../utils/workerClient'; +import {DashConfig} from '../config'; + +/** + * Initialize the WebSocket observer. + * + * Sets up handlers for: + * - SET_PROPS: Update component props when received from server + * - GET_PROPS_REQUEST: Send current prop values back to server + * + * @param store Redux store + * @param config Dash configuration + */ +export async function initializeWebSocket( + store: Store, + config: DashConfig +): Promise { + // Initialize WebSocket if: + // 1. Global websocket is enabled, OR + // 2. WebSocket config is available (for per-callback websocket=True) + const wsAvailable = !!( + config.websocket?.url && config.websocket?.worker_url + ); + if (!wsAvailable) { + return; + } + + // Check if SharedWorker is supported + if (typeof SharedWorker === 'undefined') { + console.warn( + 'SharedWorker not supported in this browser. ' + + 'WebSocket callbacks will fall back to HTTP.' + ); + return; + } + + const workerClient = getWorkerClient(); + + // Handle SET_PROPS messages + workerClient.onSetProps = (payload: SetPropsPayload) => { + const {componentId, props} = payload; + const state = store.getState(); + const componentPath = getPath(state.paths, componentId); + + if (!componentPath) { + console.warn( + `SET_PROPS: Component ${componentId} not found in layout` + ); + return; + } + + // Update the component props + store.dispatch( + updateProps({ + props, + itempath: componentPath, + renderType: 'websocket' + }) as any + ); + + // Notify observers + store.dispatch(notifyObservers({id: componentId, props}) as any); + }; + + // Handle GET_PROPS_REQUEST messages + workerClient.onGetPropsRequest = ( + requestId: string, + payload: GetPropsRequestPayload + ) => { + const {componentId, properties} = payload; + const state = store.getState(); + const componentPath = getPath(state.paths, componentId); + + const result: Record = {}; + + if (componentPath) { + const componentProps = path( + [...componentPath, 'props'], + state.layout + ) as Record | undefined; + + if (componentProps) { + for (const propName of properties) { + result[propName] = componentProps[propName]; + } + } + } + + // Send the response + workerClient.sendGetPropsResponse(requestId, result); + }; + + // Handle connection events + workerClient.onConnected = () => { + console.log('[Dash] WebSocket connected'); + }; + + workerClient.onDisconnected = (reason?: string) => { + console.log(`[Dash] WebSocket disconnected: ${reason}`); + }; + + workerClient.onError = (message: string, code?: string) => { + console.error(`[Dash] WebSocket error: ${message}`, code); + }; + + // Connect to the worker + const wsUrl = buildWebSocketUrl(config); + + try { + // config.websocket is guaranteed to exist due to wsAvailable check above + await workerClient.connect( + config.websocket!.worker_url, + wsUrl, + config.websocket!.inactivity_timeout + ); + } catch (error) { + console.error('[Dash] Failed to connect to WebSocket worker:', error); + } +} + +/** + * Build the WebSocket URL from config. + */ +function buildWebSocketUrl(config: DashConfig): string { + if (!config.websocket?.url) { + throw new Error('WebSocket URL not configured'); + } + + // Convert HTTP(S) URL to WS(S) + const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + + // The config.websocket.url is a path like "/_dash-ws-callback" + return `${wsProtocol}//${host}${config.websocket.url}`; +} + +/** + * Disconnect from the WebSocket. + */ +export function disconnectWebSocket(): void { + const workerClient = getWorkerClient(); + workerClient.disconnect(); +} diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index f1e1dc382c..38a5d7d82f 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -15,6 +15,7 @@ export interface ICallbackDefinition { dynamic_creator?: boolean; running: any; no_output?: boolean; + websocket?: boolean; } export interface ICallbackProperty { diff --git a/dash/dash-renderer/src/utils/rendererId.ts b/dash/dash-renderer/src/utils/rendererId.ts new file mode 100644 index 0000000000..b9bfcfd3af --- /dev/null +++ b/dash/dash-renderer/src/utils/rendererId.ts @@ -0,0 +1,22 @@ +/** Cached renderer ID for this page instance */ +let cachedRendererId: string | null = null; + +/** + * Generate a unique renderer ID for this page instance. + * + * Each page load gets a fresh ID to avoid conflicts with stale + * connections in the SharedWorker after page reloads. + */ +export function getRendererId(): string { + if (!cachedRendererId) { + if (typeof crypto !== 'undefined' && crypto.randomUUID) { + cachedRendererId = crypto.randomUUID(); + } else { + // Fallback for older browsers + cachedRendererId = `${Date.now()}-${Math.random() + .toString(36) + .slice(2)}`; + } + } + return cachedRendererId; +} diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts new file mode 100644 index 0000000000..6bc503b16b --- /dev/null +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -0,0 +1,345 @@ +/** + * Client for communicating with the Dash WebSocket SharedWorker. + */ + +import {getRendererId} from './rendererId'; + +/** Message types for worker communication */ +export enum WorkerMessageType { + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** Callback response structure */ +export interface CallbackResponse { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; +} + +/** Set props message payload */ +export interface SetPropsPayload { + componentId: string; + props: Record; +} + +/** Get props request payload */ +export interface GetPropsRequestPayload { + componentId: string; + properties: string[]; +} + +/** Pending callback request */ +interface PendingRequest { + resolve: (value: CallbackResponse) => void; + reject: (error: Error) => void; +} + +/** + * Client for the Dash WebSocket SharedWorker. + */ +class WorkerClient { + private worker: SharedWorker | null = null; + private rendererId: string; + private pendingCallbacks: Map = new Map(); + private requestCounter = 0; + private isConnected = false; + private connectionPromise: Promise | null = null; + private connectionResolve: (() => void) | null = null; + + /** Callback when SET_PROPS message is received */ + public onSetProps: ((payload: SetPropsPayload) => void) | null = null; + + /** Callback when GET_PROPS_REQUEST message is received */ + public onGetPropsRequest: + | ((requestId: string, payload: GetPropsRequestPayload) => void) + | null = null; + + /** Callback when connection is established */ + public onConnected: (() => void) | null = null; + + /** Callback when connection is lost */ + public onDisconnected: ((reason?: string) => void) | null = null; + + /** Callback when an error occurs */ + public onError: ((message: string, code?: string) => void) | null = null; + + constructor() { + this.rendererId = getRendererId(); + } + + /** + * Initialize the worker connection. + * @param workerUrl URL to the SharedWorker script + * @param serverUrl WebSocket server URL + * @param inactivityTimeout Optional inactivity timeout in ms + */ + public async connect( + workerUrl: string, + serverUrl: string, + inactivityTimeout?: number + ): Promise { + if (this.worker) { + // Already connected + return; + } + + // Create the SharedWorker + this.worker = new SharedWorker(workerUrl, { + name: 'dash-ws-worker' + }); + + // Set up message handling + this.worker.port.onmessage = this.handleMessage.bind(this); + + // Create promise for connection + this.connectionPromise = new Promise(resolve => { + this.connectionResolve = resolve; + }); + + // Start the port + this.worker.port.start(); + + // Send connect message + this.worker.port.postMessage({ + type: WorkerMessageType.CONNECT, + rendererId: this.rendererId, + payload: { + serverUrl, + inactivityTimeout + } + }); + + // Wait for connection + await this.connectionPromise; + } + + /** + * Disconnect from the worker. + */ + public disconnect(): void { + if (this.worker) { + this.worker.port.postMessage({ + type: WorkerMessageType.DISCONNECT, + rendererId: this.rendererId + }); + this.worker.port.close(); + this.worker = null; + } + this.isConnected = false; + this.connectionPromise = null; + this.connectionResolve = null; + + // Reject any pending callbacks + for (const [, pending] of this.pendingCallbacks) { + pending.reject(new Error('Worker disconnected')); + } + this.pendingCallbacks.clear(); + } + + /** + * Ensure the worker is connected, initiating connection if needed. + * @param config The Dash config with websocket settings + */ + public async ensureConnected(config: { + websocket?: { + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; + }): Promise { + // Already connected + if (this.isConnected) { + return; + } + + // Connection in progress, wait for it + if (this.connectionPromise) { + await this.connectionPromise; + return; + } + + // Need to initiate connection + if (!config.websocket?.url || !config.websocket?.worker_url) { + throw new Error('WebSocket config not available'); + } + + if (typeof SharedWorker === 'undefined') { + throw new Error('SharedWorker not supported'); + } + + // Build WebSocket URL + const wsProtocol = + window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + const wsUrl = `${wsProtocol}//${host}${config.websocket.url}`; + + await this.connect( + config.websocket.worker_url, + wsUrl, + config.websocket.inactivity_timeout + ); + } + + /** + * Send a callback request to the server via the worker. + * @param payload The callback payload + * @returns Promise that resolves with the callback response + */ + public async sendCallback(payload: unknown): Promise { + // Wait for initial connection if one is in progress + if (this.connectionPromise && !this.isConnected) { + await this.connectionPromise; + } + + if (!this.worker) { + throw new Error('Worker not connected'); + } + + const requestId = `${this.rendererId}-${++this.requestCounter}`; + + return new Promise((resolve, reject) => { + this.pendingCallbacks.set(requestId, {resolve, reject}); + + this.worker!.port.postMessage({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId: this.rendererId, + requestId, + payload + }); + }); + } + + /** + * Send a get_props response back to the server. + * @param requestId The request ID from the get_props request + * @param props The property values + */ + public sendGetPropsResponse( + requestId: string, + props: Record + ): void { + if (!this.worker || !this.isConnected) { + return; + } + + this.worker.port.postMessage({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId: this.rendererId, + requestId, + payload: props + }); + } + + /** + * Check if the worker is connected. + */ + public get connected(): boolean { + return this.isConnected; + } + + private handleMessage(event: MessageEvent): void { + const message = event.data; + + switch (message.type) { + case WorkerMessageType.CONNECTED: + this.isConnected = true; + if (this.connectionResolve) { + this.connectionResolve(); + this.connectionResolve = null; + } + if (this.onConnected) { + this.onConnected(); + } + break; + + case WorkerMessageType.DISCONNECTED: + this.isConnected = false; + if (this.onDisconnected) { + this.onDisconnected(message.payload?.reason); + } + break; + + case WorkerMessageType.CALLBACK_RESPONSE: { + const requestId = message.requestId; + const pending = this.pendingCallbacks.get(requestId); + if (pending) { + this.pendingCallbacks.delete(requestId); + pending.resolve(message.payload); + } + break; + } + + case WorkerMessageType.SET_PROPS: + if (this.onSetProps) { + this.onSetProps(message.payload); + } + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + if (this.onGetPropsRequest) { + this.onGetPropsRequest(message.requestId, message.payload); + } + break; + + case WorkerMessageType.ERROR: + if (this.onError) { + this.onError( + message.payload?.message || 'Unknown error', + message.payload?.code + ); + } + break; + } + } +} + +// Singleton instance +let workerClientInstance: WorkerClient | null = null; + +/** + * Get the singleton WorkerClient instance. + */ +export function getWorkerClient(): WorkerClient { + if (!workerClientInstance) { + workerClientInstance = new WorkerClient(); + } + return workerClientInstance; +} + +/** + * Check if WebSocket callbacks are globally enabled and supported. + * @param config The Dash config + */ +export function isWebSocketEnabled(config: { + websocket?: {enabled: boolean}; +}): boolean { + return !!(config.websocket?.enabled && typeof SharedWorker !== 'undefined'); +} + +/** + * Check if WebSocket infrastructure is available (for per-callback websocket). + * @param config The Dash config + */ +export function isWebSocketAvailable(config: { + websocket?: { + enabled?: boolean; + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; +}): boolean { + return !!( + config.websocket?.url && + config.websocket?.worker_url && + typeof SharedWorker !== 'undefined' + ); +} diff --git a/dash/dash-renderer/webpack.base.config.js b/dash/dash-renderer/webpack.base.config.js index ed95239f7d..e8a9d14596 100644 --- a/dash/dash-renderer/webpack.base.config.js +++ b/dash/dash-renderer/webpack.base.config.js @@ -72,6 +72,31 @@ const rendererOptions = { ...defaults }; +// WebSocket Worker configuration +const workerOptions = { + mode: 'production', + entry: { + 'dash-ws-worker': '../../@plotly/dash-websocket-worker/src/worker.ts', + }, + output: { + path: path.resolve(__dirname, "build"), + filename: '[name].js', + }, + target: 'webworker', + module: { + rules: [ + { + test: /\.ts$/, + exclude: /node_modules/, + use: ['ts-loader'], + }, + ] + }, + resolve: { + extensions: ['.ts', '.js'] + } +}; + module.exports = options => [ R.mergeAll([ options, @@ -109,5 +134,7 @@ module.exports = options => [ ] ), } - ]) + ]), + // WebSocket Worker build + workerOptions ]; diff --git a/dash/dash.py b/dash/dash.py index ea53edf341..9ae1adc1c4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -472,6 +472,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, + websocket_callbacks: Optional[bool] = False, + allowed_websocket_origins: Optional[List[str]] = None, + websocket_inactivity_timeout: Optional[int] = 300000, **obsolete, ): @@ -619,6 +622,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._assets_files: list = [] self._background_manager = background_callback_manager + self._websocket_callbacks = websocket_callbacks + self._allowed_websocket_origins = allowed_websocket_origins or [] + self._websocket_inactivity_timeout = websocket_inactivity_timeout self.logger = logging.getLogger(__name__) @@ -761,6 +767,12 @@ def _setup_routes(self): ) if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) + + # Set up WebSocket callback route if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + self.backend.serve_websocket_callback(self) + self.backend.setup_index(self) self.backend.setup_catchall(self) @@ -940,6 +952,16 @@ def _config(self): custom_dev_tools.append({**hook_dev_tools, "props": props}) config["dev_tools"] = custom_dev_tools + # Add websocket config if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + config["websocket"] = { + "enabled": bool(self._websocket_callbacks), + "url": self.config.requests_pathname_prefix + "_dash-ws-callback", + "worker_url": self._get_worker_url(), + "inactivity_timeout": self._websocket_inactivity_timeout, + } + return config def serve_reload_hash(self): @@ -967,6 +989,33 @@ def serve_health(self): """ return self.backend.make_response("OK", status=200, mimetype="text/plain") + def _get_worker_url(self) -> str: + """Get the URL for the WebSocket worker script. + + Returns: + The fingerprinted URL for the worker script served via component suites. + """ + relative_path = "dash-renderer/build/dash-ws-worker.js" + namespace = "dash" + + # Register the path so it can be served + self.registered_paths[namespace].add(relative_path) + + # Build fingerprinted URL (same pattern as _collect_and_register_resources) + module_path = os.path.join( + os.path.dirname(sys.modules[namespace].__file__), # type: ignore + relative_path, + ) + + # Use a fallback if the file doesn't exist yet (during development) + try: + modified = int(os.stat(module_path).st_mtime) + except FileNotFoundError: + modified = 0 + + fingerprint = build_fingerprint(relative_path, __version__, modified) + return f"{self.config.requests_pathname_prefix}_dash-component-suites/{namespace}/{fingerprint}" + def get_dist(self, libraries: Sequence[str]) -> list: dists = [] for dist_type in ("_js_dist", "_css_dist"): diff --git a/wsapp.py b/wsapp.py new file mode 100644 index 0000000000..98b2db2f38 --- /dev/null +++ b/wsapp.py @@ -0,0 +1,106 @@ +""" +Test app for WebSocket-based callbacks. + +Run with: + python wsapp.py + +Then open http://127.0.0.1:8050 in your browser. +""" + +from dash import Dash, html, dcc, callback, Output, Input, ctx +import time + +# Create app with FastAPI backend and WebSocket callbacks enabled +app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, +) + +app.layout = html.Div([ + html.H1("WebSocket Callbacks Test"), + + html.Div([ + html.H3("Basic Callback Test"), + html.Button("Click me", id="btn-1", n_clicks=0), + html.Div(id="output-1"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Input Test"), + dcc.Input(id="input-1", type="text", placeholder="Type something..."), + html.Div(id="output-2"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Slider Test"), + dcc.Slider(id="slider-1", min=0, max=100, value=50), + html.Div(id="output-3"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("set_props Test"), + html.Button("Update via set_props", id="btn-2", n_clicks=0), + html.Div(id="output-4", children="Initial content"), + html.Div(id="output-5", children="Will be updated by set_props"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("WebSocket Context Test"), + html.Button("Check WebSocket Context", id="btn-3", n_clicks=0), + html.Div(id="output-6"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div(id="config-display", style={"marginTop": "20px", "fontSize": "12px", "color": "#666"}), +]) + + +@callback(Output("output-1", "children"), Input("btn-1", "n_clicks")) +def update_output_1(n_clicks): + return f"Button clicked {n_clicks} times" + + +@callback(Output("output-2", "children"), Input("input-1", "value")) +def update_output_2(value): + return f"You typed: {value}" + + +@callback(Output("output-3", "children"), Input("slider-1", "value")) +def update_output_3(value): + return f"Slider value: {value}" + + +@callback(Output("output-4", "children"), Input("btn-2", "n_clicks")) +def update_with_set_props(n_clicks): + if n_clicks > 0: + # Use set_props to update another component + from dash._callback_context import set_props + set_props("output-5", {"children": f"Updated via set_props at click {n_clicks}"}) + return f"set_props button clicked {n_clicks} times" + + +@callback(Output("output-6", "children"), Input("btn-3", "n_clicks")) +def check_websocket_context(n_clicks): + if n_clicks > 0: + ws = ctx.get_websocket + if ws is not None: + return f"WebSocket context is available! (click {n_clicks})" + else: + return f"WebSocket context is None (click {n_clicks}) - may be using HTTP fallback" + return "Click to check WebSocket context" + + +@callback(Output("config-display", "children"), Input("btn-1", "n_clicks")) +def show_config(n_clicks): + config = app._config() + ws_config = config.get("websocket", {}) + if ws_config: + return f"WebSocket enabled: {ws_config.get('enabled')}, URL: {ws_config.get('url')}" + return "WebSocket not configured" + + +if __name__ == "__main__": + print("Starting WebSocket callbacks test app...") + print(f"WebSocket callbacks enabled: {app._websocket_callbacks}") + print(f"Backend websocket capability: {app.backend.websocket_capability}") + app.run(debug=True, port=8050)