diff --git a/packages/vitest/src/node/hoistMocks.ts b/packages/vitest/src/node/hoistMocks.ts index 61caeecda..0abdb0f2a 100644 --- a/packages/vitest/src/node/hoistMocks.ts +++ b/packages/vitest/src/node/hoistMocks.ts @@ -1,5 +1,7 @@ import MagicString from 'magic-string' -import type { CallExpression, Identifier, ImportDeclaration, VariableDeclaration, Node as _Node } from 'estree' +import type { AwaitExpression, CallExpression, Identifier, ImportDeclaration, VariableDeclaration, Node as _Node } from 'estree' + +// TODO: should use findNodeBefore, but it's not typed import { findNodeAround } from 'acorn-walk' import type { PluginContext } from 'rollup' import { esmWalker } from '@vitest/utils/ast' @@ -211,8 +213,9 @@ export function hoistMocks(code: string, id: string, parse: PluginContext['parse hoistedNodes.push(declarationNode) } else { - // hoist "vi.hoisted(() => {})" - hoistedNodes.push(node) + const awaitedExpression = findNodeAround(ast, node.start, 'AwaitExpression')?.node as Positioned | undefined + // hoist "await vi.hoisted(async () => {})" or "vi.hoisted(() => {})" + hoistedNodes.push(awaitedExpression?.argument === node ? awaitedExpression : node) } } } diff --git a/test/core/test/hoisted-async-simple.test.ts b/test/core/test/hoisted-async-simple.test.ts index 712be2351..8f2141e1d 100644 --- a/test/core/test/hoisted-async-simple.test.ts +++ b/test/core/test/hoisted-async-simple.test.ts @@ -6,6 +6,8 @@ import { value } from '../src/rely-on-hoisted' const globalValue = await vi.hoisted(async () => { // @ts-expect-error not typed global globalThis.someGlobalValue = 'globalValue' + // @ts-expect-error not typed global + ;(globalThis._order ??= []).push(1) return 'globalValue' }) @@ -14,6 +16,26 @@ afterAll(() => { delete globalThis.someGlobalValue }) +// _order is set in the hoisted function before tests are collected +// @ts-expect-error not typed global +expect(globalThis._order).toEqual([1, 2, 3]) + it('imported value is equal to returned from hoisted', () => { expect(value).toBe(globalValue) }) + +it('hoists async "vi.hoisted", but leaves the wrapper alone', async () => { + expect.assertions(1) + await (async () => { + expect(1).toBe(1) + vi.hoisted(() => { + // @ts-expect-error not typed global + ;(globalThis._order ??= []).push(2) + }) + })() +}) + +await vi.hoisted(async () => { + // @ts-expect-error not typed global + ;(globalThis._order ??= []).push(3) +})