diff --git a/packages/spy/src/index.ts b/packages/spy/src/index.ts index e4a3e94c8..5ee596bec 100644 --- a/packages/spy/src/index.ts +++ b/packages/spy/src/index.ts @@ -154,7 +154,7 @@ export interface MockInstance { * const increment = vi.fn().mockImplementation(count => count + 1); * expect(increment(3)).toBe(4); */ - mockImplementation(fn: ((...args: TArgs) => TReturns) | (() => Promise)): this + mockImplementation(fn: ((...args: TArgs) => TReturns)): this /** * Accepts a function that will be used as a mock implementation during the next call. Can be chained so that multiple function calls produce different results. * @example @@ -162,7 +162,7 @@ export interface MockInstance { * expect(fn(3)).toBe(4); * expect(fn(3)).toBe(3); */ - mockImplementationOnce(fn: ((...args: TArgs) => TReturns) | (() => Promise)): this + mockImplementationOnce(fn: ((...args: TArgs) => TReturns)): this /** * Overrides the original mock implementation temporarily while the callback is being executed. * @example @@ -479,16 +479,16 @@ function enhanceSpy( stub.mockReturnValueOnce = (val: TReturns) => stub.mockImplementationOnce(() => val) stub.mockResolvedValue = (val: Awaited) => - stub.mockImplementation(() => Promise.resolve(val as TReturns)) + stub.mockImplementation(() => Promise.resolve(val as TReturns) as any) stub.mockResolvedValueOnce = (val: Awaited) => - stub.mockImplementationOnce(() => Promise.resolve(val as TReturns)) + stub.mockImplementationOnce(() => Promise.resolve(val as TReturns) as any) stub.mockRejectedValue = (val: unknown) => - stub.mockImplementation(() => Promise.reject(val)) + stub.mockImplementation(() => Promise.reject(val) as any) stub.mockRejectedValueOnce = (val: unknown) => - stub.mockImplementationOnce(() => Promise.reject(val)) + stub.mockImplementationOnce(() => Promise.reject(val) as any) Object.defineProperty(stub, 'mock', { get: () => mockContext, diff --git a/test/core/test/jest-mock.test.ts b/test/core/test/jest-mock.test.ts index e460a1a42..f6c2dd479 100644 --- a/test/core/test/jest-mock.test.ts +++ b/test/core/test/jest-mock.test.ts @@ -52,6 +52,27 @@ describe('jest mock compat layer', () => { expect(mock2.getMockImplementation()).toBeUndefined() }) + it('implementation types allow only function returned types', () => { + function fn() { + return 1 + } + + function asyncFn() { + return Promise.resolve(1) + } + + const mock1 = vi.fn(fn) + const mock2 = vi.fn(asyncFn) + + mock1.mockImplementation(() => 2) + // @ts-expect-error promise is not allowed + mock1.mockImplementation(() => Promise.resolve(2)) + + // @ts-expect-error non-promise is not allowed + mock2.mockImplementation(() => 2) + mock2.mockImplementation(() => Promise.resolve(2)) + }) + it('implementation sync fn', () => { const originalFn = function () { return 'original'