zustand/src/middleware/devtools.ts
Hong-Kuan Wu b4177b3172
feat: add devtools.cleanup() method (#3111)
* feat(devtools): add cleanup method

* docs(devtools): add cleanup section

* reduce lines

* test(devtools): test if the connection removed after cleanup

---------

Co-authored-by: daishi <daishi@axlight.com>
Co-authored-by: Daishi Kato <dai-shi@users.noreply.github.com>
2025-05-21 09:49:59 +09:00

414 lines
12 KiB
TypeScript

import type {} from '@redux-devtools/extension'
import type {
StateCreator,
StoreApi,
StoreMutatorIdentifier,
} from '../vanilla.ts'
type Config = Parameters<
(Window extends { __REDUX_DEVTOOLS_EXTENSION__?: infer T }
? T
: { connect: (param: any) => any })['connect']
>[0]
declare module '../vanilla' {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
interface StoreMutators<S, A> {
'zustand/devtools': WithDevtools<S>
}
}
// FIXME https://github.com/reduxjs/redux-devtools/issues/1097
type Message = {
type: string
payload?: any
state?: any
}
type Cast<T, U> = T extends U ? T : U
type Write<T, U> = Omit<T, keyof U> & U
type TakeTwo<T> = T extends { length: 0 }
? [undefined, undefined]
: T extends { length: 1 }
? [...args0: Cast<T, unknown[]>, arg1: undefined]
: T extends { length: 0 | 1 }
? [...args0: Cast<T, unknown[]>, arg1: undefined]
: T extends { length: 2 }
? T
: T extends { length: 1 | 2 }
? T
: T extends { length: 0 | 1 | 2 }
? T
: T extends [infer A0, infer A1, ...unknown[]]
? [A0, A1]
: T extends [infer A0, (infer A1)?, ...unknown[]]
? [A0, A1?]
: T extends [(infer A0)?, (infer A1)?, ...unknown[]]
? [A0?, A1?]
: never
type WithDevtools<S> = Write<S, StoreDevtools<S>>
type Action =
| string
| {
type: string
[x: string | number | symbol]: unknown
}
type StoreDevtools<S> = S extends {
setState: {
// capture both overloads of setState
(...args: infer Sa1): infer Sr1
(...args: infer Sa2): infer Sr2
}
}
? {
setState(...args: [...args: TakeTwo<Sa1>, action?: Action]): Sr1
setState(...args: [...args: TakeTwo<Sa2>, action?: Action]): Sr2
devtools: {
cleanup: () => void
}
}
: never
export interface DevtoolsOptions extends Config {
name?: string
enabled?: boolean
anonymousActionType?: string
store?: string
}
type Devtools = <
T,
Mps extends [StoreMutatorIdentifier, unknown][] = [],
Mcs extends [StoreMutatorIdentifier, unknown][] = [],
U = T,
>(
initializer: StateCreator<T, [...Mps, ['zustand/devtools', never]], Mcs, U>,
devtoolsOptions?: DevtoolsOptions,
) => StateCreator<T, Mps, [['zustand/devtools', never], ...Mcs]>
declare module '../vanilla' {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
interface StoreMutators<S, A> {
'zustand/devtools': WithDevtools<S>
}
}
type DevtoolsImpl = <T>(
storeInitializer: StateCreator<T, [], []>,
devtoolsOptions?: DevtoolsOptions,
) => StateCreator<T, [], []>
export type NamedSet<T> = WithDevtools<StoreApi<T>>['setState']
type Connection = ReturnType<
NonNullable<Window['__REDUX_DEVTOOLS_EXTENSION__']>['connect']
>
type ConnectionName = string | undefined
type StoreName = string
type StoreInformation = StoreApi<unknown>
type ConnectionInformation = {
connection: Connection
stores: Record<StoreName, StoreInformation>
}
const trackedConnections: Map<ConnectionName, ConnectionInformation> = new Map()
const getTrackedConnectionState = (
name: string | undefined,
): Record<string, any> => {
const api = trackedConnections.get(name)
if (!api) return {}
return Object.fromEntries(
Object.entries(api.stores).map(([key, api]) => [key, api.getState()]),
)
}
const extractConnectionInformation = (
store: string | undefined,
extensionConnector: NonNullable<
(typeof window)['__REDUX_DEVTOOLS_EXTENSION__']
>,
options: Omit<DevtoolsOptions, 'enabled' | 'anonymousActionType' | 'store'>,
) => {
if (store === undefined) {
return {
type: 'untracked' as const,
connection: extensionConnector.connect(options),
}
}
const existingConnection = trackedConnections.get(options.name)
if (existingConnection) {
return { type: 'tracked' as const, store, ...existingConnection }
}
const newConnection: ConnectionInformation = {
connection: extensionConnector.connect(options),
stores: {},
}
trackedConnections.set(options.name, newConnection)
return { type: 'tracked' as const, store, ...newConnection }
}
const removeStoreFromTrackedConnections = (
name: string | undefined,
store: string | undefined,
) => {
if (store === undefined) return
const connectionInfo = trackedConnections.get(name)
if (!connectionInfo) return
delete connectionInfo.stores[store]
if (Object.keys(connectionInfo.stores).length === 0) {
trackedConnections.delete(name)
}
}
const devtoolsImpl: DevtoolsImpl =
(fn, devtoolsOptions = {}) =>
(set, get, api) => {
const { enabled, anonymousActionType, store, ...options } = devtoolsOptions
type S = ReturnType<typeof fn> & {
[store: string]: ReturnType<typeof fn>
}
type PartialState = Partial<S> | ((s: S) => Partial<S>)
let extensionConnector:
| (typeof window)['__REDUX_DEVTOOLS_EXTENSION__']
| false
try {
extensionConnector =
(enabled ?? import.meta.env?.MODE !== 'production') &&
window.__REDUX_DEVTOOLS_EXTENSION__
} catch {
// ignored
}
if (!extensionConnector) {
return fn(set, get, api)
}
const { connection, ...connectionInformation } =
extractConnectionInformation(store, extensionConnector, options)
let isRecording = true
;(api.setState as any) = ((state, replace, nameOrAction: Action) => {
const r = set(state, replace as any)
if (!isRecording) return r
const action: { type: string } =
nameOrAction === undefined
? { type: anonymousActionType || 'anonymous' }
: typeof nameOrAction === 'string'
? { type: nameOrAction }
: nameOrAction
if (store === undefined) {
connection?.send(action, get())
return r
}
connection?.send(
{
...action,
type: `${store}/${action.type}`,
},
{
...getTrackedConnectionState(options.name),
[store]: api.getState(),
},
)
return r
}) as NamedSet<S>
;(api as StoreApi<S> & StoreDevtools<S>).devtools = {
cleanup: () => {
if (
connection &&
typeof (connection as any).unsubscribe === 'function'
) {
;(connection as any).unsubscribe()
}
removeStoreFromTrackedConnections(options.name, store)
},
}
const setStateFromDevtools: StoreApi<S>['setState'] = (...a) => {
const originalIsRecording = isRecording
isRecording = false
set(...(a as Parameters<typeof set>))
isRecording = originalIsRecording
}
const initialState = fn(api.setState, get, api)
if (connectionInformation.type === 'untracked') {
connection?.init(initialState)
} else {
connectionInformation.stores[connectionInformation.store] = api
connection?.init(
Object.fromEntries(
Object.entries(connectionInformation.stores).map(([key, store]) => [
key,
key === connectionInformation.store
? initialState
: store.getState(),
]),
),
)
}
if (
(api as any).dispatchFromDevtools &&
typeof (api as any).dispatch === 'function'
) {
let didWarnAboutReservedActionType = false
const originalDispatch = (api as any).dispatch
;(api as any).dispatch = (...args: any[]) => {
if (
import.meta.env?.MODE !== 'production' &&
args[0].type === '__setState' &&
!didWarnAboutReservedActionType
) {
console.warn(
'[zustand devtools middleware] "__setState" action type is reserved ' +
'to set state from the devtools. Avoid using it.',
)
didWarnAboutReservedActionType = true
}
;(originalDispatch as any)(...args)
}
}
;(
connection as unknown as {
// FIXME https://github.com/reduxjs/redux-devtools/issues/1097
subscribe: (
listener: (message: Message) => void,
) => (() => void) | undefined
}
).subscribe((message: any) => {
switch (message.type) {
case 'ACTION':
if (typeof message.payload !== 'string') {
console.error(
'[zustand devtools middleware] Unsupported action format',
)
return
}
return parseJsonThen<{ type: unknown; state?: PartialState }>(
message.payload,
(action) => {
if (action.type === '__setState') {
if (store === undefined) {
setStateFromDevtools(action.state as PartialState)
return
}
if (Object.keys(action.state as S).length !== 1) {
console.error(
`
[zustand devtools middleware] Unsupported __setState action format.
When using 'store' option in devtools(), the 'state' should have only one key, which is a value of 'store' that was passed in devtools(),
and value of this only key should be a state object. Example: { "type": "__setState", "state": { "abc123Store": { "foo": "bar" } } }
`,
)
}
const stateFromDevtools = (action.state as S)[store]
if (
stateFromDevtools === undefined ||
stateFromDevtools === null
) {
return
}
if (
JSON.stringify(api.getState()) !==
JSON.stringify(stateFromDevtools)
) {
setStateFromDevtools(stateFromDevtools)
}
return
}
if (!(api as any).dispatchFromDevtools) return
if (typeof (api as any).dispatch !== 'function') return
;(api as any).dispatch(action)
},
)
case 'DISPATCH':
switch (message.payload.type) {
case 'RESET':
setStateFromDevtools(initialState as S)
if (store === undefined) {
return connection?.init(api.getState())
}
return connection?.init(getTrackedConnectionState(options.name))
case 'COMMIT':
if (store === undefined) {
connection?.init(api.getState())
return
}
return connection?.init(getTrackedConnectionState(options.name))
case 'ROLLBACK':
return parseJsonThen<S>(message.state, (state) => {
if (store === undefined) {
setStateFromDevtools(state)
connection?.init(api.getState())
return
}
setStateFromDevtools(state[store] as S)
connection?.init(getTrackedConnectionState(options.name))
})
case 'JUMP_TO_STATE':
case 'JUMP_TO_ACTION':
return parseJsonThen<S>(message.state, (state) => {
if (store === undefined) {
setStateFromDevtools(state)
return
}
if (
JSON.stringify(api.getState()) !==
JSON.stringify(state[store])
) {
setStateFromDevtools(state[store] as S)
}
})
case 'IMPORT_STATE': {
const { nextLiftedState } = message.payload
const lastComputedState =
nextLiftedState.computedStates.slice(-1)[0]?.state
if (!lastComputedState) return
if (store === undefined) {
setStateFromDevtools(lastComputedState)
} else {
setStateFromDevtools(lastComputedState[store])
}
connection?.send(
null as any, // FIXME no-any
nextLiftedState,
)
return
}
case 'PAUSE_RECORDING':
return (isRecording = !isRecording)
}
return
}
})
return initialState
}
export const devtools = devtoolsImpl as unknown as Devtools
const parseJsonThen = <T>(stringified: string, fn: (parsed: T) => void) => {
let parsed: T | undefined
try {
parsed = JSON.parse(stringified)
} catch (e) {
console.error(
'[zustand devtools middleware] Could not parse the received json',
e,
)
}
if (parsed !== undefined) fn(parsed as T)
}