Fix Immer type inference for setState (#2696)

* fix(immer): tweak type inference to base `setState` type off of store `setState` instead of `getState`

* fix(immer): instead, infer type directly from StoreApi<T>["setState"]

* fix(immer): instead of using `StoreApi`, extract from A2 the non-functional component of state

* docs: add comment describing why it is not derived from `A1`

* test: add example middleware that modifies getState w/o setState

* fix: add assertion for inner `set` and `get` types

---------

Co-authored-by: Daishi Kato <dai-shi@users.noreply.github.com>
This commit is contained in:
Christian van der Loo 2024-08-26 21:10:47 -04:00 committed by GitHub
parent 42bbfcfb6b
commit d7345da7cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 5 deletions

View File

@ -32,10 +32,11 @@ type SkipTwo<T> = T extends { length: 0 }
? A
: never
type SetStateType<T extends unknown[]> = Exclude<T[0], (...args: any[]) => any>
type WithImmer<S> = Write<S, StoreImmer<S>>
type StoreImmer<S> = S extends {
getState: () => infer T
setState: infer SetState
}
? SetState extends {
@ -43,13 +44,21 @@ type StoreImmer<S> = S extends {
(...a: infer A2): infer Sr2
}
? {
// Ideally, we would want to infer the `nextStateOrUpdater` `T` type from the
// `A1` type, but this is infeasible since it is an intersection with
// a partial type.
setState(
nextStateOrUpdater: T | Partial<T> | ((state: Draft<T>) => void),
nextStateOrUpdater:
| SetStateType<A2>
| Partial<SetStateType<A2>>
| ((state: Draft<SetStateType<A2>>) => void),
shouldReplace?: false,
...a: SkipTwo<A1>
): Sr1
setState(
nextStateOrUpdater: T | ((state: Draft<T>) => void),
nextStateOrUpdater:
| SetStateType<A2>
| ((state: Draft<SetStateType<A2>>) => void),
shouldReplace: true,
...a: SkipTwo<A2>
): Sr2

View File

@ -1,9 +1,9 @@
/* eslint @typescript-eslint/no-unused-expressions: off */ // FIXME
/* eslint react-compiler/react-compiler: off */
import { describe, expect, it } from 'vitest'
import { describe, expect, expectTypeOf, it } from 'vitest'
import { create } from 'zustand'
import type { StoreApi } from 'zustand'
import type { StateCreator, StoreApi, StoreMutatorIdentifier } from 'zustand'
import {
combine,
devtools,
@ -19,6 +19,27 @@ type CounterState = {
inc: () => void
}
type ExampleStateCreator<T, A> = <
Mps extends [StoreMutatorIdentifier, unknown][] = [],
Mcs extends [StoreMutatorIdentifier, unknown][] = [],
U = T,
>(
f: StateCreator<T, [...Mps, ['org/example', A]], Mcs>,
) => StateCreator<T, Mps, [['org/example', A], ...Mcs], U & A>
type Write<T, U> = Omit<T, keyof U> & U
type StoreModifyAllButSetState<S, A> = S extends {
getState: () => infer T
}
? Omit<StoreApi<T & A>, 'setState'>
: never
declare module 'zustand/vanilla' {
interface StoreMutators<S, A> {
'org/example': Write<S, StoreModifyAllButSetState<S, A>>
}
}
describe('counter state spec (no middleware)', () => {
it('no middleware', () => {
const useBoundStore = create<CounterState>((set, get) => ({
@ -64,6 +85,39 @@ describe('counter state spec (single middleware)', () => {
immer(() => ({ count: 0 })),
)
expect(testSubtyping).toBeDefined()
const exampleMiddleware = ((initializer) =>
initializer) as ExampleStateCreator<CounterState, { additional: number }>
const testDerivedSetStateType = create<CounterState>()(
exampleMiddleware(
immer((set, get) => ({
count: 0,
inc: () =>
set((state) => {
state.count = get().count + 1
type OmitFn<T> = Exclude<T, (...args: any[]) => any>
expectTypeOf<
OmitFn<Parameters<typeof set>[0]>
>().not.toMatchTypeOf<{ additional: number }>()
expectTypeOf<ReturnType<typeof get>>().toMatchTypeOf<{
additional: number
}>()
}),
})),
),
)
expect(testDerivedSetStateType).toBeDefined()
// the type of the `getState` should include our new property
expectTypeOf(testDerivedSetStateType.getState()).toMatchTypeOf<{
additional: number
}>()
// the type of the `setState` should not include our new property
expectTypeOf<
Parameters<typeof testDerivedSetStateType.setState>[0]
>().not.toMatchTypeOf<{
additional: number
}>()
})
it('redux', () => {