import { Component as ClassComponent, ReactNode, useEffect, useLayoutEffect, useState, } from 'react' import { act, fireEvent, render, waitFor } from '@testing-library/react' import ReactDOM from 'react-dom' import create, { StoreApi } from 'zustand' const consoleError = console.error afterEach(() => { console.error = consoleError }) it('creates a store hook and api object', () => { let params const result = create((...args) => { params = args return { value: null } }) expect({ params, result }).toMatchInlineSnapshot(` Object { "params": Array [ [Function], [Function], Object { "destroy": [Function], "getState": [Function], "setState": [Function], "subscribe": [Function], }, ], "result": [Function], } `) }) type CounterState = { count: number inc: () => void } it('uses the store with no args', async () => { const useBoundStore = create((set) => ({ count: 0, inc: () => set((state) => ({ count: state.count + 1 })), })) function Counter() { const { count, inc } = useBoundStore() useEffect(inc, [inc]) return
count: {count}
} const { findByText } = render() await findByText('count: 1') }) it('uses the store with selectors', async () => { const useBoundStore = create((set) => ({ count: 0, inc: () => set((state) => ({ count: state.count + 1 })), })) function Counter() { const count = useBoundStore((s) => s.count) const inc = useBoundStore((s) => s.inc) useEffect(inc, [inc]) return
count: {count}
} const { findByText } = render() await findByText('count: 1') }) it('uses the store with a selector and equality checker', async () => { const useBoundStore = create(() => ({ item: { value: 0 } })) const { setState } = useBoundStore let renderCount = 0 function Component() { // Prevent re-render if new value === 1. const item = useBoundStore( (s) => s.item, (_, newItem) => newItem.value === 1 ) return (
renderCount: {++renderCount}, value: {item.value}
) } const { findByText } = render() await findByText('renderCount: 1, value: 0') // This will not cause a re-render. act(() => setState({ item: { value: 1 } })) await findByText('renderCount: 1, value: 0') // This will cause a re-render. act(() => setState({ item: { value: 2 } })) await findByText('renderCount: 2, value: 2') }) it('only re-renders if selected state has changed', async () => { const useBoundStore = create((set) => ({ count: 0, inc: () => set((state) => ({ count: state.count + 1 })), })) let counterRenderCount = 0 let controlRenderCount = 0 function Counter() { const count = useBoundStore((state) => state.count) counterRenderCount++ return
count: {count}
} function Control() { const inc = useBoundStore((state) => state.inc) controlRenderCount++ return } const { getByText, findByText } = render( <> ) fireEvent.click(getByText('button')) await findByText('count: 1') expect(counterRenderCount).toBe(2) expect(controlRenderCount).toBe(1) }) it('re-renders with useLayoutEffect', async () => { const useBoundStore = create(() => ({ state: false })) function Component() { const { state } = useBoundStore() useLayoutEffect(() => { useBoundStore.setState({ state: true }) }, []) return <>{`${state}`} } const container = document.createElement('div') ReactDOM.render(, container) await waitFor(() => { expect(container.innerHTML).toBe('true') }) ReactDOM.unmountComponentAtNode(container) }) it('can batch updates', async () => { const useBoundStore = create((set) => ({ count: 0, inc: () => set((state) => ({ count: state.count + 1 })), })) function Counter() { const { count, inc } = useBoundStore() useEffect(() => { ReactDOM.unstable_batchedUpdates(() => { inc() inc() }) }, [inc]) return
count: {count}
} const { findByText } = render() await findByText('count: 2') }) it('can update the selector', async () => { type State = { one: string; two: string } type Props = { selector: (state: State) => string } const useBoundStore = create(() => ({ one: 'one', two: 'two', })) function Component({ selector }: Props) { return
{useBoundStore(selector)}
} const { findByText, rerender } = render( s.one} />) await findByText('one') rerender( s.two} />) await findByText('two') }) it('can update the equality checker', async () => { type State = { value: number } type Props = { equalityFn: (a: State, b: State) => boolean } const useBoundStore = create(() => ({ value: 0 })) const { setState } = useBoundStore const selector = (s: State) => s let renderCount = 0 function Component({ equalityFn }: Props) { const { value } = useBoundStore(selector, equalityFn) return (
renderCount: {++renderCount}, value: {value}
) } // Set an equality checker that always returns false to always re-render. const { findByText, rerender } = render( false} /> ) // This will cause a re-render due to the equality checker. act(() => setState({ value: 0 })) await findByText('renderCount: 2, value: 0') // Set an equality checker that always returns true to never re-render. rerender( true} />) // This will NOT cause a re-render due to the equality checker. act(() => setState({ value: 1 })) await findByText('renderCount: 3, value: 0') }) it('can call useBoundStore with progressively more arguments', async () => { type State = { value: number } type Props = { selector?: (state: State) => number equalityFn?: (a: number, b: number) => boolean } const useBoundStore = create(() => ({ value: 0 })) const { setState } = useBoundStore let renderCount = 0 function Component({ selector, equalityFn }: Props) { const value = useBoundStore(selector as any, equalityFn) return (
renderCount: {++renderCount}, value: {JSON.stringify(value)}
) } // Render with no args. const { findByText, rerender } = render() await findByText('renderCount: 1, value: {"value":0}') // Render with selector. rerender( s.value} />) await findByText('renderCount: 2, value: 0') // Render with selector and equality checker. rerender( s.value} equalityFn={(oldV, newV) => oldV > newV} /> ) // Should not cause a re-render because new value is less than previous. act(() => setState({ value: -1 })) await findByText('renderCount: 3, value: 0') act(() => setState({ value: 1 })) await findByText('renderCount: 4, value: 1') }) it('can throw an error in selector', async () => { console.error = jest.fn() type State = { value: string | number } const initialState: State = { value: 'foo' } const useBoundStore = create(() => initialState) const { setState } = useBoundStore const selector = (s: State) => // @ts-expect-error This function is supposed to throw an error s.value.toUpperCase() class ErrorBoundary extends ClassComponent< { children?: ReactNode | undefined }, { hasError: boolean } > { constructor(props: { children?: ReactNode | undefined }) { super(props) this.state = { hasError: false } } static getDerivedStateFromError() { return { hasError: true } } render() { return this.state.hasError ?
errored
: this.props.children } } function Component() { useBoundStore(selector) return
no error
} const { findByText } = render( ) await findByText('no error') act(() => { setState({ value: 123 }) }) await findByText('errored') }) it('can throw an error in equality checker', async () => { console.error = jest.fn() type State = { value: string | number } const initialState: State = { value: 'foo' } const useBoundStore = create(() => initialState) const { setState } = useBoundStore const selector = (s: State) => s const equalityFn = (a: State, b: State) => // @ts-expect-error This function is supposed to throw an error a.value.trim() === b.value.trim() class ErrorBoundary extends ClassComponent< { children?: ReactNode | undefined }, { hasError: boolean } > { constructor(props: { children?: ReactNode | undefined }) { super(props) this.state = { hasError: false } } static getDerivedStateFromError() { return { hasError: true } } render() { return this.state.hasError ?
errored
: this.props.children } } function Component() { useBoundStore(selector, equalityFn) return
no error
} const { findByText } = render( ) await findByText('no error') act(() => { setState({ value: 123 }) }) await findByText('errored') }) it('can get the store', () => { type State = { value: number getState1: () => State getState2: () => State } const { getState } = create((_, get) => ({ value: 1, getState1: () => get(), getState2: (): State => getState(), })) expect(getState().getState1().value).toBe(1) expect(getState().getState2().value).toBe(1) }) it('can set the store', () => { type State = { value: number setState1: StoreApi['setState'] setState2: StoreApi['setState'] } const { setState, getState } = create((set) => ({ value: 1, setState1: (v) => set(v), setState2: (v) => setState(v), })) getState().setState1({ value: 2 }) expect(getState().value).toBe(2) getState().setState2({ value: 3 }) expect(getState().value).toBe(3) getState().setState1((s) => ({ value: ++s.value })) expect(getState().value).toBe(4) getState().setState2((s) => ({ value: ++s.value })) expect(getState().value).toBe(5) }) it('can set the store without merging', () => { const { setState, getState } = create<{ a: number } | { b: number }>( (_set) => ({ a: 1, }) ) // Should override the state instead of merging. setState({ b: 2 }, true) expect(getState()).toEqual({ b: 2 }) }) it('can destroy the store', () => { const { destroy, getState, setState, subscribe } = create(() => ({ value: 1, })) subscribe(() => { throw new Error('did not clear listener on destroy') }) destroy() setState({ value: 2 }) expect(getState().value).toEqual(2) }) it('only calls selectors when necessary', async () => { type State = { a: number; b: number } const useBoundStore = create(() => ({ a: 0, b: 0 })) const { setState } = useBoundStore let inlineSelectorCallCount = 0 let staticSelectorCallCount = 0 function staticSelector(s: State) { staticSelectorCallCount++ return s.a } function Component() { useBoundStore((s) => (inlineSelectorCallCount++, s.b)) useBoundStore(staticSelector) return ( <>
inline: {inlineSelectorCallCount}
static: {staticSelectorCallCount}
) } const { rerender, findByText } = render() await findByText('inline: 1') await findByText('static: 1') rerender() await findByText('inline: 2') await findByText('static: 1') act(() => setState({ a: 1, b: 1 })) await findByText('inline: 4') await findByText('static: 2') }) it('ensures parent components subscribe before children', async () => { type State = { children: { [key: string]: { text: string } } } type Props = { id: string } const useBoundStore = create(() => ({ children: { '1': { text: 'child 1' }, '2': { text: 'child 2' }, }, })) const api = useBoundStore function changeState() { api.setState({ children: { '3': { text: 'child 3' }, }, }) } function Child({ id }: Props) { const text = useBoundStore((s) => s.children[id]?.text) return
{text}
} function Parent() { const childStates = useBoundStore((s) => s.children) return ( <> {Object.keys(childStates).map((id) => ( ))} ) } const { getByText, findByText } = render() fireEvent.click(getByText('change state')) await findByText('child 3') }) // https://github.com/pmndrs/zustand/issues/84 it('ensures the correct subscriber is removed on unmount', async () => { const useBoundStore = create(() => ({ count: 0 })) const api = useBoundStore function increment() { api.setState(({ count }) => ({ count: count + 1 })) } function Count() { const c = useBoundStore((s) => s.count) return
count: {c}
} function CountWithInitialIncrement() { useLayoutEffect(increment, []) return } function Component() { const [Counter, setCounter] = useState(() => CountWithInitialIncrement) useLayoutEffect(() => { setCounter(() => Count) }, []) return ( <> ) } const { findAllByText } = render() expect((await findAllByText('count: 1')).length).toBe(2) act(increment) expect((await findAllByText('count: 2')).length).toBe(2) }) // https://github.com/pmndrs/zustand/issues/86 it('ensures a subscriber is not mistakenly overwritten', async () => { const useBoundStore = create(() => ({ count: 0 })) const { setState } = useBoundStore function Count1() { const c = useBoundStore((s) => s.count) return
count1: {c}
} function Count2() { const c = useBoundStore((s) => s.count) return
count2: {c}
} // Add 1st subscriber. const { findAllByText, rerender } = render() // Replace 1st subscriber with another. rerender() // Add 2 additional subscribers. rerender( <> ) // Call all subscribers act(() => setState({ count: 1 })) expect((await findAllByText('count1: 1')).length).toBe(2) expect((await findAllByText('count2: 1')).length).toBe(1) })