From feddc0c210c555fc58d9c9a7100545603f568cf5 Mon Sep 17 00:00:00 2001 From: Daishi Kato Date: Wed, 20 Aug 2025 08:08:35 +0900 Subject: [PATCH] feat(middleare/persist): return storage promise from setState (#3206) * feat(middleare/persist): return storage promise from setState * refactor types (technically breaking) * another breaking change in types * make public types not breaking --- src/middleware/persist.ts | 73 ++++++++++++++++++++++--------------- tests/persistAsync.test.tsx | 8 +++- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/middleware/persist.ts b/src/middleware/persist.ts index 357f3e62..5a2dfe5a 100644 --- a/src/middleware/persist.ts +++ b/src/middleware/persist.ts @@ -4,10 +4,10 @@ import type { StoreMutatorIdentifier, } from '../vanilla.ts' -export interface StateStorage { +export interface StateStorage { getItem: (name: string) => string | null | Promise - setItem: (name: string, value: string) => unknown | Promise - removeItem: (name: string) => unknown | Promise + setItem: (name: string, value: string) => R + removeItem: (name: string) => R } export type StorageValue = { @@ -15,12 +15,12 @@ export type StorageValue = { version?: number } -export interface PersistStorage { +export interface PersistStorage { getItem: ( name: string, ) => StorageValue | null | Promise | null> - setItem: (name: string, value: StorageValue) => unknown | Promise - removeItem: (name: string) => unknown | Promise + setItem: (name: string, value: StorageValue) => R + removeItem: (name: string) => R } type JsonStorageOptions = { @@ -28,18 +28,18 @@ type JsonStorageOptions = { replacer?: (key: string, value: unknown) => unknown } -export function createJSONStorage( - getStorage: () => StateStorage, +export function createJSONStorage( + getStorage: () => StateStorage, options?: JsonStorageOptions, -): PersistStorage | undefined { - let storage: StateStorage | undefined +): PersistStorage | undefined { + let storage: StateStorage | undefined try { storage = getStorage() } catch { // prevent error if the storage is not defined (e.g. when server side rendering a page) return } - const persistStorage: PersistStorage = { + const persistStorage: PersistStorage = { getItem: (name) => { const parse = (str: string | null) => { if (str === null) { @@ -60,7 +60,11 @@ export function createJSONStorage( return persistStorage } -export interface PersistOptions { +export interface PersistOptions< + S, + PersistedState = S, + PersistReturn = unknown, +> { /** Name of the storage (must be unique) */ name: string /** @@ -71,7 +75,7 @@ export interface PersistOptions { * * @default createJSONStorage(() => localStorage) */ - storage?: PersistStorage | undefined + storage?: PersistStorage | undefined /** * Filter the persisted value. * @@ -118,17 +122,28 @@ export interface PersistOptions { type PersistListener = (state: S) => void -type StorePersist = { - persist: { - setOptions: (options: Partial>) => void - clearStorage: () => void - rehydrate: () => Promise | void - hasHydrated: () => boolean - onHydrate: (fn: PersistListener) => () => void - onFinishHydration: (fn: PersistListener) => () => void - getOptions: () => Partial> +type StorePersist = S extends { + getState: () => infer T + setState: { + // capture both overloads of setState + (...args: infer Sa1): infer Sr1 + (...args: infer Sa2): infer Sr2 } } + ? { + setState(...args: Sa1): Sr1 | Pr + setState(...args: Sa2): Sr2 | Pr + persist: { + setOptions: (options: Partial>) => void + clearStorage: () => void + rehydrate: () => Promise | void + hasHydrated: () => boolean + onHydrate: (fn: PersistListener) => () => void + onFinishHydration: (fn: PersistListener) => () => void + getOptions: () => Partial> + } + } + : never type Thenable = { then( @@ -172,7 +187,7 @@ const toThenable = const persistImpl: PersistImpl = (config, baseOptions) => (set, get, api) => { type S = ReturnType let options = { - storage: createJSONStorage(() => localStorage), + storage: createJSONStorage(() => localStorage), partialize: (state: S) => state, version: 0, merge: (persistedState: unknown, currentState: S) => ({ @@ -202,7 +217,7 @@ const persistImpl: PersistImpl = (config, baseOptions) => (set, get, api) => { const setItem = () => { const state = options.partialize({ ...get() }) - return (storage as PersistStorage).setItem(options.name, { + return (storage as PersistStorage).setItem(options.name, { state, version: options.version, }) @@ -212,13 +227,13 @@ const persistImpl: PersistImpl = (config, baseOptions) => (set, get, api) => { api.setState = (state, replace) => { savedSetState(state, replace as any) - void setItem() + return setItem() } const configResult = config( (...args) => { set(...(args as Parameters)) - void setItem() + return setItem() }, get, api, @@ -307,7 +322,7 @@ const persistImpl: PersistImpl = (config, baseOptions) => (set, get, api) => { }) } - ;(api as StoreApi & StorePersist).persist = { + ;(api as StoreApi & StorePersist, S, unknown>).persist = { setOptions: (newOptions) => { options = { ...options, @@ -365,9 +380,7 @@ declare module '../vanilla' { type Write = Omit & U -type WithPersist = S extends { getState: () => infer T } - ? Write> - : never +type WithPersist = Write> type PersistImpl = ( storeInitializer: StateCreator, diff --git a/tests/persistAsync.test.tsx b/tests/persistAsync.test.tsx index 4e113389..37ded00b 100644 --- a/tests/persistAsync.test.tsx +++ b/tests/persistAsync.test.tsx @@ -165,7 +165,9 @@ describe('persist middleware with async configuration', () => { }) // Write something to the store - act(() => useBoundStore.setState({ count: 42 })) + act(() => { + useBoundStore.setState({ count: 42 }) + }) expect(await screen.findByText('count: 42')).toBeInTheDocument() expect(setItemSpy).toBeCalledWith( 'test-storage', @@ -788,7 +790,9 @@ describe('persist middleware with async configuration', () => { // Write something to the store const updatedMap = new Map(map).set('foo', 'bar') - act(() => useBoundStore.setState({ map: updatedMap })) + act(() => { + useBoundStore.setState({ map: updatedMap }) + }) expect(await screen.findByText('map-content: bar')).toBeInTheDocument() expect(setItemSpy).toBeCalledWith(