chore: always require curly braces (#5885)

Co-authored-by: Ari Perkkiö <ari.perkkio@gmail.com>
This commit is contained in:
Vladimir 2024-06-16 18:10:10 +02:00 committed by GitHub
parent 66e648ff88
commit 471cf97b0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
386 changed files with 13900 additions and 6917 deletions

View File

@ -29,17 +29,21 @@ export const contributors = (contributorNames).reduce<Contributor[]>((acc, name)
function createLinks(tm: CoreTeam): CoreTeam { function createLinks(tm: CoreTeam): CoreTeam {
tm.links = [{ icon: 'github', link: `https://github.com/${tm.github}` }] tm.links = [{ icon: 'github', link: `https://github.com/${tm.github}` }]
if (tm.mastodon) if (tm.mastodon) {
tm.links.push({ icon: 'mastodon', link: tm.mastodon }) tm.links.push({ icon: 'mastodon', link: tm.mastodon })
}
if (tm.discord) if (tm.discord) {
tm.links.push({ icon: 'discord', link: tm.discord }) tm.links.push({ icon: 'discord', link: tm.discord })
}
if (tm.youtube) if (tm.youtube) {
tm.links.push({ icon: 'youtube', link: `https://www.youtube.com/@${tm.youtube}` }) tm.links.push({ icon: 'youtube', link: `https://www.youtube.com/@${tm.youtube}` })
}
if (tm.twitter) if (tm.twitter) {
tm.links.push({ icon: 'x', link: `https://twitter.com/${tm.twitter}` }) tm.links.push({ icon: 'x', link: `https://twitter.com/${tm.twitter}` })
}
return tm return tm
} }

View File

@ -19,18 +19,22 @@ function resolveOptions(options: CLIOptions<any>, parentName?: string) {
} }
function resolveCommand(name: string, config: CLIOption<any> | null): any { function resolveCommand(name: string, config: CLIOption<any> | null): any {
if (!config) if (!config) {
return null return null
}
let title = '`' let title = '`'
if (config.shorthand) if (config.shorthand) {
title += `-${config.shorthand}, ` title += `-${config.shorthand}, `
}
title += `--${config.alias || name}` title += `--${config.alias || name}`
if ('argument' in config) if ('argument' in config) {
title += ` ${config.argument}` title += ` ${config.argument}`
}
title += '`' title += '`'
if ('subcommands' in config && config.subcommands) if ('subcommands' in config && config.subcommands) {
return resolveOptions(config.subcommands, name) return resolveOptions(config.subcommands, name)
}
return { return {
title, title,

View File

@ -8,8 +8,9 @@ const dirAvatars = resolve(docsDir, 'public/user-avatars/')
const dirSponsors = resolve(docsDir, 'public/sponsors/') const dirSponsors = resolve(docsDir, 'public/sponsors/')
async function download(url: string, fileName: string) { async function download(url: string, fileName: string) {
if (existsSync(fileName)) if (existsSync(fileName)) {
return return
}
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log('downloading', fileName) console.log('downloading', fileName)
try { try {
@ -20,15 +21,17 @@ async function download(url: string, fileName: string) {
} }
async function fetchAvatars() { async function fetchAvatars() {
if (!existsSync(dirAvatars)) if (!existsSync(dirAvatars)) {
await fsp.mkdir(dirAvatars, { recursive: true }) await fsp.mkdir(dirAvatars, { recursive: true })
}
await Promise.all([...teamEmeritiMembers, ...teamMembers].map(c => c.github).map(name => download(`https://github.com/${name}.png?size=100`, join(dirAvatars, `${name}.png`)))) await Promise.all([...teamEmeritiMembers, ...teamMembers].map(c => c.github).map(name => download(`https://github.com/${name}.png?size=100`, join(dirAvatars, `${name}.png`))))
} }
async function fetchSponsors() { async function fetchSponsors() {
if (!existsSync(dirSponsors)) if (!existsSync(dirSponsors)) {
await fsp.mkdir(dirSponsors, { recursive: true }) await fsp.mkdir(dirSponsors, { recursive: true })
}
await Promise.all([ await Promise.all([
download('https://cdn.jsdelivr.net/gh/antfu/static/sponsors.svg', join(dirSponsors, 'antfu.svg')), download('https://cdn.jsdelivr.net/gh/antfu/static/sponsors.svg', join(dirSponsors, 'antfu.svg')),
download('https://cdn.jsdelivr.net/gh/patak-dev/static/sponsors.svg', join(dirSponsors, 'patak-dev.svg')), download('https://cdn.jsdelivr.net/gh/patak-dev/static/sponsors.svg', join(dirSponsors, 'patak-dev.svg')),

View File

@ -10,8 +10,9 @@ import HomePage from '../components/HomePage.vue'
import Version from '../components/Version.vue' import Version from '../components/Version.vue'
import '@shikijs/vitepress-twoslash/style.css' import '@shikijs/vitepress-twoslash/style.css'
if (inBrowser) if (inBrowser) {
import('./pwa') import('./pwa')
}
export default { export default {
...Theme, ...Theme,

View File

@ -263,8 +263,9 @@ This matcher extracts assert value (e.g., `assert v is number`), so you can perf
import { expectTypeOf } from 'vitest' import { expectTypeOf } from 'vitest'
function assertNumber(v: any): asserts v is number { function assertNumber(v: any): asserts v is number {
if (typeof v !== 'number') if (typeof v !== 'number') {
throw new TypeError('Nope !') throw new TypeError('Nope !')
}
} }
expectTypeOf(assertNumber).asserts.toBeNumber() expectTypeOf(assertNumber).asserts.toBeNumber()

View File

@ -192,8 +192,9 @@ Opposite of `toBeDefined`, `toBeUndefined` asserts that the value _is_ equal to
import { expect, test } from 'vitest' import { expect, test } from 'vitest'
function getApplesFromStock(stock: string) { function getApplesFromStock(stock: string) {
if (stock === 'Bill') if (stock === 'Bill') {
return 13 return 13
}
} }
test('mary doesn\'t have a stock', () => { test('mary doesn\'t have a stock', () => {
@ -214,8 +215,9 @@ import { Stocks } from './stocks.js'
const stocks = new Stocks() const stocks = new Stocks()
stocks.sync('Bill') stocks.sync('Bill')
if (stocks.getInfo('Bill')) if (stocks.getInfo('Bill')) {
stocks.sell('apples', 'Bill') stocks.sell('apples', 'Bill')
}
``` ```
So if you want to test that `stocks.getInfo` will be truthy, you could write: So if you want to test that `stocks.getInfo` will be truthy, you could write:
@ -247,8 +249,9 @@ import { Stocks } from './stocks.js'
const stocks = new Stocks() const stocks = new Stocks()
stocks.sync('Bill') stocks.sync('Bill')
if (!stocks.stockFailed('Bill')) if (!stocks.stockFailed('Bill')) {
stocks.sell('apples', 'Bill') stocks.sell('apples', 'Bill')
}
``` ```
So if you want to test that `stocks.stockFailed` will be falsy, you could write: So if you want to test that `stocks.stockFailed` will be falsy, you could write:
@ -660,8 +663,9 @@ For example, if we want to test that `getFruitStock('pineapples')` throws, we co
import { expect, test } from 'vitest' import { expect, test } from 'vitest'
function getFruitStock(type: string) { function getFruitStock(type: string) {
if (type === 'pineapples') if (type === 'pineapples') {
throw new Error('Pineapples are not in stock') throw new Error('Pineapples are not in stock')
}
// Do some other stuff // Do some other stuff
} }
@ -1203,8 +1207,9 @@ For example, if you have a function that fails when you call it, you may use thi
import { expect, test } from 'vitest' import { expect, test } from 'vitest'
async function buyApples(id) { async function buyApples(id) {
if (!id) if (!id) {
throw new Error('no id') throw new Error('no id')
}
} }
test('buyApples throws an error when no id provided', async () => { test('buyApples throws an error when no id provided', async () => {
@ -1301,8 +1306,9 @@ For example, if we want to test that `build()` throws due to receiving directori
import { expect, test } from 'vitest' import { expect, test } from 'vitest'
async function build(dir) { async function build(dir) {
if (dir.includes('no-src')) if (dir.includes('no-src')) {
throw new Error(`${dir}/src does not exist`) throw new Error(`${dir}/src does not exist`)
}
} }
const errorDirs = [ const errorDirs = [
@ -1594,14 +1600,15 @@ function areAnagramsEqual(a: unknown, b: unknown): boolean | undefined {
const isAAnagramComparator = isAnagramComparator(a) const isAAnagramComparator = isAnagramComparator(a)
const isBAnagramComparator = isAnagramComparator(b) const isBAnagramComparator = isAnagramComparator(b)
if (isAAnagramComparator && isBAnagramComparator) if (isAAnagramComparator && isBAnagramComparator) {
return a.equals(b) return a.equals(b)
}
else if (isAAnagramComparator === isBAnagramComparator) else if (isAAnagramComparator === isBAnagramComparator) {
return undefined return undefined
}
else else {
return false return false
}
} }
expect.addEqualityTesters([areAnagramsEqual]) expect.addEqualityTesters([areAnagramsEqual])

View File

@ -618,8 +618,9 @@ You can also nest describe blocks if you have a hierarchy of tests or benchmarks
import { describe, expect, test } from 'vitest' import { describe, expect, test } from 'vitest'
function numberToCurrency(value: number | string) { function numberToCurrency(value: number | string) {
if (typeof value !== 'number') if (typeof value !== 'number') {
throw new Error('Value must be a number') throw new TypeError('Value must be a number')
}
return value.toFixed(2).toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',') return value.toFixed(2).toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',')
} }

View File

@ -673,8 +673,9 @@ let i = 0
setTimeout(() => console.log(++i)) setTimeout(() => console.log(++i))
const interval = setInterval(() => { const interval = setInterval(() => {
console.log(++i) console.log(++i)
if (i === 3) if (i === 3) {
clearInterval(interval) clearInterval(interval)
}
}, 50) }, 50)
vi.runAllTimers() vi.runAllTimers()
@ -818,8 +819,9 @@ test('Server started successfully', async () => {
await vi.waitFor( await vi.waitFor(
() => { () => {
if (!server.isReady) if (!server.isReady) {
throw new Error('Server not started') throw new Error('Server not started')
}
console.log('Server started') console.log('Server started')
}, },

View File

@ -2131,12 +2131,14 @@ export default defineConfig({
test: { test: {
onStackTrace(error: Error, { file }: ParsedStack): boolean | void { onStackTrace(error: Error, { file }: ParsedStack): boolean | void {
// If we've encountered a ReferenceError, show the whole stack. // If we've encountered a ReferenceError, show the whole stack.
if (error.name === 'ReferenceError') if (error.name === 'ReferenceError') {
return return
}
// Reject all frames from third party libraries. // Reject all frames from third party libraries.
if (file.includes('node_modules')) if (file.includes('node_modules')) {
return false return false
}
}, },
}, },
}) })

View File

@ -27,8 +27,9 @@ function purchase() {
const currentHour = new Date().getHours() const currentHour = new Date().getHours()
const [open, close] = businessHours const [open, close] = businessHours
if (currentHour > open && currentHour < close) if (currentHour > open && currentHour < close) {
return { message: 'Success' } return { message: 'Success' }
}
return { message: 'Error' } return { message: 'Error' }
} }
@ -194,8 +195,9 @@ export default {
{ {
name: 'virtual-modules', name: 'virtual-modules',
resolveId(id) { resolveId(id) {
if (id === '$app/forms') if (id === '$app/forms') {
return 'virtual:$app/forms' return 'virtual:$app/forms'
}
} }
} }
] ]

View File

@ -45,6 +45,7 @@ export default antfu(
'no-undef': 'off', 'no-undef': 'off',
'ts/no-invalid-this': 'off', 'ts/no-invalid-this': 'off',
'eslint-comments/no-unlimited-disable': 'off', 'eslint-comments/no-unlimited-disable': 'off',
'curly': ['error', 'all'],
// TODO: migrate and turn it back on // TODO: migrate and turn it back on
'ts/ban-types': 'off', 'ts/ban-types': 'off',

View File

@ -1,8 +1,9 @@
import { dev } from '$app/environment' import { dev } from '$app/environment'
export function add(a: number, b: number) { export function add(a: number, b: number) {
if (dev) if (dev) {
console.warn(`Adding ${a} and ${b}`) console.warn(`Adding ${a} and ${b}`)
}
return a + b return a + b
} }

View File

@ -55,6 +55,7 @@
"lint-staged": "^15.2.5", "lint-staged": "^15.2.5",
"magic-string": "^0.30.10", "magic-string": "^0.30.10",
"pathe": "^1.1.2", "pathe": "^1.1.2",
"prettier": "^2.8.8",
"rimraf": "^5.0.7", "rimraf": "^5.0.7",
"rollup": "^4.18.0", "rollup": "^4.18.0",
"rollup-plugin-dts": "^6.1.1", "rollup-plugin-dts": "^6.1.1",

View File

@ -19,12 +19,24 @@ export interface FsOptions {
flag?: string | number flag?: string | number
} }
export interface TypePayload { type: string } export interface TypePayload {
export interface PressPayload { press: string } type: string
export interface DownPayload { down: string } }
export interface UpPayload { up: string } export interface PressPayload {
press: string
}
export interface DownPayload {
down: string
}
export interface UpPayload {
up: string
}
export type SendKeysPayload = TypePayload | PressPayload | DownPayload | UpPayload export type SendKeysPayload =
| TypePayload
| PressPayload
| DownPayload
| UpPayload
export interface ScreenshotOptions { export interface ScreenshotOptions {
element?: Element element?: Element
@ -35,8 +47,15 @@ export interface ScreenshotOptions {
} }
export interface BrowserCommands { export interface BrowserCommands {
readFile: (path: string, options?: BufferEncoding | FsOptions) => Promise<string> readFile: (
writeFile: (path: string, content: string, options?: BufferEncoding | FsOptions & { mode?: number | string }) => Promise<void> path: string,
options?: BufferEncoding | FsOptions
) => Promise<string>
writeFile: (
path: string,
content: string,
options?: BufferEncoding | (FsOptions & { mode?: number | string })
) => Promise<void>
removeFile: (path: string) => Promise<void> removeFile: (path: string) => Promise<void>
sendKeys: (payload: SendKeysPayload) => Promise<void> sendKeys: (payload: SendKeysPayload) => Promise<void>
} }

View File

@ -9,7 +9,10 @@ import type {
declare module 'vitest/node' { declare module 'vitest/node' {
interface BrowserProviderOptions { interface BrowserProviderOptions {
launch?: LaunchOptions launch?: LaunchOptions
context?: Omit<BrowserContextOptions, 'ignoreHTTPSErrors' | 'serviceWorkers'> context?: Omit<
BrowserContextOptions,
'ignoreHTTPSErrors' | 'serviceWorkers'
>
} }
export interface BrowserCommandContext { export interface BrowserCommandContext {

View File

@ -33,35 +33,36 @@ const input = {
providers: './src/node/providers/index.ts', providers: './src/node/providers/index.ts',
} }
export default () => defineConfig([ export default () =>
{ defineConfig([
input, {
output: { input,
dir: 'dist', output: {
format: 'esm', dir: 'dist',
format: 'esm',
},
external,
plugins,
}, },
external, {
plugins, input: './src/client/context.ts',
}, output: {
{ file: 'dist/context.js',
input: './src/client/context.ts', format: 'esm',
output: { },
file: 'dist/context.js', plugins: [
format: 'esm', esbuild({
target: 'node18',
}),
],
}, },
plugins: [ {
esbuild({ input: input.index,
target: 'node18', output: {
}), file: 'dist/index.d.ts',
], format: 'esm',
}, },
{ external,
input: input.index, plugins: [dts()],
output: {
file: 'dist/index.d.ts',
format: 'esm',
}, },
external, ])
plugins: [dts()],
},
])

View File

@ -78,13 +78,20 @@ export type IframeChannelEvent =
| IframeChannelIncomingEvent | IframeChannelIncomingEvent
| IframeChannelOutgoingEvent | IframeChannelOutgoingEvent
export const channel = new BroadcastChannel(`vitest:${getBrowserState().contextId}`) export const channel = new BroadcastChannel(
`vitest:${getBrowserState().contextId}`,
)
export function waitForChannel(event: IframeChannelOutgoingEvent['type']) { export function waitForChannel(event: IframeChannelOutgoingEvent['type']) {
return new Promise<void>((resolve) => { return new Promise<void>((resolve) => {
channel.addEventListener('message', (e) => { channel.addEventListener(
if (e.data?.type === event) 'message',
resolve() (e) => {
}, { once: true }) if (e.data?.type === event) {
resolve()
}
},
{ once: true },
)
}) })
} }

View File

@ -9,9 +9,10 @@ const PAGE_TYPE = getBrowserState().type
export const PORT = import.meta.hot ? '51204' : location.port export const PORT = import.meta.hot ? '51204' : location.port
export const HOST = [location.hostname, PORT].filter(Boolean).join(':') export const HOST = [location.hostname, PORT].filter(Boolean).join(':')
export const SESSION_ID = PAGE_TYPE === 'orchestrator' export const SESSION_ID
? getBrowserState().contextId = PAGE_TYPE === 'orchestrator'
: crypto.randomUUID() ? getBrowserState().contextId
: crypto.randomUUID()
export const ENTRY_URL = `${ export const ENTRY_URL = `${
location.protocol === 'https:' ? 'wss:' : 'ws:' location.protocol === 'https:' ? 'wss:' : 'ws:'
}//${HOST}/__vitest_browser_api__?type=${PAGE_TYPE}&sessionId=${SESSION_ID}` }//${HOST}/__vitest_browser_api__?type=${PAGE_TYPE}&sessionId=${SESSION_ID}`
@ -27,7 +28,10 @@ export interface VitestBrowserClient {
waitForConnection: () => Promise<void> waitForConnection: () => Promise<void>
} }
export type BrowserRPC = BirpcReturn<WebSocketBrowserHandlers, WebSocketBrowserEvents> export type BrowserRPC = BirpcReturn<
WebSocketBrowserHandlers,
WebSocketBrowserEvents
>
function createClient() { function createClient() {
const autoReconnect = true const autoReconnect = true
@ -44,46 +48,55 @@ function createClient() {
let onMessage: Function let onMessage: Function
ctx.rpc = createBirpc<WebSocketBrowserHandlers, WebSocketBrowserEvents>({ ctx.rpc = createBirpc<WebSocketBrowserHandlers, WebSocketBrowserEvents>(
onCancel: setCancel, {
async startMocking(id: string) { onCancel: setCancel,
// @ts-expect-error not typed global async startMocking(id: string) {
if (typeof __vitest_mocker__ === 'undefined') // @ts-expect-error not typed global
throw new Error(`Cannot mock modules in the orchestrator process`) if (typeof __vitest_mocker__ === 'undefined') {
// @ts-expect-error not typed global throw new TypeError(
const mocker = __vitest_mocker__ as VitestBrowserClientMocker `Cannot mock modules in the orchestrator process`,
const exports = await mocker.resolve(id) )
return Object.keys(exports)
},
async createTesters(files: string[]) {
if (PAGE_TYPE !== 'orchestrator')
return
getBrowserState().createTesters?.(files)
},
}, {
post: msg => ctx.ws.send(msg),
on: fn => (onMessage = fn),
serialize: e => stringify(e, (_, v) => {
if (v instanceof Error) {
return {
name: v.name,
message: v.message,
stack: v.stack,
} }
} // @ts-expect-error not typed global
return v const mocker = __vitest_mocker__ as VitestBrowserClientMocker
}), const exports = await mocker.resolve(id)
deserialize: parse, return Object.keys(exports)
onTimeoutError(functionName) { },
throw new Error(`[vitest-browser]: Timeout calling "${functionName}"`) async createTesters(files: string[]) {
if (PAGE_TYPE !== 'orchestrator') {
return
}
getBrowserState().createTesters?.(files)
},
}, },
}) {
post: msg => ctx.ws.send(msg),
on: fn => (onMessage = fn),
serialize: e =>
stringify(e, (_, v) => {
if (v instanceof Error) {
return {
name: v.name,
message: v.message,
stack: v.stack,
}
}
return v
}),
deserialize: parse,
onTimeoutError(functionName) {
throw new Error(`[vitest-browser]: Timeout calling "${functionName}"`)
},
},
)
let openPromise: Promise<void> let openPromise: Promise<void>
function reconnect(reset = false) { function reconnect(reset = false) {
if (reset) if (reset) {
tries = reconnectTries tries = reconnectTries
}
ctx.ws = new WebSocket(ENTRY_URL) ctx.ws = new WebSocket(ENTRY_URL)
registerWS() registerWS()
} }
@ -91,10 +104,15 @@ function createClient() {
function registerWS() { function registerWS() {
openPromise = new Promise((resolve, reject) => { openPromise = new Promise((resolve, reject) => {
const timeout = setTimeout(() => { const timeout = setTimeout(() => {
reject(new Error(`Cannot connect to the server in ${connectTimeout / 1000} seconds`)) reject(
new Error(
`Cannot connect to the server in ${connectTimeout / 1000} seconds`,
),
)
}, connectTimeout)?.unref?.() }, connectTimeout)?.unref?.()
if (ctx.ws.OPEN === ctx.ws.readyState) if (ctx.ws.OPEN === ctx.ws.readyState) {
resolve() resolve()
}
// still have a listener even if it's already open to update tries // still have a listener even if it's already open to update tries
ctx.ws.addEventListener('open', () => { ctx.ws.addEventListener('open', () => {
tries = reconnectTries tries = reconnectTries
@ -107,8 +125,9 @@ function createClient() {
}) })
ctx.ws.addEventListener('close', () => { ctx.ws.addEventListener('close', () => {
tries -= 1 tries -= 1
if (autoReconnect && tries > 0) if (autoReconnect && tries > 0) {
setTimeout(reconnect, reconnectInterval) setTimeout(reconnect, reconnectInterval)
}
}) })
} }

View File

@ -1,32 +1,48 @@
import type { Task, WorkerGlobalState } from 'vitest' import type { Task, WorkerGlobalState } from 'vitest'
import type { BrowserPage, UserEvent, UserEventClickOptions } from '../../context' import type {
BrowserPage,
UserEvent,
UserEventClickOptions,
} from '../../context'
import type { BrowserRPC } from './client' import type { BrowserRPC } from './client'
import type { BrowserRunnerState } from './utils' import type { BrowserRunnerState } from './utils'
// this file should not import anything directly, only types // this file should not import anything directly, only types
function convertElementToXPath(element: Element) { function convertElementToXPath(element: Element) {
if (!element || !(element instanceof Element)) if (!element || !(element instanceof Element)) {
throw new Error(`Expected DOM element to be an instance of Element, received ${typeof element}`) throw new Error(
`Expected DOM element to be an instance of Element, received ${typeof element}`,
)
}
return getPathTo(element) return getPathTo(element)
} }
function getPathTo(element: Element): string { function getPathTo(element: Element): string {
if (element.id !== '') if (element.id !== '') {
return `id("${element.id}")` return `id("${element.id}")`
}
if (!element.parentNode || element === document.documentElement) if (!element.parentNode || element === document.documentElement) {
return element.tagName return element.tagName
}
let ix = 0 let ix = 0
const siblings = element.parentNode.childNodes const siblings = element.parentNode.childNodes
for (let i = 0; i < siblings.length; i++) { for (let i = 0; i < siblings.length; i++) {
const sibling = siblings[i] const sibling = siblings[i]
if (sibling === element) if (sibling === element) {
return `${getPathTo(element.parentNode as Element)}/${element.tagName}[${ix + 1}]` return `${getPathTo(element.parentNode as Element)}/${element.tagName}[${
if (sibling.nodeType === 1 && (sibling as Element).tagName === element.tagName) ix + 1
}]`
}
if (
sibling.nodeType === 1
&& (sibling as Element).tagName === element.tagName
) {
ix++ ix++
}
} }
return 'invalid xpath' return 'invalid xpath'
} }
@ -35,7 +51,9 @@ function getPathTo(element: Element): string {
const state = (): WorkerGlobalState => __vitest_worker__ const state = (): WorkerGlobalState => __vitest_worker__
// @ts-expect-error not typed global // @ts-expect-error not typed global
const runner = (): BrowserRunnerState => __vitest_browser_runner__ const runner = (): BrowserRunnerState => __vitest_browser_runner__
const filepath = () => state().filepath || state().current?.file?.filepath || undefined function filepath() {
return state().filepath || state().current?.file?.filepath || undefined
}
const rpc = () => state().rpc as any as BrowserRPC const rpc = () => state().rpc as any as BrowserRPC
const contextId = runner().contextId const contextId = runner().contextId
const channel = new BroadcastChannel(`vitest:${contextId}`) const channel = new BroadcastChannel(`vitest:${contextId}`)
@ -74,8 +92,9 @@ export const page: BrowserPage = {
}, },
async screenshot(options = {}) { async screenshot(options = {}) {
const currentTest = state().current const currentTest = state().current
if (!currentTest) if (!currentTest) {
throw new Error('Cannot take a screenshot outside of a test.') throw new Error('Cannot take a screenshot outside of a test.')
}
if (currentTest.concurrent) { if (currentTest.concurrent) {
throw new Error( throw new Error(
@ -92,16 +111,15 @@ export const page: BrowserPage = {
screenshotIds[repeatCount] ??= {} screenshotIds[repeatCount] ??= {}
screenshotIds[repeatCount][taskName] = number + 1 screenshotIds[repeatCount][taskName] = number + 1
const name = options.path || `${taskName.replace(/[^a-z0-9]/g, '-')}-${number}.png` const name
= options.path || `${taskName.replace(/[^a-z0-9]/g, '-')}-${number}.png`
return triggerCommand( return triggerCommand('__vitest_screenshot', name, {
'__vitest_screenshot', ...options,
name, element: options.element
{ ? convertElementToXPath(options.element)
...options, : undefined,
element: options.element ? convertElementToXPath(options.element) : undefined, })
},
)
}, },
} }

View File

@ -8,7 +8,9 @@ If needed, mock the \`${name}\` call manually like:
\`\`\` \`\`\`
import { expect, vi } from "vitest" import { expect, vi } from "vitest"
vi.spyOn(window, "${name}")${defaultValue ? `.mockReturnValue(${JSON.stringify(defaultValue)})` : ''} vi.spyOn(window, "${name}")${
defaultValue ? `.mockReturnValue(${JSON.stringify(defaultValue)})` : ''
}
${name}(${formatedParams}) ${name}(${formatedParams})
expect(${name}).toHaveBeenCalledWith(${formatedParams}) expect(${name}).toHaveBeenCalledWith(${formatedParams})
\`\`\``) \`\`\``)

View File

@ -4,23 +4,46 @@ import { getConfig, importId } from './utils'
const { Date, console } = globalThis const { Date, console } = globalThis
export async function setupConsoleLogSpy() { export async function setupConsoleLogSpy() {
const { stringify, format, inspect } = await importId('vitest/utils') as typeof import('vitest/utils') const { stringify, format, inspect } = (await importId(
const { log, info, error, dir, dirxml, trace, time, timeEnd, timeLog, warn, debug, count, countReset } = console 'vitest/utils',
)) as typeof import('vitest/utils')
const {
log,
info,
error,
dir,
dirxml,
trace,
time,
timeEnd,
timeLog,
warn,
debug,
count,
countReset,
} = console
const formatInput = (input: unknown) => { const formatInput = (input: unknown) => {
if (input instanceof Node) if (input instanceof Node) {
return stringify(input) return stringify(input)
}
return format(input) return format(input)
} }
const processLog = (args: unknown[]) => args.map(formatInput).join(' ') const processLog = (args: unknown[]) => args.map(formatInput).join(' ')
const sendLog = (type: 'stdout' | 'stderr', content: string, disableStack?: boolean) => { const sendLog = (
if (content.startsWith('[vite]')) type: 'stdout' | 'stderr',
content: string,
disableStack?: boolean,
) => {
if (content.startsWith('[vite]')) {
return return
}
const unknownTestId = '__vitest__unknown_test__' const unknownTestId = '__vitest__unknown_test__'
// @ts-expect-error untyped global // @ts-expect-error untyped global
const taskId = globalThis.__vitest_worker__?.current?.id ?? unknownTestId const taskId = globalThis.__vitest_worker__?.current?.id ?? unknownTestId
const origin = getConfig().printConsoleTrace && !disableStack const origin
? new Error('STACK_TRACE').stack?.split('\n').slice(1).join('\n') = getConfig().printConsoleTrace && !disableStack
: undefined ? new Error('STACK_TRACE').stack?.split('\n').slice(1).join('\n')
: undefined
rpc().sendLog({ rpc().sendLog({
origin, origin,
content, content,
@ -31,14 +54,18 @@ export async function setupConsoleLogSpy() {
size: content.length, size: content.length,
}) })
} }
const stdout = (base: (...args: unknown[]) => void) => (...args: unknown[]) => { const stdout
sendLog('stdout', processLog(args)) = (base: (...args: unknown[]) => void) =>
return base(...args) (...args: unknown[]) => {
} sendLog('stdout', processLog(args))
const stderr = (base: (...args: unknown[]) => void) => (...args: unknown[]) => { return base(...args)
sendLog('stderr', processLog(args)) }
return base(...args) const stderr
} = (base: (...args: unknown[]) => void) =>
(...args: unknown[]) => {
sendLog('stderr', processLog(args))
return base(...args)
}
console.log = stdout(log) console.log = stdout(log)
console.debug = stdout(debug) console.debug = stdout(debug)
console.info = stdout(info) console.info = stdout(info)
@ -77,10 +104,12 @@ export async function setupConsoleLogSpy() {
console.timeLog = (label = 'default') => { console.timeLog = (label = 'default') => {
timeLog(label) timeLog(label)
if (!(label in timeLabels)) if (!(label in timeLabels)) {
sendLog('stderr', `Timer "${label}" does not exist`) sendLog('stderr', `Timer "${label}" does not exist`)
else }
else {
sendLog('stdout', `${label}: ${timeLabels[label]} ms`) sendLog('stdout', `${label}: ${timeLabels[label]} ms`)
}
} }
console.timeEnd = (label = 'default') => { console.timeEnd = (label = 'default') => {

View File

@ -21,25 +21,30 @@ export class VitestBrowserClientMocker {
private spyModule!: SpyModule private spyModule!: SpyModule
setupWorker() { setupWorker() {
channel.addEventListener('message', async (e: MessageEvent<IframeChannelOutgoingEvent>) => { channel.addEventListener(
if (e.data.type === 'mock-factory:request') { 'message',
try { async (e: MessageEvent<IframeChannelOutgoingEvent>) => {
const module = await this.resolve(e.data.id) if (e.data.type === 'mock-factory:request') {
const exports = Object.keys(module) try {
channel.postMessage({ const module = await this.resolve(e.data.id)
type: 'mock-factory:response', const exports = Object.keys(module)
exports, channel.postMessage({
}) type: 'mock-factory:response',
exports,
})
}
catch (err: any) {
const { processError } = (await importId(
'vitest/browser',
)) as typeof import('vitest/browser')
channel.postMessage({
type: 'mock-factory:error',
error: processError(err),
})
}
} }
catch (err: any) { },
const { processError } = await importId('vitest/browser') as typeof import('vitest/browser') )
channel.postMessage({
type: 'mock-factory:error',
error: processError(err),
})
}
}
})
} }
public setSpyModule(mod: SpyModule) { public setSpyModule(mod: SpyModule) {
@ -48,8 +53,11 @@ export class VitestBrowserClientMocker {
public async importActual(id: string, importer: string) { public async importActual(id: string, importer: string) {
const resolved = await rpc().resolveId(id, importer) const resolved = await rpc().resolveId(id, importer)
if (resolved == null) if (resolved == null) {
throw new Error(`[vitest] Cannot resolve ${id} imported from ${importer}`) throw new Error(
`[vitest] Cannot resolve ${id} imported from ${importer}`,
)
}
const ext = extname(resolved.id) const ext = extname(resolved.id)
const url = new URL(`/@id/${resolved.id}`, location.href) const url = new URL(`/@id/${resolved.id}`, location.href)
const query = `_vitest_original&ext.${ext}` const query = `_vitest_original&ext.${ext}`
@ -61,14 +69,20 @@ export class VitestBrowserClientMocker {
public async importMock(rawId: string, importer: string) { public async importMock(rawId: string, importer: string) {
await this.prepare() await this.prepare()
const { resolvedId, type, mockPath } = await rpc().resolveMock(rawId, importer, false) const { resolvedId, type, mockPath } = await rpc().resolveMock(
rawId,
importer,
false,
)
const factoryReturn = this.get(resolvedId) const factoryReturn = this.get(resolvedId)
if (factoryReturn) if (factoryReturn) {
return factoryReturn return factoryReturn
}
if (this.factories[resolvedId]) if (this.factories[resolvedId]) {
return await this.resolve(resolvedId) return await this.resolve(resolvedId)
}
if (type === 'redirect') { if (type === 'redirect') {
const url = new URL(`/@id/${mockPath}`, location.href) const url = new URL(`/@id/${mockPath}`, location.href)
@ -90,8 +104,9 @@ export class VitestBrowserClientMocker {
public async invalidate() { public async invalidate() {
const ids = Array.from(this.ids) const ids = Array.from(this.ids)
if (!ids.length) if (!ids.length) {
return return
}
await rpc().invalidate(ids) await rpc().invalidate(ids)
channel.postMessage({ type: 'mock:invalidate' }) channel.postMessage({ type: 'mock:invalidate' })
this.ids.clear() this.ids.clear()
@ -102,8 +117,9 @@ export class VitestBrowserClientMocker {
public async resolve(id: string) { public async resolve(id: string) {
const factory = this.factories[id] const factory = this.factories[id]
if (!factory) if (!factory) {
throw new Error(`Cannot resolve ${id} mock: no factory provided`) throw new Error(`Cannot resolve ${id} mock: no factory provided`)
}
try { try {
this.mockObjects[id] = await factory() this.mockObjects[id] = await factory()
return this.mockObjects[id] return this.mockObjects[id]
@ -120,13 +136,15 @@ export class VitestBrowserClientMocker {
} }
public queueMock(id: string, importer: string, factory?: () => any) { public queueMock(id: string, importer: string, factory?: () => any) {
const promise = rpc().resolveMock(id, importer, !!factory) const promise = rpc()
.resolveMock(id, importer, !!factory)
.then(async ({ mockPath, resolvedId }) => { .then(async ({ mockPath, resolvedId }) => {
this.ids.add(resolvedId) this.ids.add(resolvedId)
const urlPaths = resolveMockPaths(resolvedId) const urlPaths = resolveMockPaths(resolvedId)
const resolvedMock = typeof mockPath === 'string' const resolvedMock
? new URL(resolvedMockedPath(mockPath), location.href).toString() = typeof mockPath === 'string'
: mockPath ? new URL(resolvedMockedPath(mockPath), location.href).toString()
: mockPath
urlPaths.forEach((url) => { urlPaths.forEach((url) => {
this.mocks[url] = resolvedMock this.mocks[url] = resolvedMock
this.factories[url] = factory! this.factories[url] = factory!
@ -137,17 +155,20 @@ export class VitestBrowserClientMocker {
mock: resolvedMock, mock: resolvedMock,
}) })
await waitForChannel('mock:done') await waitForChannel('mock:done')
}).finally(() => { })
.finally(() => {
this.queue.delete(promise) this.queue.delete(promise)
}) })
this.queue.add(promise) this.queue.add(promise)
} }
public queueUnmock(id: string, importer: string) { public queueUnmock(id: string, importer: string) {
const promise = rpc().resolveId(id, importer) const promise = rpc()
.resolveId(id, importer)
.then(async (resolved) => { .then(async (resolved) => {
if (!resolved) if (!resolved) {
return return
}
this.ids.delete(resolved.id) this.ids.delete(resolved.id)
const urlPaths = resolveMockPaths(resolved.id) const urlPaths = resolveMockPaths(resolved.id)
urlPaths.forEach((url) => { urlPaths.forEach((url) => {
@ -168,15 +189,17 @@ export class VitestBrowserClientMocker {
} }
public async prepare() { public async prepare() {
if (!this.queue.size) if (!this.queue.size) {
return return
await Promise.all([ }
...this.queue.values(), await Promise.all([...this.queue.values()])
])
} }
// TODO: move this logic into a util(?) // TODO: move this logic into a util(?)
public mockObject(object: Record<Key, any>, mockExports: Record<Key, any> = {}) { public mockObject(
object: Record<Key, any>,
mockExports: Record<Key, any> = {},
) {
const finalizers = new Array<() => void>() const finalizers = new Array<() => void>()
const refs = new RefTracker() const refs = new RefTracker()
@ -190,10 +213,16 @@ export class VitestBrowserClientMocker {
} }
} }
const mockPropertiesOf = (container: Record<Key, any>, newContainer: Record<Key, any>) => { const mockPropertiesOf = (
container: Record<Key, any>,
newContainer: Record<Key, any>,
) => {
const containerType = /* #__PURE__ */ getType(container) const containerType = /* #__PURE__ */ getType(container)
const isModule = containerType === 'Module' || !!container.__esModule const isModule = containerType === 'Module' || !!container.__esModule
for (const { key: property, descriptor } of getAllMockableProperties(container, isModule)) { for (const { key: property, descriptor } of getAllMockableProperties(
container,
isModule,
)) {
// Modules define their exports as getters. We want to process those. // Modules define their exports as getters. We want to process those.
if (!isModule && descriptor.get) { if (!isModule && descriptor.get) {
try { try {
@ -206,8 +235,9 @@ export class VitestBrowserClientMocker {
} }
// Skip special read-only props, we don't want to mess with those. // Skip special read-only props, we don't want to mess with those.
if (isSpecialProp(property, containerType)) if (isSpecialProp(property, containerType)) {
continue continue
}
const value = container[property] const value = container[property]
@ -215,7 +245,9 @@ export class VitestBrowserClientMocker {
// recursion in circular objects. // recursion in circular objects.
const refId = refs.getId(value) const refId = refs.getId(value)
if (refId !== undefined) { if (refId !== undefined) {
finalizers.push(() => define(newContainer, property, refs.getMockedValue(refId))) finalizers.push(() =>
define(newContainer, property, refs.getMockedValue(refId)),
)
continue continue
} }
@ -226,38 +258,54 @@ export class VitestBrowserClientMocker {
continue continue
} }
const isFunction = type.includes('Function') && typeof value === 'function' const isFunction
if ((!isFunction || value.__isMockFunction) && type !== 'Object' && type !== 'Module') { = type.includes('Function') && typeof value === 'function'
if (
(!isFunction || value.__isMockFunction)
&& type !== 'Object'
&& type !== 'Module'
) {
define(newContainer, property, value) define(newContainer, property, value)
continue continue
} }
// Sometimes this assignment fails for some unknown reason. If it does, // Sometimes this assignment fails for some unknown reason. If it does,
// just move along. // just move along.
if (!define(newContainer, property, isFunction ? value : {})) if (!define(newContainer, property, isFunction ? value : {})) {
continue continue
}
if (isFunction) { if (isFunction) {
const spyModule = this.spyModule const spyModule = this.spyModule
if (!spyModule) if (!spyModule) {
throw new Error('[vitest] `spyModule` is not defined. This is Vitest error. Please open a new issue with reproduction.') throw new Error(
'[vitest] `spyModule` is not defined. This is Vitest error. Please open a new issue with reproduction.',
)
}
function mockFunction(this: any) { function mockFunction(this: any) {
// detect constructor call and mock each instance's methods // detect constructor call and mock each instance's methods
// so that mock states between prototype/instances don't affect each other // so that mock states between prototype/instances don't affect each other
// (jest reference https://github.com/jestjs/jest/blob/2c3d2409879952157433de215ae0eee5188a4384/packages/jest-mock/src/index.ts#L678-L691) // (jest reference https://github.com/jestjs/jest/blob/2c3d2409879952157433de215ae0eee5188a4384/packages/jest-mock/src/index.ts#L678-L691)
if (this instanceof newContainer[property]) { if (this instanceof newContainer[property]) {
for (const { key, descriptor } of getAllMockableProperties(this, false)) { for (const { key, descriptor } of getAllMockableProperties(
this,
false,
)) {
// skip getter since it's not mocked on prototype as well // skip getter since it's not mocked on prototype as well
if (descriptor.get) if (descriptor.get) {
continue continue
}
const value = this[key] const value = this[key]
const type = /* #__PURE__ */ getType(value) const type = /* #__PURE__ */ getType(value)
const isFunction = type.includes('Function') && typeof value === 'function' const isFunction
= type.includes('Function') && typeof value === 'function'
if (isFunction) { if (isFunction) {
// mock and delegate calls to original prototype method, which should be also mocked already // mock and delegate calls to original prototype method, which should be also mocked already
const original = this[key] const original = this[key]
const mock = spyModule.spyOn(this, key as string).mockImplementation(original) const mock = spyModule
.spyOn(this, key as string)
.mockImplementation(original)
mock.mockRestore = () => { mock.mockRestore = () => {
mock.mockReset() mock.mockReset()
mock.mockImplementation(original) mock.mockImplementation(original)
@ -267,7 +315,9 @@ export class VitestBrowserClientMocker {
} }
} }
} }
const mock = spyModule.spyOn(newContainer, property).mockImplementation(mockFunction) const mock = spyModule
.spyOn(newContainer, property)
.mockImplementation(mockFunction)
mock.mockRestore = () => { mock.mockRestore = () => {
mock.mockReset() mock.mockReset()
mock.mockImplementation(mockFunction) mock.mockImplementation(mockFunction)
@ -286,17 +336,20 @@ export class VitestBrowserClientMocker {
mockPropertiesOf(object, mockedObject) mockPropertiesOf(object, mockedObject)
// Plug together refs // Plug together refs
for (const finalizer of finalizers) for (const finalizer of finalizers) {
finalizer() finalizer()
}
return mockedObject return mockedObject
} }
} }
function isSpecialProp(prop: Key, parentType: string) { function isSpecialProp(prop: Key, parentType: string) {
return parentType.includes('Function') return (
parentType.includes('Function')
&& typeof prop === 'string' && typeof prop === 'string'
&& ['arguments', 'callee', 'caller', 'length', 'name'].includes(prop) && ['arguments', 'callee', 'caller', 'length', 'name'].includes(prop)
)
} }
class RefTracker { class RefTracker {
@ -322,39 +375,56 @@ class RefTracker {
type Key = string | symbol type Key = string | symbol
export function getAllMockableProperties(obj: any, isModule: boolean) { export function getAllMockableProperties(obj: any, isModule: boolean) {
const allProps = new Map<string | symbol, { key: string | symbol; descriptor: PropertyDescriptor }>() const allProps = new Map<
string | symbol,
{ key: string | symbol; descriptor: PropertyDescriptor }
>()
let curr = obj let curr = obj
do { do {
// we don't need properties from these // we don't need properties from these
if (curr === Object.prototype || curr === Function.prototype || curr === RegExp.prototype) if (
curr === Object.prototype
|| curr === Function.prototype
|| curr === RegExp.prototype
) {
break break
}
collectOwnProperties(curr, (key) => { collectOwnProperties(curr, (key) => {
const descriptor = Object.getOwnPropertyDescriptor(curr, key) const descriptor = Object.getOwnPropertyDescriptor(curr, key)
if (descriptor) if (descriptor) {
allProps.set(key, { key, descriptor }) allProps.set(key, { key, descriptor })
}
}) })
// eslint-disable-next-line no-cond-assign // eslint-disable-next-line no-cond-assign
} while (curr = Object.getPrototypeOf(curr)) } while ((curr = Object.getPrototypeOf(curr)))
// default is not specified in ownKeys, if module is interoped // default is not specified in ownKeys, if module is interoped
if (isModule && !allProps.has('default') && 'default' in obj) { if (isModule && !allProps.has('default') && 'default' in obj) {
const descriptor = Object.getOwnPropertyDescriptor(obj, 'default') const descriptor = Object.getOwnPropertyDescriptor(obj, 'default')
if (descriptor) if (descriptor) {
allProps.set('default', { key: 'default', descriptor }) allProps.set('default', { key: 'default', descriptor })
}
} }
return Array.from(allProps.values()) return Array.from(allProps.values())
} }
function collectOwnProperties(obj: any, collector: Set<string | symbol> | ((key: string | symbol) => void)) { function collectOwnProperties(
const collect = typeof collector === 'function' ? collector : (key: string | symbol) => collector.add(key) obj: any,
collector: Set<string | symbol> | ((key: string | symbol) => void),
) {
const collect
= typeof collector === 'function'
? collector
: (key: string | symbol) => collector.add(key)
Object.getOwnPropertyNames(obj).forEach(collect) Object.getOwnPropertyNames(obj).forEach(collect)
Object.getOwnPropertySymbols(obj).forEach(collect) Object.getOwnPropertySymbols(obj).forEach(collect)
} }
function resolvedMockedPath(path: string) { function resolvedMockedPath(path: string) {
const config = getBrowserState().viteConfig const config = getBrowserState().viteConfig
if (path.startsWith(config.root)) if (path.startsWith(config.root)) {
return path.slice(config.root.length) return path.slice(config.root.length)
}
return path return path
} }
@ -365,11 +435,13 @@ function resolveMockPaths(path: string) {
const paths = [path, join('/@fs/', path)] const paths = [path, join('/@fs/', path)]
// URL can be /file/path.js, but path is resolved to /file/path // URL can be /file/path.js, but path is resolved to /file/path
if (path.startsWith(config.root)) if (path.startsWith(config.root)) {
paths.push(path.slice(config.root.length)) paths.push(path.slice(config.root.length))
}
if (path.startsWith(fsRoot)) if (path.startsWith(fsRoot)) {
paths.push(path.slice(fsRoot.length)) paths.push(path.slice(fsRoot.length))
}
return paths return paths
} }

View File

@ -1,6 +1,11 @@
import { http } from 'msw/core/http' import { http } from 'msw/core/http'
import { setupWorker } from 'msw/browser' import { setupWorker } from 'msw/browser'
import type { IframeChannelEvent, IframeMockEvent, IframeMockingDoneEvent, IframeUnmockEvent } from './channel' import type {
IframeChannelEvent,
IframeMockEvent,
IframeMockingDoneEvent,
IframeUnmockEvent,
} from './channel'
import { channel } from './channel' import { channel } from './channel'
import { client } from './client' import { client } from './client'
@ -10,8 +15,9 @@ export function createModuleMocker() {
const worker = setupWorker( const worker = setupWorker(
http.get(/.+/, async ({ request }) => { http.get(/.+/, async ({ request }) => {
const path = removeTimestamp(request.url.slice(location.origin.length)) const path = removeTimestamp(request.url.slice(location.origin.length))
if (!mocks.has(path)) if (!mocks.has(path)) {
return passthrough() return passthrough()
}
const mock = mocks.get(path) const mock = mocks.get(path)
@ -20,11 +26,14 @@ export function createModuleMocker() {
// TODO: check how the error looks // TODO: check how the error looks
const exports = await getFactoryExports(path) const exports = await getFactoryExports(path)
const module = `const module = __vitest_mocker__.get('${path}');` const module = `const module = __vitest_mocker__.get('${path}');`
const keys = exports.map((name) => { const keys = exports
if (name === 'default') .map((name) => {
return `export default module['default'];` if (name === 'default') {
return `export const ${name} = module['${name}'];` return `export default module['default'];`
}).join('\n') }
return `export const ${name} = module['${name}'];`
})
.join('\n')
const text = `${module}\n${keys}` const text = `${module}\n${keys}`
return new Response(text, { return new Response(text, {
headers: { headers: {
@ -33,8 +42,9 @@ export function createModuleMocker() {
}) })
} }
if (typeof mock === 'string') if (typeof mock === 'string') {
return Response.redirect(mock) return Response.redirect(mock)
}
const content = await client.rpc.automock(path) const content = await client.rpc.automock(path)
return new Response(content, { return new Response(content, {
@ -49,19 +59,23 @@ export function createModuleMocker() {
let startPromise: undefined | Promise<unknown> let startPromise: undefined | Promise<unknown>
async function init() { async function init() {
if (started) if (started) {
return return
if (startPromise) }
if (startPromise) {
return startPromise return startPromise
startPromise = worker.start({ }
serviceWorker: { startPromise = worker
url: '/__virtual_vitest__:mocker-worker.js', .start({
}, serviceWorker: {
quiet: true, url: '/__virtual_vitest__:mocker-worker.js',
}).finally(() => { },
started = true quiet: true,
startPromise = undefined })
}) .finally(() => {
started = true
startPromise = undefined
})
await startPromise await startPromise
} }
@ -88,16 +102,19 @@ function getFactoryExports(id: string) {
id, id,
}) })
return new Promise<string[]>((resolve, reject) => { return new Promise<string[]>((resolve, reject) => {
channel.addEventListener('message', function onMessage(e: MessageEvent<IframeChannelEvent>) { channel.addEventListener(
if (e.data.type === 'mock-factory:response') { 'message',
resolve(e.data.exports) function onMessage(e: MessageEvent<IframeChannelEvent>) {
channel.removeEventListener('message', onMessage) if (e.data.type === 'mock-factory:response') {
} resolve(e.data.exports)
if (e.data.type === 'mock-factory:error') { channel.removeEventListener('message', onMessage)
reject(e.data.error) }
channel.removeEventListener('message', onMessage) if (e.data.type === 'mock-factory:error') {
} reject(e.data.error)
}) channel.removeEventListener('message', onMessage)
}
},
)
}) })
} }

View File

@ -25,8 +25,9 @@ getBrowserState().createTesters = async (files) => {
function debug(...args: unknown[]) { function debug(...args: unknown[]) {
const debug = getConfig().env.VITEST_BROWSER_DEBUG const debug = getConfig().env.VITEST_BROWSER_DEBUG
if (debug && debug !== 'false') if (debug && debug !== 'false') {
client.rpc.debug(...args.map(String)) client.rpc.debug(...args.map(String))
}
} }
function createIframe(container: HTMLDivElement, file: string) { function createIframe(container: HTMLDivElement, file: string) {
@ -37,7 +38,12 @@ function createIframe(container: HTMLDivElement, file: string) {
const iframe = document.createElement('iframe') const iframe = document.createElement('iframe')
iframe.setAttribute('loading', 'eager') iframe.setAttribute('loading', 'eager')
iframe.setAttribute('src', `${url.pathname}__vitest_test__/__test__/${getBrowserState().contextId}/${encodeURIComponent(file)}`) iframe.setAttribute(
'src',
`${url.pathname}__vitest_test__/__test__/${
getBrowserState().contextId
}/${encodeURIComponent(file)}`,
)
iframe.setAttribute('data-vitest', 'true') iframe.setAttribute('data-vitest', 'true')
iframe.style.display = 'block' iframe.style.display = 'block'
@ -84,89 +90,106 @@ client.ws.addEventListener('open', async () => {
const mocker = createModuleMocker() const mocker = createModuleMocker()
channel.addEventListener('message', async (e: MessageEvent<IframeChannelIncomingEvent>): Promise<void> => { channel.addEventListener(
debug('channel event', JSON.stringify(e.data)) 'message',
switch (e.data.type) { async (e: MessageEvent<IframeChannelIncomingEvent>): Promise<void> => {
case 'viewport': { debug('channel event', JSON.stringify(e.data))
const { width, height, id } = e.data switch (e.data.type) {
const iframe = iframes.get(id) case 'viewport': {
if (!iframe) { const { width, height, id } = e.data
const error = new Error(`Cannot find iframe with id ${id}`) const iframe = iframes.get(id)
channel.postMessage({ type: 'viewport:fail', id, error: error.message }) if (!iframe) {
await client.rpc.onUnhandledError({ const error = new Error(`Cannot find iframe with id ${id}`)
name: 'Teardown Error', channel.postMessage({
message: error.message, type: 'viewport:fail',
}, 'Teardown Error') id,
return error: error.message,
} })
await setIframeViewport(iframe, width, height) await client.rpc.onUnhandledError(
channel.postMessage({ type: 'viewport:done', id }) {
break name: 'Teardown Error',
} message: error.message,
case 'done': { },
const filenames = e.data.filenames 'Teardown Error',
filenames.forEach(filename => runningFiles.delete(filename)) )
return
if (!runningFiles.size) {
const ui = getUiAPI()
// in isolated mode we don't change UI because it will slow down tests,
// so we only select it when the run is done
if (ui && filenames.length > 1) {
const id = generateFileId(filenames[filenames.length - 1])
ui.setCurrentFileId(id)
} }
await done() await setIframeViewport(iframe, width, height)
channel.postMessage({ type: 'viewport:done', id })
break
} }
else { case 'done': {
// keep the last iframe const filenames = e.data.filenames
const iframeId = e.data.id filenames.forEach(filename => runningFiles.delete(filename))
iframes.get(iframeId)?.remove()
iframes.delete(iframeId)
}
break
}
// error happened at the top level, this should never happen in user code, but it can trigger during development
case 'error': {
const iframeId = e.data.id
iframes.delete(iframeId)
await client.rpc.onUnhandledError(e.data.error, e.data.errorType)
if (iframeId === ID_ALL)
runningFiles.clear()
else
runningFiles.delete(iframeId)
if (!runningFiles.size)
await done()
break
}
case 'mock:invalidate':
mocker.invalidate()
break
case 'unmock':
await mocker.unmock(e.data)
break
case 'mock':
await mocker.mock(e.data)
break
case 'mock-factory:error':
case 'mock-factory:response':
// handled manually
break
default: {
e.data satisfies never
await client.rpc.onUnhandledError({ if (!runningFiles.size) {
name: 'Unexpected Event', const ui = getUiAPI()
message: `Unexpected event: ${(e.data as any).type}`, // in isolated mode we don't change UI because it will slow down tests,
}, 'Unexpected Event') // so we only select it when the run is done
await done() if (ui && filenames.length > 1) {
const id = generateFileId(filenames[filenames.length - 1])
ui.setCurrentFileId(id)
}
await done()
}
else {
// keep the last iframe
const iframeId = e.data.id
iframes.get(iframeId)?.remove()
iframes.delete(iframeId)
}
break
}
// error happened at the top level, this should never happen in user code, but it can trigger during development
case 'error': {
const iframeId = e.data.id
iframes.delete(iframeId)
await client.rpc.onUnhandledError(e.data.error, e.data.errorType)
if (iframeId === ID_ALL) {
runningFiles.clear()
}
else {
runningFiles.delete(iframeId)
}
if (!runningFiles.size) {
await done()
}
break
}
case 'mock:invalidate':
mocker.invalidate()
break
case 'unmock':
await mocker.unmock(e.data)
break
case 'mock':
await mocker.mock(e.data)
break
case 'mock-factory:error':
case 'mock-factory:response':
// handled manually
break
default: {
e.data satisfies never
await client.rpc.onUnhandledError(
{
name: 'Unexpected Event',
message: `Unexpected event: ${(e.data as any).type}`,
},
'Unexpected Event',
)
await done()
}
} }
} },
}) )
// if page was refreshed, there will be no test files // if page was refreshed, there will be no test files
// createTesters will be called again when tests are running in the UI // createTesters will be called again when tests are running in the UI
if (testFiles.length) if (testFiles.length) {
await createTesters(testFiles) await createTesters(testFiles)
}
}) })
async function createTesters(testFiles: string[]) { async function createTesters(testFiles: string[]) {
@ -186,10 +209,7 @@ async function createTesters(testFiles: string[]) {
iframes.clear() iframes.clear()
if (config.isolate === false) { if (config.isolate === false) {
const iframe = createIframe( const iframe = createIframe(container, ID_ALL)
container,
ID_ALL,
)
await setIframeViewport(iframe, width, height) await setIframeViewport(iframe, width, height)
} }
@ -197,21 +217,21 @@ async function createTesters(testFiles: string[]) {
// otherwise, we need to wait for each iframe to finish before creating the next one // otherwise, we need to wait for each iframe to finish before creating the next one
// this is the most stable way to run tests in the browser // this is the most stable way to run tests in the browser
for (const file of testFiles) { for (const file of testFiles) {
const iframe = createIframe( const iframe = createIframe(container, file)
container,
file,
)
await setIframeViewport(iframe, width, height) await setIframeViewport(iframe, width, height)
await new Promise<void>((resolve) => { await new Promise<void>((resolve) => {
channel.addEventListener('message', function handler(e: MessageEvent<IframeChannelEvent>) { channel.addEventListener(
// done and error can only be triggered by the previous iframe 'message',
if (e.data.type === 'done' || e.data.type === 'error') { function handler(e: MessageEvent<IframeChannelEvent>) {
channel.removeEventListener('message', handler) // done and error can only be triggered by the previous iframe
resolve() if (e.data.type === 'done' || e.data.type === 'error') {
} channel.removeEventListener('message', handler)
}) resolve()
}
},
)
}) })
} }
} }
@ -224,7 +244,11 @@ function generateFileId(file: string) {
return generateHash(`${path}${project}`) return generateHash(`${path}${project}`)
} }
async function setIframeViewport(iframe: HTMLIFrameElement, width: number, height: number) { async function setIframeViewport(
iframe: HTMLIFrameElement,
width: number,
height: number,
) {
const ui = getUiAPI() const ui = getUiAPI()
if (ui) { if (ui) {
await ui.setIframeViewport(width, height) await ui.setIframeViewport(width, height)

View File

@ -1,18 +1,18 @@
const moduleCache = new Map() const moduleCache = new Map();
function wrapModule(module) { function wrapModule(module) {
if (typeof module === 'function') { if (typeof module === "function") {
const promise = new Promise((resolve, reject) => { const promise = new Promise((resolve, reject) => {
if (typeof __vitest_mocker__ === 'undefined') if (typeof __vitest_mocker__ === "undefined")
return module().then(resolve, reject) return module().then(resolve, reject);
__vitest_mocker__.prepare().finally(() => { __vitest_mocker__.prepare().finally(() => {
module().then(resolve, reject) module().then(resolve, reject);
}) });
}) });
moduleCache.set(promise, { promise, evaluated: false }) moduleCache.set(promise, { promise, evaluated: false });
return promise.finally(() => moduleCache.delete(promise)) return promise.finally(() => moduleCache.delete(promise));
} }
return module return module;
} }
window.__vitest_browser_runner__ = { window.__vitest_browser_runner__ = {
@ -23,25 +23,24 @@ window.__vitest_browser_runner__ = {
files: { __VITEST_FILES__ }, files: { __VITEST_FILES__ },
type: { __VITEST_TYPE__ }, type: { __VITEST_TYPE__ },
contextId: { __VITEST_CONTEXT_ID__ }, contextId: { __VITEST_CONTEXT_ID__ },
} };
const config = __vitest_browser_runner__.config const config = __vitest_browser_runner__.config;
if (config.testNamePattern) if (config.testNamePattern)
config.testNamePattern = parseRegexp(config.testNamePattern) config.testNamePattern = parseRegexp(config.testNamePattern);
function parseRegexp(input) { function parseRegexp(input) {
// Parse input // Parse input
const m = input.match(/(\/?)(.+)\1([a-z]*)/i) const m = input.match(/(\/?)(.+)\1([a-z]*)/i);
// match nothing // match nothing
if (!m) if (!m) return /$^/;
return /$^/
// Invalid flags // Invalid flags
if (m[3] && !/^(?!.*?(.).*?\1)[gmixXsuUAJ]+$/.test(m[3])) if (m[3] && !/^(?!.*?(.).*?\1)[gmixXsuUAJ]+$/.test(m[3]))
return RegExp(input) return RegExp(input);
// Create the regular expression // Create the regular expression
return new RegExp(m[2], m[3]) return new RegExp(m[2], m[3]);
} }

View File

@ -1,13 +1,12 @@
import type { import type { getSafeTimers } from '@vitest/utils'
getSafeTimers,
} from '@vitest/utils'
import { importId } from './utils' import { importId } from './utils'
import type { VitestBrowserClient } from './client' import type { VitestBrowserClient } from './client'
const { get } = Reflect const { get } = Reflect
function withSafeTimers(getTimers: typeof getSafeTimers, fn: () => void) { function withSafeTimers(getTimers: typeof getSafeTimers, fn: () => void) {
const { setTimeout, clearTimeout, setImmediate, clearImmediate } = getTimers() const { setTimeout, clearTimeout, setImmediate, clearImmediate }
= getTimers()
const currentSetTimeout = globalThis.setTimeout const currentSetTimeout = globalThis.setTimeout
const currentClearTimeout = globalThis.clearTimeout const currentClearTimeout = globalThis.clearTimeout
@ -34,28 +33,34 @@ function withSafeTimers(getTimers: typeof getSafeTimers, fn: () => void) {
const promises = new Set<Promise<unknown>>() const promises = new Set<Promise<unknown>>()
export async function rpcDone() { export async function rpcDone() {
if (!promises.size) if (!promises.size) {
return return
}
const awaitable = Array.from(promises) const awaitable = Array.from(promises)
return Promise.all(awaitable) return Promise.all(awaitable)
} }
export function createSafeRpc(client: VitestBrowserClient, getTimers: () => any): VitestBrowserClient['rpc'] { export function createSafeRpc(
client: VitestBrowserClient,
getTimers: () => any,
): VitestBrowserClient['rpc'] {
return new Proxy(client.rpc, { return new Proxy(client.rpc, {
get(target, p, handler) { get(target, p, handler) {
if (p === 'then') if (p === 'then') {
return return
}
const sendCall = get(target, p, handler) const sendCall = get(target, p, handler)
const safeSendCall = (...args: any[]) => withSafeTimers(getTimers, async () => { const safeSendCall = (...args: any[]) =>
const result = sendCall(...args) withSafeTimers(getTimers, async () => {
promises.add(result) const result = sendCall(...args)
try { promises.add(result)
return await result try {
} return await result
finally { }
promises.delete(result) finally {
} promises.delete(result)
}) }
})
safeSendCall.asEvent = sendCall.asEvent safeSendCall.asEvent = sendCall.asEvent
return safeSendCall return safeSendCall
}, },
@ -64,7 +69,9 @@ export function createSafeRpc(client: VitestBrowserClient, getTimers: () => any)
export async function loadSafeRpc(client: VitestBrowserClient) { export async function loadSafeRpc(client: VitestBrowserClient) {
// if importing /@id/ failed, we reload the page waiting until Vite prebundles it // if importing /@id/ failed, we reload the page waiting until Vite prebundles it
const { getSafeTimers } = await importId('vitest/utils') as typeof import('vitest/utils') const { getSafeTimers } = (await importId(
'vitest/utils',
)) as typeof import('vitest/utils')
return createSafeRpc(client, getSafeTimers) return createSafeRpc(client, getSafeTimers)
} }

View File

@ -10,18 +10,21 @@ interface BrowserRunnerOptions {
config: ResolvedConfig config: ResolvedConfig
} }
export const browserHashMap = new Map<string, [test: boolean, timstamp: string]>() export const browserHashMap = new Map<
string,
[test: boolean, timstamp: string]
>()
interface CoverageHandler { interface CoverageHandler {
takeCoverage: () => Promise<unknown> takeCoverage: () => Promise<unknown>
} }
export function createBrowserRunner( export function createBrowserRunner(
runnerClass: { new(config: ResolvedConfig): VitestRunner }, runnerClass: { new (config: ResolvedConfig): VitestRunner },
mocker: VitestBrowserClientMocker, mocker: VitestBrowserClientMocker,
state: WorkerGlobalState, state: WorkerGlobalState,
coverageModule: CoverageHandler | null, coverageModule: CoverageHandler | null,
): { new(options: BrowserRunnerOptions): VitestRunner } { ): { new (options: BrowserRunnerOptions): VitestRunner } {
return class BrowserTestRunner extends runnerClass implements VitestRunner { return class BrowserTestRunner extends runnerClass implements VitestRunner {
public config: ResolvedConfig public config: ResolvedConfig
hashMap = browserHashMap hashMap = browserHashMap
@ -101,9 +104,14 @@ export function createBrowserRunner(
let cachedRunner: VitestRunner | null = null let cachedRunner: VitestRunner | null = null
export async function initiateRunner(state: WorkerGlobalState, mocker: VitestBrowserClientMocker, config: ResolvedConfig) { export async function initiateRunner(
if (cachedRunner) state: WorkerGlobalState,
mocker: VitestBrowserClientMocker,
config: ResolvedConfig,
) {
if (cachedRunner) {
return cachedRunner return cachedRunner
}
const [ const [
{ VitestTestRunner, NodeBenchmarkRunner }, { VitestTestRunner, NodeBenchmarkRunner },
{ takeCoverageInsideWorker, loadDiffConfig, loadSnapshotSerializers }, { takeCoverageInsideWorker, loadDiffConfig, loadSnapshotSerializers },
@ -111,12 +119,16 @@ export async function initiateRunner(state: WorkerGlobalState, mocker: VitestBro
importId('vitest/runners') as Promise<typeof import('vitest/runners')>, importId('vitest/runners') as Promise<typeof import('vitest/runners')>,
importId('vitest/browser') as Promise<typeof import('vitest/browser')>, importId('vitest/browser') as Promise<typeof import('vitest/browser')>,
]) ])
const runnerClass = config.mode === 'test' ? VitestTestRunner : NodeBenchmarkRunner const runnerClass
= config.mode === 'test' ? VitestTestRunner : NodeBenchmarkRunner
const BrowserRunner = createBrowserRunner(runnerClass, mocker, state, { const BrowserRunner = createBrowserRunner(runnerClass, mocker, state, {
takeCoverage: () => takeCoverageInsideWorker(config.coverage, { executeId: importId }), takeCoverage: () =>
takeCoverageInsideWorker(config.coverage, { executeId: importId }),
}) })
if (!config.snapshotOptions.snapshotEnvironment) if (!config.snapshotOptions.snapshotEnvironment) {
config.snapshotOptions.snapshotEnvironment = new VitestBrowserSnapshotEnvironment() config.snapshotOptions.snapshotEnvironment
= new VitestBrowserSnapshotEnvironment()
}
const runner = new BrowserRunner({ const runner = new BrowserRunner({
config, config,
}) })
@ -131,22 +143,27 @@ export async function initiateRunner(state: WorkerGlobalState, mocker: VitestBro
} }
async function updateFilesLocations(files: File[]) { async function updateFilesLocations(files: File[]) {
const { loadSourceMapUtils } = await importId('vitest/utils') as typeof import('vitest/utils') const { loadSourceMapUtils } = (await importId(
'vitest/utils',
)) as typeof import('vitest/utils')
const { TraceMap, originalPositionFor } = await loadSourceMapUtils() const { TraceMap, originalPositionFor } = await loadSourceMapUtils()
const promises = files.map(async (file) => { const promises = files.map(async (file) => {
const result = await rpc().getBrowserFileSourceMap(file.filepath) const result = await rpc().getBrowserFileSourceMap(file.filepath)
if (!result) if (!result) {
return null return null
}
const traceMap = new TraceMap(result as any) const traceMap = new TraceMap(result as any)
function updateLocation(task: Task) { function updateLocation(task: Task) {
if (task.location) { if (task.location) {
const { line, column } = originalPositionFor(traceMap, task.location) const { line, column } = originalPositionFor(traceMap, task.location)
if (line != null && column != null) if (line != null && column != null) {
task.location = { line, column: task.each ? column : column + 1 } task.location = { line, column: task.each ? column : column + 1 }
}
} }
if ('tasks' in task) if ('tasks' in task) {
task.tasks.forEach(updateLocation) task.tasks.forEach(updateLocation)
}
} }
file.tasks.forEach(updateLocation) file.tasks.forEach(updateLocation)
return null return null

View File

@ -19,7 +19,14 @@
<script>{__VITEST_INJECTOR__}</script> <script>{__VITEST_INJECTOR__}</script>
{__VITEST_SCRIPTS__} {__VITEST_SCRIPTS__}
</head> </head>
<body style="width: 100%; height: 100%; transform: scale(1); transform-origin: left top;"> <body
style="
width: 100%;
height: 100%;
transform: scale(1);
transform-origin: left top;
"
>
<script type="module" src="/tester.ts"></script> <script type="module" src="/tester.ts"></script>
{__VITEST_APPEND__} {__VITEST_APPEND__}
</body> </body>

View File

@ -6,7 +6,11 @@ import { browserHashMap, initiateRunner } from './runner'
import { getBrowserState, getConfig, importId } from './utils' import { getBrowserState, getConfig, importId } from './utils'
import { loadSafeRpc } from './rpc' import { loadSafeRpc } from './rpc'
import { VitestBrowserClientMocker } from './mocker' import { VitestBrowserClientMocker } from './mocker'
import { registerUnexpectedErrors, registerUnhandledErrors, serializeError } from './unhandled' import {
registerUnexpectedErrors,
registerUnhandledErrors,
serializeError,
} from './unhandled'
const stopErrorHandler = registerUnhandledErrors() const stopErrorHandler = registerUnhandledErrors()
@ -15,26 +19,44 @@ const reloadStart = url.searchParams.get('__reloadStart')
function debug(...args: unknown[]) { function debug(...args: unknown[]) {
const debug = getConfig().env.VITEST_BROWSER_DEBUG const debug = getConfig().env.VITEST_BROWSER_DEBUG
if (debug && debug !== 'false') if (debug && debug !== 'false') {
client.rpc.debug(...args.map(String)) client.rpc.debug(...args.map(String))
}
} }
async function tryCall<T>(fn: () => Promise<T>): Promise<T | false | undefined> { async function tryCall<T>(
fn: () => Promise<T>,
): Promise<T | false | undefined> {
try { try {
return await fn() return await fn()
} }
catch (err: any) { catch (err: any) {
const now = Date.now() const now = Date.now()
// try for 30 seconds // try for 30 seconds
const canTry = !reloadStart || (now - Number(reloadStart) < 30_000) const canTry = !reloadStart || now - Number(reloadStart) < 30_000
const errorStack = (() => { const errorStack = (() => {
if (!err) if (!err) {
return null return null
return err.stack?.includes(err.message) ? err.stack : `${err.message}\n${err.stack}` }
return err.stack?.includes(err.message)
? err.stack
: `${err.message}\n${err.stack}`
})() })()
debug('failed to resolve runner', 'trying again:', canTry, 'time is', now, 'reloadStart is', reloadStart, ':\n', errorStack) debug(
'failed to resolve runner',
'trying again:',
canTry,
'time is',
now,
'reloadStart is',
reloadStart,
':\n',
errorStack,
)
if (!canTry) { if (!canTry) {
const error = serializeError(new Error('Vitest failed to load its runner after 30 seconds.')) const error = serializeError(
new Error('Vitest failed to load its runner after 30 seconds.'),
)
error.cause = serializeError(err) error.cause = serializeError(err)
await client.rpc.onUnhandledError(error, 'Preload Error') await client.rpc.onUnhandledError(error, 'Preload Error')
@ -114,14 +136,17 @@ async function prepareTestEnvironment(files: string[]) {
const version = url.searchParams.get('browserv') || '' const version = url.searchParams.get('browserv') || ''
files.forEach((filename) => { files.forEach((filename) => {
const currentVersion = browserHashMap.get(filename) const currentVersion = browserHashMap.get(filename)
if (!currentVersion || currentVersion[1] !== version) if (!currentVersion || currentVersion[1] !== version) {
browserHashMap.set(filename, [true, version]) browserHashMap.set(filename, [true, version])
}
}) })
const [runner, { startTests, setupCommonEnv, SpyModule }] = await Promise.all([ const [runner, { startTests, setupCommonEnv, SpyModule }] = await Promise.all(
initiateRunner(state, mocker, config), [
importId('vitest/browser') as Promise<typeof import('vitest/browser')>, initiateRunner(state, mocker, config),
]) importId('vitest/browser') as Promise<typeof import('vitest/browser')>,
],
)
mocker.setSpyModule(SpyModule) mocker.setSpyModule(SpyModule)
mocker.setupWorker() mocker.setupWorker()
@ -155,7 +180,10 @@ async function runTests(files: string[]) {
debug('client is connected to ws server') debug('client is connected to ws server')
let preparedData: Awaited<ReturnType<typeof prepareTestEnvironment>> | undefined | false let preparedData:
| Awaited<ReturnType<typeof prepareTestEnvironment>>
| undefined
| false
// if importing /@id/ failed, we reload the page waiting until Vite prebundles it // if importing /@id/ failed, we reload the page waiting until Vite prebundles it
try { try {
@ -191,8 +219,9 @@ async function runTests(files: string[]) {
try { try {
await setupCommonEnv(config) await setupCommonEnv(config)
for (const file of files) for (const file of files) {
await startTests([file], runner) await startTests([file], runner)
}
} }
finally { finally {
state.environmentTeardownRun = true state.environmentTeardownRun = true

View File

@ -31,22 +31,30 @@ async function defaultErrorReport(type: string, unhandledError: any) {
function catchWindowErrors(cb: (e: ErrorEvent) => void) { function catchWindowErrors(cb: (e: ErrorEvent) => void) {
let userErrorListenerCount = 0 let userErrorListenerCount = 0
function throwUnhandlerError(e: ErrorEvent) { function throwUnhandlerError(e: ErrorEvent) {
if (userErrorListenerCount === 0 && e.error != null) if (userErrorListenerCount === 0 && e.error != null) {
cb(e) cb(e)
else }
else {
console.error(e.error) console.error(e.error)
}
} }
const addEventListener = window.addEventListener.bind(window) const addEventListener = window.addEventListener.bind(window)
const removeEventListener = window.removeEventListener.bind(window) const removeEventListener = window.removeEventListener.bind(window)
window.addEventListener('error', throwUnhandlerError) window.addEventListener('error', throwUnhandlerError)
window.addEventListener = function (...args: Parameters<typeof addEventListener>) { window.addEventListener = function (
if (args[0] === 'error') ...args: Parameters<typeof addEventListener>
) {
if (args[0] === 'error') {
userErrorListenerCount++ userErrorListenerCount++
}
return addEventListener.apply(this, args) return addEventListener.apply(this, args)
} }
window.removeEventListener = function (...args: Parameters<typeof removeEventListener>) { window.removeEventListener = function (
if (args[0] === 'error' && userErrorListenerCount) ...args: Parameters<typeof removeEventListener>
) {
if (args[0] === 'error' && userErrorListenerCount) {
userErrorListenerCount-- userErrorListenerCount--
}
return removeEventListener.apply(this, args) return removeEventListener.apply(this, args)
} }
return function clearErrorHandlers() { return function clearErrorHandlers() {
@ -55,8 +63,11 @@ function catchWindowErrors(cb: (e: ErrorEvent) => void) {
} }
export function registerUnhandledErrors() { export function registerUnhandledErrors() {
const stopErrorHandler = catchWindowErrors(e => defaultErrorReport('Error', e.error)) const stopErrorHandler = catchWindowErrors(e =>
const stopRejectionHandler = on('unhandledrejection', e => defaultErrorReport('Unhandled Rejection', e.reason)) defaultErrorReport('Error', e.error),
)
const stopRejectionHandler = on('unhandledrejection', e =>
defaultErrorReport('Unhandled Rejection', e.reason))
return () => { return () => {
stopErrorHandler() stopErrorHandler()
stopRejectionHandler() stopRejectionHandler()
@ -64,12 +75,21 @@ export function registerUnhandledErrors() {
} }
export function registerUnexpectedErrors(rpc: typeof client.rpc) { export function registerUnexpectedErrors(rpc: typeof client.rpc) {
catchWindowErrors(event => reportUnexpectedError(rpc, 'Error', event.error)) catchWindowErrors(event =>
on('unhandledrejection', event => reportUnexpectedError(rpc, 'Unhandled Rejection', event.reason)) reportUnexpectedError(rpc, 'Error', event.error),
)
on('unhandledrejection', event =>
reportUnexpectedError(rpc, 'Unhandled Rejection', event.reason))
} }
async function reportUnexpectedError(rpc: typeof client.rpc, type: string, error: any) { async function reportUnexpectedError(
const { processError } = await importId('vitest/browser') as typeof import('vitest/browser') rpc: typeof client.rpc,
type: string,
error: any,
) {
const { processError } = (await importId(
'vitest/browser',
)) as typeof import('vitest/browser')
const processedError = processError(error) const processedError = processError(error)
await rpc.onUnhandledError(processedError, type) await rpc.onUnhandledError(processedError, type)
} }

View File

@ -27,27 +27,35 @@ export default defineConfig({
name: 'virtual:msw', name: 'virtual:msw',
enforce: 'pre', enforce: 'pre',
resolveId(id) { resolveId(id) {
if (id.startsWith('msw')) if (id.startsWith('msw')) {
return `/__virtual_vitest__:${id}` return `/__virtual_vitest__:${id}`
}
}, },
}, },
{ {
name: 'copy-ui-plugin', name: 'copy-ui-plugin',
/* eslint-disable no-console */ /* eslint-disable no-console */
closeBundle: async () => { closeBundle: async () => {
const root = resolve(fileURLToPath(import.meta.url), '../../../../../packages') const root = resolve(
fileURLToPath(import.meta.url),
'../../../../../packages',
)
const ui = resolve(root, 'ui/dist/client') const ui = resolve(root, 'ui/dist/client')
const browser = resolve(root, 'browser/dist/client/__vitest__/') const browser = resolve(root, 'browser/dist/client/__vitest__/')
const timeout = setTimeout(() => console.log('[copy-ui-plugin] Waiting for UI to be built...'), 1000) const timeout = setTimeout(
() => console.log('[copy-ui-plugin] Waiting for UI to be built...'),
1000,
)
await waitFor(() => fs.existsSync(ui)) await waitFor(() => fs.existsSync(ui))
clearTimeout(timeout) clearTimeout(timeout)
const files = fg.sync('**/*', { cwd: ui }) const files = fg.sync('**/*', { cwd: ui })
if (fs.existsSync(browser)) if (fs.existsSync(browser)) {
fs.rmSync(browser, { recursive: true }) fs.rmSync(browser, { recursive: true })
}
fs.mkdirSync(browser, { recursive: true }) fs.mkdirSync(browser, { recursive: true })
fs.mkdirSync(resolve(browser, 'assets')) fs.mkdirSync(resolve(browser, 'assets'))
@ -63,11 +71,13 @@ export default defineConfig({
}) })
async function waitFor(method: () => boolean, retries = 100): Promise<void> { async function waitFor(method: () => boolean, retries = 100): Promise<void> {
if (method()) if (method()) {
return return
}
if (retries === 0) if (retries === 0) {
throw new Error('Timeout in waitFor') throw new Error('Timeout in waitFor')
}
await new Promise(resolve => setTimeout(resolve, 500)) await new Promise(resolve => setTimeout(resolve, 500))

View File

@ -1,4 +1,14 @@
import type { Declaration, ExportDefaultDeclaration, ExportNamedDeclaration, Expression, Identifier, Literal, Pattern, Positioned, Program } from '@vitest/utils/ast' import type {
Declaration,
ExportDefaultDeclaration,
ExportNamedDeclaration,
Expression,
Identifier,
Literal,
Pattern,
Positioned,
Program,
} from '@vitest/utils/ast'
import MagicString from 'magic-string' import MagicString from 'magic-string'
// TODO: better source map replacement // TODO: better source map replacement
@ -28,21 +38,25 @@ export function automockModule(code: string, parse: (code: string) => Program) {
// export const [test, ...rest] = [1, 2, 3] // export const [test, ...rest] = [1, 2, 3]
else if (expression.type === 'ArrayPattern') { else if (expression.type === 'ArrayPattern') {
expression.elements.forEach((element) => { expression.elements.forEach((element) => {
if (!element) if (!element) {
return return
}
traversePattern(element) traversePattern(element)
}) })
} }
else if (expression.type === 'ObjectPattern') { else if (expression.type === 'ObjectPattern') {
expression.properties.forEach((property) => { expression.properties.forEach((property) => {
// export const { ...rest } = {} // export const { ...rest } = {}
if (property.type === 'RestElement') if (property.type === 'RestElement') {
traversePattern(property) traversePattern(property)
}
// export const { test, test2: alias } = {} // export const { test, test2: alias } = {}
else if (property.type === 'Property') else if (property.type === 'Property') {
traversePattern(property.value) traversePattern(property.value)
else }
else {
property satisfies never property satisfies never
}
}) })
} }
else if (expression.type === 'RestElement') { else if (expression.type === 'RestElement') {
@ -51,12 +65,16 @@ export function automockModule(code: string, parse: (code: string) => Program) {
// const [name[1], name[2]] = [] // const [name[1], name[2]] = []
// cannot be used in export // cannot be used in export
else if (expression.type === 'AssignmentPattern') { else if (expression.type === 'AssignmentPattern') {
throw new Error(`AssignmentPattern is not supported. Please open a new bug report.`) throw new Error(
`AssignmentPattern is not supported. Please open a new bug report.`,
)
} }
// const test = thing.func() // const test = thing.func()
// cannot be used in export // cannot be used in export
else if (expression.type === 'MemberExpression') { else if (expression.type === 'MemberExpression') {
throw new Error(`MemberExpression is not supported. Please open a new bug report.`) throw new Error(
`MemberExpression is not supported. Please open a new bug report.`,
)
} }
else { else {
expression satisfies never expression satisfies never
@ -89,9 +107,7 @@ export function automockModule(code: string, parse: (code: string) => Program) {
const exported = specifier.exported as Literal | Identifier const exported = specifier.exported as Literal | Identifier
allSpecifiers.push({ allSpecifiers.push({
alias: exported.type === 'Literal' alias: exported.type === 'Literal' ? exported.raw! : exported.name,
? exported.raw!
: exported.name,
name: specifier.local.name, name: specifier.local.name,
}) })
}) })
@ -106,13 +122,13 @@ export function automockModule(code: string, parse: (code: string) => Program) {
importNames.push([specifier.local.name, importedName]) importNames.push([specifier.local.name, importedName])
allSpecifiers.push({ allSpecifiers.push({
name: importedName, name: importedName,
alias: exported.type === 'Literal' alias: exported.type === 'Literal' ? exported.raw! : exported.name,
? exported.raw!
: exported.name,
}) })
}) })
const importString = `import { ${importNames.map(([name, alias]) => `${name} as ${alias}`).join(', ')} } from '${source.value}'` const importString = `import { ${importNames
.map(([name, alias]) => `${name} as ${alias}`)
.join(', ')} } from '${source.value}'`
m.overwrite(node.start, node.end, importString) m.overwrite(node.start, node.end, importString)
} }
@ -131,13 +147,17 @@ const __vitest_es_current_module__ = {
} }
const __vitest_mocked_module__ = __vitest_mocker__.mockObject(__vitest_es_current_module__) const __vitest_mocked_module__ = __vitest_mocker__.mockObject(__vitest_es_current_module__)
` `
const assigning = allSpecifiers.map(({ name }, index) => { const assigning = allSpecifiers
return `const __vitest_mocked_${index}__ = __vitest_mocked_module__["${name}"]` .map(({ name }, index) => {
}).join('\n') return `const __vitest_mocked_${index}__ = __vitest_mocked_module__["${name}"]`
})
.join('\n')
const redeclarations = allSpecifiers.map(({ name, alias }, index) => { const redeclarations = allSpecifiers
return ` __vitest_mocked_${index}__ as ${alias || name},` .map(({ name, alias }, index) => {
}).join('\n') return ` __vitest_mocked_${index}__ as ${alias || name},`
})
.join('\n')
const specifiersExports = ` const specifiersExports = `
export { export {
${redeclarations} ${redeclarations}

View File

@ -5,29 +5,43 @@ import type { BrowserCommand, WorkspaceProject } from 'vitest/node'
import type { BrowserCommands } from '../../../context' import type { BrowserCommands } from '../../../context'
function assertFileAccess(path: string, project: WorkspaceProject) { function assertFileAccess(path: string, project: WorkspaceProject) {
if (!isFileServingAllowed(path, project.server) && !isFileServingAllowed(path, project.ctx.server)) if (
throw new Error(`Access denied to "${path}". See Vite config documentation for "server.fs": https://vitejs.dev/config/server-options.html#server-fs-strict.`) !isFileServingAllowed(path, project.server)
&& !isFileServingAllowed(path, project.ctx.server)
) {
throw new Error(
`Access denied to "${path}". See Vite config documentation for "server.fs": https://vitejs.dev/config/server-options.html#server-fs-strict.`,
)
}
} }
export const readFile: BrowserCommand<Parameters<BrowserCommands['readFile']>> = async ({ project, testPath = process.cwd() }, path, options = {}) => { export const readFile: BrowserCommand<
Parameters<BrowserCommands['readFile']>
> = async ({ project, testPath = process.cwd() }, path, options = {}) => {
const filepath = resolve(dirname(testPath), path) const filepath = resolve(dirname(testPath), path)
assertFileAccess(filepath, project) assertFileAccess(filepath, project)
// never return a Buffer // never return a Buffer
if (typeof options === 'object' && !options.encoding) if (typeof options === 'object' && !options.encoding) {
options.encoding = 'utf-8' options.encoding = 'utf-8'
}
return fsp.readFile(filepath, options) return fsp.readFile(filepath, options)
} }
export const writeFile: BrowserCommand<Parameters<BrowserCommands['writeFile']>> = async ({ project, testPath = process.cwd() }, path, data, options) => { export const writeFile: BrowserCommand<
Parameters<BrowserCommands['writeFile']>
> = async ({ project, testPath = process.cwd() }, path, data, options) => {
const filepath = resolve(dirname(testPath), path) const filepath = resolve(dirname(testPath), path)
assertFileAccess(filepath, project) assertFileAccess(filepath, project)
const dir = dirname(filepath) const dir = dirname(filepath)
if (!fs.existsSync(dir)) if (!fs.existsSync(dir)) {
await fsp.mkdir(dir, { recursive: true }) await fsp.mkdir(dir, { recursive: true })
}
await fsp.writeFile(filepath, data, options) await fsp.writeFile(filepath, data, options)
} }
export const removeFile: BrowserCommand<Parameters<BrowserCommands['removeFile']>> = async ({ project, testPath = process.cwd() }, path) => { export const removeFile: BrowserCommand<
Parameters<BrowserCommands['removeFile']>
> = async ({ project, testPath = process.cwd() }, path) => {
const filepath = resolve(dirname(testPath), path) const filepath = resolve(dirname(testPath), path)
assertFileAccess(filepath, project) assertFileAccess(filepath, project)
await fsp.rm(filepath) await fsp.rm(filepath)

View File

@ -1,9 +1,5 @@
import { click } from './click' import { click } from './click'
import { import { readFile, removeFile, writeFile } from './fs'
readFile,
removeFile,
writeFile,
} from './fs'
import { sendKeys } from './keyboard' import { sendKeys } from './keyboard'
import { screenshot } from './screenshot' import { screenshot } from './screenshot'

View File

@ -19,13 +19,16 @@ function isObject(payload: unknown): payload is Record<string, unknown> {
function isSendKeysPayload(payload: unknown): boolean { function isSendKeysPayload(payload: unknown): boolean {
const validOptions = ['type', 'press', 'down', 'up'] const validOptions = ['type', 'press', 'down', 'up']
if (!isObject(payload)) if (!isObject(payload)) {
throw new Error('You must provide a `SendKeysPayload` object') throw new Error('You must provide a `SendKeysPayload` object')
}
const numberOfValidOptions = Object.keys(payload).filter(key => const numberOfValidOptions = Object.keys(payload).filter(key =>
validOptions.includes(key), validOptions.includes(key),
).length ).length
const unknownOptions = Object.keys(payload).filter(key => !validOptions.includes(key)) const unknownOptions = Object.keys(payload).filter(
key => !validOptions.includes(key),
)
if (numberOfValidOptions > 1) { if (numberOfValidOptions > 1) {
throw new Error( throw new Error(
@ -41,8 +44,11 @@ function isSendKeysPayload(payload: unknown): boolean {
)}.`, )}.`,
) )
} }
if (unknownOptions.length > 0) if (unknownOptions.length > 0) {
throw new Error(`Unknown options \`${unknownOptions.join(', ')}\` present.`) throw new Error(
`Unknown options \`${unknownOptions.join(', ')}\` present.`,
)
}
return true return true
} }
@ -63,31 +69,43 @@ function isUpPayload(payload: SendKeysPayload): payload is UpPayload {
return 'up' in payload return 'up' in payload
} }
export const sendKeys: BrowserCommand<Parameters<BrowserCommands['sendKeys']>> = async ({ provider, contextId }, payload) => { export const sendKeys: BrowserCommand<
if (!isSendKeysPayload(payload) || !payload) Parameters<BrowserCommands['sendKeys']>
> = async ({ provider, contextId }, payload) => {
if (!isSendKeysPayload(payload) || !payload) {
throw new Error('You must provide a `SendKeysPayload` object') throw new Error('You must provide a `SendKeysPayload` object')
}
if (provider instanceof PlaywrightBrowserProvider) { if (provider instanceof PlaywrightBrowserProvider) {
const page = provider.getPage(contextId) const page = provider.getPage(contextId)
if (isTypePayload(payload)) if (isTypePayload(payload)) {
await page.keyboard.type(payload.type) await page.keyboard.type(payload.type)
else if (isPressPayload(payload)) }
else if (isPressPayload(payload)) {
await page.keyboard.press(payload.press) await page.keyboard.press(payload.press)
else if (isDownPayload(payload)) }
else if (isDownPayload(payload)) {
await page.keyboard.down(payload.down) await page.keyboard.down(payload.down)
else if (isUpPayload(payload)) }
else if (isUpPayload(payload)) {
await page.keyboard.up(payload.up) await page.keyboard.up(payload.up)
}
} }
else if (provider instanceof WebdriverBrowserProvider) { else if (provider instanceof WebdriverBrowserProvider) {
const browser = provider.browser! const browser = provider.browser!
if (isTypePayload(payload)) if (isTypePayload(payload)) {
await browser.keys(payload.type.split('')) await browser.keys(payload.type.split(''))
else if (isPressPayload(payload)) }
else if (isPressPayload(payload)) {
await browser.keys([payload.press]) await browser.keys([payload.press])
else }
else {
throw new Error('Only "press" and "type" are supported by webdriverio.') throw new Error('Only "press" and "type" are supported by webdriverio.')
}
} }
else { else {
throw new TypeError(`"sendKeys" is not supported for ${provider.name} browser provider.`) throw new TypeError(
`"sendKeys" is not supported for ${provider.name} browser provider.`,
)
} }
} }

View File

@ -8,11 +8,20 @@ import { PlaywrightBrowserProvider } from '../providers/playwright'
import { WebdriverBrowserProvider } from '../providers/webdriver' import { WebdriverBrowserProvider } from '../providers/webdriver'
// TODO: expose provider specific options in types // TODO: expose provider specific options in types
export const screenshot: BrowserCommand<[string, ScreenshotOptions]> = async (context, name: string, options = {}) => { export const screenshot: BrowserCommand<[string, ScreenshotOptions]> = async (
if (!context.testPath) context,
name: string,
options = {},
) => {
if (!context.testPath) {
throw new Error(`Cannot take a screenshot without a test path`) throw new Error(`Cannot take a screenshot without a test path`)
}
const path = resolveScreenshotPath(context.testPath, name, context.project.config) const path = resolveScreenshotPath(
context.testPath,
name,
context.project.config,
)
const savePath = normalize(path) const savePath = normalize(path)
await mkdir(dirname(path), { recursive: true }) await mkdir(dirname(path), { recursive: true })
@ -42,10 +51,16 @@ export const screenshot: BrowserCommand<[string, ScreenshotOptions]> = async (co
return path return path
} }
throw new Error(`Provider "${context.provider.name}" does not support screenshots`) throw new Error(
`Provider "${context.provider.name}" does not support screenshots`,
)
} }
function resolveScreenshotPath(testPath: string, name: string, config: ResolvedConfig) { function resolveScreenshotPath(
testPath: string,
name: string,
config: ResolvedConfig,
) {
const dir = dirname(testPath) const dir = dirname(testPath)
const base = basename(testPath) const base = basename(testPath)
if (config.browser.screenshotDirectory) { if (config.browser.screenshotDirectory) {

View File

@ -6,7 +6,7 @@ export type UserEventCommand<T extends (...args: any) => any> = BrowserCommand<
type ConvertElementToLocator<T> = T extends Element ? string : T type ConvertElementToLocator<T> = T extends Element ? string : T
type ConvertUserEventParameters<T extends unknown[]> = { type ConvertUserEventParameters<T extends unknown[]> = {
[K in keyof T]: ConvertElementToLocator<T[K]> [K in keyof T]: ConvertElementToLocator<T[K]>;
} }
export function defineBrowserCommand<T extends unknown[]>( export function defineBrowserCommand<T extends unknown[]>(

View File

@ -3,7 +3,11 @@ import type { PluginContext } from 'rollup'
import { esmWalker } from '@vitest/utils/ast' import { esmWalker } from '@vitest/utils/ast'
import type { Expression, Positioned } from '@vitest/utils/ast' import type { Expression, Positioned } from '@vitest/utils/ast'
export function injectDynamicImport(code: string, id: string, parse: PluginContext['parse']) { export function injectDynamicImport(
code: string,
id: string,
parse: PluginContext['parse'],
) {
const s = new MagicString(code) const s = new MagicString(code)
let ast: any let ast: any
@ -23,7 +27,11 @@ export function injectDynamicImport(code: string, id: string, parse: PluginConte
}, },
onDynamicImport(node) { onDynamicImport(node) {
const replace = '__vitest_browser_runner__.wrapModule(() => import(' const replace = '__vitest_browser_runner__.wrapModule(() => import('
s.overwrite(node.start, (node.source as Positioned<Expression>).start, replace) s.overwrite(
node.start,
(node.source as Positioned<Expression>).start,
replace,
)
s.overwrite(node.end - 1, node.end, '))') s.overwrite(node.end - 1, node.end, '))')
}, },
}) })

View File

@ -31,39 +31,54 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
} }
}, },
async configureServer(server) { async configureServer(server) {
const testerHtml = readFile(resolve(distRoot, 'client/tester.html'), 'utf8') const testerHtml = readFile(
resolve(distRoot, 'client/tester.html'),
'utf8',
)
const orchestratorHtml = project.config.browser.ui const orchestratorHtml = project.config.browser.ui
? readFile(resolve(distRoot, 'client/__vitest__/index.html'), 'utf8') ? readFile(resolve(distRoot, 'client/__vitest__/index.html'), 'utf8')
: readFile(resolve(distRoot, 'client/orchestrator.html'), 'utf8') : readFile(resolve(distRoot, 'client/orchestrator.html'), 'utf8')
const injectorJs = readFile(resolve(distRoot, 'client/esm-client-injector.js'), 'utf8') const injectorJs = readFile(
resolve(distRoot, 'client/esm-client-injector.js'),
'utf8',
)
const manifest = (async () => { const manifest = (async () => {
return JSON.parse(await readFile(`${distRoot}/client/.vite/manifest.json`, 'utf8')) return JSON.parse(
await readFile(`${distRoot}/client/.vite/manifest.json`, 'utf8'),
)
})() })()
const favicon = `${base}favicon.svg` const favicon = `${base}favicon.svg`
const testerPrefix = `${base}__vitest_test__/__test__/` const testerPrefix = `${base}__vitest_test__/__test__/`
server.middlewares.use((_req, res, next) => { server.middlewares.use((_req, res, next) => {
const headers = server.config.server.headers const headers = server.config.server.headers
if (headers) { if (headers) {
for (const name in headers) for (const name in headers) {
res.setHeader(name, headers[name]!) res.setHeader(name, headers[name]!)
}
} }
next() next()
}) })
let orchestratorScripts: string | undefined let orchestratorScripts: string | undefined
let testerScripts: string | undefined let testerScripts: string | undefined
server.middlewares.use(async (req, res, next) => { server.middlewares.use(async (req, res, next) => {
if (!req.url) if (!req.url) {
return next() return next()
}
const url = new URL(req.url, 'http://localhost') const url = new URL(req.url, 'http://localhost')
if (!url.pathname.startsWith(testerPrefix) && url.pathname !== base) if (!url.pathname.startsWith(testerPrefix) && url.pathname !== base) {
return next() return next()
}
res.setHeader('Cache-Control', 'no-cache, max-age=0, must-revalidate') res.setHeader(
'Cache-Control',
'no-cache, max-age=0, must-revalidate',
)
res.setHeader('Content-Type', 'text/html; charset=utf-8') res.setHeader('Content-Type', 'text/html; charset=utf-8')
const config = wrapConfig(project.getSerializableConfig()) const config = wrapConfig(project.getSerializableConfig())
config.env ??= {} config.env ??= {}
config.env.VITEST_BROWSER_DEBUG = process.env.VITEST_BROWSER_DEBUG || '' config.env.VITEST_BROWSER_DEBUG
= process.env.VITEST_BROWSER_DEBUG || ''
// remove custom iframe related headers to allow the iframe to load // remove custom iframe related headers to allow the iframe to load
res.removeHeader('X-Frame-Options') res.removeHeader('X-Frame-Options')
@ -72,8 +87,9 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
let contextId = url.searchParams.get('contextId') let contextId = url.searchParams.get('contextId')
// it's possible to open the page without a context, // it's possible to open the page without a context,
// for now, let's assume it should be the first one // for now, let's assume it should be the first one
if (!contextId) if (!contextId) {
contextId = project.browserState.keys().next().value ?? 'none' contextId = project.browserState.keys().next().value ?? 'none'
}
const files = project.browserState.get(contextId!)?.files ?? [] const files = project.browserState.get(contextId!)?.files ?? []
@ -83,15 +99,20 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
root: project.browser!.config.root, root: project.browser!.config.root,
}), }),
__VITEST_FILES__: JSON.stringify(files), __VITEST_FILES__: JSON.stringify(files),
__VITEST_TYPE__: url.pathname === base ? '"orchestrator"' : '"tester"', __VITEST_TYPE__:
url.pathname === base ? '"orchestrator"' : '"tester"',
__VITEST_CONTEXT_ID__: JSON.stringify(contextId), __VITEST_CONTEXT_ID__: JSON.stringify(contextId),
}) })
// disable CSP for the orchestrator as we are the ones controlling it // disable CSP for the orchestrator as we are the ones controlling it
res.removeHeader('Content-Security-Policy') res.removeHeader('Content-Security-Policy')
if (!orchestratorScripts) if (!orchestratorScripts) {
orchestratorScripts = await formatScripts(project.config.browser.orchestratorScripts, server) orchestratorScripts = await formatScripts(
project.config.browser.orchestratorScripts,
server,
)
}
let baseHtml = await orchestratorHtml let baseHtml = await orchestratorHtml
@ -99,14 +120,16 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
if (project.config.browser.ui) { if (project.config.browser.ui) {
const manifestContent = await manifest const manifestContent = await manifest
const jsEntry = manifestContent['orchestrator.html'].file const jsEntry = manifestContent['orchestrator.html'].file
baseHtml = baseHtml.replaceAll('./assets/', `${base}__vitest__/assets/`).replace( baseHtml = baseHtml
'<!-- !LOAD_METADATA! -->', .replaceAll('./assets/', `${base}__vitest__/assets/`)
[ .replace(
'<script>{__VITEST_INJECTOR__}</script>', '<!-- !LOAD_METADATA! -->',
'{__VITEST_SCRIPTS__}', [
`<script type="module" crossorigin src="${jsEntry}"></script>`, '<script>{__VITEST_INJECTOR__}</script>',
].join('\n'), '{__VITEST_SCRIPTS__}',
) `<script type="module" crossorigin src="${jsEntry}"></script>`,
].join('\n'),
)
} }
const html = replacer(baseHtml, { const html = replacer(baseHtml, {
@ -125,14 +148,23 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
if (typeof csp === 'string') { if (typeof csp === 'string') {
// add frame-ancestors to allow the iframe to be loaded by Vitest, // add frame-ancestors to allow the iframe to be loaded by Vitest,
// but keep the rest of the CSP // but keep the rest of the CSP
res.setHeader('Content-Security-Policy', csp.replace(/frame-ancestors [^;]+/, 'frame-ancestors *')) res.setHeader(
'Content-Security-Policy',
csp.replace(/frame-ancestors [^;]+/, 'frame-ancestors *'),
)
} }
const [contextId, testFile] = url.pathname.slice(testerPrefix.length).split('/') const [contextId, testFile] = url.pathname
.slice(testerPrefix.length)
.split('/')
const decodedTestFile = decodeURIComponent(testFile) const decodedTestFile = decodeURIComponent(testFile)
const testFiles = await project.globTestFiles() const testFiles = await project.globTestFiles()
// if decoded test file is "__vitest_all__" or not in the list of known files, run all tests // if decoded test file is "__vitest_all__" or not in the list of known files, run all tests
const tests = decodedTestFile === '__vitest_all__' || !testFiles.includes(decodedTestFile) ? '__vitest_browser_runner__.files' : JSON.stringify([decodedTestFile]) const tests
= decodedTestFile === '__vitest_all__'
|| !testFiles.includes(decodedTestFile)
? '__vitest_browser_runner__.files'
: JSON.stringify([decodedTestFile])
const iframeId = JSON.stringify(decodedTestFile) const iframeId = JSON.stringify(decodedTestFile)
const files = project.browserState.get(contextId)?.files ?? [] const files = project.browserState.get(contextId)?.files ?? []
@ -142,12 +174,17 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
__VITEST_VITE_CONFIG__: JSON.stringify({ __VITEST_VITE_CONFIG__: JSON.stringify({
root: project.browser!.config.root, root: project.browser!.config.root,
}), }),
__VITEST_TYPE__: url.pathname === base ? '"orchestrator"' : '"tester"', __VITEST_TYPE__:
url.pathname === base ? '"orchestrator"' : '"tester"',
__VITEST_CONTEXT_ID__: JSON.stringify(contextId), __VITEST_CONTEXT_ID__: JSON.stringify(contextId),
}) })
if (!testerScripts) if (!testerScripts) {
testerScripts = await formatScripts(project.config.browser.testerScripts, server) testerScripts = await formatScripts(
project.config.browser.testerScripts,
server,
)
}
const html = replacer(await testerHtml, { const html = replacer(await testerHtml, {
__VITEST_FAVICON__: favicon, __VITEST_FAVICON__: favicon,
@ -155,8 +192,8 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
__VITEST_SCRIPTS__: testerScripts, __VITEST_SCRIPTS__: testerScripts,
__VITEST_INJECTOR__: injector, __VITEST_INJECTOR__: injector,
__VITEST_APPEND__: __VITEST_APPEND__:
// TODO: have only a single global variable to not pollute the global scope // TODO: have only a single global variable to not pollute the global scope
`<script type="module"> `<script type="module">
__vitest_browser_runner__.runningFiles = ${tests} __vitest_browser_runner__.runningFiles = ${tests}
__vitest_browser_runner__.iframeId = ${iframeId} __vitest_browser_runner__.iframeId = ${iframeId}
__vitest_browser_runner__.runTests(__vitest_browser_runner__.runningFiles) __vitest_browser_runner__.runTests(__vitest_browser_runner__.runningFiles)
@ -175,16 +212,26 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
const coverageFolder = resolveCoverageFolder(project) const coverageFolder = resolveCoverageFolder(project)
const coveragePath = coverageFolder ? coverageFolder[1] : undefined const coveragePath = coverageFolder ? coverageFolder[1] : undefined
if (coveragePath && base === coveragePath) if (coveragePath && base === coveragePath) {
throw new Error(`The ui base path and the coverage path cannot be the same: ${base}, change coverage.reportsDirectory`) throw new Error(
`The ui base path and the coverage path cannot be the same: ${base}, change coverage.reportsDirectory`,
)
}
coverageFolder && server.middlewares.use(coveragePath!, sirv(coverageFolder[0], { coverageFolder
single: true, && server.middlewares.use(
dev: true, coveragePath!,
setHeaders: (res) => { sirv(coverageFolder[0], {
res.setHeader('Cache-Control', 'public,max-age=0,must-revalidate') single: true,
}, dev: true,
})) setHeaders: (res) => {
res.setHeader(
'Cache-Control',
'public,max-age=0,must-revalidate',
)
},
}),
)
}, },
}, },
{ {
@ -192,7 +239,9 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
enforce: 'pre', enforce: 'pre',
async config() { async config() {
const allTestFiles = await project.globTestFiles() const allTestFiles = await project.globTestFiles()
const browserTestFiles = allTestFiles.filter(file => getFilePoolName(project, file) === 'browser') const browserTestFiles = allTestFiles.filter(
file => getFilePoolName(project, file) === 'browser',
)
const setupFiles = toArray(project.config.setupFiles) const setupFiles = toArray(project.config.setupFiles)
const vitestPaths = [ const vitestPaths = [
resolve(vitestDist, 'index.js'), resolve(vitestDist, 'index.js'),
@ -202,11 +251,7 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
] ]
return { return {
optimizeDeps: { optimizeDeps: {
entries: [ entries: [...browserTestFiles, ...setupFiles, ...vitestPaths],
...browserTestFiles,
...setupFiles,
...vitestPaths,
],
exclude: [ exclude: [
'vitest', 'vitest',
'vitest/utils', 'vitest/utils',
@ -243,15 +288,18 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
} }
}, },
async resolveId(id) { async resolveId(id) {
if (!/\?browserv=\w+$/.test(id)) if (!/\?browserv=\w+$/.test(id)) {
return return
}
let useId = id.slice(0, id.lastIndexOf('?')) let useId = id.slice(0, id.lastIndexOf('?'))
if (useId.startsWith('/@fs/')) if (useId.startsWith('/@fs/')) {
useId = useId.slice(5) useId = useId.slice(5)
}
if (/^\w:/.test(useId)) if (/^\w:/.test(useId)) {
useId = useId.replace(/\\/g, '/') useId = useId.replace(/\\/g, '/')
}
return useId return useId
}, },
@ -262,16 +310,13 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
if (rawId.startsWith('/__virtual_vitest__:')) { if (rawId.startsWith('/__virtual_vitest__:')) {
let id = rawId.slice('/__virtual_vitest__:'.length) let id = rawId.slice('/__virtual_vitest__:'.length)
// TODO: don't hardcode // TODO: don't hardcode
if (id === 'mocker-worker.js') if (id === 'mocker-worker.js') {
id = 'msw/mockServiceWorker.js' id = 'msw/mockServiceWorker.js'
}
const resolved = await this.resolve( const resolved = await this.resolve(id, distRoot, {
id, skipSelf: true,
distRoot, })
{
skipSelf: true,
},
)
return resolved return resolved
} }
}, },
@ -292,7 +337,9 @@ export default (project: WorkspaceProject, base = '/'): Plugin[] => {
const _require = createRequire(import.meta.url) const _require = createRequire(import.meta.url)
build.onResolve({ filter: /@vue\/test-utils/ }, (args) => { build.onResolve({ filter: /@vue\/test-utils/ }, (args) => {
// resolve to CJS instead of the browser because the browser version expects a global Vue object // resolve to CJS instead of the browser because the browser version expects a global Vue object
const resolved = _require.resolve(args.path, { paths: [args.importer] }) const resolved = _require.resolve(args.path, {
paths: [args.importer],
})
return { path: resolved } return { path: resolved }
}) })
}, },
@ -310,15 +357,17 @@ function resolveCoverageFolder(project: WorkspaceProject) {
const options = project.ctx.config const options = project.ctx.config
const htmlReporter = options.coverage?.enabled const htmlReporter = options.coverage?.enabled
? options.coverage.reporter.find((reporter) => { ? options.coverage.reporter.find((reporter) => {
if (typeof reporter === 'string') if (typeof reporter === 'string') {
return reporter === 'html' return reporter === 'html'
}
return reporter[0] === 'html' return reporter[0] === 'html'
}) })
: undefined : undefined
if (!htmlReporter) if (!htmlReporter) {
return undefined return undefined
}
// reportsDirectory not resolved yet // reportsDirectory not resolved yet
const root = resolve( const root = resolve(
@ -326,12 +375,16 @@ function resolveCoverageFolder(project: WorkspaceProject) {
options.coverage.reportsDirectory || coverageConfigDefaults.reportsDirectory, options.coverage.reportsDirectory || coverageConfigDefaults.reportsDirectory,
) )
const subdir = (Array.isArray(htmlReporter) && htmlReporter.length > 1 && 'subdir' in htmlReporter[1]) const subdir
? htmlReporter[1].subdir = Array.isArray(htmlReporter)
: undefined && htmlReporter.length > 1
&& 'subdir' in htmlReporter[1]
? htmlReporter[1].subdir
: undefined
if (!subdir || typeof subdir !== 'string') if (!subdir || typeof subdir !== 'string') {
return [root, `/${basename(root)}/`] return [root, `/${basename(root)}/`]
}
return [resolve(root, subdir), `/${basename(root)}/${subdir}/`] return [resolve(root, subdir), `/${basename(root)}/${subdir}/`]
} }
@ -340,10 +393,9 @@ function wrapConfig(config: ResolvedConfig): ResolvedConfig {
return { return {
...config, ...config,
// workaround RegExp serialization // workaround RegExp serialization
testNamePattern: testNamePattern: config.testNamePattern
config.testNamePattern ? (config.testNamePattern.toString() as any as RegExp)
? config.testNamePattern.toString() as any as RegExp : undefined,
: undefined,
} }
} }
@ -351,17 +403,30 @@ function replacer(code: string, values: Record<string, string>) {
return code.replace(/\{\s*(\w+)\s*\}/g, (_, key) => values[key] ?? '') return code.replace(/\{\s*(\w+)\s*\}/g, (_, key) => values[key] ?? '')
} }
async function formatScripts(scripts: BrowserScript[] | undefined, server: ViteDevServer) { async function formatScripts(
if (!scripts?.length) scripts: BrowserScript[] | undefined,
server: ViteDevServer,
) {
if (!scripts?.length) {
return '' return ''
const promises = scripts.map(async ({ content, src, async, id, type = 'module' }, index) => { }
const srcLink = (src ? (await server.pluginContainer.resolveId(src))?.id : undefined) || src const promises = scripts.map(
const transformId = srcLink || join(server.config.root, `virtual__${id || `injected-${index}.js`}`) async ({ content, src, async, id, type = 'module' }, index) => {
await server.moduleGraph.ensureEntryFromUrl(transformId) const srcLink
const contentProcessed = content && type === 'module' = (src ? (await server.pluginContainer.resolveId(src))?.id : undefined)
? (await server.pluginContainer.transform(content, transformId)).code || src
: content const transformId
return `<script type="${type}"${async ? ' async' : ''}${srcLink ? ` src="${slash(`/@fs/${srcLink}`)}"` : ''}>${contentProcessed || ''}</script>` = srcLink
}) || join(server.config.root, `virtual__${id || `injected-${index}.js`}`)
await server.moduleGraph.ensureEntryFromUrl(transformId)
const contentProcessed
= content && type === 'module'
? (await server.pluginContainer.transform(content, transformId)).code
: content
return `<script type="${type}"${async ? ' async' : ''}${
srcLink ? ` src="${slash(`/@fs/${srcLink}`)}"` : ''
}>${contentProcessed || ''}</script>`
},
)
return (await Promise.all(promises)).join('\n') return (await Promise.all(promises)).join('\n')
} }

View File

@ -13,39 +13,55 @@ const __dirname = dirname(fileURLToPath(import.meta.url))
export default function BrowserContext(project: WorkspaceProject): Plugin { export default function BrowserContext(project: WorkspaceProject): Plugin {
project.config.browser.commands ??= {} project.config.browser.commands ??= {}
for (const [name, command] of Object.entries(builtinCommands)) for (const [name, command] of Object.entries(builtinCommands)) {
project.config.browser.commands[name] ??= command project.config.browser.commands[name] ??= command
}
// validate names because they can't be used as identifiers // validate names because they can't be used as identifiers
for (const command in project.config.browser.commands) { for (const command in project.config.browser.commands) {
if (!/^[a-z_$][\w$]*$/i.test(command)) if (!/^[a-z_$][\w$]*$/i.test(command)) {
throw new Error(`Invalid command name "${command}". Only alphanumeric characters, $ and _ are allowed.`) throw new Error(
`Invalid command name "${command}". Only alphanumeric characters, $ and _ are allowed.`,
)
}
} }
return { return {
name: 'vitest:browser:virtual-module:context', name: 'vitest:browser:virtual-module:context',
enforce: 'pre', enforce: 'pre',
resolveId(id) { resolveId(id) {
if (id === ID_CONTEXT) if (id === ID_CONTEXT) {
return VIRTUAL_ID_CONTEXT return VIRTUAL_ID_CONTEXT
}
}, },
load(id) { load(id) {
if (id === VIRTUAL_ID_CONTEXT) if (id === VIRTUAL_ID_CONTEXT) {
return generateContextFile.call(this, project) return generateContextFile.call(this, project)
}
}, },
} }
} }
async function generateContextFile(this: PluginContext, project: WorkspaceProject) { async function generateContextFile(
this: PluginContext,
project: WorkspaceProject,
) {
const commands = Object.keys(project.config.browser.commands ?? {}) const commands = Object.keys(project.config.browser.commands ?? {})
const filepathCode = '__vitest_worker__.filepath || __vitest_worker__.current?.file?.filepath || undefined' const filepathCode
= '__vitest_worker__.filepath || __vitest_worker__.current?.file?.filepath || undefined'
const provider = project.browserProvider! const provider = project.browserProvider!
const commandsCode = commands.filter(command => !command.startsWith('__vitest')).map((command) => { const commandsCode = commands
return ` ["${command}"]: (...args) => rpc().triggerCommand(contextId, "${command}", filepath(), args),` .filter(command => !command.startsWith('__vitest'))
}).join('\n') .map((command) => {
return ` ["${command}"]: (...args) => rpc().triggerCommand(contextId, "${command}", filepath(), args),`
})
.join('\n')
const userEventNonProviderImport = await getUserEventImport(provider, this.resolve.bind(this)) const userEventNonProviderImport = await getUserEventImport(
provider,
this.resolve.bind(this),
)
const distContextPath = slash(`/@fs/${resolve(__dirname, 'context.js')}`) const distContextPath = slash(`/@fs/${resolve(__dirname, 'context.js')}`)
return ` return `
@ -65,16 +81,25 @@ export const server = {
} }
} }
export const commands = server.commands export const commands = server.commands
export const userEvent = ${provider.name === 'preview' ? '__vitest_user_event__' : '__userEvent_CDP__'} export const userEvent = ${
provider.name === 'preview' ? '__vitest_user_event__' : '__userEvent_CDP__'
}
export { page } export { page }
` `
} }
async function getUserEventImport(provider: BrowserProvider, resolve: (id: string, importer: string) => Promise<null | { id: string }>) { async function getUserEventImport(
if (provider.name !== 'preview') provider: BrowserProvider,
resolve: (id: string, importer: string) => Promise<null | { id: string }>,
) {
if (provider.name !== 'preview') {
return '' return ''
}
const resolved = await resolve('@testing-library/user-event', __dirname) const resolved = await resolve('@testing-library/user-event', __dirname)
if (!resolved) if (!resolved) {
throw new Error(`Failed to resolve user-event package from ${__dirname}`) throw new Error(`Failed to resolve user-event package from ${__dirname}`)
return `import { userEvent as __vitest_user_event__ } from '${slash(`/@fs/${resolved.id}`)}'` }
return `import { userEvent as __vitest_user_event__ } from '${slash(
`/@fs/${resolved.id}`,
)}'`
} }

View File

@ -9,8 +9,9 @@ export default (): Plugin => {
enforce: 'post', enforce: 'post',
transform(source, id) { transform(source, id) {
// TODO: test is not called for static imports // TODO: test is not called for static imports
if (!regexDynamicImport.test(source)) if (!regexDynamicImport.test(source)) {
return return
}
return injectDynamicImport(source, id, this.parse) return injectDynamicImport(source, id, this.parse)
}, },
} }

View File

@ -1,10 +1,21 @@
import type { Browser, BrowserContext, BrowserContextOptions, LaunchOptions, Page } from 'playwright' import type {
import type { BrowserProvider, BrowserProviderInitializationOptions, WorkspaceProject } from 'vitest/node' Browser,
BrowserContext,
BrowserContextOptions,
LaunchOptions,
Page,
} from 'playwright'
import type {
BrowserProvider,
BrowserProviderInitializationOptions,
WorkspaceProject,
} from 'vitest/node'
export const playwrightBrowsers = ['firefox', 'webkit', 'chromium'] as const export const playwrightBrowsers = ['firefox', 'webkit', 'chromium'] as const
export type PlaywrightBrowser = typeof playwrightBrowsers[number] export type PlaywrightBrowser = (typeof playwrightBrowsers)[number]
export interface PlaywrightProviderOptions extends BrowserProviderInitializationOptions { export interface PlaywrightProviderOptions
extends BrowserProviderInitializationOptions {
browser: PlaywrightBrowser browser: PlaywrightBrowser
} }
@ -31,18 +42,23 @@ export class PlaywrightBrowserProvider implements BrowserProvider {
return playwrightBrowsers return playwrightBrowsers
} }
initialize(project: WorkspaceProject, { browser, options }: PlaywrightProviderOptions) { initialize(
project: WorkspaceProject,
{ browser, options }: PlaywrightProviderOptions,
) {
this.ctx = project this.ctx = project
this.browserName = browser this.browserName = browser
this.options = options as any this.options = options as any
} }
private async openBrowser() { private async openBrowser() {
if (this.browserPromise) if (this.browserPromise) {
return this.browserPromise return this.browserPromise
}
if (this.browser) if (this.browser) {
return this.browser return this.browser
}
this.browserPromise = (async () => { this.browserPromise = (async () => {
const options = this.ctx.config.browser const options = this.ctx.config.browser
@ -62,8 +78,9 @@ export class PlaywrightBrowserProvider implements BrowserProvider {
} }
private async createContext(contextId: string) { private async createContext(contextId: string) {
if (this.contexts.has(contextId)) if (this.contexts.has(contextId)) {
return this.contexts.get(contextId)! return this.contexts.get(contextId)!
}
const browser = await this.openBrowser() const browser = await this.openBrowser()
const context = await browser.newContext({ const context = await browser.newContext({
@ -77,8 +94,9 @@ export class PlaywrightBrowserProvider implements BrowserProvider {
public getPage(contextId: string) { public getPage(contextId: string) {
const page = this.pages.get(contextId) const page = this.pages.get(contextId)
if (!page) if (!page) {
throw new Error(`Page "${contextId}" not found`) throw new Error(`Page "${contextId}" not found`)
}
return page return page
} }

View File

@ -22,14 +22,18 @@ export class PreviewBrowserProvider implements BrowserProvider {
async initialize(ctx: WorkspaceProject) { async initialize(ctx: WorkspaceProject) {
this.ctx = ctx this.ctx = ctx
this.open = false this.open = false
if (ctx.config.browser.headless) if (ctx.config.browser.headless) {
throw new Error('You\'ve enabled headless mode for "preview" provider but it doesn\'t support it. Use "playwright" or "webdriverio" instead: https://vitest.dev/guide/browser#configuration') throw new Error(
'You\'ve enabled headless mode for "preview" provider but it doesn\'t support it. Use "playwright" or "webdriverio" instead: https://vitest.dev/guide/browser#configuration',
)
}
} }
async openPage(_contextId: string, url: string) { async openPage(_contextId: string, url: string) {
this.open = true this.open = true
if (!this.ctx.browser) if (!this.ctx.browser) {
throw new Error('Browser is not initialized') throw new Error('Browser is not initialized')
}
const options = this.ctx.browser.config.server const options = this.ctx.browser.config.server
const _open = options.open const _open = options.open
options.open = url options.open = url
@ -37,6 +41,5 @@ export class PreviewBrowserProvider implements BrowserProvider {
options.open = _open options.open = _open
} }
async close() { async close() {}
}
} }

View File

@ -1,10 +1,15 @@
import type { BrowserProvider, BrowserProviderInitializationOptions, WorkspaceProject } from 'vitest/node' import type {
BrowserProvider,
BrowserProviderInitializationOptions,
WorkspaceProject,
} from 'vitest/node'
import type { RemoteOptions } from 'webdriverio' import type { RemoteOptions } from 'webdriverio'
const webdriverBrowsers = ['firefox', 'chrome', 'edge', 'safari'] as const const webdriverBrowsers = ['firefox', 'chrome', 'edge', 'safari'] as const
type WebdriverBrowser = typeof webdriverBrowsers[number] type WebdriverBrowser = (typeof webdriverBrowsers)[number]
interface WebdriverProviderOptions extends BrowserProviderInitializationOptions { interface WebdriverProviderOptions
extends BrowserProviderInitializationOptions {
browser: WebdriverBrowser browser: WebdriverBrowser
} }
@ -23,7 +28,10 @@ export class WebdriverBrowserProvider implements BrowserProvider {
return webdriverBrowsers return webdriverBrowsers
} }
async initialize(ctx: WorkspaceProject, { browser, options }: WebdriverProviderOptions) { async initialize(
ctx: WorkspaceProject,
{ browser, options }: WebdriverProviderOptions,
) {
this.ctx = ctx this.ctx = ctx
this.browserName = browser this.browserName = browser
this.options = options as RemoteOptions this.options = options as RemoteOptions
@ -31,7 +39,10 @@ export class WebdriverBrowserProvider implements BrowserProvider {
async beforeCommand() { async beforeCommand() {
const page = this.browser! const page = this.browser!
const iframe = await page.findElement('css selector', 'iframe[data-vitest]') const iframe = await page.findElement(
'css selector',
'iframe[data-vitest]',
)
await page.switchToFrame(iframe) await page.switchToFrame(iframe)
} }
@ -46,14 +57,18 @@ export class WebdriverBrowserProvider implements BrowserProvider {
} }
async openBrowser() { async openBrowser() {
if (this.browser) if (this.browser) {
return this.browser return this.browser
}
const options = this.ctx.config.browser const options = this.ctx.config.browser
if (this.browserName === 'safari') { if (this.browserName === 'safari') {
if (options.headless) if (options.headless) {
throw new Error('You\'ve enabled headless mode for Safari but it doesn\'t currently support it.') throw new Error(
'You\'ve enabled headless mode for Safari but it doesn\'t currently support it.',
)
}
} }
const { remote } = await import('webdriverio') const { remote } = await import('webdriverio')
@ -85,7 +100,7 @@ export class WebdriverBrowserProvider implements BrowserProvider {
if (browser !== 'safari' && options.headless) { if (browser !== 'safari' && options.headless) {
const [key, args] = headlessMap[browser] const [key, args] = headlessMap[browser]
const currentValues = (this.options?.capabilities as any)?.[key] || {} const currentValues = (this.options?.capabilities as any)?.[key] || {}
const newArgs = [...currentValues.args || [], ...args] const newArgs = [...(currentValues.args || []), ...args]
capabilities[key] = { ...currentValues, args: newArgs as any } capabilities[key] = { ...currentValues, args: newArgs as any }
} }

View File

@ -1,10 +1,7 @@
{ {
"extends": "../../tsconfig.base.json", "extends": "../../tsconfig.base.json",
"compilerOptions": { "compilerOptions": {
"types": [ "types": ["node", "vite/client"]
"node",
"vite/client"
]
}, },
"exclude": ["dist", "node_modules"] "exclude": ["dist", "node_modules"]
} }

View File

@ -48,8 +48,6 @@ export default () => [
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
}, },
] ]

View File

@ -3,7 +3,9 @@ import { COVERAGE_STORE_KEY } from './constants'
export async function getProvider() { export async function getProvider() {
// to not bundle the provider // to not bundle the provider
const providerPath = './provider.js' const providerPath = './provider.js'
const { IstanbulCoverageProvider } = await import(providerPath) as typeof import('./provider') const { IstanbulCoverageProvider } = (await import(
providerPath
)) as typeof import('./provider')
return new IstanbulCoverageProvider() return new IstanbulCoverageProvider()
} }

View File

@ -1,7 +1,23 @@
import { existsSync, promises as fs, readdirSync, writeFileSync } from 'node:fs' import {
existsSync,
promises as fs,
readdirSync,
writeFileSync,
} from 'node:fs'
import { resolve } from 'pathe' import { resolve } from 'pathe'
import type { AfterSuiteRunMeta, CoverageIstanbulOptions, CoverageProvider, ReportContext, ResolvedCoverageOptions, Vitest } from 'vitest' import type {
import { coverageConfigDefaults, defaultExclude, defaultInclude } from 'vitest/config' AfterSuiteRunMeta,
CoverageIstanbulOptions,
CoverageProvider,
ReportContext,
ResolvedCoverageOptions,
Vitest,
} from 'vitest'
import {
coverageConfigDefaults,
defaultExclude,
defaultInclude,
} from 'vitest/config'
import { BaseCoverageProvider } from 'vitest/coverage' import { BaseCoverageProvider } from 'vitest/coverage'
import c from 'picocolors' import c from 'picocolors'
import { parseModule } from 'magicast' import { parseModule } from 'magicast'
@ -19,11 +35,16 @@ import { COVERAGE_STORE_KEY } from './constants'
type Options = ResolvedCoverageOptions<'istanbul'> type Options = ResolvedCoverageOptions<'istanbul'>
type Filename = string type Filename = string
type CoverageFilesByTransformMode = Record<AfterSuiteRunMeta['transformMode'], Filename[]> type CoverageFilesByTransformMode = Record<
type ProjectName = NonNullable<AfterSuiteRunMeta['projectName']> | typeof DEFAULT_PROJECT AfterSuiteRunMeta['transformMode'],
Filename[]
>
type ProjectName =
| NonNullable<AfterSuiteRunMeta['projectName']>
| typeof DEFAULT_PROJECT
interface TestExclude { interface TestExclude {
new(opts: { new (opts: {
cwd?: string | string[] cwd?: string | string[]
include?: string | string[] include?: string | string[]
exclude?: string | string[] exclude?: string | string[]
@ -40,7 +61,9 @@ const DEFAULT_PROJECT = Symbol.for('default-project')
const debug = createDebug('vitest:coverage') const debug = createDebug('vitest:coverage')
let uniqueId = 0 let uniqueId = 0
export class IstanbulCoverageProvider extends BaseCoverageProvider implements CoverageProvider { export class IstanbulCoverageProvider
extends BaseCoverageProvider
implements CoverageProvider {
name = 'istanbul' name = 'istanbul'
ctx!: Vitest ctx!: Vitest
@ -64,15 +87,22 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
// Resolved fields // Resolved fields
provider: 'istanbul', provider: 'istanbul',
reportsDirectory: resolve(ctx.config.root, config.reportsDirectory || coverageConfigDefaults.reportsDirectory), reportsDirectory: resolve(
reporter: this.resolveReporters(config.reporter || coverageConfigDefaults.reporter), ctx.config.root,
config.reportsDirectory || coverageConfigDefaults.reportsDirectory,
),
reporter: this.resolveReporters(
config.reporter || coverageConfigDefaults.reporter,
),
thresholds: config.thresholds && { thresholds: config.thresholds && {
...config.thresholds, ...config.thresholds,
lines: config.thresholds['100'] ? 100 : config.thresholds.lines, lines: config.thresholds['100'] ? 100 : config.thresholds.lines,
branches: config.thresholds['100'] ? 100 : config.thresholds.branches, branches: config.thresholds['100'] ? 100 : config.thresholds.branches,
functions: config.thresholds['100'] ? 100 : config.thresholds.functions, functions: config.thresholds['100'] ? 100 : config.thresholds.functions,
statements: config.thresholds['100'] ? 100 : config.thresholds.statements, statements: config.thresholds['100']
? 100
: config.thresholds.statements,
}, },
} }
@ -90,7 +120,10 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
this.testExclude = new _TestExclude({ this.testExclude = new _TestExclude({
cwd: ctx.config.root, cwd: ctx.config.root,
include: typeof this.options.include === 'undefined' ? undefined : [...this.options.include], include:
typeof this.options.include === 'undefined'
? undefined
: [...this.options.include],
exclude: [...defaultExclude, ...defaultInclude, ...this.options.exclude], exclude: [...defaultExclude, ...defaultInclude, ...this.options.exclude],
excludeNodeModules: true, excludeNodeModules: true,
extension: this.options.extension, extension: this.options.extension,
@ -98,9 +131,14 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
}) })
const shard = this.ctx.config.shard const shard = this.ctx.config.shard
const tempDirectory = `.tmp${shard ? `-${shard.index}-${shard.count}` : ''}` const tempDirectory = `.tmp${
shard ? `-${shard.index}-${shard.count}` : ''
}`
this.coverageFilesDirectory = resolve(this.options.reportsDirectory, tempDirectory) this.coverageFilesDirectory = resolve(
this.options.reportsDirectory,
tempDirectory,
)
} }
resolveOptions() { resolveOptions() {
@ -108,16 +146,24 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
} }
onFileTransform(sourceCode: string, id: string, pluginCtx: any) { onFileTransform(sourceCode: string, id: string, pluginCtx: any) {
if (!this.testExclude.shouldInstrument(id)) if (!this.testExclude.shouldInstrument(id)) {
return return
}
const sourceMap = pluginCtx.getCombinedSourcemap() const sourceMap = pluginCtx.getCombinedSourcemap()
sourceMap.sources = sourceMap.sources.map(removeQueryParameters) sourceMap.sources = sourceMap.sources.map(removeQueryParameters)
// Exclude SWC's decorators that are left in source maps // Exclude SWC's decorators that are left in source maps
sourceCode = sourceCode.replaceAll('_ts_decorate', '/* istanbul ignore next */_ts_decorate') sourceCode = sourceCode.replaceAll(
'_ts_decorate',
'/* istanbul ignore next */_ts_decorate',
)
const code = this.instrumenter.instrumentSync(sourceCode, id, sourceMap as any) const code = this.instrumenter.instrumentSync(
sourceCode,
id,
sourceMap as any,
)
const map = this.instrumenter.lastSourceMap() as any const map = this.instrumenter.lastSourceMap() as any
return { code, map } return { code, map }
@ -129,11 +175,13 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
* backwards compatibility is a breaking change. * backwards compatibility is a breaking change.
*/ */
onAfterSuiteRun({ coverage, transformMode, projectName }: AfterSuiteRunMeta) { onAfterSuiteRun({ coverage, transformMode, projectName }: AfterSuiteRunMeta) {
if (!coverage) if (!coverage) {
return return
}
if (transformMode !== 'web' && transformMode !== 'ssr') if (transformMode !== 'web' && transformMode !== 'ssr') {
throw new Error(`Invalid transform mode: ${transformMode}`) throw new Error(`Invalid transform mode: ${transformMode}`)
}
let entry = this.coverageFiles.get(projectName || DEFAULT_PROJECT) let entry = this.coverageFiles.get(projectName || DEFAULT_PROJECT)
@ -142,7 +190,10 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
this.coverageFiles.set(projectName || DEFAULT_PROJECT, entry) this.coverageFiles.set(projectName || DEFAULT_PROJECT, entry)
} }
const filename = resolve(this.coverageFilesDirectory, `coverage-${uniqueId++}.json`) const filename = resolve(
this.coverageFilesDirectory,
`coverage-${uniqueId++}.json`,
)
entry[transformMode].push(filename) entry[transformMode].push(filename)
const promise = fs.writeFile(filename, JSON.stringify(coverage), 'utf-8') const promise = fs.writeFile(filename, JSON.stringify(coverage), 'utf-8')
@ -150,11 +201,21 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
} }
async clean(clean = true) { async clean(clean = true) {
if (clean && existsSync(this.options.reportsDirectory)) if (clean && existsSync(this.options.reportsDirectory)) {
await fs.rm(this.options.reportsDirectory, { recursive: true, force: true, maxRetries: 10 }) await fs.rm(this.options.reportsDirectory, {
recursive: true,
force: true,
maxRetries: 10,
})
}
if (existsSync(this.coverageFilesDirectory)) if (existsSync(this.coverageFilesDirectory)) {
await fs.rm(this.coverageFilesDirectory, { recursive: true, force: true, maxRetries: 10 }) await fs.rm(this.coverageFilesDirectory, {
recursive: true,
force: true,
maxRetries: 10,
})
}
await fs.mkdir(this.coverageFilesDirectory, { recursive: true }) await fs.mkdir(this.coverageFilesDirectory, { recursive: true })
@ -171,33 +232,45 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
this.pendingPromises = [] this.pendingPromises = []
for (const coveragePerProject of this.coverageFiles.values()) { for (const coveragePerProject of this.coverageFiles.values()) {
for (const filenames of [coveragePerProject.ssr, coveragePerProject.web]) { for (const filenames of [
coveragePerProject.ssr,
coveragePerProject.web,
]) {
const coverageMapByTransformMode = libCoverage.createCoverageMap({}) const coverageMapByTransformMode = libCoverage.createCoverageMap({})
for (const chunk of this.toSlices(filenames, this.options.processingConcurrency)) { for (const chunk of this.toSlices(
filenames,
this.options.processingConcurrency,
)) {
if (debug.enabled) { if (debug.enabled) {
index += chunk.length index += chunk.length
debug('Covered files %d/%d', index, total) debug('Covered files %d/%d', index, total)
} }
await Promise.all(chunk.map(async (filename) => { await Promise.all(
const contents = await fs.readFile(filename, 'utf-8') chunk.map(async (filename) => {
const coverage = JSON.parse(contents) as CoverageMap const contents = await fs.readFile(filename, 'utf-8')
const coverage = JSON.parse(contents) as CoverageMap
coverageMapByTransformMode.merge(coverage) coverageMapByTransformMode.merge(coverage)
})) }),
)
} }
// Source maps can change based on projectName and transform mode. // Source maps can change based on projectName and transform mode.
// Coverage transform re-uses source maps so we need to separate transforms from each other. // Coverage transform re-uses source maps so we need to separate transforms from each other.
const transformedCoverage = await transformCoverage(coverageMapByTransformMode) const transformedCoverage = await transformCoverage(
coverageMapByTransformMode,
)
coverageMap.merge(transformedCoverage) coverageMap.merge(transformedCoverage)
} }
} }
if (this.options.all && allTestsRun) { if (this.options.all && allTestsRun) {
const coveredFiles = coverageMap.files() const coveredFiles = coverageMap.files()
const uncoveredCoverage = await this.getCoverageMapForUncoveredFiles(coveredFiles) const uncoveredCoverage = await this.getCoverageMapForUncoveredFiles(
coveredFiles,
)
coverageMap.merge(await transformCoverage(uncoveredCoverage)) coverageMap.merge(await transformCoverage(uncoveredCoverage))
} }
@ -207,7 +280,7 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
async reportCoverage(coverageMap: unknown, { allTestsRun }: ReportContext) { async reportCoverage(coverageMap: unknown, { allTestsRun }: ReportContext) {
await this.generateReports( await this.generateReports(
coverageMap as CoverageMap || libCoverage.createCoverageMap({}), (coverageMap as CoverageMap) || libCoverage.createCoverageMap({}),
allTestsRun, allTestsRun,
) )
@ -219,28 +292,37 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
await fs.rm(this.coverageFilesDirectory, { recursive: true }) await fs.rm(this.coverageFilesDirectory, { recursive: true })
// Remove empty reports directory, e.g. when only text-reporter is used // Remove empty reports directory, e.g. when only text-reporter is used
if (readdirSync(this.options.reportsDirectory).length === 0) if (readdirSync(this.options.reportsDirectory).length === 0) {
await fs.rm(this.options.reportsDirectory, { recursive: true }) await fs.rm(this.options.reportsDirectory, { recursive: true })
}
} }
} }
async generateReports(coverageMap: CoverageMap, allTestsRun: boolean | undefined) { async generateReports(
coverageMap: CoverageMap,
allTestsRun: boolean | undefined,
) {
const context = libReport.createContext({ const context = libReport.createContext({
dir: this.options.reportsDirectory, dir: this.options.reportsDirectory,
coverageMap, coverageMap,
watermarks: this.options.watermarks, watermarks: this.options.watermarks,
}) })
if (this.hasTerminalReporter(this.options.reporter)) if (this.hasTerminalReporter(this.options.reporter)) {
this.ctx.logger.log(c.blue(' % ') + c.dim('Coverage report from ') + c.yellow(this.name)) this.ctx.logger.log(
c.blue(' % ') + c.dim('Coverage report from ') + c.yellow(this.name),
)
}
for (const reporter of this.options.reporter) { for (const reporter of this.options.reporter) {
// Type assertion required for custom reporters // Type assertion required for custom reporters
reports.create(reporter[0] as Parameters<typeof reports.create>[0], { reports
skipFull: this.options.skipFull, .create(reporter[0] as Parameters<typeof reports.create>[0], {
projectRoot: this.ctx.config.root, skipFull: this.options.skipFull,
...reporter[1], projectRoot: this.ctx.config.root,
}).execute(context) ...reporter[1],
})
.execute(context)
} }
if (this.options.thresholds) { if (this.options.thresholds) {
@ -257,17 +339,27 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
}) })
if (this.options.thresholds.autoUpdate && allTestsRun) { if (this.options.thresholds.autoUpdate && allTestsRun) {
if (!this.ctx.server.config.configFile) if (!this.ctx.server.config.configFile) {
throw new Error('Missing configurationFile. The "coverage.thresholds.autoUpdate" can only be enabled when configuration file is used.') throw new Error(
'Missing configurationFile. The "coverage.thresholds.autoUpdate" can only be enabled when configuration file is used.',
)
}
const configFilePath = this.ctx.server.config.configFile const configFilePath = this.ctx.server.config.configFile
const configModule = parseModule(await fs.readFile(configFilePath, 'utf8')) const configModule = parseModule(
await fs.readFile(configFilePath, 'utf8'),
)
this.updateThresholds({ this.updateThresholds({
thresholds: resolvedThresholds, thresholds: resolvedThresholds,
perFile: this.options.thresholds.perFile, perFile: this.options.thresholds.perFile,
configurationFile: configModule, configurationFile: configModule,
onUpdate: () => writeFileSync(configFilePath, configModule.generate().code, 'utf-8'), onUpdate: () =>
writeFileSync(
configFilePath,
configModule.generate().code,
'utf-8',
),
}) })
} }
} }
@ -276,18 +368,24 @@ export class IstanbulCoverageProvider extends BaseCoverageProvider implements Co
async mergeReports(coverageMaps: unknown[]) { async mergeReports(coverageMaps: unknown[]) {
const coverageMap = libCoverage.createCoverageMap({}) const coverageMap = libCoverage.createCoverageMap({})
for (const coverage of coverageMaps) for (const coverage of coverageMaps) {
coverageMap.merge(coverage as CoverageMap) coverageMap.merge(coverage as CoverageMap)
}
await this.generateReports(coverageMap, true) await this.generateReports(coverageMap, true)
} }
private async getCoverageMapForUncoveredFiles(coveredFiles: string[]) { private async getCoverageMapForUncoveredFiles(coveredFiles: string[]) {
const allFiles = await this.testExclude.glob(this.ctx.config.root) const allFiles = await this.testExclude.glob(this.ctx.config.root)
let includedFiles = allFiles.map(file => resolve(this.ctx.config.root, file)) let includedFiles = allFiles.map(file =>
resolve(this.ctx.config.root, file),
)
if (this.ctx.config.changed) if (this.ctx.config.changed) {
includedFiles = (this.ctx.config.related || []).filter(file => includedFiles.includes(file)) includedFiles = (this.ctx.config.related || []).filter(file =>
includedFiles.includes(file),
)
}
const uncoveredFiles = includedFiles const uncoveredFiles = includedFiles
.filter(file => !coveredFiles.includes(file)) .filter(file => !coveredFiles.includes(file))

View File

@ -49,8 +49,6 @@ export default () => [
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
}, },
] ]

View File

@ -5,7 +5,9 @@ export default {
async getProvider() { async getProvider() {
// to not bundle the provider // to not bundle the provider
const name = './provider.js' const name = './provider.js'
const { V8CoverageProvider } = await import(name) as typeof import('./provider') const { V8CoverageProvider } = (await import(
name
)) as typeof import('./provider')
return new V8CoverageProvider() return new V8CoverageProvider()
}, },
} }

View File

@ -1,4 +1,9 @@
import { existsSync, promises as fs, readdirSync, writeFileSync } from 'node:fs' import {
existsSync,
promises as fs,
readdirSync,
writeFileSync,
} from 'node:fs'
import type { Profiler } from 'node:inspector' import type { Profiler } from 'node:inspector'
import { fileURLToPath, pathToFileURL } from 'node:url' import { fileURLToPath, pathToFileURL } from 'node:url'
import v8ToIstanbul from 'v8-to-istanbul' import v8ToIstanbul from 'v8-to-istanbul'
@ -18,16 +23,26 @@ import { stripLiteral } from 'strip-literal'
import createDebug from 'debug' import createDebug from 'debug'
import { cleanUrl } from 'vite-node/utils' import { cleanUrl } from 'vite-node/utils'
import type { EncodedSourceMap, FetchResult } from 'vite-node' import type { EncodedSourceMap, FetchResult } from 'vite-node'
import { coverageConfigDefaults, defaultExclude, defaultInclude } from 'vitest/config' import {
coverageConfigDefaults,
defaultExclude,
defaultInclude,
} from 'vitest/config'
import { BaseCoverageProvider } from 'vitest/coverage' import { BaseCoverageProvider } from 'vitest/coverage'
import type { AfterSuiteRunMeta, CoverageProvider, CoverageV8Options, ReportContext, ResolvedCoverageOptions } from 'vitest' import type {
AfterSuiteRunMeta,
CoverageProvider,
CoverageV8Options,
ReportContext,
ResolvedCoverageOptions,
} from 'vitest'
import type { Vitest } from 'vitest/node' import type { Vitest } from 'vitest/node'
// @ts-expect-error missing types // @ts-expect-error missing types
import _TestExclude from 'test-exclude' import _TestExclude from 'test-exclude'
interface TestExclude { interface TestExclude {
new(opts: { new (opts: {
cwd?: string | string[] cwd?: string | string[]
include?: string | string[] include?: string | string[]
exclude?: string | string[] exclude?: string | string[]
@ -44,21 +59,30 @@ type Options = ResolvedCoverageOptions<'v8'>
type TransformResults = Map<string, FetchResult> type TransformResults = Map<string, FetchResult>
type Filename = string type Filename = string
type RawCoverage = Profiler.TakePreciseCoverageReturnType type RawCoverage = Profiler.TakePreciseCoverageReturnType
type CoverageFilesByTransformMode = Record<AfterSuiteRunMeta['transformMode'], Filename[]> type CoverageFilesByTransformMode = Record<
type ProjectName = NonNullable<AfterSuiteRunMeta['projectName']> | typeof DEFAULT_PROJECT AfterSuiteRunMeta['transformMode'],
Filename[]
>
type ProjectName =
| NonNullable<AfterSuiteRunMeta['projectName']>
| typeof DEFAULT_PROJECT
// TODO: vite-node should export this // TODO: vite-node should export this
const WRAPPER_LENGTH = 185 const WRAPPER_LENGTH = 185
// Note that this needs to match the line ending as well // Note that this needs to match the line ending as well
const VITE_EXPORTS_LINE_PATTERN = /Object\.defineProperty\(__vite_ssr_exports__.*\n/g const VITE_EXPORTS_LINE_PATTERN
const DECORATOR_METADATA_PATTERN = /_ts_metadata\("design:paramtypes", \[[^\]]*\]\),*/g = /Object\.defineProperty\(__vite_ssr_exports__.*\n/g
const DECORATOR_METADATA_PATTERN
= /_ts_metadata\("design:paramtypes", \[[^\]]*\]\),*/g
const DEFAULT_PROJECT = Symbol.for('default-project') const DEFAULT_PROJECT = Symbol.for('default-project')
const debug = createDebug('vitest:coverage') const debug = createDebug('vitest:coverage')
let uniqueId = 0 let uniqueId = 0
export class V8CoverageProvider extends BaseCoverageProvider implements CoverageProvider { export class V8CoverageProvider
extends BaseCoverageProvider
implements CoverageProvider {
name = 'v8' name = 'v8'
ctx!: Vitest ctx!: Vitest
@ -81,21 +105,31 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
// Resolved fields // Resolved fields
provider: 'v8', provider: 'v8',
reporter: this.resolveReporters(config.reporter || coverageConfigDefaults.reporter), reporter: this.resolveReporters(
reportsDirectory: resolve(ctx.config.root, config.reportsDirectory || coverageConfigDefaults.reportsDirectory), config.reporter || coverageConfigDefaults.reporter,
),
reportsDirectory: resolve(
ctx.config.root,
config.reportsDirectory || coverageConfigDefaults.reportsDirectory,
),
thresholds: config.thresholds && { thresholds: config.thresholds && {
...config.thresholds, ...config.thresholds,
lines: config.thresholds['100'] ? 100 : config.thresholds.lines, lines: config.thresholds['100'] ? 100 : config.thresholds.lines,
branches: config.thresholds['100'] ? 100 : config.thresholds.branches, branches: config.thresholds['100'] ? 100 : config.thresholds.branches,
functions: config.thresholds['100'] ? 100 : config.thresholds.functions, functions: config.thresholds['100'] ? 100 : config.thresholds.functions,
statements: config.thresholds['100'] ? 100 : config.thresholds.statements, statements: config.thresholds['100']
? 100
: config.thresholds.statements,
}, },
} }
this.testExclude = new _TestExclude({ this.testExclude = new _TestExclude({
cwd: ctx.config.root, cwd: ctx.config.root,
include: typeof this.options.include === 'undefined' ? undefined : [...this.options.include], include:
typeof this.options.include === 'undefined'
? undefined
: [...this.options.include],
exclude: [...defaultExclude, ...defaultInclude, ...this.options.exclude], exclude: [...defaultExclude, ...defaultInclude, ...this.options.exclude],
excludeNodeModules: true, excludeNodeModules: true,
extension: this.options.extension, extension: this.options.extension,
@ -103,9 +137,14 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
}) })
const shard = this.ctx.config.shard const shard = this.ctx.config.shard
const tempDirectory = `.tmp${shard ? `-${shard.index}-${shard.count}` : ''}` const tempDirectory = `.tmp${
shard ? `-${shard.index}-${shard.count}` : ''
}`
this.coverageFilesDirectory = resolve(this.options.reportsDirectory, tempDirectory) this.coverageFilesDirectory = resolve(
this.options.reportsDirectory,
tempDirectory,
)
} }
resolveOptions() { resolveOptions() {
@ -113,11 +152,21 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
} }
async clean(clean = true) { async clean(clean = true) {
if (clean && existsSync(this.options.reportsDirectory)) if (clean && existsSync(this.options.reportsDirectory)) {
await fs.rm(this.options.reportsDirectory, { recursive: true, force: true, maxRetries: 10 }) await fs.rm(this.options.reportsDirectory, {
recursive: true,
force: true,
maxRetries: 10,
})
}
if (existsSync(this.coverageFilesDirectory)) if (existsSync(this.coverageFilesDirectory)) {
await fs.rm(this.coverageFilesDirectory, { recursive: true, force: true, maxRetries: 10 }) await fs.rm(this.coverageFilesDirectory, {
recursive: true,
force: true,
maxRetries: 10,
})
}
await fs.mkdir(this.coverageFilesDirectory, { recursive: true }) await fs.mkdir(this.coverageFilesDirectory, { recursive: true })
@ -131,8 +180,9 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
* backwards compatibility is a breaking change. * backwards compatibility is a breaking change.
*/ */
onAfterSuiteRun({ coverage, transformMode, projectName }: AfterSuiteRunMeta) { onAfterSuiteRun({ coverage, transformMode, projectName }: AfterSuiteRunMeta) {
if (transformMode !== 'web' && transformMode !== 'ssr') if (transformMode !== 'web' && transformMode !== 'ssr') {
throw new Error(`Invalid transform mode: ${transformMode}`) throw new Error(`Invalid transform mode: ${transformMode}`)
}
let entry = this.coverageFiles.get(projectName || DEFAULT_PROJECT) let entry = this.coverageFiles.get(projectName || DEFAULT_PROJECT)
@ -141,7 +191,10 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
this.coverageFiles.set(projectName || DEFAULT_PROJECT, entry) this.coverageFiles.set(projectName || DEFAULT_PROJECT, entry)
} }
const filename = resolve(this.coverageFilesDirectory, `coverage-${uniqueId++}.json`) const filename = resolve(
this.coverageFilesDirectory,
`coverage-${uniqueId++}.json`,
)
entry[transformMode].push(filename) entry[transformMode].push(filename)
const promise = fs.writeFile(filename, JSON.stringify(coverage), 'utf-8') const promise = fs.writeFile(filename, JSON.stringify(coverage), 'utf-8')
@ -156,24 +209,38 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
await Promise.all(this.pendingPromises) await Promise.all(this.pendingPromises)
this.pendingPromises = [] this.pendingPromises = []
for (const [projectName, coveragePerProject] of this.coverageFiles.entries()) { for (const [
for (const [transformMode, filenames] of Object.entries(coveragePerProject) as [AfterSuiteRunMeta['transformMode'], Filename[]][]) { projectName,
coveragePerProject,
] of this.coverageFiles.entries()) {
for (const [transformMode, filenames] of Object.entries(
coveragePerProject,
) as [AfterSuiteRunMeta['transformMode'], Filename[]][]) {
let merged: RawCoverage = { result: [] } let merged: RawCoverage = { result: [] }
for (const chunk of this.toSlices(filenames, this.options.processingConcurrency)) { for (const chunk of this.toSlices(
filenames,
this.options.processingConcurrency,
)) {
if (debug.enabled) { if (debug.enabled) {
index += chunk.length index += chunk.length
debug('Covered files %d/%d', index, total) debug('Covered files %d/%d', index, total)
} }
await Promise.all(chunk.map(async (filename) => { await Promise.all(
const contents = await fs.readFile(filename, 'utf-8') chunk.map(async (filename) => {
const coverage = JSON.parse(contents) as RawCoverage const contents = await fs.readFile(filename, 'utf-8')
merged = mergeProcessCovs([merged, coverage]) const coverage = JSON.parse(contents) as RawCoverage
})) merged = mergeProcessCovs([merged, coverage])
}),
)
} }
const converted = await this.convertCoverage(merged, projectName, transformMode) const converted = await this.convertCoverage(
merged,
projectName,
transformMode,
)
// Source maps can change based on projectName and transform mode. // Source maps can change based on projectName and transform mode.
// Coverage transform re-uses source maps so we need to separate transforms from each other. // Coverage transform re-uses source maps so we need to separate transforms from each other.
@ -194,11 +261,17 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
} }
async reportCoverage(coverageMap: unknown, { allTestsRun }: ReportContext) { async reportCoverage(coverageMap: unknown, { allTestsRun }: ReportContext) {
if (provider === 'stackblitz') if (provider === 'stackblitz') {
this.ctx.logger.log(c.blue(' % ') + c.yellow('@vitest/coverage-v8 does not work on Stackblitz. Report will be empty.')) this.ctx.logger.log(
c.blue(' % ')
+ c.yellow(
'@vitest/coverage-v8 does not work on Stackblitz. Report will be empty.',
),
)
}
await this.generateReports( await this.generateReports(
coverageMap as CoverageMap || libCoverage.createCoverageMap({}), (coverageMap as CoverageMap) || libCoverage.createCoverageMap({}),
allTestsRun, allTestsRun,
) )
@ -210,8 +283,9 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
await fs.rm(this.coverageFilesDirectory, { recursive: true }) await fs.rm(this.coverageFilesDirectory, { recursive: true })
// Remove empty reports directory, e.g. when only text-reporter is used // Remove empty reports directory, e.g. when only text-reporter is used
if (readdirSync(this.options.reportsDirectory).length === 0) if (readdirSync(this.options.reportsDirectory).length === 0) {
await fs.rm(this.options.reportsDirectory, { recursive: true }) await fs.rm(this.options.reportsDirectory, { recursive: true })
}
} }
} }
@ -222,16 +296,21 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
watermarks: this.options.watermarks, watermarks: this.options.watermarks,
}) })
if (this.hasTerminalReporter(this.options.reporter)) if (this.hasTerminalReporter(this.options.reporter)) {
this.ctx.logger.log(c.blue(' % ') + c.dim('Coverage report from ') + c.yellow(this.name)) this.ctx.logger.log(
c.blue(' % ') + c.dim('Coverage report from ') + c.yellow(this.name),
)
}
for (const reporter of this.options.reporter) { for (const reporter of this.options.reporter) {
// Type assertion required for custom reporters // Type assertion required for custom reporters
reports.create(reporter[0] as Parameters<typeof reports.create>[0], { reports
skipFull: this.options.skipFull, .create(reporter[0] as Parameters<typeof reports.create>[0], {
projectRoot: this.ctx.config.root, skipFull: this.options.skipFull,
...reporter[1], projectRoot: this.ctx.config.root,
}).execute(context) ...reporter[1],
})
.execute(context)
} }
if (this.options.thresholds) { if (this.options.thresholds) {
@ -248,17 +327,27 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
}) })
if (this.options.thresholds.autoUpdate && allTestsRun) { if (this.options.thresholds.autoUpdate && allTestsRun) {
if (!this.ctx.server.config.configFile) if (!this.ctx.server.config.configFile) {
throw new Error('Missing configurationFile. The "coverage.thresholds.autoUpdate" can only be enabled when configuration file is used.') throw new Error(
'Missing configurationFile. The "coverage.thresholds.autoUpdate" can only be enabled when configuration file is used.',
)
}
const configFilePath = this.ctx.server.config.configFile const configFilePath = this.ctx.server.config.configFile
const configModule = parseModule(await fs.readFile(configFilePath, 'utf8')) const configModule = parseModule(
await fs.readFile(configFilePath, 'utf8'),
)
this.updateThresholds({ this.updateThresholds({
thresholds: resolvedThresholds, thresholds: resolvedThresholds,
perFile: this.options.thresholds.perFile, perFile: this.options.thresholds.perFile,
configurationFile: configModule, configurationFile: configModule,
onUpdate: () => writeFileSync(configFilePath, configModule.generate().code, 'utf-8'), onUpdate: () =>
writeFileSync(
configFilePath,
configModule.generate().code,
'utf-8',
),
}) })
} }
} }
@ -267,20 +356,28 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
async mergeReports(coverageMaps: unknown[]) { async mergeReports(coverageMaps: unknown[]) {
const coverageMap = libCoverage.createCoverageMap({}) const coverageMap = libCoverage.createCoverageMap({})
for (const coverage of coverageMaps) for (const coverage of coverageMaps) {
coverageMap.merge(coverage as CoverageMap) coverageMap.merge(coverage as CoverageMap)
}
await this.generateReports(coverageMap, true) await this.generateReports(coverageMap, true)
} }
private async getUntestedFiles(testedFiles: string[]): Promise<RawCoverage> { private async getUntestedFiles(testedFiles: string[]): Promise<RawCoverage> {
const transformResults = normalizeTransformResults(this.ctx.vitenode.fetchCache) const transformResults = normalizeTransformResults(
this.ctx.vitenode.fetchCache,
)
const allFiles = await this.testExclude.glob(this.ctx.config.root) const allFiles = await this.testExclude.glob(this.ctx.config.root)
let includedFiles = allFiles.map(file => resolve(this.ctx.config.root, file)) let includedFiles = allFiles.map(file =>
resolve(this.ctx.config.root, file),
)
if (this.ctx.config.changed) if (this.ctx.config.changed) {
includedFiles = (this.ctx.config.related || []).filter(file => includedFiles.includes(file)) includedFiles = (this.ctx.config.related || []).filter(file =>
includedFiles.includes(file),
)
}
const uncoveredFiles = includedFiles const uncoveredFiles = includedFiles
.map(file => pathToFileURL(file)) .map(file => pathToFileURL(file))
@ -289,48 +386,67 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
let merged: RawCoverage = { result: [] } let merged: RawCoverage = { result: [] }
let index = 0 let index = 0
for (const chunk of this.toSlices(uncoveredFiles, this.options.processingConcurrency)) { for (const chunk of this.toSlices(
uncoveredFiles,
this.options.processingConcurrency,
)) {
if (debug.enabled) { if (debug.enabled) {
index += chunk.length index += chunk.length
debug('Uncovered files %d/%d', index, uncoveredFiles.length) debug('Uncovered files %d/%d', index, uncoveredFiles.length)
} }
const coverages = await Promise.all(chunk.map(async (filename) => { const coverages = await Promise.all(
const { originalSource, source } = await this.getSources(filename.href, transformResults) chunk.map(async (filename) => {
const { originalSource, source } = await this.getSources(
filename.href,
transformResults,
)
// Ignore empty files, e.g. files that contain only typescript types and no runtime code // Ignore empty files, e.g. files that contain only typescript types and no runtime code
if (source && stripLiteral(source).trim() === '') if (source && stripLiteral(source).trim() === '') {
return null return null
}
const coverage = { const coverage = {
url: filename.href, url: filename.href,
scriptId: '0', scriptId: '0',
// Create a made up function to mark whole file as uncovered. Note that this does not exist in source maps. // Create a made up function to mark whole file as uncovered. Note that this does not exist in source maps.
functions: [{ functions: [
ranges: [{ {
startOffset: 0, ranges: [
endOffset: originalSource.length, {
count: 0, startOffset: 0,
}], endOffset: originalSource.length,
isBlockCoverage: true, count: 0,
// This is magical value that indicates an empty report: https://github.com/istanbuljs/v8-to-istanbul/blob/fca5e6a9e6ef38a9cdc3a178d5a6cf9ef82e6cab/lib/v8-to-istanbul.js#LL131C40-L131C40 },
functionName: '(empty-report)', ],
}], isBlockCoverage: true,
} // This is magical value that indicates an empty report: https://github.com/istanbuljs/v8-to-istanbul/blob/fca5e6a9e6ef38a9cdc3a178d5a6cf9ef82e6cab/lib/v8-to-istanbul.js#LL131C40-L131C40
functionName: '(empty-report)',
},
],
}
return { result: [coverage] } return { result: [coverage] }
})) }),
)
merged = mergeProcessCovs([ merged = mergeProcessCovs([
merged, merged,
...coverages.filter((cov): cov is NonNullable<typeof cov> => cov != null), ...coverages.filter(
(cov): cov is NonNullable<typeof cov> => cov != null,
),
]) ])
} }
return merged return merged
} }
private async getSources(url: string, transformResults: TransformResults, functions: Profiler.FunctionCoverage[] = []): Promise<{ private async getSources(
url: string,
transformResults: TransformResults,
functions: Profiler.FunctionCoverage[] = [],
): Promise<{
source: string source: string
originalSource: string originalSource: string
sourceMap?: { sourcemap: EncodedSourceMap } sourceMap?: { sourcemap: EncodedSourceMap }
@ -339,21 +455,28 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
const filePath = normalize(fileURLToPath(url)) const filePath = normalize(fileURLToPath(url))
let isExecuted = true let isExecuted = true
let transformResult: FetchResult | Awaited<ReturnType<typeof this.ctx.vitenode.transformRequest>> = transformResults.get(filePath) let transformResult:
| FetchResult
| Awaited<ReturnType<typeof this.ctx.vitenode.transformRequest>>
= transformResults.get(filePath)
if (!transformResult) { if (!transformResult) {
isExecuted = false isExecuted = false
transformResult = await this.ctx.vitenode.transformRequest(filePath).catch(() => null) transformResult = await this.ctx.vitenode
.transformRequest(filePath)
.catch(() => null)
} }
const map = transformResult?.map as (EncodedSourceMap | undefined) const map = transformResult?.map as EncodedSourceMap | undefined
const code = transformResult?.code const code = transformResult?.code
const sourcesContent = map?.sourcesContent?.[0] || await fs.readFile(filePath, 'utf-8').catch(() => { const sourcesContent
// If file does not exist construct a dummy source for it. = map?.sourcesContent?.[0]
// These can be files that were generated dynamically during the test run and were removed after it. || (await fs.readFile(filePath, 'utf-8').catch(() => {
const length = findLongestFunctionLength(functions) // If file does not exist construct a dummy source for it.
return '.'.repeat(length) // These can be files that were generated dynamically during the test run and were removed after it.
}) const length = findLongestFunctionLength(functions)
return '.'.repeat(length)
}))
// These can be uncovered files included by "all: true" or files that are loaded outside vite-node // These can be uncovered files included by "all: true" or files that are loaded outside vite-node
if (!map) { if (!map) {
@ -365,8 +488,9 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
} }
const sources = [url] const sources = [url]
if (map.sources && map.sources[0] && !url.endsWith(map.sources[0])) if (map.sources && map.sources[0] && !url.endsWith(map.sources[0])) {
sources[0] = new URL(map.sources[0], url).href sources[0] = new URL(map.sources[0], url).href
}
return { return {
isExecuted, isExecuted,
@ -383,33 +507,58 @@ export class V8CoverageProvider extends BaseCoverageProvider implements Coverage
} }
} }
private async convertCoverage(coverage: RawCoverage, projectName?: ProjectName, transformMode?: 'web' | 'ssr'): Promise<CoverageMap> { private async convertCoverage(
const viteNode = this.ctx.projects.find(project => project.getName() === projectName)?.vitenode || this.ctx.vitenode coverage: RawCoverage,
const fetchCache = transformMode ? viteNode.fetchCaches[transformMode] : viteNode.fetchCache projectName?: ProjectName,
transformMode?: 'web' | 'ssr',
): Promise<CoverageMap> {
const viteNode
= this.ctx.projects.find(project => project.getName() === projectName)
?.vitenode || this.ctx.vitenode
const fetchCache = transformMode
? viteNode.fetchCaches[transformMode]
: viteNode.fetchCache
const transformResults = normalizeTransformResults(fetchCache) const transformResults = normalizeTransformResults(fetchCache)
const scriptCoverages = coverage.result.filter(result => this.testExclude.shouldInstrument(fileURLToPath(result.url))) const scriptCoverages = coverage.result.filter(result =>
this.testExclude.shouldInstrument(fileURLToPath(result.url)),
)
const coverageMap = libCoverage.createCoverageMap({}) const coverageMap = libCoverage.createCoverageMap({})
let index = 0 let index = 0
for (const chunk of this.toSlices(scriptCoverages, this.options.processingConcurrency)) { for (const chunk of this.toSlices(
scriptCoverages,
this.options.processingConcurrency,
)) {
if (debug.enabled) { if (debug.enabled) {
index += chunk.length index += chunk.length
debug('Converting %d/%d', index, scriptCoverages.length) debug('Converting %d/%d', index, scriptCoverages.length)
} }
await Promise.all(chunk.map(async ({ url, functions }) => { await Promise.all(
const sources = await this.getSources(url, transformResults, functions) chunk.map(async ({ url, functions }) => {
const sources = await this.getSources(
url,
transformResults,
functions,
)
// If file was executed by vite-node we'll need to add its wrapper // If file was executed by vite-node we'll need to add its wrapper
const wrapperLength = sources.isExecuted ? WRAPPER_LENGTH : 0 const wrapperLength = sources.isExecuted ? WRAPPER_LENGTH : 0
const converter = v8ToIstanbul(url, wrapperLength, sources, undefined, this.options.ignoreEmptyLines) const converter = v8ToIstanbul(
await converter.load() url,
wrapperLength,
sources,
undefined,
this.options.ignoreEmptyLines,
)
await converter.load()
converter.applyCoverage(functions) converter.applyCoverage(functions)
coverageMap.merge(converter.toIstanbul()) coverageMap.merge(converter.toIstanbul())
})) }),
)
} }
return coverageMap return coverageMap
@ -426,16 +575,25 @@ async function transformCoverage(coverageMap: CoverageMap) {
* - Vite's export helpers: e.g. `Object.defineProperty(__vite_ssr_exports__, "sum", { enumerable: true, configurable: true, get(){ return sum }});` * - Vite's export helpers: e.g. `Object.defineProperty(__vite_ssr_exports__, "sum", { enumerable: true, configurable: true, get(){ return sum }});`
* - SWC's decorator metadata: e.g. `_ts_metadata("design:paramtypes", [\ntypeof Request === "undefined" ? Object : Request\n]),` * - SWC's decorator metadata: e.g. `_ts_metadata("design:paramtypes", [\ntypeof Request === "undefined" ? Object : Request\n]),`
*/ */
function excludeGeneratedCode(source: string | undefined, map: EncodedSourceMap) { function excludeGeneratedCode(
if (!source) source: string | undefined,
map: EncodedSourceMap,
) {
if (!source) {
return map return map
}
if (!source.match(VITE_EXPORTS_LINE_PATTERN) && !source.match(DECORATOR_METADATA_PATTERN)) if (
!source.match(VITE_EXPORTS_LINE_PATTERN)
&& !source.match(DECORATOR_METADATA_PATTERN)
) {
return map return map
}
const trimmed = new MagicString(source) const trimmed = new MagicString(source)
trimmed.replaceAll(VITE_EXPORTS_LINE_PATTERN, '\n') trimmed.replaceAll(VITE_EXPORTS_LINE_PATTERN, '\n')
trimmed.replaceAll(DECORATOR_METADATA_PATTERN, match => '\n'.repeat(match.split('\n').length - 1)) trimmed.replaceAll(DECORATOR_METADATA_PATTERN, match =>
'\n'.repeat(match.split('\n').length - 1))
const trimmedMap = trimmed.generateMap({ hires: 'boundary' }) const trimmedMap = trimmed.generateMap({ hires: 'boundary' })
@ -453,20 +611,26 @@ function excludeGeneratedCode(source: string | undefined, map: EncodedSourceMap)
*/ */
function findLongestFunctionLength(functions: Profiler.FunctionCoverage[]) { function findLongestFunctionLength(functions: Profiler.FunctionCoverage[]) {
return functions.reduce((previous, current) => { return functions.reduce((previous, current) => {
const maxEndOffset = current.ranges.reduce((endOffset, range) => Math.max(endOffset, range.endOffset), 0) const maxEndOffset = current.ranges.reduce(
(endOffset, range) => Math.max(endOffset, range.endOffset),
0,
)
return Math.max(previous, maxEndOffset) return Math.max(previous, maxEndOffset)
}, 0) }, 0)
} }
function normalizeTransformResults(fetchCache: Map<string, { result: FetchResult }>) { function normalizeTransformResults(
fetchCache: Map<string, { result: FetchResult }>,
) {
const normalized: TransformResults = new Map() const normalized: TransformResults = new Map()
for (const [key, value] of fetchCache.entries()) { for (const [key, value] of fetchCache.entries()) {
const cleanEntry = cleanUrl(key) const cleanEntry = cleanUrl(key)
if (!normalized.has(cleanEntry)) if (!normalized.has(cleanEntry)) {
normalized.set(cleanEntry, value.result) normalized.set(cleanEntry, value.result)
}
} }
return normalized return normalized

View File

@ -1,6 +1,6 @@
/* /*
* For details about the Profiler.* messages see https://chromedevtools.github.io/devtools-protocol/v8/Profiler/ * For details about the Profiler.* messages see https://chromedevtools.github.io/devtools-protocol/v8/Profiler/
*/ */
import inspector from 'node:inspector' import inspector from 'node:inspector'
import type { Profiler } from 'node:inspector' import type { Profiler } from 'node:inspector'
@ -20,8 +20,9 @@ export function startCoverage() {
export async function takeCoverage() { export async function takeCoverage() {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
session.post('Profiler.takePreciseCoverage', async (error, coverage) => { session.post('Profiler.takePreciseCoverage', async (error, coverage) => {
if (error) if (error) {
return reject(error) return reject(error)
}
// Reduce amount of data sent over rpc by doing some early result filtering // Reduce amount of data sent over rpc by doing some early result filtering
const result = coverage.result.filter(filterResult) const result = coverage.result.filter(filterResult)
@ -29,8 +30,9 @@ export async function takeCoverage() {
resolve({ result }) resolve({ result })
}) })
if (provider === 'stackblitz') if (provider === 'stackblitz') {
resolve({ result: [] }) resolve({ result: [] })
}
}) })
} }
@ -41,11 +43,13 @@ export function stopCoverage() {
} }
function filterResult(coverage: Profiler.ScriptCoverage): boolean { function filterResult(coverage: Profiler.ScriptCoverage): boolean {
if (!coverage.url.startsWith('file://')) if (!coverage.url.startsWith('file://')) {
return false return false
}
if (coverage.url.includes('/node_modules/')) if (coverage.url.includes('/node_modules/')) {
return false return false
}
return true return true
} }

View File

@ -6,7 +6,11 @@ Jest's expect matchers as a Chai plugin.
```js ```js
import * as chai from 'chai' import * as chai from 'chai'
import { JestAsymmetricMatchers, JestChaiExpect, JestExtend } from '@vitest/expect' import {
JestAsymmetricMatchers,
JestChaiExpect,
JestExtend,
} from '@vitest/expect'
// allows using expect.extend instead of chai.use to extend plugins // allows using expect.extend instead of chai.use to extend plugins
chai.use(JestExtend) chai.use(JestExtend)

View File

@ -20,7 +20,11 @@ const plugins = [
}), }),
copy({ copy({
targets: [ targets: [
{ src: 'node_modules/@types/chai/index.d.ts', dest: 'dist', rename: 'chai.d.cts' }, {
src: 'node_modules/@types/chai/index.d.ts',
dest: 'dist',
rename: 'chai.d.cts',
},
], ],
}), }),
] ]
@ -46,15 +50,14 @@ export default defineConfig([
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
onwarn, onwarn,
}, },
]) ])
function onwarn(message) { function onwarn(message) {
if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) {
return return
}
console.error(message) console.error(message)
} }

View File

@ -1,4 +1,6 @@
export const MATCHERS_OBJECT = Symbol.for('matchers-object') export const MATCHERS_OBJECT = Symbol.for('matchers-object')
export const JEST_MATCHERS_OBJECT = Symbol.for('$$jest-matchers-object') export const JEST_MATCHERS_OBJECT = Symbol.for('$$jest-matchers-object')
export const GLOBAL_EXPECT = Symbol.for('expect-global') export const GLOBAL_EXPECT = Symbol.for('expect-global')
export const ASYMMETRIC_MATCHERS_OBJECT = Symbol.for('asymmetric-matchers-object') export const ASYMMETRIC_MATCHERS_OBJECT = Symbol.for(
'asymmetric-matchers-object',
)

View File

@ -1,9 +1,20 @@
import type { ChaiPlugin, MatcherState } from './types' import type { ChaiPlugin, MatcherState } from './types'
import { GLOBAL_EXPECT } from './constants' import { GLOBAL_EXPECT } from './constants'
import { getState } from './state' import { getState } from './state'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils' import {
diff,
getCustomEqualityTesters,
getMatcherUtils,
stringify,
} from './jest-matcher-utils'
import { equals, isA, iterableEquality, pluralize, subsetEquality } from './jest-utils' import {
equals,
isA,
iterableEquality,
pluralize,
subsetEquality,
} from './jest-utils'
export interface AsymmetricMatcherInterface { export interface AsymmetricMatcherInterface {
asymmetricMatch: (other: unknown) => boolean asymmetricMatch: (other: unknown) => boolean
@ -40,23 +51,25 @@ export abstract class AsymmetricMatcher<
abstract asymmetricMatch(other: unknown): boolean abstract asymmetricMatch(other: unknown): boolean
abstract toString(): string abstract toString(): string
getExpectedType?(): string getExpectedType?(): string
toAsymmetricMatcher?(): string toAsymmetricMatcher?(): string;
// implement custom chai/loupe inspect for better AssertionError.message formatting // implement custom chai/loupe inspect for better AssertionError.message formatting
// https://github.com/chaijs/loupe/blob/9b8a6deabcd50adc056a64fb705896194710c5c6/src/index.ts#L29 // https://github.com/chaijs/loupe/blob/9b8a6deabcd50adc056a64fb705896194710c5c6/src/index.ts#L29
[Symbol.for('chai/inspect')](options: { depth: number; truncate: number }) { [Symbol.for('chai/inspect')](options: { depth: number; truncate: number }) {
// minimal pretty-format with simple manual truncation // minimal pretty-format with simple manual truncation
const result = stringify(this, options.depth, { min: true }) const result = stringify(this, options.depth, { min: true })
if (result.length <= options.truncate) if (result.length <= options.truncate) {
return result return result
}
return `${this.toString()}{…}` return `${this.toString()}{…}`
} }
} }
export class StringContaining extends AsymmetricMatcher<string> { export class StringContaining extends AsymmetricMatcher<string> {
constructor(sample: string, inverse = false) { constructor(sample: string, inverse = false) {
if (!isA('String', sample)) if (!isA('String', sample)) {
throw new Error('Expected is not a string') throw new Error('Expected is not a string')
}
super(sample, inverse) super(sample, inverse)
} }
@ -90,27 +103,33 @@ export class Anything extends AsymmetricMatcher<void> {
} }
} }
export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>> { export class ObjectContaining extends AsymmetricMatcher<
Record<string, unknown>
> {
constructor(sample: Record<string, unknown>, inverse = false) { constructor(sample: Record<string, unknown>, inverse = false) {
super(sample, inverse) super(sample, inverse)
} }
getPrototype(obj: object) { getPrototype(obj: object) {
if (Object.getPrototypeOf) if (Object.getPrototypeOf) {
return Object.getPrototypeOf(obj) return Object.getPrototypeOf(obj)
}
if (obj.constructor.prototype === obj) if (obj.constructor.prototype === obj) {
return null return null
}
return obj.constructor.prototype return obj.constructor.prototype
} }
hasProperty(obj: object | null, property: string): boolean { hasProperty(obj: object | null, property: string): boolean {
if (!obj) if (!obj) {
return false return false
}
if (Object.prototype.hasOwnProperty.call(obj, property)) if (Object.prototype.hasOwnProperty.call(obj, property)) {
return true return true
}
return this.hasProperty(this.getPrototype(obj), property) return this.hasProperty(this.getPrototype(obj), property)
} }
@ -118,9 +137,8 @@ export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>>
asymmetricMatch(other: any) { asymmetricMatch(other: any) {
if (typeof this.sample !== 'object') { if (typeof this.sample !== 'object') {
throw new TypeError( throw new TypeError(
`You must provide an object to ${this.toString()}, not '${ `You must provide an object to ${this.toString()}, not '${typeof this
typeof this.sample .sample}'.`,
}'.`,
) )
} }
@ -128,7 +146,14 @@ export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>>
const matcherContext = this.getMatcherContext() const matcherContext = this.getMatcherContext()
for (const property in this.sample) { for (const property in this.sample) {
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property], matcherContext.customTesters)) { if (
!this.hasProperty(other, property)
|| !equals(
this.sample[property],
other[property],
matcherContext.customTesters,
)
) {
result = false result = false
break break
} }
@ -154,9 +179,8 @@ export class ArrayContaining<T = unknown> extends AsymmetricMatcher<Array<T>> {
asymmetricMatch(other: Array<T>) { asymmetricMatch(other: Array<T>) {
if (!Array.isArray(this.sample)) { if (!Array.isArray(this.sample)) {
throw new TypeError( throw new TypeError(
`You must provide an array to ${this.toString()}, not '${ `You must provide an array to ${this.toString()}, not '${typeof this
typeof this.sample .sample}'.`,
}'.`,
) )
} }
@ -165,7 +189,9 @@ export class ArrayContaining<T = unknown> extends AsymmetricMatcher<Array<T>> {
= this.sample.length === 0 = this.sample.length === 0
|| (Array.isArray(other) || (Array.isArray(other)
&& this.sample.every(item => && this.sample.every(item =>
other.some(another => equals(item, another, matcherContext.customTesters)), other.some(another =>
equals(item, another, matcherContext.customTesters),
),
)) ))
return this.inverse ? !result : result return this.inverse ? !result : result
@ -192,8 +218,9 @@ export class Any extends AsymmetricMatcher<any> {
} }
fnNameFor(func: Function) { fnNameFor(func: Function) {
if (func.name) if (func.name) {
return func.name return func.name
}
const functionToString = Function.prototype.toString const functionToString = Function.prototype.toString
@ -204,26 +231,33 @@ export class Any extends AsymmetricMatcher<any> {
} }
asymmetricMatch(other: unknown) { asymmetricMatch(other: unknown) {
if (this.sample === String) if (this.sample === String) {
return typeof other == 'string' || other instanceof String return typeof other == 'string' || other instanceof String
}
if (this.sample === Number) if (this.sample === Number) {
return typeof other == 'number' || other instanceof Number return typeof other == 'number' || other instanceof Number
}
if (this.sample === Function) if (this.sample === Function) {
return typeof other == 'function' || other instanceof Function return typeof other == 'function' || other instanceof Function
}
if (this.sample === Boolean) if (this.sample === Boolean) {
return typeof other == 'boolean' || other instanceof Boolean return typeof other == 'boolean' || other instanceof Boolean
}
if (this.sample === BigInt) if (this.sample === BigInt) {
return typeof other == 'bigint' || other instanceof BigInt return typeof other == 'bigint' || other instanceof BigInt
}
if (this.sample === Symbol) if (this.sample === Symbol) {
return typeof other == 'symbol' || other instanceof Symbol return typeof other == 'symbol' || other instanceof Symbol
}
if (this.sample === Object) if (this.sample === Object) {
return typeof other == 'object' return typeof other == 'object'
}
return other instanceof this.sample return other instanceof this.sample
} }
@ -233,20 +267,25 @@ export class Any extends AsymmetricMatcher<any> {
} }
getExpectedType() { getExpectedType() {
if (this.sample === String) if (this.sample === String) {
return 'string' return 'string'
}
if (this.sample === Number) if (this.sample === Number) {
return 'number' return 'number'
}
if (this.sample === Function) if (this.sample === Function) {
return 'function' return 'function'
}
if (this.sample === Object) if (this.sample === Object) {
return 'object' return 'object'
}
if (this.sample === Boolean) if (this.sample === Boolean) {
return 'boolean' return 'boolean'
}
return this.fnNameFor(this.sample) return this.fnNameFor(this.sample)
} }
@ -258,8 +297,9 @@ export class Any extends AsymmetricMatcher<any> {
export class StringMatching extends AsymmetricMatcher<RegExp> { export class StringMatching extends AsymmetricMatcher<RegExp> {
constructor(sample: string | RegExp, inverse = false) { constructor(sample: string | RegExp, inverse = false) {
if (!isA('String', sample) && !isA('RegExp', sample)) if (!isA('String', sample) && !isA('RegExp', sample)) {
throw new Error('Expected is not a String or a RegExp') throw new Error('Expected is not a String or a RegExp')
}
super(new RegExp(sample), inverse) super(new RegExp(sample), inverse)
} }
@ -283,11 +323,13 @@ class CloseTo extends AsymmetricMatcher<number> {
private readonly precision: number private readonly precision: number
constructor(sample: number, precision = 2, inverse = false) { constructor(sample: number, precision = 2, inverse = false) {
if (!isA('Number', sample)) if (!isA('Number', sample)) {
throw new Error('Expected is not a Number') throw new Error('Expected is not a Number')
}
if (!isA('Number', precision)) if (!isA('Number', precision)) {
throw new Error('Precision is not a Number') throw new Error('Precision is not a Number')
}
super(sample) super(sample)
this.inverse = inverse this.inverse = inverse
@ -295,19 +337,25 @@ class CloseTo extends AsymmetricMatcher<number> {
} }
asymmetricMatch(other: number) { asymmetricMatch(other: number) {
if (!isA('Number', other)) if (!isA('Number', other)) {
return false return false
}
let result = false let result = false
if (other === Number.POSITIVE_INFINITY && this.sample === Number.POSITIVE_INFINITY) { if (
other === Number.POSITIVE_INFINITY
&& this.sample === Number.POSITIVE_INFINITY
) {
result = true // Infinity - Infinity is NaN result = true // Infinity - Infinity is NaN
} }
else if (other === Number.NEGATIVE_INFINITY && this.sample === Number.NEGATIVE_INFINITY) { else if (
other === Number.NEGATIVE_INFINITY
&& this.sample === Number.NEGATIVE_INFINITY
) {
result = true // -Infinity - -Infinity is NaN result = true // -Infinity - -Infinity is NaN
} }
else { else {
result result = Math.abs(this.sample - other) < 10 ** -this.precision / 2
= Math.abs(this.sample - other) < 10 ** -this.precision / 2
} }
return this.inverse ? !result : result return this.inverse ? !result : result
} }
@ -330,17 +378,9 @@ class CloseTo extends AsymmetricMatcher<number> {
} }
export const JestAsymmetricMatchers: ChaiPlugin = (chai, utils) => { export const JestAsymmetricMatchers: ChaiPlugin = (chai, utils) => {
utils.addMethod( utils.addMethod(chai.expect, 'anything', () => new Anything())
chai.expect,
'anything',
() => new Anything(),
)
utils.addMethod( utils.addMethod(chai.expect, 'any', (expected: unknown) => new Any(expected))
chai.expect,
'any',
(expected: unknown) => new Any(expected),
)
utils.addMethod( utils.addMethod(
chai.expect, chai.expect,
@ -370,14 +410,18 @@ export const JestAsymmetricMatchers: ChaiPlugin = (chai, utils) => {
chai.expect, chai.expect,
'closeTo', 'closeTo',
(expected: any, precision?: number) => new CloseTo(expected, precision), (expected: any, precision?: number) => new CloseTo(expected, precision),
) );
// defineProperty does not work // defineProperty does not work
;(chai.expect as any).not = { (chai.expect as any).not = {
stringContaining: (expected: string) => new StringContaining(expected, true), stringContaining: (expected: string) =>
new StringContaining(expected, true),
objectContaining: (expected: any) => new ObjectContaining(expected, true), objectContaining: (expected: any) => new ObjectContaining(expected, true),
arrayContaining: <T = unknown>(expected: Array<T>) => new ArrayContaining<T>(expected, true), arrayContaining: <T = unknown>(expected: Array<T>) =>
stringMatching: (expected: string | RegExp) => new StringMatching(expected, true), new ArrayContaining<T>(expected, true),
closeTo: (expected: any, precision?: number) => new CloseTo(expected, precision, true), stringMatching: (expected: string | RegExp) =>
new StringMatching(expected, true),
closeTo: (expected: any, precision?: number) =>
new CloseTo(expected, precision, true),
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -10,16 +10,20 @@ import { ASYMMETRIC_MATCHERS_OBJECT, JEST_MATCHERS_OBJECT } from './constants'
import { AsymmetricMatcher } from './jest-asymmetric-matchers' import { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { getState } from './state' import { getState } from './state'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'
import { import {
equals, diff,
iterableEquality, getCustomEqualityTesters,
subsetEquality, getMatcherUtils,
} from './jest-utils' stringify,
} from './jest-matcher-utils'
import { equals, iterableEquality, subsetEquality } from './jest-utils'
import { wrapSoft } from './utils' import { wrapSoft } from './utils'
function getMatcherState(assertion: Chai.AssertionStatic & Chai.Assertion, expect: ExpectStatic) { function getMatcherState(
assertion: Chai.AssertionStatic & Chai.Assertion,
expect: ExpectStatic,
) {
const obj = assertion._obj const obj = assertion._obj
const isNot = util.flag(assertion, 'negate') as boolean const isNot = util.flag(assertion, 'negate') as boolean
const promise = util.flag(assertion, 'promise') || '' const promise = util.flag(assertion, 'promise') || ''
@ -57,90 +61,123 @@ class JestExtendError extends Error {
} }
} }
function JestExtendPlugin(c: Chai.ChaiStatic, expect: ExpectStatic, matchers: MatchersObject): ChaiPlugin { function JestExtendPlugin(
c: Chai.ChaiStatic,
expect: ExpectStatic,
matchers: MatchersObject,
): ChaiPlugin {
return (_, utils) => { return (_, utils) => {
Object.entries(matchers).forEach(([expectAssertionName, expectAssertion]) => { Object.entries(matchers).forEach(
function expectWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) { ([expectAssertionName, expectAssertion]) => {
const { state, isNot, obj } = getMatcherState(this, expect) function expectWrapper(
this: Chai.AssertionStatic & Chai.Assertion,
...args: any[]
) {
const { state, isNot, obj } = getMatcherState(this, expect)
// @ts-expect-error args wanting tuple // @ts-expect-error args wanting tuple
const result = expectAssertion.call(state, obj, ...args) const result = expectAssertion.call(state, obj, ...args)
if (result && typeof result === 'object' && result instanceof Promise) { if (
return result.then(({ pass, message, actual, expected }) => { result
if ((pass && isNot) || (!pass && !isNot)) && typeof result === 'object'
throw new JestExtendError(message(), actual, expected) && result instanceof Promise
}) ) {
return result.then(({ pass, message, actual, expected }) => {
if ((pass && isNot) || (!pass && !isNot)) {
throw new JestExtendError(message(), actual, expected)
}
})
}
const { pass, message, actual, expected } = result
if ((pass && isNot) || (!pass && !isNot)) {
throw new JestExtendError(message(), actual, expected)
}
} }
const { pass, message, actual, expected } = result const softWrapper = wrapSoft(utils, expectWrapper)
utils.addMethod(
(globalThis as any)[JEST_MATCHERS_OBJECT].matchers,
expectAssertionName,
softWrapper,
)
utils.addMethod(
c.Assertion.prototype,
expectAssertionName,
softWrapper,
)
if ((pass && isNot) || (!pass && !isNot)) class CustomMatcher extends AsymmetricMatcher<[unknown, ...unknown[]]> {
throw new JestExtendError(message(), actual, expected) constructor(inverse = false, ...sample: [unknown, ...unknown[]]) {
} super(sample, inverse)
}
const softWrapper = wrapSoft(utils, expectWrapper) asymmetricMatch(other: unknown) {
utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, softWrapper) const { pass } = expectAssertion.call(
utils.addMethod(c.Assertion.prototype, expectAssertionName, softWrapper) this.getMatcherContext(expect),
other,
...this.sample,
) as SyncExpectationResult
class CustomMatcher extends AsymmetricMatcher<[unknown, ...unknown[]]> { return this.inverse ? !pass : pass
constructor(inverse = false, ...sample: [unknown, ...unknown[]]) { }
super(sample, inverse)
toString() {
return `${this.inverse ? 'not.' : ''}${expectAssertionName}`
}
getExpectedType() {
return 'any'
}
toAsymmetricMatcher() {
return `${this.toString()}<${this.sample.map(String).join(', ')}>`
}
} }
asymmetricMatch(other: unknown) { const customMatcher = (...sample: [unknown, ...unknown[]]) =>
const { pass } = expectAssertion.call( new CustomMatcher(false, ...sample)
this.getMatcherContext(expect),
other,
...this.sample,
) as SyncExpectationResult
return this.inverse ? !pass : pass Object.defineProperty(expect, expectAssertionName, {
} configurable: true,
enumerable: true,
value: customMatcher,
writable: true,
})
toString() { Object.defineProperty(expect.not, expectAssertionName, {
return `${this.inverse ? 'not.' : ''}${expectAssertionName}` configurable: true,
} enumerable: true,
value: (...sample: [unknown, ...unknown[]]) =>
new CustomMatcher(true, ...sample),
writable: true,
})
getExpectedType() { // keep track of asymmetric matchers on global so that it can be copied over to local context's `expect`.
return 'any' // note that the negated variant is automatically shared since it's assigned on the single `expect.not` object.
} Object.defineProperty(
(globalThis as any)[ASYMMETRIC_MATCHERS_OBJECT],
toAsymmetricMatcher() { expectAssertionName,
return `${this.toString()}<${this.sample.map(String).join(', ')}>` {
} configurable: true,
} enumerable: true,
value: customMatcher,
const customMatcher = (...sample: [unknown, ...unknown[]]) => new CustomMatcher(false, ...sample) writable: true,
},
Object.defineProperty(expect, expectAssertionName, { )
configurable: true, },
enumerable: true, )
value: customMatcher,
writable: true,
})
Object.defineProperty(expect.not, expectAssertionName, {
configurable: true,
enumerable: true,
value: (...sample: [unknown, ...unknown[]]) => new CustomMatcher(true, ...sample),
writable: true,
})
// keep track of asymmetric matchers on global so that it can be copied over to local context's `expect`.
// note that the negated variant is automatically shared since it's assigned on the single `expect.not` object.
Object.defineProperty(((globalThis as any)[ASYMMETRIC_MATCHERS_OBJECT]), expectAssertionName, {
configurable: true,
enumerable: true,
value: customMatcher,
writable: true,
})
})
} }
} }
export const JestExtend: ChaiPlugin = (chai, utils) => { export const JestExtend: ChaiPlugin = (chai, utils) => {
utils.addMethod(chai.expect, 'extend', (expect: ExpectStatic, expects: MatchersObject) => { utils.addMethod(
use(JestExtendPlugin(chai, expect, expects)) chai.expect,
}) 'extend',
(expect: ExpectStatic, expects: MatchersObject) => {
use(JestExtendPlugin(chai, expect, expects))
},
)
} }

View File

@ -49,12 +49,12 @@ export function getMatcherUtils() {
} }
if (matcherName.includes('.')) { if (matcherName.includes('.')) {
// Old format: for backward compatibility, // Old format: for backward compatibility,
// especially without promise or isNot options // especially without promise or isNot options
dimString += matcherName dimString += matcherName
} }
else { else {
// New format: omit period from matcherName arg // New format: omit period from matcherName arg
hint += DIM_COLOR(`${dimString}.`) + matcherName hint += DIM_COLOR(`${dimString}.`) + matcherName
dimString = '' dimString = ''
} }
@ -64,16 +64,19 @@ export function getMatcherUtils() {
} }
else { else {
hint += DIM_COLOR(`${dimString}(`) + expectedColor(expected) hint += DIM_COLOR(`${dimString}(`) + expectedColor(expected)
if (secondArgument) if (secondArgument) {
hint += DIM_COLOR(', ') + secondArgumentColor(secondArgument) hint += DIM_COLOR(', ') + secondArgumentColor(secondArgument)
}
dimString = ')' dimString = ')'
} }
if (comment !== '') if (comment !== '') {
dimString += ` // ${comment}` dimString += ` // ${comment}`
}
if (dimString !== '') if (dimString !== '') {
hint += DIM_COLOR(dimString) hint += DIM_COLOR(dimString)
}
return hint return hint
} }

View File

@ -39,21 +39,31 @@ export function equals(
const functionToString = Function.prototype.toString const functionToString = Function.prototype.toString
export function isAsymmetric(obj: any) { export function isAsymmetric(obj: any) {
return !!obj && typeof obj === 'object' && 'asymmetricMatch' in obj && isA('Function', obj.asymmetricMatch) return (
!!obj
&& typeof obj === 'object'
&& 'asymmetricMatch' in obj
&& isA('Function', obj.asymmetricMatch)
)
} }
export function hasAsymmetric(obj: any, seen = new Set()): boolean { export function hasAsymmetric(obj: any, seen = new Set()): boolean {
if (seen.has(obj)) if (seen.has(obj)) {
return false return false
}
seen.add(obj) seen.add(obj)
if (isAsymmetric(obj)) if (isAsymmetric(obj)) {
return true return true
if (Array.isArray(obj)) }
if (Array.isArray(obj)) {
return obj.some(i => hasAsymmetric(i, seen)) return obj.some(i => hasAsymmetric(i, seen))
if (obj instanceof Set) }
if (obj instanceof Set) {
return Array.from(obj).some(i => hasAsymmetric(i, seen)) return Array.from(obj).some(i => hasAsymmetric(i, seen))
if (isObject(obj)) }
if (isObject(obj)) {
return Object.values(obj).some(v => hasAsymmetric(v, seen)) return Object.values(obj).some(v => hasAsymmetric(v, seen))
}
return false return false
} }
@ -61,14 +71,17 @@ function asymmetricMatch(a: any, b: any) {
const asymmetricA = isAsymmetric(a) const asymmetricA = isAsymmetric(a)
const asymmetricB = isAsymmetric(b) const asymmetricB = isAsymmetric(b)
if (asymmetricA && asymmetricB) if (asymmetricA && asymmetricB) {
return undefined return undefined
}
if (asymmetricA) if (asymmetricA) {
return a.asymmetricMatch(b) return a.asymmetricMatch(b)
}
if (asymmetricB) if (asymmetricB) {
return b.asymmetricMatch(a) return b.asymmetricMatch(a)
}
} }
// Equality function lovingly adapted from isEqual in // Equality function lovingly adapted from isEqual in
@ -84,32 +97,44 @@ function eq(
let result = true let result = true
const asymmetricResult = asymmetricMatch(a, b) const asymmetricResult = asymmetricMatch(a, b)
if (asymmetricResult !== undefined) if (asymmetricResult !== undefined) {
return asymmetricResult return asymmetricResult
}
const testerContext: TesterContext = { equals } const testerContext: TesterContext = { equals }
for (let i = 0; i < customTesters.length; i++) { for (let i = 0; i < customTesters.length; i++) {
const customTesterResult = customTesters[i].call(testerContext, a, b, customTesters) const customTesterResult = customTesters[i].call(
if (customTesterResult !== undefined) testerContext,
a,
b,
customTesters,
)
if (customTesterResult !== undefined) {
return customTesterResult return customTesterResult
}
} }
if (a instanceof Error && b instanceof Error) if (a instanceof Error && b instanceof Error) {
return a.message === b.message return a.message === b.message
}
if (typeof URL === 'function' && a instanceof URL && b instanceof URL) if (typeof URL === 'function' && a instanceof URL && b instanceof URL) {
return a.href === b.href return a.href === b.href
}
if (Object.is(a, b)) if (Object.is(a, b)) {
return true return true
}
// A strict comparison is necessary because `null == undefined`. // A strict comparison is necessary because `null == undefined`.
if (a === null || b === null) if (a === null || b === null) {
return a === b return a === b
}
const className = Object.prototype.toString.call(a) const className = Object.prototype.toString.call(a)
if (className !== Object.prototype.toString.call(b)) if (className !== Object.prototype.toString.call(b)) {
return false return false
}
switch (className) { switch (className) {
case '[object Boolean]': case '[object Boolean]':
@ -133,18 +158,20 @@ function eq(
// Coerce dates to numeric primitive values. Dates are compared by their // Coerce dates to numeric primitive values. Dates are compared by their
// millisecond representations. Note that invalid dates with millisecond representations // millisecond representations. Note that invalid dates with millisecond representations
// of `NaN` are equivalent. // of `NaN` are equivalent.
return (numA === numB) || (Number.isNaN(numA) && Number.isNaN(numB)) return numA === numB || (Number.isNaN(numA) && Number.isNaN(numB))
} }
// RegExps are compared by their source patterns and flags. // RegExps are compared by their source patterns and flags.
case '[object RegExp]': case '[object RegExp]':
return a.source === b.source && a.flags === b.flags return a.source === b.source && a.flags === b.flags
} }
if (typeof a !== 'object' || typeof b !== 'object') if (typeof a !== 'object' || typeof b !== 'object') {
return false return false
}
// Use DOM3 method isEqualNode (IE>=9) // Use DOM3 method isEqualNode (IE>=9)
if (isDomNode(a) && isDomNode(b)) if (isDomNode(a) && isDomNode(b)) {
return a.isEqualNode(b) return a.isEqualNode(b)
}
// Used to detect circular references. // Used to detect circular references.
let length = aStack.length let length = aStack.length
@ -153,19 +180,21 @@ function eq(
// unique nested structures. // unique nested structures.
// circular references at same depth are equal // circular references at same depth are equal
// circular reference is not equal to non-circular one // circular reference is not equal to non-circular one
if (aStack[length] === a) if (aStack[length] === a) {
return bStack[length] === b return bStack[length] === b
}
else if (bStack[length] === b) else if (bStack[length] === b) {
return false return false
}
} }
// Add the first object to the stack of traversed objects. // Add the first object to the stack of traversed objects.
aStack.push(a) aStack.push(a)
bStack.push(b) bStack.push(b)
// Recursively compare objects and arrays. // Recursively compare objects and arrays.
// Compare array lengths to determine if a deep comparison is necessary. // Compare array lengths to determine if a deep comparison is necessary.
if (className === '[object Array]' && a.length !== b.length) if (className === '[object Array]' && a.length !== b.length) {
return false return false
}
// Deep compare objects. // Deep compare objects.
const aKeys = keys(a, hasKey) const aKeys = keys(a, hasKey)
@ -173,8 +202,9 @@ function eq(
let size = aKeys.length let size = aKeys.length
// Ensure that both objects contain the same number of properties before comparing deep equality. // Ensure that both objects contain the same number of properties before comparing deep equality.
if (keys(b, hasKey).length !== size) if (keys(b, hasKey).length !== size) {
return false return false
}
while (size--) { while (size--) {
key = aKeys[size] key = aKeys[size]
@ -184,8 +214,9 @@ function eq(
= hasKey(b, key) = hasKey(b, key)
&& eq(a[key], b[key], aStack, bStack, customTesters, hasKey) && eq(a[key], b[key], aStack, bStack, customTesters, hasKey)
if (!result) if (!result) {
return false return false
}
} }
// Remove the first object from the stack of traversed objects. // Remove the first object from the stack of traversed objects.
aStack.pop() aStack.pop()
@ -198,8 +229,9 @@ function keys(obj: object, hasKey: (obj: object, key: string) => boolean) {
const keys = [] const keys = []
for (const key in obj) { for (const key in obj) {
if (hasKey(obj, key)) if (hasKey(obj, key)) {
keys.push(key) keys.push(key)
}
} }
return keys.concat( return keys.concat(
(Object.getOwnPropertySymbols(obj) as Array<any>).filter( (Object.getOwnPropertySymbols(obj) as Array<any>).filter(
@ -236,8 +268,9 @@ function isDomNode(obj: any): boolean {
} }
export function fnNameFor(func: Function) { export function fnNameFor(func: Function) {
if (func.name) if (func.name) {
return func.name return func.name
}
const matches = functionToString const matches = functionToString
.call(func) .call(func)
@ -246,21 +279,25 @@ export function fnNameFor(func: Function) {
} }
function getPrototype(obj: object) { function getPrototype(obj: object) {
if (Object.getPrototypeOf) if (Object.getPrototypeOf) {
return Object.getPrototypeOf(obj) return Object.getPrototypeOf(obj)
}
if (obj.constructor.prototype === obj) if (obj.constructor.prototype === obj) {
return null return null
}
return obj.constructor.prototype return obj.constructor.prototype
} }
export function hasProperty(obj: object | null, property: string): boolean { export function hasProperty(obj: object | null, property: string): boolean {
if (!obj) if (!obj) {
return false return false
}
if (Object.prototype.hasOwnProperty.call(obj, property)) if (Object.prototype.hasOwnProperty.call(obj, property)) {
return true return true
}
return hasProperty(getPrototype(obj), property) return hasProperty(getPrototype(obj), property)
} }
@ -331,7 +368,13 @@ function hasIterator(object: any) {
return !!(object != null && object[IteratorSymbol]) return !!(object != null && object[IteratorSymbol])
} }
export function iterableEquality(a: any, b: any, customTesters: Array<Tester> = [], aStack: Array<any> = [], bStack: Array<any> = []): boolean | undefined { export function iterableEquality(
a: any,
b: any,
customTesters: Array<Tester> = [],
aStack: Array<any> = [],
bStack: Array<any> = [],
): boolean | undefined {
if ( if (
typeof a !== 'object' typeof a !== 'object'
|| typeof b !== 'object' || typeof b !== 'object'
@ -343,8 +386,9 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
return undefined return undefined
} }
if (a.constructor !== b.constructor) if (a.constructor !== b.constructor) {
return false return false
}
let length = aStack.length let length = aStack.length
while (length--) { while (length--) {
@ -352,8 +396,9 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
// unique nested structures. // unique nested structures.
// circular references at same depth are equal // circular references at same depth are equal
// circular reference is not equal to non-circular one // circular reference is not equal to non-circular one
if (aStack[length] === a) if (aStack[length] === a) {
return bStack[length] === b return bStack[length] === b
}
} }
aStack.push(a) aStack.push(a)
bStack.push(b) bStack.push(b)
@ -364,13 +409,7 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
] ]
function iterableEqualityWithStack(a: any, b: any) { function iterableEqualityWithStack(a: any, b: any) {
return iterableEquality( return iterableEquality(a, b, [...customTesters], [...aStack], [...bStack])
a,
b,
[...customTesters],
[...aStack],
[...bStack],
)
} }
if (a.size !== undefined) { if (a.size !== undefined) {
@ -384,8 +423,9 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
let has = false let has = false
for (const bValue of b) { for (const bValue of b) {
const isEqual = equals(aValue, bValue, filteredCustomTesters) const isEqual = equals(aValue, bValue, filteredCustomTesters)
if (isEqual === true) if (isEqual === true) {
has = true has = true
}
} }
if (has === false) { if (has === false) {
@ -408,14 +448,24 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
) { ) {
let has = false let has = false
for (const bEntry of b) { for (const bEntry of b) {
const matchedKey = equals(aEntry[0], bEntry[0], filteredCustomTesters) const matchedKey = equals(
aEntry[0],
bEntry[0],
filteredCustomTesters,
)
let matchedValue = false let matchedValue = false
if (matchedKey === true) if (matchedKey === true) {
matchedValue = equals(aEntry[1], bEntry[1], filteredCustomTesters) matchedValue = equals(
aEntry[1],
bEntry[1],
filteredCustomTesters,
)
}
if (matchedValue === true) if (matchedValue === true) {
has = true has = true
}
} }
if (has === false) { if (has === false) {
@ -435,15 +485,13 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
for (const aValue of a) { for (const aValue of a) {
const nextB = bIterator.next() const nextB = bIterator.next()
if ( if (nextB.done || !equals(aValue, nextB.value, filteredCustomTesters)) {
nextB.done
|| !equals(aValue, nextB.value, filteredCustomTesters)
) {
return false return false
} }
} }
if (!bIterator.next().done) if (!bIterator.next().done) {
return false return false
}
if ( if (
!isImmutableList(a) !isImmutableList(a)
@ -453,8 +501,9 @@ export function iterableEquality(a: any, b: any, customTesters: Array<Tester> =
) { ) {
const aEntries = Object.entries(a) const aEntries = Object.entries(a)
const bEntries = Object.entries(b) const bEntries = Object.entries(b)
if (!equals(aEntries, bEntries)) if (!equals(aEntries, bEntries)) {
return false return false
}
} }
// Remove the first value from the stack of traversed values. // Remove the first value from the stack of traversed values.
@ -470,8 +519,9 @@ function hasPropertyInObject(object: object, key: string | symbol): boolean {
const shouldTerminate const shouldTerminate
= !object || typeof object !== 'object' || object === Object.prototype = !object || typeof object !== 'object' || object === Object.prototype
if (shouldTerminate) if (shouldTerminate) {
return false return false
}
return ( return (
Object.prototype.hasOwnProperty.call(object, key) Object.prototype.hasOwnProperty.call(object, key)
@ -480,37 +530,47 @@ function hasPropertyInObject(object: object, key: string | symbol): boolean {
} }
function isObjectWithKeys(a: any) { function isObjectWithKeys(a: any) {
return isObject(a) return (
isObject(a)
&& !(a instanceof Error) && !(a instanceof Error)
&& !(Array.isArray(a)) && !Array.isArray(a)
&& !(a instanceof Date) && !(a instanceof Date)
)
} }
export function subsetEquality(object: unknown, subset: unknown, customTesters: Array<Tester> = []): boolean | undefined { export function subsetEquality(
const filteredCustomTesters = customTesters.filter(t => t !== subsetEquality) object: unknown,
subset: unknown,
customTesters: Array<Tester> = [],
): boolean | undefined {
const filteredCustomTesters = customTesters.filter(
t => t !== subsetEquality,
)
// subsetEquality needs to keep track of the references // subsetEquality needs to keep track of the references
// it has already visited to avoid infinite loops in case // it has already visited to avoid infinite loops in case
// there are circular references in the subset passed to it. // there are circular references in the subset passed to it.
const subsetEqualityWithContext const subsetEqualityWithContext
= (seenReferences: WeakMap<object, boolean> = new WeakMap()) => = (seenReferences: WeakMap<object, boolean> = new WeakMap()) =>
(object: any, subset: any): boolean | undefined => { (object: any, subset: any): boolean | undefined => {
if (!isObjectWithKeys(subset)) if (!isObjectWithKeys(subset)) {
return undefined return undefined
}
return Object.keys(subset).every((key) => { return Object.keys(subset).every((key) => {
if (subset[key] != null && typeof subset[key] === 'object') { if (subset[key] != null && typeof subset[key] === 'object') {
if (seenReferences.has(subset[key])) if (seenReferences.has(subset[key])) {
return equals(object[key], subset[key], filteredCustomTesters) return equals(object[key], subset[key], filteredCustomTesters)
}
seenReferences.set(subset[key], true) seenReferences.set(subset[key], true)
} }
const result const result
= object != null = object != null
&& hasPropertyInObject(object, key) && hasPropertyInObject(object, key)
&& equals(object[key], subset[key], [ && equals(object[key], subset[key], [
...filteredCustomTesters, ...filteredCustomTesters,
subsetEqualityWithContext(seenReferences), subsetEqualityWithContext(seenReferences),
]) ])
// The main goal of using seenReference is to avoid circular node on tree. // The main goal of using seenReference is to avoid circular node on tree.
// It will only happen within a parent and its child, not a node and nodes next to it (same level) // It will only happen within a parent and its child, not a node and nodes next to it (same level)
// We should keep the reference for a parent and its child only // We should keep the reference for a parent and its child only
@ -525,19 +585,24 @@ export function subsetEquality(object: unknown, subset: unknown, customTesters:
} }
export function typeEquality(a: any, b: any): boolean | undefined { export function typeEquality(a: any, b: any): boolean | undefined {
if (a == null || b == null || a.constructor === b.constructor) if (a == null || b == null || a.constructor === b.constructor) {
return undefined return undefined
}
return false return false
} }
export function arrayBufferEquality(a: unknown, b: unknown): boolean | undefined { export function arrayBufferEquality(
a: unknown,
b: unknown,
): boolean | undefined {
let dataViewA = a as DataView let dataViewA = a as DataView
let dataViewB = b as DataView let dataViewB = b as DataView
if (!(a instanceof DataView && b instanceof DataView)) { if (!(a instanceof DataView && b instanceof DataView)) {
if (!(a instanceof ArrayBuffer) || !(b instanceof ArrayBuffer)) if (!(a instanceof ArrayBuffer) || !(b instanceof ArrayBuffer)) {
return undefined return undefined
}
try { try {
dataViewA = new DataView(a) dataViewA = new DataView(a)
@ -549,36 +614,48 @@ export function arrayBufferEquality(a: unknown, b: unknown): boolean | undefined
} }
// Buffers are not equal when they do not have the same byte length // Buffers are not equal when they do not have the same byte length
if (dataViewA.byteLength !== dataViewB.byteLength) if (dataViewA.byteLength !== dataViewB.byteLength) {
return false return false
}
// Check if every byte value is equal to each other // Check if every byte value is equal to each other
for (let i = 0; i < dataViewA.byteLength; i++) { for (let i = 0; i < dataViewA.byteLength; i++) {
if (dataViewA.getUint8(i) !== dataViewB.getUint8(i)) if (dataViewA.getUint8(i) !== dataViewB.getUint8(i)) {
return false return false
}
} }
return true return true
} }
export function sparseArrayEquality(a: unknown, b: unknown, customTesters: Array<Tester> = []): boolean | undefined { export function sparseArrayEquality(
if (!Array.isArray(a) || !Array.isArray(b)) a: unknown,
b: unknown,
customTesters: Array<Tester> = [],
): boolean | undefined {
if (!Array.isArray(a) || !Array.isArray(b)) {
return undefined return undefined
}
// A sparse array [, , 1] will have keys ["2"] whereas [undefined, undefined, 1] will have keys ["0", "1", "2"] // A sparse array [, , 1] will have keys ["2"] whereas [undefined, undefined, 1] will have keys ["0", "1", "2"]
const aKeys = Object.keys(a) const aKeys = Object.keys(a)
const bKeys = Object.keys(b) const bKeys = Object.keys(b)
const filteredCustomTesters = customTesters.filter(t => t !== sparseArrayEquality) const filteredCustomTesters = customTesters.filter(
return ( t => t !== sparseArrayEquality,
equals(a, b, filteredCustomTesters, true) && equals(aKeys, bKeys)
) )
return equals(a, b, filteredCustomTesters, true) && equals(aKeys, bKeys)
} }
export function generateToBeMessage(deepEqualityName: string, expected = '#{this}', actual = '#{exp}') { export function generateToBeMessage(
deepEqualityName: string,
expected = '#{this}',
actual = '#{exp}',
) {
const toBeMessage = `expected ${expected} to be ${actual} // Object.is equality` const toBeMessage = `expected ${expected} to be ${actual} // Object.is equality`
if (['toStrictEqual', 'toEqual'].includes(deepEqualityName)) if (['toStrictEqual', 'toEqual'].includes(deepEqualityName)) {
return `${toBeMessage}\n\nIf it should pass with deep equality, replace "toBe" with "${deepEqualityName}"\n\nExpected: ${expected}\nReceived: serializes to the same string\n` return `${toBeMessage}\n\nIf it should pass with deep equality, replace "toBe" with "${deepEqualityName}"\n\nExpected: ${expected}\nReceived: serializes to the same string\n`
}
return toBeMessage return toBeMessage
} }
@ -596,59 +673,73 @@ export function getObjectKeys(object: object): Array<string | symbol> {
] ]
} }
export function getObjectSubset(object: any, subset: any, customTesters: Array<Tester> = []): { subset: any; stripped: number } { export function getObjectSubset(
object: any,
subset: any,
customTesters: Array<Tester> = [],
): { subset: any; stripped: number } {
let stripped = 0 let stripped = 0
const getObjectSubsetWithContext = (seenReferences: WeakMap<object, boolean> = new WeakMap()) => (object: any, subset: any): any => { const getObjectSubsetWithContext
if (Array.isArray(object)) { = (seenReferences: WeakMap<object, boolean> = new WeakMap()) =>
if (Array.isArray(subset) && subset.length === object.length) { (object: any, subset: any): any => {
// The map method returns correct subclass of subset. if (Array.isArray(object)) {
return subset.map((sub: any, i: number) => if (Array.isArray(subset) && subset.length === object.length) {
getObjectSubsetWithContext(seenReferences)(object[i], sub), // The map method returns correct subclass of subset.
) return subset.map((sub: any, i: number) =>
} getObjectSubsetWithContext(seenReferences)(object[i], sub),
} )
else if (object instanceof Date) {
return object
}
else if (isObject(object) && isObject(subset)) {
if (
equals(object, subset, [
...customTesters,
iterableEquality,
subsetEquality,
])
) {
// Avoid unnecessary copy which might return Object instead of subclass.
return subset
}
const trimmed: any = {}
seenReferences.set(object, trimmed)
for (const key of getObjectKeys(object)) {
if (hasPropertyInObject(subset, key)) {
trimmed[key] = seenReferences.has(object[key])
? seenReferences.get(object[key])
: getObjectSubsetWithContext(seenReferences)(object[key], subset[key])
}
else {
if (!seenReferences.has(object[key])) {
stripped += 1
if (isObject(object[key]))
stripped += getObjectKeys(object[key]).length
getObjectSubsetWithContext(seenReferences)(object[key], subset[key])
} }
} }
else if (object instanceof Date) {
return object
}
else if (isObject(object) && isObject(subset)) {
if (
equals(object, subset, [
...customTesters,
iterableEquality,
subsetEquality,
])
) {
// Avoid unnecessary copy which might return Object instead of subclass.
return subset
}
const trimmed: any = {}
seenReferences.set(object, trimmed)
for (const key of getObjectKeys(object)) {
if (hasPropertyInObject(subset, key)) {
trimmed[key] = seenReferences.has(object[key])
? seenReferences.get(object[key])
: getObjectSubsetWithContext(seenReferences)(
object[key],
subset[key],
)
}
else {
if (!seenReferences.has(object[key])) {
stripped += 1
if (isObject(object[key])) {
stripped += getObjectKeys(object[key]).length
}
getObjectSubsetWithContext(seenReferences)(
object[key],
subset[key],
)
}
}
}
if (getObjectKeys(trimmed).length > 0) {
return trimmed
}
}
return object
} }
if (getObjectKeys(trimmed).length > 0)
return trimmed
}
return object
}
return { subset: getObjectSubsetWithContext()(object, subset), stripped } return { subset: getObjectSubsetWithContext()(object, subset), stripped }
} }

View File

@ -1,5 +1,10 @@
import type { ExpectStatic, MatcherState, Tester } from './types' import type { ExpectStatic, MatcherState, Tester } from './types'
import { ASYMMETRIC_MATCHERS_OBJECT, GLOBAL_EXPECT, JEST_MATCHERS_OBJECT, MATCHERS_OBJECT } from './constants' import {
ASYMMETRIC_MATCHERS_OBJECT,
GLOBAL_EXPECT,
JEST_MATCHERS_OBJECT,
MATCHERS_OBJECT,
} from './constants'
if (!Object.prototype.hasOwnProperty.call(globalThis, MATCHERS_OBJECT)) { if (!Object.prototype.hasOwnProperty.call(globalThis, MATCHERS_OBJECT)) {
const globalState = new WeakMap<ExpectStatic, MatcherState>() const globalState = new WeakMap<ExpectStatic, MatcherState>()
@ -22,7 +27,9 @@ if (!Object.prototype.hasOwnProperty.call(globalThis, MATCHERS_OBJECT)) {
}) })
} }
export function getState<State extends MatcherState = MatcherState>(expect: ExpectStatic): State { export function getState<State extends MatcherState = MatcherState>(
expect: ExpectStatic,
): State {
return (globalThis as any)[MATCHERS_OBJECT].get(expect) return (globalThis as any)[MATCHERS_OBJECT].get(expect)
} }

View File

@ -16,7 +16,7 @@ export type Tester = (
this: TesterContext, this: TesterContext,
a: any, a: any,
b: any, b: any,
customTesters: Array<Tester>, customTesters: Array<Tester>
) => boolean | undefined ) => boolean | undefined
export interface TesterContext { export interface TesterContext {
@ -24,7 +24,7 @@ export interface TesterContext {
a: unknown, a: unknown,
b: unknown, b: unknown,
customTesters?: Array<Tester>, customTesters?: Array<Tester>,
strictCheck?: boolean, strictCheck?: boolean
) => boolean ) => boolean
} }
export type { DiffOptions } from '@vitest/utils/diff' export type { DiffOptions } from '@vitest/utils/diff'
@ -50,7 +50,7 @@ export interface MatcherState {
a: unknown, a: unknown,
b: unknown, b: unknown,
customTesters?: Array<Tester>, customTesters?: Array<Tester>,
strictCheck?: boolean, strictCheck?: boolean
) => boolean ) => boolean
expand?: boolean expand?: boolean
expectedAssertionsNumber?: number | null expectedAssertionsNumber?: number | null
@ -88,9 +88,14 @@ export interface RawMatcherFn<T extends MatcherState = MatcherState> {
(this: T, received: any, expected: any, options?: any): ExpectationResult (this: T, received: any, expected: any, options?: any): ExpectationResult
} }
export type MatchersObject<T extends MatcherState = MatcherState> = Record<string, RawMatcherFn<T>> export type MatchersObject<T extends MatcherState = MatcherState> = Record<
string,
RawMatcherFn<T>
>
export interface ExpectStatic extends Chai.ExpectStatic, AsymmetricMatchersContaining { export interface ExpectStatic
extends Chai.ExpectStatic,
AsymmetricMatchersContaining {
<T>(actual: T, message?: string): Assertion<T> <T>(actual: T, message?: string): Assertion<T>
extend: (expects: MatchersObject) => void extend: (expects: MatchersObject) => void
anything: () => any anything: () => any
@ -130,7 +135,10 @@ export interface JestAssertion<T = any> extends jest.Matchers<void, T> {
toBeInstanceOf: <E>(expected: E) => void toBeInstanceOf: <E>(expected: E) => void
toBeCalledTimes: (times: number) => void toBeCalledTimes: (times: number) => void
toHaveLength: (length: number) => void toHaveLength: (length: number) => void
toHaveProperty: <E>(property: string | (string | number)[], value?: E) => void toHaveProperty: <E>(
property: string | (string | number)[],
value?: E
) => void
toBeCloseTo: (number: number, numDigits?: number) => void toBeCloseTo: (number: number, numDigits?: number) => void
toHaveBeenCalledTimes: (times: number) => void toHaveBeenCalledTimes: (times: number) => void
toHaveBeenCalled: () => void toHaveBeenCalled: () => void
@ -160,7 +168,7 @@ type VitestAssertion<A, T> = {
? Assertion<T> ? Assertion<T>
: A[K] extends (...args: any[]) => any : A[K] extends (...args: any[]) => any
? A[K] // not converting function since they may contain overload ? A[K] // not converting function since they may contain overload
: VitestAssertion<A[K], T> : VitestAssertion<A[K], T>;
} & ((type: string, message?: string) => Assertion) } & ((type: string, message?: string) => Assertion)
type Promisify<O> = { type Promisify<O> = {
@ -168,13 +176,25 @@ type Promisify<O> = {
? O extends R ? O extends R
? Promisify<O[K]> ? Promisify<O[K]>
: (...args: A) => Promise<R> : (...args: A) => Promise<R>
: O[K] : O[K];
} }
export type PromisifyAssertion<T> = Promisify<Assertion<T>> export type PromisifyAssertion<T> = Promisify<Assertion<T>>
export interface Assertion<T = any> extends VitestAssertion<Chai.Assertion, T>, JestAssertion<T> { export interface Assertion<T = any>
toBeTypeOf: (expected: 'bigint' | 'boolean' | 'function' | 'number' | 'object' | 'string' | 'symbol' | 'undefined') => void extends VitestAssertion<Chai.Assertion, T>,
JestAssertion<T> {
toBeTypeOf: (
expected:
| 'bigint'
| 'boolean'
| 'function'
| 'number'
| 'object'
| 'string'
| 'symbol'
| 'undefined'
) => void
toHaveBeenCalledOnce: () => void toHaveBeenCalledOnce: () => void
toSatisfy: <E>(matcher: (value: E) => boolean, message?: string) => void toSatisfy: <E>(matcher: (value: E) => boolean, message?: string) => void
@ -192,7 +212,6 @@ declare global {
// support augmenting jest.Matchers by other libraries // support augmenting jest.Matchers by other libraries
// eslint-disable-next-line ts/no-namespace // eslint-disable-next-line ts/no-namespace
namespace jest { namespace jest {
// eslint-disable-next-line unused-imports/no-unused-vars // eslint-disable-next-line unused-imports/no-unused-vars
interface Matchers<R, T = {}> {} interface Matchers<R, T = {}> {}
} }

View File

@ -2,34 +2,44 @@ import { processError } from '@vitest/utils/error'
import type { Test } from '@vitest/runner/types' import type { Test } from '@vitest/runner/types'
import type { Assertion } from './types' import type { Assertion } from './types'
export function recordAsyncExpect(test: any, promise: Promise<any> | PromiseLike<any>) { export function recordAsyncExpect(
test: any,
promise: Promise<any> | PromiseLike<any>,
) {
// record promise for test, that resolves before test ends // record promise for test, that resolves before test ends
if (test && promise instanceof Promise) { if (test && promise instanceof Promise) {
// if promise is explicitly awaited, remove it from the list // if promise is explicitly awaited, remove it from the list
promise = promise.finally(() => { promise = promise.finally(() => {
const index = test.promises.indexOf(promise) const index = test.promises.indexOf(promise)
if (index !== -1) if (index !== -1) {
test.promises.splice(index, 1) test.promises.splice(index, 1)
}
}) })
// record promise // record promise
if (!test.promises) if (!test.promises) {
test.promises = [] test.promises = []
}
test.promises.push(promise) test.promises.push(promise)
} }
return promise return promise
} }
export function wrapSoft(utils: Chai.ChaiUtils, fn: (this: Chai.AssertionStatic & Assertion, ...args: any[]) => void) { export function wrapSoft(
utils: Chai.ChaiUtils,
fn: (this: Chai.AssertionStatic & Assertion, ...args: any[]) => void,
) {
return function (this: Chai.AssertionStatic & Assertion, ...args: any[]) { return function (this: Chai.AssertionStatic & Assertion, ...args: any[]) {
if (!utils.flag(this, 'soft')) if (!utils.flag(this, 'soft')) {
return fn.apply(this, args) return fn.apply(this, args)
}
const test: Test = utils.flag(this, 'vitest-test') const test: Test = utils.flag(this, 'vitest-test')
if (!test) if (!test) {
throw new Error('expect.soft() can only be used inside a test') throw new Error('expect.soft() can only be used inside a test')
}
try { try {
return fn.apply(this, args) return fn.apply(this, args)

View File

@ -48,15 +48,14 @@ export default defineConfig([
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
onwarn, onwarn,
}, },
]) ])
function onwarn(message) { function onwarn(message) {
if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) {
return return
}
console.error(message) console.error(message)
} }

View File

@ -1,15 +1,27 @@
import { processError } from '@vitest/utils/error' import { processError } from '@vitest/utils/error'
import type { File, SuiteHooks } from './types' import type { File, SuiteHooks } from './types'
import type { VitestRunner } from './types/runner' import type { VitestRunner } from './types/runner'
import { calculateSuiteHash, createFileTask, interpretTaskModes, someTasksAreOnly } from './utils/collect' import {
import { clearCollectorContext, createSuiteHooks, getDefaultSuite } from './suite' calculateSuiteHash,
createFileTask,
interpretTaskModes,
someTasksAreOnly,
} from './utils/collect'
import {
clearCollectorContext,
createSuiteHooks,
getDefaultSuite,
} from './suite'
import { getHooks, setHooks } from './map' import { getHooks, setHooks } from './map'
import { collectorContext } from './context' import { collectorContext } from './context'
import { runSetupFiles } from './setup' import { runSetupFiles } from './setup'
const now = Date.now const now = Date.now
export async function collectTests(paths: string[], runner: VitestRunner): Promise<File[]> { export async function collectTests(
paths: string[],
runner: VitestRunner,
): Promise<File[]> {
const files: File[] = [] const files: File[] = []
const config = runner.config const config = runner.config
@ -66,13 +78,20 @@ export async function collectTests(paths: string[], runner: VitestRunner): Promi
calculateSuiteHash(file) calculateSuiteHash(file)
const hasOnlyTasks = someTasksAreOnly(file) const hasOnlyTasks = someTasksAreOnly(file)
interpretTaskModes(file, config.testNamePattern, hasOnlyTasks, false, config.allowOnly) interpretTaskModes(
file,
config.testNamePattern,
hasOnlyTasks,
false,
config.allowOnly,
)
file.tasks.forEach((task) => { file.tasks.forEach((task) => {
// task.suite refers to the internal default suite object // task.suite refers to the internal default suite object
// it should not be reported // it should not be reported
if (task.suite?.id === '') if (task.suite?.id === '') {
delete task.suite delete task.suite
}
}) })
files.push(file) files.push(file)
} }
@ -83,7 +102,7 @@ export async function collectTests(paths: string[], runner: VitestRunner): Promi
function mergeHooks(baseHooks: SuiteHooks, hooks: SuiteHooks): SuiteHooks { function mergeHooks(baseHooks: SuiteHooks, hooks: SuiteHooks): SuiteHooks {
for (const _key in hooks) { for (const _key in hooks) {
const key = _key as keyof SuiteHooks const key = _key as keyof SuiteHooks
baseHooks[key].push(...hooks[key] as any) baseHooks[key].push(...(hooks[key] as any))
} }
return baseHooks return baseHooks

View File

@ -1,6 +1,13 @@
import type { Awaitable } from '@vitest/utils' import type { Awaitable } from '@vitest/utils'
import { getSafeTimers } from '@vitest/utils' import { getSafeTimers } from '@vitest/utils'
import type { Custom, ExtendedContext, RuntimeContext, SuiteCollector, TaskContext, Test } from './types' import type {
Custom,
ExtendedContext,
RuntimeContext,
SuiteCollector,
TaskContext,
Test,
} from './types'
import type { VitestRunner } from './types/runner' import type { VitestRunner } from './types/runner'
import { PendingError } from './errors' import { PendingError } from './errors'
@ -13,36 +20,46 @@ export function collectTask(task: SuiteCollector) {
collectorContext.currentSuite?.tasks.push(task) collectorContext.currentSuite?.tasks.push(task)
} }
export async function runWithSuite(suite: SuiteCollector, fn: (() => Awaitable<void>)) { export async function runWithSuite(
suite: SuiteCollector,
fn: () => Awaitable<void>,
) {
const prev = collectorContext.currentSuite const prev = collectorContext.currentSuite
collectorContext.currentSuite = suite collectorContext.currentSuite = suite
await fn() await fn()
collectorContext.currentSuite = prev collectorContext.currentSuite = prev
} }
export function withTimeout<T extends((...args: any[]) => any)>( export function withTimeout<T extends (...args: any[]) => any>(
fn: T, fn: T,
timeout: number, timeout: number,
isHook = false, isHook = false,
): T { ): T {
if (timeout <= 0 || timeout === Number.POSITIVE_INFINITY) if (timeout <= 0 || timeout === Number.POSITIVE_INFINITY) {
return fn return fn
}
const { setTimeout, clearTimeout } = getSafeTimers() const { setTimeout, clearTimeout } = getSafeTimers()
return ((...args: (T extends ((...args: infer A) => any) ? A : never)) => { return ((...args: T extends (...args: infer A) => any ? A : never) => {
return Promise.race([fn(...args), new Promise((resolve, reject) => { return Promise.race([
const timer = setTimeout(() => { fn(...args),
clearTimeout(timer) new Promise((resolve, reject) => {
reject(new Error(makeTimeoutMsg(isHook, timeout))) const timer = setTimeout(() => {
}, timeout) clearTimeout(timer)
// `unref` might not exist in browser reject(new Error(makeTimeoutMsg(isHook, timeout)))
timer.unref?.() }, timeout)
})]) as Awaitable<void> // `unref` might not exist in browser
timer.unref?.()
}),
]) as Awaitable<void>
}) as T }) as T
} }
export function createTestContext<T extends Test | Custom>(test: T, runner: VitestRunner): ExtendedContext<T> { export function createTestContext<T extends Test | Custom>(
test: T,
runner: VitestRunner,
): ExtendedContext<T> {
const context = function () { const context = function () {
throw new Error('done() callback is deprecated, use promise instead') throw new Error('done() callback is deprecated, use promise instead')
} as unknown as TaskContext<T> } as unknown as TaskContext<T>
@ -64,9 +81,15 @@ export function createTestContext<T extends Test | Custom>(test: T, runner: Vite
test.onFinished.push(fn) test.onFinished.push(fn)
} }
return runner.extendTaskContext?.(context) as ExtendedContext<T> || context return (runner.extendTaskContext?.(context) as ExtendedContext<T>) || context
} }
function makeTimeoutMsg(isHook: boolean, timeout: number) { function makeTimeoutMsg(isHook: boolean, timeout: number) {
return `${isHook ? 'Hook' : 'Test'} timed out in ${timeout}ms.\nIf this is a long-running ${isHook ? 'hook' : 'test'}, pass a timeout value as the last argument or configure it globally with "${isHook ? 'hookTimeout' : 'testTimeout'}".` return `${
isHook ? 'Hook' : 'Test'
} timed out in ${timeout}ms.\nIf this is a long-running ${
isHook ? 'hook' : 'test'
}, pass a timeout value as the last argument or configure it globally with "${
isHook ? 'hookTimeout' : 'testTimeout'
}".`
} }

View File

@ -15,14 +15,18 @@ export interface FixtureItem extends FixtureOptions {
deps?: FixtureItem[] deps?: FixtureItem[]
} }
export function mergeContextFixtures(fixtures: Record<string, any>, context: { fixtures?: FixtureItem[] } = {}) { export function mergeContextFixtures(
fixtures: Record<string, any>,
context: { fixtures?: FixtureItem[] } = {},
) {
const fixtureOptionKeys = ['auto'] const fixtureOptionKeys = ['auto']
const fixtureArray: FixtureItem[] = Object.entries(fixtures) const fixtureArray: FixtureItem[] = Object.entries(fixtures).map(
.map(([prop, value]) => { ([prop, value]) => {
const fixtureItem = { value } as FixtureItem const fixtureItem = { value } as FixtureItem
if ( if (
Array.isArray(value) && value.length >= 2 Array.isArray(value)
&& value.length >= 2
&& isObject(value[1]) && isObject(value[1])
&& Object.keys(value[1]).some(key => fixtureOptionKeys.includes(key)) && Object.keys(value[1]).some(key => fixtureOptionKeys.includes(key))
) { ) {
@ -34,19 +38,25 @@ export function mergeContextFixtures(fixtures: Record<string, any>, context: { f
fixtureItem.prop = prop fixtureItem.prop = prop
fixtureItem.isFn = typeof fixtureItem.value === 'function' fixtureItem.isFn = typeof fixtureItem.value === 'function'
return fixtureItem return fixtureItem
}) },
)
if (Array.isArray(context.fixtures)) if (Array.isArray(context.fixtures)) {
context.fixtures = context.fixtures.concat(fixtureArray) context.fixtures = context.fixtures.concat(fixtureArray)
else }
else {
context.fixtures = fixtureArray context.fixtures = fixtureArray
}
// Update dependencies of fixture functions // Update dependencies of fixture functions
fixtureArray.forEach((fixture) => { fixtureArray.forEach((fixture) => {
if (fixture.isFn) { if (fixture.isFn) {
const usedProps = getUsedProps(fixture.value) const usedProps = getUsedProps(fixture.value)
if (usedProps.length) if (usedProps.length) {
fixture.deps = context.fixtures!.filter(({ prop }) => prop !== fixture.prop && usedProps.includes(prop)) fixture.deps = context.fixtures!.filter(
({ prop }) => prop !== fixture.prop && usedProps.includes(prop),
)
}
} }
}) })
@ -54,52 +64,69 @@ export function mergeContextFixtures(fixtures: Record<string, any>, context: { f
} }
const fixtureValueMaps = new Map<TestContext, Map<FixtureItem, any>>() const fixtureValueMaps = new Map<TestContext, Map<FixtureItem, any>>()
const cleanupFnArrayMap = new Map<TestContext, Array<() => void | Promise<void>>>() const cleanupFnArrayMap = new Map<
TestContext,
Array<() => void | Promise<void>>
>()
export async function callFixtureCleanup(context: TestContext) { export async function callFixtureCleanup(context: TestContext) {
const cleanupFnArray = cleanupFnArrayMap.get(context) ?? [] const cleanupFnArray = cleanupFnArrayMap.get(context) ?? []
for (const cleanup of cleanupFnArray.reverse()) for (const cleanup of cleanupFnArray.reverse()) {
await cleanup() await cleanup()
}
cleanupFnArrayMap.delete(context) cleanupFnArrayMap.delete(context)
} }
export function withFixtures(fn: Function, testContext?: TestContext) { export function withFixtures(fn: Function, testContext?: TestContext) {
return (hookContext?: TestContext) => { return (hookContext?: TestContext) => {
const context: TestContext & { [key: string]: any } | undefined = hookContext || testContext const context: (TestContext & { [key: string]: any }) | undefined
= hookContext || testContext
if (!context) if (!context) {
return fn({}) return fn({})
}
const fixtures = getFixture(context) const fixtures = getFixture(context)
if (!fixtures?.length) if (!fixtures?.length) {
return fn(context) return fn(context)
}
const usedProps = getUsedProps(fn) const usedProps = getUsedProps(fn)
const hasAutoFixture = fixtures.some(({ auto }) => auto) const hasAutoFixture = fixtures.some(({ auto }) => auto)
if (!usedProps.length && !hasAutoFixture) if (!usedProps.length && !hasAutoFixture) {
return fn(context) return fn(context)
}
if (!fixtureValueMaps.get(context)) if (!fixtureValueMaps.get(context)) {
fixtureValueMaps.set(context, new Map<FixtureItem, any>()) fixtureValueMaps.set(context, new Map<FixtureItem, any>())
const fixtureValueMap: Map<FixtureItem, any> = fixtureValueMaps.get(context)! }
const fixtureValueMap: Map<FixtureItem, any>
= fixtureValueMaps.get(context)!
if (!cleanupFnArrayMap.has(context)) if (!cleanupFnArrayMap.has(context)) {
cleanupFnArrayMap.set(context, []) cleanupFnArrayMap.set(context, [])
}
const cleanupFnArray = cleanupFnArrayMap.get(context)! const cleanupFnArray = cleanupFnArrayMap.get(context)!
const usedFixtures = fixtures.filter(({ prop, auto }) => auto || usedProps.includes(prop)) const usedFixtures = fixtures.filter(
({ prop, auto }) => auto || usedProps.includes(prop),
)
const pendingFixtures = resolveDeps(usedFixtures) const pendingFixtures = resolveDeps(usedFixtures)
if (!pendingFixtures.length) if (!pendingFixtures.length) {
return fn(context) return fn(context)
}
async function resolveFixtures() { async function resolveFixtures() {
for (const fixture of pendingFixtures) { for (const fixture of pendingFixtures) {
// fixture could be already initialized during "before" hook // fixture could be already initialized during "before" hook
if (fixtureValueMap.has(fixture)) if (fixtureValueMap.has(fixture)) {
continue continue
}
const resolvedValue = fixture.isFn ? await resolveFixtureFunction(fixture.value, context, cleanupFnArray) : fixture.value const resolvedValue = fixture.isFn
? await resolveFixtureFunction(fixture.value, context, cleanupFnArray)
: fixture.value
context![fixture.prop] = resolvedValue context![fixture.prop] = resolvedValue
fixtureValueMap.set(fixture, resolvedValue) fixtureValueMap.set(fixture, resolvedValue)
cleanupFnArray.unshift(() => { cleanupFnArray.unshift(() => {
@ -113,9 +140,12 @@ export function withFixtures(fn: Function, testContext?: TestContext) {
} }
async function resolveFixtureFunction( async function resolveFixtureFunction(
fixtureFn: (context: unknown, useFn: (arg: unknown) => Promise<void>) => Promise<void>, fixtureFn: (
context: unknown,
useFn: (arg: unknown) => Promise<void>
) => Promise<void>,
context: unknown, context: unknown,
cleanupFnArray: (() => (void | Promise<void>))[], cleanupFnArray: (() => void | Promise<void>)[],
): Promise<unknown> { ): Promise<unknown> {
// wait for `use` call to extract fixture value // wait for `use` call to extract fixture value
const useFnArgPromise = createDefer() const useFnArgPromise = createDefer()
@ -148,16 +178,27 @@ async function resolveFixtureFunction(
return useFnArgPromise return useFnArgPromise
} }
function resolveDeps(fixtures: FixtureItem[], depSet = new Set<FixtureItem>(), pendingFixtures: FixtureItem[] = []) { function resolveDeps(
fixtures: FixtureItem[],
depSet = new Set<FixtureItem>(),
pendingFixtures: FixtureItem[] = [],
) {
fixtures.forEach((fixture) => { fixtures.forEach((fixture) => {
if (pendingFixtures.includes(fixture)) if (pendingFixtures.includes(fixture)) {
return return
}
if (!fixture.isFn || !fixture.deps) { if (!fixture.isFn || !fixture.deps) {
pendingFixtures.push(fixture) pendingFixtures.push(fixture)
return return
} }
if (depSet.has(fixture)) if (depSet.has(fixture)) {
throw new Error(`Circular fixture dependency detected: ${fixture.prop} <- ${[...depSet].reverse().map(d => d.prop).join(' <- ')}`) throw new Error(
`Circular fixture dependency detected: ${fixture.prop} <- ${[...depSet]
.reverse()
.map(d => d.prop)
.join(' <- ')}`,
)
}
depSet.add(fixture) depSet.add(fixture)
resolveDeps(fixture.deps, depSet, pendingFixtures) resolveDeps(fixture.deps, depSet, pendingFixtures)
@ -170,22 +211,28 @@ function resolveDeps(fixtures: FixtureItem[], depSet = new Set<FixtureItem>(), p
function getUsedProps(fn: Function) { function getUsedProps(fn: Function) {
const match = fn.toString().match(/[^(]*\(([^)]*)/) const match = fn.toString().match(/[^(]*\(([^)]*)/)
if (!match) if (!match) {
return [] return []
}
const args = splitByComma(match[1]) const args = splitByComma(match[1])
if (!args.length) if (!args.length) {
return [] return []
}
let first = args[0] let first = args[0]
if ('__VITEST_FIXTURE_INDEX__' in fn) { if ('__VITEST_FIXTURE_INDEX__' in fn) {
first = args[(fn as any).__VITEST_FIXTURE_INDEX__] first = args[(fn as any).__VITEST_FIXTURE_INDEX__]
if (!first) if (!first) {
return [] return []
}
} }
if (!(first.startsWith('{') && first.endsWith('}'))) if (!(first.startsWith('{') && first.endsWith('}'))) {
throw new Error(`The first argument inside a fixture must use object destructuring pattern, e.g. ({ test } => {}). Instead, received "${first}".`) throw new Error(
`The first argument inside a fixture must use object destructuring pattern, e.g. ({ test } => {}). Instead, received "${first}".`,
)
}
const _first = first.slice(1, -1).replace(/\s/g, '') const _first = first.slice(1, -1).replace(/\s/g, '')
const props = splitByComma(_first).map((prop) => { const props = splitByComma(_first).map((prop) => {
@ -193,8 +240,11 @@ function getUsedProps(fn: Function) {
}) })
const last = props.at(-1) const last = props.at(-1)
if (last && last.startsWith('...')) if (last && last.startsWith('...')) {
throw new Error(`Rest parameters are not supported in fixtures, received "${last}".`) throw new Error(
`Rest parameters are not supported in fixtures, received "${last}".`,
)
}
return props return props
} }
@ -212,13 +262,15 @@ function splitByComma(s: string) {
} }
else if (!stack.length && s[i] === ',') { else if (!stack.length && s[i] === ',') {
const token = s.substring(start, i).trim() const token = s.substring(start, i).trim()
if (token) if (token) {
result.push(token) result.push(token)
}
start = i + 1 start = i + 1
} }
} }
const lastToken = s.substring(start).trim() const lastToken = s.substring(start).trim()
if (lastToken) if (lastToken) {
result.push(lastToken) result.push(lastToken)
}
return result return result
} }

View File

@ -1,4 +1,9 @@
import type { OnTestFailedHandler, OnTestFinishedHandler, SuiteHooks, TaskPopulated } from './types' import type {
OnTestFailedHandler,
OnTestFinishedHandler,
SuiteHooks,
TaskPopulated,
} from './types'
import { getCurrentSuite, getRunner } from './suite' import { getCurrentSuite, getRunner } from './suite'
import { getCurrentTest } from './test-state' import { getCurrentTest } from './test-state'
import { withTimeout } from './context' import { withTimeout } from './context'
@ -10,34 +15,62 @@ function getDefaultHookTimeout() {
// suite hooks // suite hooks
export function beforeAll(fn: SuiteHooks['beforeAll'][0], timeout?: number) { export function beforeAll(fn: SuiteHooks['beforeAll'][0], timeout?: number) {
return getCurrentSuite().on('beforeAll', withTimeout(fn, timeout ?? getDefaultHookTimeout(), true)) return getCurrentSuite().on(
'beforeAll',
withTimeout(fn, timeout ?? getDefaultHookTimeout(), true),
)
} }
export function afterAll(fn: SuiteHooks['afterAll'][0], timeout?: number) { export function afterAll(fn: SuiteHooks['afterAll'][0], timeout?: number) {
return getCurrentSuite().on('afterAll', withTimeout(fn, timeout ?? getDefaultHookTimeout(), true)) return getCurrentSuite().on(
'afterAll',
withTimeout(fn, timeout ?? getDefaultHookTimeout(), true),
)
} }
export function beforeEach<ExtraContext = {}>(fn: SuiteHooks<ExtraContext>['beforeEach'][0], timeout?: number) { export function beforeEach<ExtraContext = {}>(
return getCurrentSuite<ExtraContext>().on('beforeEach', withTimeout(withFixtures(fn), timeout ?? getDefaultHookTimeout(), true)) fn: SuiteHooks<ExtraContext>['beforeEach'][0],
timeout?: number,
) {
return getCurrentSuite<ExtraContext>().on(
'beforeEach',
withTimeout(withFixtures(fn), timeout ?? getDefaultHookTimeout(), true),
)
} }
export function afterEach<ExtraContext = {}>(fn: SuiteHooks<ExtraContext>['afterEach'][0], timeout?: number) { export function afterEach<ExtraContext = {}>(
return getCurrentSuite<ExtraContext>().on('afterEach', withTimeout(withFixtures(fn), timeout ?? getDefaultHookTimeout(), true)) fn: SuiteHooks<ExtraContext>['afterEach'][0],
timeout?: number,
) {
return getCurrentSuite<ExtraContext>().on(
'afterEach',
withTimeout(withFixtures(fn), timeout ?? getDefaultHookTimeout(), true),
)
} }
export const onTestFailed = createTestHook<OnTestFailedHandler>('onTestFailed', (test, handler) => { export const onTestFailed = createTestHook<OnTestFailedHandler>(
test.onFailed ||= [] 'onTestFailed',
test.onFailed.push(handler) (test, handler) => {
}) test.onFailed ||= []
test.onFailed.push(handler)
},
)
export const onTestFinished = createTestHook<OnTestFinishedHandler>('onTestFinished', (test, handler) => { export const onTestFinished = createTestHook<OnTestFinishedHandler>(
test.onFinished ||= [] 'onTestFinished',
test.onFinished.push(handler) (test, handler) => {
}) test.onFinished ||= []
test.onFinished.push(handler)
},
)
function createTestHook<T>(name: string, handler: (test: TaskPopulated, handler: T) => void) { function createTestHook<T>(
name: string,
handler: (test: TaskPopulated, handler: T) => void,
) {
return (fn: T) => { return (fn: T) => {
const current = getCurrentTest() const current = getCurrentTest()
if (!current) if (!current) {
throw new Error(`Hook ${name}() can only be called inside a test`) throw new Error(`Hook ${name}() can only be called inside a test`)
}
return handler(current, fn) return handler(current, fn)
} }

View File

@ -1,6 +1,20 @@
export { startTests, updateTask } from './run' export { startTests, updateTask } from './run'
export { test, it, describe, suite, getCurrentSuite, createTaskCollector } from './suite' export {
export { beforeAll, beforeEach, afterAll, afterEach, onTestFailed, onTestFinished } from './hooks' test,
it,
describe,
suite,
getCurrentSuite,
createTaskCollector,
} from './suite'
export {
beforeAll,
beforeEach,
afterAll,
afterEach,
onTestFailed,
onTestFinished,
} from './hooks'
export { setFn, getFn, getHooks, setHooks } from './map' export { setFn, getFn, getHooks, setHooks } from './map'
export { getCurrentTest } from './test-state' export { getCurrentTest } from './test-state'
export { processError } from '@vitest/utils/error' export { processError } from '@vitest/utils/error'

View File

@ -7,15 +7,18 @@ const fnMap = new WeakMap()
const fixtureMap = new WeakMap() const fixtureMap = new WeakMap()
const hooksMap = new WeakMap() const hooksMap = new WeakMap()
export function setFn(key: Test | Custom, fn: (() => Awaitable<void>)) { export function setFn(key: Test | Custom, fn: () => Awaitable<void>) {
fnMap.set(key, fn) fnMap.set(key, fn)
} }
export function getFn<Task = Test | Custom>(key: Task): (() => Awaitable<void>) { export function getFn<Task = Test | Custom>(key: Task): () => Awaitable<void> {
return fnMap.get(key as any) return fnMap.get(key as any)
} }
export function setFixture(key: TestContext, fixture: FixtureItem[] | undefined) { export function setFixture(
key: TestContext,
fixture: FixtureItem[] | undefined,
) {
fixtureMap.set(key, fixture) fixtureMap.set(key, fixture)
} }

View File

@ -4,7 +4,21 @@ import { getSafeTimers, shuffle } from '@vitest/utils'
import { processError } from '@vitest/utils/error' import { processError } from '@vitest/utils/error'
import type { DiffOptions } from '@vitest/utils/diff' import type { DiffOptions } from '@vitest/utils/diff'
import type { VitestRunner } from './types/runner' import type { VitestRunner } from './types/runner'
import type { Custom, File, HookCleanupCallback, HookListener, SequenceHooks, Suite, SuiteHooks, Task, TaskMeta, TaskResult, TaskResultPack, TaskState, Test } from './types' import type {
Custom,
File,
HookCleanupCallback,
HookListener,
SequenceHooks,
Suite,
SuiteHooks,
Task,
TaskMeta,
TaskResult,
TaskResultPack,
TaskState,
Test,
} from './types'
import { partitionSuiteChildren } from './utils/suite' import { partitionSuiteChildren } from './utils/suite'
import { getFn, getHooks } from './map' import { getFn, getHooks } from './map'
import { collectTests } from './collect' import { collectTests } from './collect'
@ -15,11 +29,18 @@ import { callFixtureCleanup } from './fixture'
const now = Date.now const now = Date.now
function updateSuiteHookState(suite: Task, name: keyof SuiteHooks, state: TaskState, runner: VitestRunner) { function updateSuiteHookState(
if (!suite.result) suite: Task,
name: keyof SuiteHooks,
state: TaskState,
runner: VitestRunner,
) {
if (!suite.result) {
suite.result = { state: 'run' } suite.result = { state: 'run' }
if (!suite.result?.hooks) }
if (!suite.result?.hooks) {
suite.result.hooks = {} suite.result.hooks = {}
}
const suiteHooks = suite.result.hooks const suiteHooks = suite.result.hooks
if (suiteHooks) { if (suiteHooks) {
suiteHooks[name] = state suiteHooks[name] = state
@ -27,23 +48,34 @@ function updateSuiteHookState(suite: Task, name: keyof SuiteHooks, state: TaskSt
} }
} }
function getSuiteHooks(suite: Suite, name: keyof SuiteHooks, sequence: SequenceHooks) { function getSuiteHooks(
suite: Suite,
name: keyof SuiteHooks,
sequence: SequenceHooks,
) {
const hooks = getHooks(suite)[name] const hooks = getHooks(suite)[name]
if (sequence === 'stack' && (name === 'afterAll' || name === 'afterEach')) if (sequence === 'stack' && (name === 'afterAll' || name === 'afterEach')) {
return hooks.slice().reverse() return hooks.slice().reverse()
}
return hooks return hooks
} }
async function callTaskHooks(task: Task, hooks: ((result: TaskResult) => Awaitable<void>)[], sequence: SequenceHooks) { async function callTaskHooks(
if (sequence === 'stack') task: Task,
hooks: ((result: TaskResult) => Awaitable<void>)[],
sequence: SequenceHooks,
) {
if (sequence === 'stack') {
hooks = hooks.slice().reverse() hooks = hooks.slice().reverse()
}
if (sequence === 'parallel') { if (sequence === 'parallel') {
await Promise.all(hooks.map(fn => fn(task.result!))) await Promise.all(hooks.map(fn => fn(task.result!)))
} }
else { else {
for (const fn of hooks) for (const fn of hooks) {
await fn(task.result!) await fn(task.result!)
}
} }
} }
@ -58,13 +90,12 @@ export async function callSuiteHook<T extends keyof SuiteHooks>(
const callbacks: HookCleanupCallback[] = [] const callbacks: HookCleanupCallback[] = []
// stop at file level // stop at file level
const parentSuite: Suite | null = 'filepath' in suite const parentSuite: Suite | null
? null = 'filepath' in suite ? null : suite.suite || suite.file
: (suite.suite || suite.file)
if (name === 'beforeEach' && parentSuite) { if (name === 'beforeEach' && parentSuite) {
callbacks.push( callbacks.push(
...await callSuiteHook(parentSuite, currentTask, name, runner, args), ...(await callSuiteHook(parentSuite, currentTask, name, runner, args)),
) )
} }
@ -73,18 +104,21 @@ export async function callSuiteHook<T extends keyof SuiteHooks>(
const hooks = getSuiteHooks(suite, name, sequence) const hooks = getSuiteHooks(suite, name, sequence)
if (sequence === 'parallel') { if (sequence === 'parallel') {
callbacks.push(...await Promise.all(hooks.map(fn => fn(...args as any)))) callbacks.push(
...(await Promise.all(hooks.map(fn => fn(...(args as any))))),
)
} }
else { else {
for (const hook of hooks) for (const hook of hooks) {
callbacks.push(await hook(...args as any)) callbacks.push(await hook(...(args as any)))
}
} }
updateSuiteHookState(currentTask, name, 'pass', runner) updateSuiteHookState(currentTask, name, 'pass', runner)
if (name === 'afterEach' && parentSuite) { if (name === 'afterEach' && parentSuite) {
callbacks.push( callbacks.push(
...await callSuiteHook(parentSuite, currentTask, name, runner, args), ...(await callSuiteHook(parentSuite, currentTask, name, runner, args)),
) )
} }
@ -113,11 +147,7 @@ async function sendTasksUpdate(runner: VitestRunner) {
if (packs.size) { if (packs.size) {
const taskPacks = Array.from(packs).map<TaskResultPack>(([id, task]) => { const taskPacks = Array.from(packs).map<TaskResultPack>(([id, task]) => {
return [ return [id, task[0], task[1]]
id,
task[0],
task[1],
]
}) })
const p = runner.onTaskUpdate?.(taskPacks) const p = runner.onTaskUpdate?.(taskPacks)
packs.clear() packs.clear()
@ -126,18 +156,22 @@ async function sendTasksUpdate(runner: VitestRunner) {
} }
async function callCleanupHooks(cleanups: HookCleanupCallback[]) { async function callCleanupHooks(cleanups: HookCleanupCallback[]) {
await Promise.all(cleanups.map(async (fn) => { await Promise.all(
if (typeof fn !== 'function') cleanups.map(async (fn) => {
return if (typeof fn !== 'function') {
await fn() return
})) }
await fn()
}),
)
} }
export async function runTest(test: Test | Custom, runner: VitestRunner) { export async function runTest(test: Test | Custom, runner: VitestRunner) {
await runner.onBeforeRunTask?.(test) await runner.onBeforeRunTask?.(test)
if (test.mode !== 'run') if (test.mode !== 'run') {
return return
}
if (test.result?.state === 'fail') { if (test.result?.state === 'fail') {
updateTask(test, runner) updateTask(test, runner)
@ -163,36 +197,56 @@ export async function runTest(test: Test | Custom, runner: VitestRunner) {
for (let retryCount = 0; retryCount <= retry; retryCount++) { for (let retryCount = 0; retryCount <= retry; retryCount++) {
let beforeEachCleanups: HookCleanupCallback[] = [] let beforeEachCleanups: HookCleanupCallback[] = []
try { try {
await runner.onBeforeTryTask?.(test, { retry: retryCount, repeats: repeatCount }) await runner.onBeforeTryTask?.(test, {
retry: retryCount,
repeats: repeatCount,
})
test.result.repeatCount = repeatCount test.result.repeatCount = repeatCount
beforeEachCleanups = await callSuiteHook(suite, test, 'beforeEach', runner, [test.context, suite]) beforeEachCleanups = await callSuiteHook(
suite,
test,
'beforeEach',
runner,
[test.context, suite],
)
if (runner.runTask) { if (runner.runTask) {
await runner.runTask(test) await runner.runTask(test)
} }
else { else {
const fn = getFn(test) const fn = getFn(test)
if (!fn) if (!fn) {
throw new Error('Test function is not found. Did you add it using `setFn`?') throw new Error(
'Test function is not found. Did you add it using `setFn`?',
)
}
await fn() await fn()
} }
// some async expect will be added to this array, in case user forget to await theme // some async expect will be added to this array, in case user forget to await theme
if (test.promises) { if (test.promises) {
const result = await Promise.allSettled(test.promises) const result = await Promise.allSettled(test.promises)
const errors = result.map(r => r.status === 'rejected' ? r.reason : undefined).filter(Boolean) const errors = result
if (errors.length) .map(r => (r.status === 'rejected' ? r.reason : undefined))
.filter(Boolean)
if (errors.length) {
throw errors throw errors
}
} }
await runner.onAfterTryTask?.(test, { retry: retryCount, repeats: repeatCount }) await runner.onAfterTryTask?.(test, {
retry: retryCount,
repeats: repeatCount,
})
if (test.result.state !== 'fail') { if (test.result.state !== 'fail') {
if (!test.repeats) if (!test.repeats) {
test.result.state = 'pass' test.result.state = 'pass'
else if (test.repeats && retry === retryCount) }
else if (test.repeats && retry === retryCount) {
test.result.state = 'pass' test.result.state = 'pass'
}
} }
} }
catch (e) { catch (e) {
@ -209,7 +263,10 @@ export async function runTest(test: Test | Custom, runner: VitestRunner) {
} }
try { try {
await callSuiteHook(suite, test, 'afterEach', runner, [test.context, suite]) await callSuiteHook(suite, test, 'afterEach', runner, [
test.context,
suite,
])
await callCleanupHooks(beforeEachCleanups) await callCleanupHooks(beforeEachCleanups)
await callFixtureCleanup(test.context) await callFixtureCleanup(test.context)
} }
@ -217,8 +274,9 @@ export async function runTest(test: Test | Custom, runner: VitestRunner) {
failTask(test.result, e, runner.config.diffOptions) failTask(test.result, e, runner.config.diffOptions)
} }
if (test.result.state === 'pass') if (test.result.state === 'pass') {
break break
}
if (retryCount < retry) { if (retryCount < retry) {
// reset state when retry test // reset state when retry test
@ -240,7 +298,11 @@ export async function runTest(test: Test | Custom, runner: VitestRunner) {
if (test.result.state === 'fail') { if (test.result.state === 'fail') {
try { try {
await callTaskHooks(test, test.onFailed || [], runner.config.sequence.hooks) await callTaskHooks(
test,
test.onFailed || [],
runner.config.sequence.hooks,
)
} }
catch (e) { catch (e) {
failTask(test.result, e, runner.config.diffOptions) failTask(test.result, e, runner.config.diffOptions)
@ -276,9 +338,7 @@ function failTask(result: TaskResult, err: unknown, diffOptions?: DiffOptions) {
} }
result.state = 'fail' result.state = 'fail'
const errors = Array.isArray(err) const errors = Array.isArray(err) ? err : [err]
? err
: [err]
for (const e of errors) { for (const e of errors) {
const error = processError(e, diffOptions) const error = processError(e, diffOptions)
result.errors ??= [] result.errors ??= []
@ -291,8 +351,9 @@ function markTasksAsSkipped(suite: Suite, runner: VitestRunner) {
t.mode = 'skip' t.mode = 'skip'
t.result = { ...t.result, state: 'skip' } t.result = { ...t.result, state: 'skip' }
updateTask(t, runner) updateTask(t, runner)
if (t.type === 'suite') if (t.type === 'suite') {
markTasksAsSkipped(t, runner) markTasksAsSkipped(t, runner)
}
}) })
} }
@ -324,7 +385,13 @@ export async function runSuite(suite: Suite, runner: VitestRunner) {
} }
else { else {
try { try {
beforeAllCleanups = await callSuiteHook(suite, suite, 'beforeAll', runner, [suite]) beforeAllCleanups = await callSuiteHook(
suite,
suite,
'beforeAll',
runner,
[suite],
)
if (runner.runSuite) { if (runner.runSuite) {
await runner.runSuite(suite) await runner.runSuite(suite)
@ -338,13 +405,18 @@ export async function runSuite(suite: Suite, runner: VitestRunner) {
const { sequence } = runner.config const { sequence } = runner.config
if (sequence.shuffle || suite.shuffle) { if (sequence.shuffle || suite.shuffle) {
// run describe block independently from tests // run describe block independently from tests
const suites = tasksGroup.filter(group => group.type === 'suite') const suites = tasksGroup.filter(
group => group.type === 'suite',
)
const tests = tasksGroup.filter(group => group.type === 'test') const tests = tasksGroup.filter(group => group.type === 'test')
const groups = shuffle<Task[]>([suites, tests], sequence.seed) const groups = shuffle<Task[]>([suites, tests], sequence.seed)
tasksGroup = groups.flatMap(group => shuffle(group, sequence.seed)) tasksGroup = groups.flatMap(group =>
shuffle(group, sequence.seed),
)
} }
for (const c of tasksGroup) for (const c of tasksGroup) {
await runSuiteChild(c, runner) await runSuiteChild(c, runner)
}
} }
} }
} }
@ -365,7 +437,9 @@ export async function runSuite(suite: Suite, runner: VitestRunner) {
if (!runner.config.passWithNoTests && !hasTests(suite)) { if (!runner.config.passWithNoTests && !hasTests(suite)) {
suite.result.state = 'fail' suite.result.state = 'fail'
if (!suite.result.errors?.length) { if (!suite.result.errors?.length) {
const error = processError(new Error(`No test found in suite ${suite.name}`)) const error = processError(
new Error(`No test found in suite ${suite.name}`),
)
suite.result.errors = [error] suite.result.errors = [error]
} }
} }
@ -388,11 +462,12 @@ export async function runSuite(suite: Suite, runner: VitestRunner) {
let limitMaxConcurrency: ReturnType<typeof limit> let limitMaxConcurrency: ReturnType<typeof limit>
async function runSuiteChild(c: Task, runner: VitestRunner) { async function runSuiteChild(c: Task, runner: VitestRunner) {
if (c.type === 'test' || c.type === 'custom') if (c.type === 'test' || c.type === 'custom') {
return limitMaxConcurrency(() => runTest(c, runner)) return limitMaxConcurrency(() => runTest(c, runner))
}
else if (c.type === 'suite') else if (c.type === 'suite') {
return runSuite(c, runner) return runSuite(c, runner)
}
} }
export async function runFiles(files: File[], runner: VitestRunner) { export async function runFiles(files: File[], runner: VitestRunner) {
@ -401,7 +476,9 @@ export async function runFiles(files: File[], runner: VitestRunner) {
for (const file of files) { for (const file of files) {
if (!file.tasks.length && !runner.config.passWithNoTests) { if (!file.tasks.length && !runner.config.passWithNoTests) {
if (!file.result?.errors?.length) { if (!file.result?.errors?.length) {
const error = processError(new Error(`No test suite found in file ${file.filepath}`)) const error = processError(
new Error(`No test suite found in file ${file.filepath}`),
)
file.result = { file.result = {
state: 'fail', state: 'fail',
errors: [error], errors: [error],

View File

@ -1,7 +1,10 @@
import { toArray } from '@vitest/utils' import { toArray } from '@vitest/utils'
import type { VitestRunner, VitestRunnerConfig } from './types' import type { VitestRunner, VitestRunnerConfig } from './types'
export async function runSetupFiles(config: VitestRunnerConfig, runner: VitestRunner) { export async function runSetupFiles(
config: VitestRunnerConfig,
runner: VitestRunner,
) {
const files = toArray(config.setupFiles) const files = toArray(config.setupFiles)
if (config.sequence.setupFiles === 'parallel') { if (config.sequence.setupFiles === 'parallel') {
await Promise.all( await Promise.all(
@ -11,7 +14,8 @@ export async function runSetupFiles(config: VitestRunnerConfig, runner: VitestRu
) )
} }
else { else {
for (const fsPath of files) for (const fsPath of files) {
await runner.importFile(fsPath, 'setup') await runner.importFile(fsPath, 'setup')
}
} }
} }

View File

@ -1,9 +1,39 @@
import { format, isNegativeNaN, isObject, objDisplay, objectAttr, toArray } from '@vitest/utils' import {
format,
isNegativeNaN,
isObject,
objDisplay,
objectAttr,
toArray,
} from '@vitest/utils'
import { parseSingleStack } from '@vitest/utils/source-map' import { parseSingleStack } from '@vitest/utils/source-map'
import type { Custom, CustomAPI, File, Fixtures, RunMode, Suite, SuiteAPI, SuiteCollector, SuiteFactory, SuiteHooks, Task, TaskCustomOptions, Test, TestAPI, TestFunction, TestOptions } from './types' import type {
Custom,
CustomAPI,
File,
Fixtures,
RunMode,
Suite,
SuiteAPI,
SuiteCollector,
SuiteFactory,
SuiteHooks,
Task,
TaskCustomOptions,
Test,
TestAPI,
TestFunction,
TestOptions,
} from './types'
import type { VitestRunner } from './types/runner' import type { VitestRunner } from './types/runner'
import { createChainable } from './utils/chain' import { createChainable } from './utils/chain'
import { collectTask, collectorContext, createTestContext, runWithSuite, withTimeout } from './context' import {
collectTask,
collectorContext,
createTestContext,
runWithSuite,
withTimeout,
} from './context'
import { getHooks, setFixture, setFn, setHooks } from './map' import { getHooks, setFixture, setFn, setHooks } from './map'
import type { FixtureItem } from './fixture' import type { FixtureItem } from './fixture'
import { mergeContextFixtures, withFixtures } from './fixture' import { mergeContextFixtures, withFixtures } from './fixture'
@ -11,14 +41,24 @@ import { getCurrentTest } from './test-state'
// apis // apis
export const suite = createSuite() export const suite = createSuite()
export const test = createTest( export const test = createTest(function (
function (name: string | Function, optionsOrFn?: TestOptions | TestFunction, optionsOrTest?: number | TestOptions | TestFunction) { name: string | Function,
if (getCurrentTest()) optionsOrFn?: TestOptions | TestFunction,
throw new Error('Calling the test function inside another test function is not allowed. Please put it inside "describe" or "suite" so it can be properly collected.') optionsOrTest?: number | TestOptions | TestFunction,
) {
if (getCurrentTest()) {
throw new Error(
'Calling the test function inside another test function is not allowed. Please put it inside "describe" or "suite" so it can be properly collected.',
)
}
getCurrentSuite().test.fn.call(this, formatName(name), optionsOrFn as TestOptions, optionsOrTest as TestFunction) getCurrentSuite().test.fn.call(
}, this,
) formatName(name),
optionsOrFn as TestOptions,
optionsOrTest as TestFunction,
)
})
// alias // alias
export const describe = suite export const describe = suite
@ -46,9 +86,13 @@ function createDefaultSuite(runner: VitestRunner) {
return api('', { concurrent: config.concurrent }, () => {}) return api('', { concurrent: config.concurrent }, () => {})
} }
export function clearCollectorContext(filepath: string, currentRunner: VitestRunner) { export function clearCollectorContext(
if (!defaultSuite) filepath: string,
currentRunner: VitestRunner,
) {
if (!defaultSuite) {
defaultSuite = createDefaultSuite(currentRunner) defaultSuite = createDefaultSuite(currentRunner)
}
runner = currentRunner runner = currentRunner
currentTestFilepath = filepath currentTestFilepath = filepath
collectorContext.tasks.length = 0 collectorContext.tasks.length = 0
@ -57,7 +101,8 @@ export function clearCollectorContext(filepath: string, currentRunner: VitestRun
} }
export function getCurrentSuite<ExtraContext = {}>() { export function getCurrentSuite<ExtraContext = {}>() {
return (collectorContext.currentSuite || defaultSuite) as SuiteCollector<ExtraContext> return (collectorContext.currentSuite
|| defaultSuite) as SuiteCollector<ExtraContext>
} }
export function createSuiteHooks() { export function createSuiteHooks() {
@ -79,10 +124,13 @@ function parseArguments<T extends (...args: any[]) => any>(
// it('', () => {}, { retry: 2 }) // it('', () => {}, { retry: 2 })
if (typeof optionsOrTest === 'object') { if (typeof optionsOrTest === 'object') {
// it('', { retry: 2 }, { retry: 3 }) // it('', { retry: 2 }, { retry: 3 })
if (typeof optionsOrFn === 'object') if (typeof optionsOrFn === 'object') {
throw new TypeError('Cannot use two objects as arguments. Please provide options and a function callback in that order.') throw new TypeError(
// TODO: more info, add a name 'Cannot use two objects as arguments. Please provide options and a function callback in that order.',
// console.warn('The third argument is deprecated. Please use the second argument for options.') )
}
// TODO: more info, add a name
// console.warn('The third argument is deprecated. Please use the second argument for options.')
options = optionsOrTest options = optionsOrTest
} }
// it('', () => {}, 1000) // it('', () => {}, 1000)
@ -95,8 +143,11 @@ function parseArguments<T extends (...args: any[]) => any>(
} }
if (typeof optionsOrFn === 'function') { if (typeof optionsOrFn === 'function') {
if (typeof optionsOrTest === 'function') if (typeof optionsOrTest === 'function') {
throw new TypeError('Cannot use two functions as arguments. Please use the second argument for options.') throw new TypeError(
'Cannot use two functions as arguments. Please use the second argument for options.',
)
}
fn = optionsOrFn as T fn = optionsOrFn as T
} }
else if (typeof optionsOrTest === 'function') { else if (typeof optionsOrTest === 'function') {
@ -110,7 +161,14 @@ function parseArguments<T extends (...args: any[]) => any>(
} }
// implementations // implementations
function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, mode: RunMode, shuffle?: boolean, each?: boolean, suiteOptions?: TestOptions) { function createSuiteCollector(
name: string,
factory: SuiteFactory = () => {},
mode: RunMode,
shuffle?: boolean,
each?: boolean,
suiteOptions?: TestOptions,
) {
const tasks: (Test | Custom | Suite | SuiteCollector)[] = [] const tasks: (Test | Custom | Suite | SuiteCollector)[] = []
const factoryQueue: (Test | Suite | SuiteCollector)[] = [] const factoryQueue: (Test | Suite | SuiteCollector)[] = []
@ -130,14 +188,25 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
file: undefined!, file: undefined!,
retry: options.retry ?? runner.config.retry, retry: options.retry ?? runner.config.retry,
repeats: options.repeats, repeats: options.repeats,
mode: options.only ? 'only' : options.skip ? 'skip' : options.todo ? 'todo' : 'run', mode: options.only
? 'only'
: options.skip
? 'skip'
: options.todo
? 'todo'
: 'run',
meta: options.meta ?? Object.create(null), meta: options.meta ?? Object.create(null),
} }
const handler = options.handler const handler = options.handler
if (options.concurrent || (!options.sequential && runner.config.sequence.concurrent)) if (
options.concurrent
|| (!options.sequential && runner.config.sequence.concurrent)
) {
task.concurrent = true task.concurrent = true
if (shuffle) }
if (shuffle) {
task.shuffle = true task.shuffle = true
}
const context = createTestContext(task, runner) const context = createTestContext(task, runner)
// create test context // create test context
@ -148,10 +217,13 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
setFixture(context, options.fixtures) setFixture(context, options.fixtures)
if (handler) { if (handler) {
setFn(task, withTimeout( setFn(
withFixtures(handler, context), task,
options?.timeout ?? runner.config.testTimeout, withTimeout(
)) withFixtures(handler, context),
options?.timeout ?? runner.config.testTimeout,
),
)
} }
if (runner.config.includeTaskLocation) { if (runner.config.includeTaskLocation) {
@ -161,32 +233,38 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
const error = new Error('stacktrace').stack! const error = new Error('stacktrace').stack!
Error.stackTraceLimit = limit Error.stackTraceLimit = limit
const stack = findTestFileStackTrace(error, task.each ?? false) const stack = findTestFileStackTrace(error, task.each ?? false)
if (stack) if (stack) {
task.location = stack task.location = stack
}
} }
tasks.push(task) tasks.push(task)
return task return task
} }
const test = createTest(function (name: string | Function, optionsOrFn?: TestOptions | TestFunction, optionsOrTest?: number | TestOptions | TestFunction) { const test = createTest(function (
let { options, handler } = parseArguments( name: string | Function,
optionsOrFn, optionsOrFn?: TestOptions | TestFunction,
optionsOrTest, optionsOrTest?: number | TestOptions | TestFunction,
) ) {
let { options, handler } = parseArguments(optionsOrFn, optionsOrTest)
// inherit repeats, retry, timeout from suite // inherit repeats, retry, timeout from suite
if (typeof suiteOptions === 'object') if (typeof suiteOptions === 'object') {
options = Object.assign({}, suiteOptions, options) options = Object.assign({}, suiteOptions, options)
}
// inherit concurrent / sequential from suite // inherit concurrent / sequential from suite
options.concurrent = this.concurrent || (!this.sequential && options?.concurrent) options.concurrent
options.sequential = this.sequential || (!this.concurrent && options?.sequential) = this.concurrent || (!this.sequential && options?.concurrent)
options.sequential
= this.sequential || (!this.concurrent && options?.sequential)
const test = task( const test = task(formatName(name), {
formatName(name), ...this,
{ ...this, ...options, handler }, ...options,
) as unknown as Test handler,
}) as unknown as Test
test.type = 'test' test.type = 'test'
}) })
@ -205,12 +283,13 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
} }
function addHook<T extends keyof SuiteHooks>(name: T, ...fn: SuiteHooks[T]) { function addHook<T extends keyof SuiteHooks>(name: T, ...fn: SuiteHooks[T]) {
getHooks(suite)[name].push(...fn as any) getHooks(suite)[name].push(...(fn as any))
} }
function initSuite(includeLocation: boolean) { function initSuite(includeLocation: boolean) {
if (typeof suiteOptions === 'number') if (typeof suiteOptions === 'number') {
suiteOptions = { timeout: suiteOptions } suiteOptions = { timeout: suiteOptions }
}
suite = { suite = {
id: '', id: '',
@ -231,8 +310,9 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
const error = new Error('stacktrace').stack! const error = new Error('stacktrace').stack!
Error.stackTraceLimit = limit Error.stackTraceLimit = limit
const stack = findTestFileStackTrace(error, suite.each ?? false) const stack = findTestFileStackTrace(error, suite.each ?? false)
if (stack) if (stack) {
suite.location = stack suite.location = stack
}
} }
setHooks(suite, createSuiteHooks()) setHooks(suite, createSuiteHooks())
@ -245,17 +325,20 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
} }
async function collect(file: File) { async function collect(file: File) {
if (!file) if (!file) {
throw new TypeError('File is required to collect tasks.') throw new TypeError('File is required to collect tasks.')
}
factoryQueue.length = 0 factoryQueue.length = 0
if (factory) if (factory) {
await runWithSuite(collector, () => factory(test)) await runWithSuite(collector, () => factory(test))
}
const allChildren: Task[] = [] const allChildren: Task[] = []
for (const i of [...factoryQueue, ...tasks]) for (const i of [...factoryQueue, ...tasks]) {
allChildren.push(i.type === 'collector' ? await i.collect(file) : i) allChildren.push(i.type === 'collector' ? await i.collect(file) : i)
}
suite.file = file suite.file = file
suite.tasks = allChildren suite.tasks = allChildren
@ -273,8 +356,19 @@ function createSuiteCollector(name: string, factory: SuiteFactory = () => { }, m
} }
function createSuite() { function createSuite() {
function suiteFn(this: Record<string, boolean | undefined>, name: string | Function, factoryOrOptions?: SuiteFactory | TestOptions, optionsOrFactory: number | TestOptions | SuiteFactory = {}) { function suiteFn(
const mode: RunMode = this.only ? 'only' : this.skip ? 'skip' : this.todo ? 'todo' : 'run' this: Record<string, boolean | undefined>,
name: string | Function,
factoryOrOptions?: SuiteFactory | TestOptions,
optionsOrFactory: number | TestOptions | SuiteFactory = {},
) {
const mode: RunMode = this.only
? 'only'
: this.skip
? 'skip'
: this.todo
? 'todo'
: 'run'
const currentSuite = getCurrentSuite() const currentSuite = getCurrentSuite()
let { options, handler: factory } = parseArguments( let { options, handler: factory } = parseArguments(
@ -283,33 +377,52 @@ function createSuite() {
) )
// inherit options from current suite // inherit options from current suite
if (currentSuite?.options) if (currentSuite?.options) {
options = { ...currentSuite.options, ...options } options = { ...currentSuite.options, ...options }
}
// inherit concurrent / sequential from suite // inherit concurrent / sequential from suite
const isConcurrent = options.concurrent || (this.concurrent && !this.sequential) const isConcurrent
const isSequential = options.sequential || (this.sequential && !this.concurrent) = options.concurrent || (this.concurrent && !this.sequential)
const isSequential
= options.sequential || (this.sequential && !this.concurrent)
options.concurrent = isConcurrent && !isSequential options.concurrent = isConcurrent && !isSequential
options.sequential = isSequential && !isConcurrent options.sequential = isSequential && !isConcurrent
return createSuiteCollector(formatName(name), factory, mode, this.shuffle, this.each, options) return createSuiteCollector(
formatName(name),
factory,
mode,
this.shuffle,
this.each,
options,
)
} }
suiteFn.each = function<T>(this: { withContext: () => SuiteAPI; setContext: (key: string, value: boolean | undefined) => SuiteAPI }, cases: ReadonlyArray<T>, ...args: any[]) { suiteFn.each = function <T>(
this: {
withContext: () => SuiteAPI
setContext: (key: string, value: boolean | undefined) => SuiteAPI
},
cases: ReadonlyArray<T>,
...args: any[]
) {
const suite = this.withContext() const suite = this.withContext()
this.setContext('each', true) this.setContext('each', true)
if (Array.isArray(cases) && args.length) if (Array.isArray(cases) && args.length) {
cases = formatTemplateString(cases, args) cases = formatTemplateString(cases, args)
}
return (name: string | Function, optionsOrFn: ((...args: T[]) => void) | TestOptions, fnOrOptions?: ((...args: T[]) => void) | number | TestOptions) => { return (
name: string | Function,
optionsOrFn: ((...args: T[]) => void) | TestOptions,
fnOrOptions?: ((...args: T[]) => void) | number | TestOptions,
) => {
const _name = formatName(name) const _name = formatName(name)
const arrayOnlyCases = cases.every(Array.isArray) const arrayOnlyCases = cases.every(Array.isArray)
const { options, handler } = parseArguments( const { options, handler } = parseArguments(optionsOrFn, fnOrOptions)
optionsOrFn,
fnOrOptions,
)
const fnFirst = typeof optionsOrFn === 'function' const fnFirst = typeof optionsOrFn === 'function'
@ -317,12 +430,17 @@ function createSuite() {
const items = Array.isArray(i) ? i : [i] const items = Array.isArray(i) ? i : [i]
if (fnFirst) { if (fnFirst) {
arrayOnlyCases arrayOnlyCases
? suite(formatTitle(_name, items, idx), () => handler(...items), options) ? suite(
formatTitle(_name, items, idx),
() => handler(...items),
options,
)
: suite(formatTitle(_name, items, idx), () => handler(i), options) : suite(formatTitle(_name, items, idx), () => handler(i), options)
} }
else { else {
arrayOnlyCases arrayOnlyCases
? suite(formatTitle(_name, items, idx), options, () => handler(...items)) ? suite(formatTitle(_name, items, idx), options, () =>
handler(...items))
: suite(formatTitle(_name, items, idx), options, () => handler(i)) : suite(formatTitle(_name, items, idx), options, () => handler(i))
} }
}) })
@ -331,8 +449,10 @@ function createSuite() {
} }
} }
suiteFn.skipIf = (condition: any) => (condition ? suite.skip : suite) as SuiteAPI suiteFn.skipIf = (condition: any) =>
suiteFn.runIf = (condition: any) => (condition ? suite : suite.skip) as SuiteAPI (condition ? suite.skip : suite) as SuiteAPI
suiteFn.runIf = (condition: any) =>
(condition ? suite : suite.skip) as SuiteAPI
return createChainable( return createChainable(
['concurrent', 'sequential', 'shuffle', 'skip', 'only', 'todo'], ['concurrent', 'sequential', 'shuffle', 'skip', 'only', 'todo'],
@ -346,21 +466,30 @@ export function createTaskCollector(
) { ) {
const taskFn = fn as any const taskFn = fn as any
taskFn.each = function<T>(this: { withContext: () => SuiteAPI; setContext: (key: string, value: boolean | undefined) => SuiteAPI }, cases: ReadonlyArray<T>, ...args: any[]) { taskFn.each = function <T>(
this: {
withContext: () => SuiteAPI
setContext: (key: string, value: boolean | undefined) => SuiteAPI
},
cases: ReadonlyArray<T>,
...args: any[]
) {
const test = this.withContext() const test = this.withContext()
this.setContext('each', true) this.setContext('each', true)
if (Array.isArray(cases) && args.length) if (Array.isArray(cases) && args.length) {
cases = formatTemplateString(cases, args) cases = formatTemplateString(cases, args)
}
return (name: string | Function, optionsOrFn: ((...args: T[]) => void) | TestOptions, fnOrOptions?: ((...args: T[]) => void) | number | TestOptions) => { return (
name: string | Function,
optionsOrFn: ((...args: T[]) => void) | TestOptions,
fnOrOptions?: ((...args: T[]) => void) | number | TestOptions,
) => {
const _name = formatName(name) const _name = formatName(name)
const arrayOnlyCases = cases.every(Array.isArray) const arrayOnlyCases = cases.every(Array.isArray)
const { options, handler } = parseArguments( const { options, handler } = parseArguments(optionsOrFn, fnOrOptions)
optionsOrFn,
fnOrOptions,
)
const fnFirst = typeof optionsOrFn === 'function' const fnFirst = typeof optionsOrFn === 'function'
@ -369,12 +498,17 @@ export function createTaskCollector(
if (fnFirst) { if (fnFirst) {
arrayOnlyCases arrayOnlyCases
? test(formatTitle(_name, items, idx), () => handler(...items), options) ? test(
formatTitle(_name, items, idx),
() => handler(...items),
options,
)
: test(formatTitle(_name, items, idx), () => handler(i), options) : test(formatTitle(_name, items, idx), () => handler(i), options)
} }
else { else {
arrayOnlyCases arrayOnlyCases
? test(formatTitle(_name, items, idx), options, () => handler(...items)) ? test(formatTitle(_name, items, idx), options, () =>
handler(...items))
: test(formatTitle(_name, items, idx), options, () => handler(i)) : test(formatTitle(_name, items, idx), options, () => handler(i))
} }
}) })
@ -393,8 +527,9 @@ export function createTaskCollector(
) { ) {
const test = this.withContext() const test = this.withContext()
if (Array.isArray(cases) && args.length) if (Array.isArray(cases) && args.length) {
cases = formatTemplateString(cases, args) cases = formatTemplateString(cases, args)
}
return ( return (
name: string | Function, name: string | Function,
@ -423,8 +558,17 @@ export function createTaskCollector(
taskFn.extend = function (fixtures: Fixtures<Record<string, any>>) { taskFn.extend = function (fixtures: Fixtures<Record<string, any>>) {
const _context = mergeContextFixtures(fixtures, context) const _context = mergeContextFixtures(fixtures, context)
return createTest(function fn(name: string | Function, optionsOrFn?: TestOptions | TestFunction, optionsOrTest?: number | TestOptions | TestFunction) { return createTest(function fn(
getCurrentSuite().test.fn.call(this, formatName(name), optionsOrFn as TestOptions, optionsOrTest as TestFunction) name: string | Function,
optionsOrFn?: TestOptions | TestFunction,
optionsOrTest?: number | TestOptions | TestFunction,
) {
getCurrentSuite().test.fn.call(
this,
formatName(name),
optionsOrFn as TestOptions,
optionsOrTest as TestFunction,
)
}, _context) }, _context)
} }
@ -433,25 +577,34 @@ export function createTaskCollector(
taskFn, taskFn,
) as CustomAPI ) as CustomAPI
if (context) if (context) {
(_test as any).mergeContext(context) (_test as any).mergeContext(context)
}
return _test return _test
} }
function createTest(fn: ( function createTest(
( fn: (
this: Record<'concurrent' | 'sequential' | 'skip' | 'only' | 'todo' | 'fails' | 'each', boolean | undefined> & { fixtures?: FixtureItem[] }, this: Record<
'concurrent' | 'sequential' | 'skip' | 'only' | 'todo' | 'fails' | 'each',
boolean | undefined
> & { fixtures?: FixtureItem[] },
title: string, title: string,
optionsOrFn?: TestOptions | TestFunction, optionsOrFn?: TestOptions | TestFunction,
optionsOrTest?: number | TestOptions | TestFunction, optionsOrTest?: number | TestOptions | TestFunction
) => void ) => void,
), context?: Record<string, any>) { context?: Record<string, any>,
) {
return createTaskCollector(fn, context) as TestAPI return createTaskCollector(fn, context) as TestAPI
} }
function formatName(name: string | Function) { function formatName(name: string | Function) {
return typeof name === 'string' ? name : name instanceof Function ? (name.name || '<anonymous>') : String(name) return typeof name === 'string'
? name
: name instanceof Function
? name.name || '<anonymous>'
: String(name)
} }
function formatTitle(template: string, items: any[], idx: number) { function formatTitle(template: string, items: any[], idx: number) {
@ -483,19 +636,28 @@ function formatTitle(template: string, items: any[], idx: number) {
formatted = formatted.replace( formatted = formatted.replace(
/\$([$\w.]+)/g, /\$([$\w.]+)/g,
// https://github.com/chaijs/chai/pull/1490 // https://github.com/chaijs/chai/pull/1490
(_, key) => objDisplay(objectAttr(items[0], key), { truncate: runner?.config?.chaiConfig?.truncateThreshold }) as unknown as string, (_, key) =>
objDisplay(objectAttr(items[0], key), {
truncate: runner?.config?.chaiConfig?.truncateThreshold,
}) as unknown as string,
) )
} }
return formatted return formatted
} }
function formatTemplateString(cases: any[], args: any[]): any[] { function formatTemplateString(cases: any[], args: any[]): any[] {
const header = cases.join('').trim().replace(/ /g, '').split('\n').map(i => i.split('|'))[0] const header = cases
.join('')
.trim()
.replace(/ /g, '')
.split('\n')
.map(i => i.split('|'))[0]
const res: any[] = [] const res: any[] = []
for (let i = 0; i < Math.floor((args.length) / header.length); i++) { for (let i = 0; i < Math.floor(args.length / header.length); i++) {
const oneCase: Record<string, any> = {} const oneCase: Record<string, any> = {}
for (let j = 0; j < header.length; j++) for (let j = 0; j < header.length; j++) {
oneCase[header[j]] = args[i * header.length + j] as any oneCase[header[j]] = args[i * header.length + j] as any
}
res.push(oneCase) res.push(oneCase)
} }
return res return res

View File

@ -40,13 +40,13 @@ export interface VitestRunnerConfig {
export type VitestRunnerImportSource = 'collect' | 'setup' export type VitestRunnerImportSource = 'collect' | 'setup'
export interface VitestRunnerConstructor { export interface VitestRunnerConstructor {
new(config: VitestRunnerConfig): VitestRunner new (config: VitestRunnerConfig): VitestRunner
} }
export type CancelReason = export type CancelReason =
| 'keyboard-input' | 'keyboard-input'
| 'test-failure' | 'test-failure'
| string & Record<string, never> | (string & Record<string, never>)
export interface VitestRunner { export interface VitestRunner {
/** /**
@ -76,7 +76,10 @@ export interface VitestRunner {
/** /**
* Called before actually running the test function. Already has "result" with "state" and "startTime". * Called before actually running the test function. Already has "result" with "state" and "startTime".
*/ */
onBeforeTryTask?: (test: Task, options: { retry: number; repeats: number }) => unknown onBeforeTryTask?: (
test: Task,
options: { retry: number; repeats: number }
) => unknown
/** /**
* Called after result and state are set. * Called after result and state are set.
*/ */

View File

@ -50,7 +50,11 @@ export interface TaskResult {
repeatCount?: number repeatCount?: number
} }
export type TaskResultPack = [id: string, result: TaskResult | undefined, meta: TaskMeta] export type TaskResultPack = [
id: string,
result: TaskResult | undefined,
meta: TaskMeta,
]
export interface Suite extends TaskBase { export interface Suite extends TaskBase {
file: File file: File
@ -78,7 +82,9 @@ export interface Custom<ExtraContext = {}> extends TaskPopulated {
export type Task = Test | Suite | Custom | File export type Task = Test | Suite | Custom | File
export type DoneCallback = (error?: any) => void export type DoneCallback = (error?: any) => void
export type TestFunction<ExtraContext = {}> = (context: ExtendedContext<Test> & ExtraContext) => Awaitable<any> | void export type TestFunction<ExtraContext = {}> = (
context: ExtendedContext<Test> & ExtraContext
) => Awaitable<any> | void
// jest's ExtractEachCallbackArgs // jest's ExtractEachCallbackArgs
type ExtractEachCallbackArgs<T extends ReadonlyArray<any>> = { type ExtractEachCallbackArgs<T extends ReadonlyArray<any>> = {
@ -122,23 +128,25 @@ interface EachFunctionReturn<T extends any[]> {
( (
name: string | Function, name: string | Function,
fn: (...args: T) => Awaitable<void>, fn: (...args: T) => Awaitable<void>,
options: TestOptions, options: TestOptions
): void ): void
( (
name: string | Function, name: string | Function,
fn: (...args: T) => Awaitable<void>, fn: (...args: T) => Awaitable<void>,
options?: number | TestOptions, options?: number | TestOptions
): void ): void
( (
name: string | Function, name: string | Function,
options: TestOptions, options: TestOptions,
fn: (...args: T) => Awaitable<void>, fn: (...args: T) => Awaitable<void>
): void ): void
} }
interface TestEachFunction { interface TestEachFunction {
<T extends any[] | [any]>(cases: ReadonlyArray<T>): EachFunctionReturn<T> <T extends any[] | [any]>(cases: ReadonlyArray<T>): EachFunctionReturn<T>
<T extends ReadonlyArray<any>>(cases: ReadonlyArray<T>): EachFunctionReturn<ExtractEachCallbackArgs<T>> <T extends ReadonlyArray<any>>(cases: ReadonlyArray<T>): EachFunctionReturn<
ExtractEachCallbackArgs<T>
>
<T>(cases: ReadonlyArray<T>): EachFunctionReturn<T[]> <T>(cases: ReadonlyArray<T>): EachFunctionReturn<T[]>
(...args: [TemplateStringsArray, ...any]): EachFunctionReturn<any[]> (...args: [TemplateStringsArray, ...any]): EachFunctionReturn<any[]>
} }
@ -146,35 +154,53 @@ interface TestEachFunction {
interface TestForFunctionReturn<Arg, Context> { interface TestForFunctionReturn<Arg, Context> {
( (
name: string | Function, name: string | Function,
fn: (arg: Arg, context: Context) => Awaitable<void>, fn: (arg: Arg, context: Context) => Awaitable<void>
): void ): void
( (
name: string | Function, name: string | Function,
options: TestOptions, options: TestOptions,
fn: (args: Arg, context: Context) => Awaitable<void>, fn: (args: Arg, context: Context) => Awaitable<void>
): void ): void
} }
interface TestForFunction<ExtraContext> { interface TestForFunction<ExtraContext> {
// test.for([1, 2, 3]) // test.for([1, 2, 3])
// test.for([[1, 2], [3, 4, 5]]) // test.for([[1, 2], [3, 4, 5]])
<T>(cases: ReadonlyArray<T>): TestForFunctionReturn<T, ExtendedContext<Test> & ExtraContext> <T>(cases: ReadonlyArray<T>): TestForFunctionReturn<
T,
ExtendedContext<Test> & ExtraContext
>
// test.for` // test.for`
// a | b // a | b
// {1} | {2} // {1} | {2}
// {3} | {4} // {3} | {4}
// ` // `
(strings: TemplateStringsArray, ...values: any[]): TestForFunctionReturn<any, ExtendedContext<Test> & ExtraContext> (strings: TemplateStringsArray, ...values: any[]): TestForFunctionReturn<
any,
ExtendedContext<Test> & ExtraContext
>
} }
interface TestCollectorCallable<C = {}> { interface TestCollectorCallable<C = {}> {
/** /**
* @deprecated Use options as the second argument instead * @deprecated Use options as the second argument instead
*/ */
<ExtraContext extends C>(name: string | Function, fn: TestFunction<ExtraContext>, options: TestOptions): void <ExtraContext extends C>(
<ExtraContext extends C>(name: string | Function, fn?: TestFunction<ExtraContext>, options?: number | TestOptions): void name: string | Function,
<ExtraContext extends C>(name: string | Function, options?: TestOptions, fn?: TestFunction<ExtraContext>): void fn: TestFunction<ExtraContext>,
options: TestOptions
): void
<ExtraContext extends C>(
name: string | Function,
fn?: TestFunction<ExtraContext>,
options?: number | TestOptions
): void
<ExtraContext extends C>(
name: string | Function,
options?: TestOptions,
fn?: TestFunction<ExtraContext>
): void
} }
type ChainableTestAPI<ExtraContext = {}> = ChainableFunction< type ChainableTestAPI<ExtraContext = {}> = ChainableFunction<
@ -238,19 +264,31 @@ interface ExtendedAPI<ExtraContext> {
runIf: (condition: any) => ChainableTestAPI<ExtraContext> runIf: (condition: any) => ChainableTestAPI<ExtraContext>
} }
export type CustomAPI<ExtraContext = {}> = ChainableTestAPI<ExtraContext> & ExtendedAPI<ExtraContext> & { export type CustomAPI<ExtraContext = {}> = ChainableTestAPI<ExtraContext> &
extend: <T extends Record<string, any> = {}>(fixtures: Fixtures<T, ExtraContext>) => CustomAPI<{ ExtendedAPI<ExtraContext> & {
[K in keyof T | keyof ExtraContext]: extend: <T extends Record<string, any> = {}>(
K extends keyof T ? T[K] : fixtures: Fixtures<T, ExtraContext>
K extends keyof ExtraContext ? ExtraContext[K] : never }> ) => CustomAPI<{
} [K in keyof T | keyof ExtraContext]: K extends keyof T
? T[K]
: K extends keyof ExtraContext
? ExtraContext[K]
: never;
}>
}
export type TestAPI<ExtraContext = {}> = ChainableTestAPI<ExtraContext> & ExtendedAPI<ExtraContext> & { export type TestAPI<ExtraContext = {}> = ChainableTestAPI<ExtraContext> &
extend: <T extends Record<string, any> = {}>(fixtures: Fixtures<T, ExtraContext>) => TestAPI<{ ExtendedAPI<ExtraContext> & {
[K in keyof T | keyof ExtraContext]: extend: <T extends Record<string, any> = {}>(
K extends keyof T ? T[K] : fixtures: Fixtures<T, ExtraContext>
K extends keyof ExtraContext ? ExtraContext[K] : never }> ) => TestAPI<{
} [K in keyof T | keyof ExtraContext]: K extends keyof T
? T[K]
: K extends keyof ExtraContext
? ExtraContext[K]
: never;
}>
}
export interface FixtureOptions { export interface FixtureOptions {
/** /**
@ -260,14 +298,25 @@ export interface FixtureOptions {
} }
export type Use<T> = (value: T) => Promise<void> export type Use<T> = (value: T) => Promise<void>
export type FixtureFn<T, K extends keyof T, ExtraContext> = export type FixtureFn<T, K extends keyof T, ExtraContext> = (
(context: Omit<T, K> & ExtraContext, use: Use<T[K]>) => Promise<void> context: Omit<T, K> & ExtraContext,
export type Fixture<T, K extends keyof T, ExtraContext = {}> = use: Use<T[K]>
((...args: any) => any) extends T[K] ) => Promise<void>
? (T[K] extends any ? FixtureFn<T, K, Omit<ExtraContext, Exclude<keyof T, K>>> : never) export type Fixture<T, K extends keyof T, ExtraContext = {}> = ((
: T[K] | (T[K] extends any ? FixtureFn<T, K, Omit<ExtraContext, Exclude<keyof T, K>>> : never) ...args: any
) => any) extends T[K]
? T[K] extends any
? FixtureFn<T, K, Omit<ExtraContext, Exclude<keyof T, K>>>
: never
:
| T[K]
| (T[K] extends any
? FixtureFn<T, K, Omit<ExtraContext, Exclude<keyof T, K>>>
: never)
export type Fixtures<T extends Record<string, any>, ExtraContext = {}> = { export type Fixtures<T extends Record<string, any>, ExtraContext = {}> = {
[K in keyof T]: Fixture<T, K, ExtraContext & ExtendedContext<Test>> | [Fixture<T, K, ExtraContext & ExtendedContext<Test>>, FixtureOptions?] [K in keyof T]:
| Fixture<T, K, ExtraContext & ExtendedContext<Test>>
| [Fixture<T, K, ExtraContext & ExtendedContext<Test>>, FixtureOptions?];
} }
export type InferFixturesTypes<T> = T extends TestAPI<infer C> ? C : T export type InferFixturesTypes<T> = T extends TestAPI<infer C> ? C : T
@ -276,9 +325,21 @@ interface SuiteCollectorCallable<ExtraContext = {}> {
/** /**
* @deprecated Use options as the second argument instead * @deprecated Use options as the second argument instead
*/ */
<OverrideExtraContext extends ExtraContext = ExtraContext>(name: string | Function, fn: SuiteFactory<OverrideExtraContext>, options: TestOptions): SuiteCollector<OverrideExtraContext> <OverrideExtraContext extends ExtraContext = ExtraContext>(
<OverrideExtraContext extends ExtraContext = ExtraContext>(name: string | Function, fn?: SuiteFactory<OverrideExtraContext>, options?: number | TestOptions): SuiteCollector<OverrideExtraContext> name: string | Function,
<OverrideExtraContext extends ExtraContext = ExtraContext>(name: string | Function, options: TestOptions, fn?: SuiteFactory<OverrideExtraContext>): SuiteCollector<OverrideExtraContext> fn: SuiteFactory<OverrideExtraContext>,
options: TestOptions
): SuiteCollector<OverrideExtraContext>
<OverrideExtraContext extends ExtraContext = ExtraContext>(
name: string | Function,
fn?: SuiteFactory<OverrideExtraContext>,
options?: number | TestOptions
): SuiteCollector<OverrideExtraContext>
<OverrideExtraContext extends ExtraContext = ExtraContext>(
name: string | Function,
options: TestOptions,
fn?: SuiteFactory<OverrideExtraContext>
): SuiteCollector<OverrideExtraContext>
} }
type ChainableSuiteAPI<ExtraContext = {}> = ChainableFunction< type ChainableSuiteAPI<ExtraContext = {}> = ChainableFunction<
@ -294,15 +355,22 @@ export type SuiteAPI<ExtraContext = {}> = ChainableSuiteAPI<ExtraContext> & {
runIf: (condition: any) => ChainableSuiteAPI<ExtraContext> runIf: (condition: any) => ChainableSuiteAPI<ExtraContext>
} }
export type HookListener<T extends any[], Return = void> = (...args: T) => Awaitable<Return> export type HookListener<T extends any[], Return = void> = (
...args: T
) => Awaitable<Return>
export type HookCleanupCallback = (() => Awaitable<unknown>) | void export type HookCleanupCallback = (() => Awaitable<unknown>) | void
export interface SuiteHooks<ExtraContext = {}> { export interface SuiteHooks<ExtraContext = {}> {
beforeAll: HookListener<[Readonly<Suite | File>], HookCleanupCallback>[] beforeAll: HookListener<[Readonly<Suite | File>], HookCleanupCallback>[]
afterAll: HookListener<[Readonly<Suite | File>]>[] afterAll: HookListener<[Readonly<Suite | File>]>[]
beforeEach: HookListener<[ExtendedContext<Test | Custom> & ExtraContext, Readonly<Suite>], HookCleanupCallback>[] beforeEach: HookListener<
afterEach: HookListener<[ExtendedContext<Test | Custom> & ExtraContext, Readonly<Suite>]>[] [ExtendedContext<Test | Custom> & ExtraContext, Readonly<Suite>],
HookCleanupCallback
>[]
afterEach: HookListener<
[ExtendedContext<Test | Custom> & ExtraContext, Readonly<Suite>]
>[]
} }
export interface TaskCustomOptions extends TestOptions { export interface TaskCustomOptions extends TestOptions {
@ -324,14 +392,24 @@ export interface SuiteCollector<ExtraContext = {}> {
options?: TestOptions options?: TestOptions
type: 'collector' type: 'collector'
test: TestAPI<ExtraContext> test: TestAPI<ExtraContext>
tasks: (Suite | Custom<ExtraContext> | Test<ExtraContext> | SuiteCollector<ExtraContext>)[] tasks: (
| Suite
| Custom<ExtraContext>
| Test<ExtraContext>
| SuiteCollector<ExtraContext>
)[]
task: (name: string, options?: TaskCustomOptions) => Custom<ExtraContext> task: (name: string, options?: TaskCustomOptions) => Custom<ExtraContext>
collect: (file: File) => Promise<Suite> collect: (file: File) => Promise<Suite>
clear: () => void clear: () => void
on: <T extends keyof SuiteHooks<ExtraContext>>(name: T, ...fn: SuiteHooks<ExtraContext>[T]) => void on: <T extends keyof SuiteHooks<ExtraContext>>(
name: T,
...fn: SuiteHooks<ExtraContext>[T]
) => void
} }
export type SuiteFactory<ExtraContext = {}> = (test: TestAPI<ExtraContext>) => Awaitable<void> export type SuiteFactory<ExtraContext = {}> = (
test: TestAPI<ExtraContext>
) => Awaitable<void>
export interface RuntimeContext { export interface RuntimeContext {
tasks: (SuiteCollector | Test)[] tasks: (SuiteCollector | Test)[]
@ -362,7 +440,8 @@ export interface TaskContext<Task extends Custom | Test = Custom | Test> {
skip: () => void skip: () => void
} }
export type ExtendedContext<T extends Custom | Test> = TaskContext<T> & TestContext export type ExtendedContext<T extends Custom | Test> = TaskContext<T> &
TestContext
export type OnTestFailedHandler = (result: TaskResult) => Awaitable<void> export type OnTestFailedHandler = (result: TaskResult) => Awaitable<void>
export type OnTestFinishedHandler = (result: TaskResult) => Awaitable<void> export type OnTestFinishedHandler = (result: TaskResult) => Awaitable<void>

View File

@ -1,5 +1,9 @@
export type ChainableFunction<T extends string, F extends (...args: any) => any, C = {}> = F & { export type ChainableFunction<
[x in T]: ChainableFunction<T, F, C> T extends string,
F extends (...args: any) => any,
C = {},
> = F & {
[x in T]: ChainableFunction<T, F, C>;
} & { } & {
fn: (this: Record<T, any>, ...args: Parameters<F>) => ReturnType<F> fn: (this: Record<T, any>, ...args: Parameters<F>) => ReturnType<F>
} & C } & C

View File

@ -5,7 +5,13 @@ import type { File, Suite, TaskBase } from '../types'
/** /**
* If any tasks been marked as `only`, mark all other tasks as `skip`. * If any tasks been marked as `only`, mark all other tasks as `skip`.
*/ */
export function interpretTaskModes(suite: Suite, namePattern?: string | RegExp, onlyMode?: boolean, parentIsOnly?: boolean, allowOnly?: boolean) { export function interpretTaskModes(
suite: Suite,
namePattern?: string | RegExp,
onlyMode?: boolean,
parentIsOnly?: boolean,
allowOnly?: boolean,
) {
const suiteIsOnly = parentIsOnly || suite.mode === 'only' const suiteIsOnly = parentIsOnly || suite.mode === 'only'
suite.tasks.forEach((t) => { suite.tasks.forEach((t) => {
@ -28,21 +34,25 @@ export function interpretTaskModes(suite: Suite, namePattern?: string | RegExp,
} }
} }
if (t.type === 'test') { if (t.type === 'test') {
if (namePattern && !getTaskFullName(t).match(namePattern)) if (namePattern && !getTaskFullName(t).match(namePattern)) {
t.mode = 'skip' t.mode = 'skip'
}
} }
else if (t.type === 'suite') { else if (t.type === 'suite') {
if (t.mode === 'skip') if (t.mode === 'skip') {
skipAllTasks(t) skipAllTasks(t)
else }
else {
interpretTaskModes(t, namePattern, onlyMode, includeTask, allowOnly) interpretTaskModes(t, namePattern, onlyMode, includeTask, allowOnly)
}
} }
}) })
// if all subtasks are skipped, mark as skip // if all subtasks are skipped, mark as skip
if (suite.mode === 'run') { if (suite.mode === 'run') {
if (suite.tasks.length && suite.tasks.every(i => i.mode !== 'run')) if (suite.tasks.length && suite.tasks.every(i => i.mode !== 'run')) {
suite.mode = 'skip' suite.mode = 'skip'
}
} }
} }
@ -51,23 +61,31 @@ function getTaskFullName(task: TaskBase): string {
} }
export function someTasksAreOnly(suite: Suite): boolean { export function someTasksAreOnly(suite: Suite): boolean {
return suite.tasks.some(t => t.mode === 'only' || (t.type === 'suite' && someTasksAreOnly(t))) return suite.tasks.some(
t => t.mode === 'only' || (t.type === 'suite' && someTasksAreOnly(t)),
)
} }
function skipAllTasks(suite: Suite) { function skipAllTasks(suite: Suite) {
suite.tasks.forEach((t) => { suite.tasks.forEach((t) => {
if (t.mode === 'run') { if (t.mode === 'run') {
t.mode = 'skip' t.mode = 'skip'
if (t.type === 'suite') if (t.type === 'suite') {
skipAllTasks(t) skipAllTasks(t)
}
} }
}) })
} }
function checkAllowOnly(task: TaskBase, allowOnly?: boolean) { function checkAllowOnly(task: TaskBase, allowOnly?: boolean) {
if (allowOnly) if (allowOnly) {
return return
const error = processError(new Error('[Vitest] Unexpected .only modifier. Remove it or pass --allowOnly argument to bypass this error')) }
const error = processError(
new Error(
'[Vitest] Unexpected .only modifier. Remove it or pass --allowOnly argument to bypass this error',
),
)
task.result = { task.result = {
state: 'fail', state: 'fail',
errors: [error], errors: [error],
@ -76,8 +94,9 @@ function checkAllowOnly(task: TaskBase, allowOnly?: boolean) {
export function generateHash(str: string): string { export function generateHash(str: string): string {
let hash = 0 let hash = 0
if (str.length === 0) if (str.length === 0) {
return `${hash}` return `${hash}`
}
for (let i = 0; i < str.length; i++) { for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i) const char = str.charCodeAt(i)
hash = (hash << 5) - hash + char hash = (hash << 5) - hash + char
@ -89,12 +108,17 @@ export function generateHash(str: string): string {
export function calculateSuiteHash(parent: Suite) { export function calculateSuiteHash(parent: Suite) {
parent.tasks.forEach((t, idx) => { parent.tasks.forEach((t, idx) => {
t.id = `${parent.id}_${idx}` t.id = `${parent.id}_${idx}`
if (t.type === 'suite') if (t.type === 'suite') {
calculateSuiteHash(t) calculateSuiteHash(t)
}
}) })
} }
export function createFileTask(filepath: string, root: string, projectName: string) { export function createFileTask(
filepath: string,
root: string,
projectName: string,
) {
const path = relative(root, filepath) const path = relative(root, filepath)
const file: File = { const file: File = {
id: generateHash(`${path}${projectName || ''}`), id: generateHash(`${path}${projectName || ''}`),

View File

@ -15,8 +15,9 @@ export function partitionSuiteChildren(suite: Suite) {
tasksGroup = [c] tasksGroup = [c]
} }
} }
if (tasksGroup.length > 0) if (tasksGroup.length > 0) {
tasksGroups.push(tasksGroup) tasksGroups.push(tasksGroup)
}
return tasksGroups return tasksGroups
} }

View File

@ -19,8 +19,9 @@ export function getTests(suite: Arrayable<Task>): (Test | Custom)[] {
} }
else { else {
const taskTests = getTests(task) const taskTests = getTests(task)
for (const test of taskTests) for (const test of taskTests) {
tests.push(test) tests.push(test)
}
} }
} }
} }
@ -29,19 +30,28 @@ export function getTests(suite: Arrayable<Task>): (Test | Custom)[] {
} }
export function getTasks(tasks: Arrayable<Task> = []): Task[] { export function getTasks(tasks: Arrayable<Task> = []): Task[] {
return toArray(tasks).flatMap(s => isAtomTest(s) ? [s] : [s, ...getTasks(s.tasks)]) return toArray(tasks).flatMap(s =>
isAtomTest(s) ? [s] : [s, ...getTasks(s.tasks)],
)
} }
export function getSuites(suite: Arrayable<Task>): Suite[] { export function getSuites(suite: Arrayable<Task>): Suite[] {
return toArray(suite).flatMap(s => s.type === 'suite' ? [s, ...getSuites(s.tasks)] : []) return toArray(suite).flatMap(s =>
s.type === 'suite' ? [s, ...getSuites(s.tasks)] : [],
)
} }
export function hasTests(suite: Arrayable<Suite>): boolean { export function hasTests(suite: Arrayable<Suite>): boolean {
return toArray(suite).some(s => s.tasks.some(c => isAtomTest(c) || hasTests(c))) return toArray(suite).some(s =>
s.tasks.some(c => isAtomTest(c) || hasTests(c)),
)
} }
export function hasFailed(suite: Arrayable<Task>): boolean { export function hasFailed(suite: Arrayable<Task>): boolean {
return toArray(suite).some(s => s.result?.state === 'fail' || (s.type === 'suite' && hasFailed(s.tasks))) return toArray(suite).some(
s =>
s.result?.state === 'fail' || (s.type === 'suite' && hasFailed(s.tasks)),
)
} }
export function getNames(task: Task) { export function getNames(task: Task) {
@ -50,12 +60,14 @@ export function getNames(task: Task) {
while (current?.suite) { while (current?.suite) {
current = current.suite current = current.suite
if (current?.name) if (current?.name) {
names.unshift(current.name) names.unshift(current.name)
}
} }
if (current !== task.file) if (current !== task.file) {
names.unshift(task.file.name) names.unshift(task.file.name)
}
return names return names
} }

View File

@ -12,7 +12,8 @@ import { SnapshotManager } from '@vitest/snapshot/manager'
const client = new SnapshotClient({ const client = new SnapshotClient({
// you need to provide your own equality check implementation if you use it // you need to provide your own equality check implementation if you use it
// this function is called when `.toMatchSnapshot({ property: 1 })` is called // this function is called when `.toMatchSnapshot({ property: 1 })` is called
isEqual: (received, expected) => equals(received, expected, [iterableEquality, subsetEquality]), isEqual: (received, expected) =>
equals(received, expected, [iterableEquality, subsetEquality]),
}) })
// class that implements snapshot saving and reading // class that implements snapshot saving and reading
@ -53,7 +54,11 @@ const options = {
snapshotEnvironment: environment, snapshotEnvironment: environment,
} }
await client.startCurrentRun(getCurrentFilepath(), getCurrentTestName(), options) await client.startCurrentRun(
getCurrentFilepath(),
getCurrentTestName(),
options
)
// this will save snapshot to a file which is returned by "snapshotEnvironment.resolvePath" // this will save snapshot to a file which is returned by "snapshotEnvironment.resolvePath"
client.assert({ client.assert({

View File

@ -51,15 +51,14 @@ export default defineConfig([
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
onwarn, onwarn,
}, },
]) ])
function onwarn(message) { function onwarn(message) {
if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) {
return return
}
console.error(message) console.error(message)
} }

View File

@ -3,7 +3,12 @@ import SnapshotState from './port/state'
import type { SnapshotStateOptions } from './types' import type { SnapshotStateOptions } from './types'
import type { RawSnapshotInfo } from './port/rawSnapshot' import type { RawSnapshotInfo } from './port/rawSnapshot'
function createMismatchError(message: string, expand: boolean | undefined, actual: unknown, expected: unknown) { function createMismatchError(
message: string,
expand: boolean | undefined,
actual: unknown,
expected: unknown,
) {
const error = new Error(message) const error = new Error(message)
Object.defineProperty(error, 'actual', { Object.defineProperty(error, 'actual', {
value: actual, value: actual,
@ -52,7 +57,11 @@ export class SnapshotClient {
constructor(private options: SnapshotClientOptions = {}) {} constructor(private options: SnapshotClientOptions = {}) {}
async startCurrentRun(filepath: string, name: string, options: SnapshotStateOptions) { async startCurrentRun(
filepath: string,
name: string,
options: SnapshotStateOptions,
) {
this.filepath = filepath this.filepath = filepath
this.name = name this.name = name
@ -62,10 +71,7 @@ export class SnapshotClient {
if (!this.getSnapshotState(filepath)) { if (!this.getSnapshotState(filepath)) {
this.snapshotStateMap.set( this.snapshotStateMap.set(
filepath, filepath,
await SnapshotState.create( await SnapshotState.create(filepath, options),
filepath,
options,
),
) )
} }
this.snapshotState = this.getSnapshotState(filepath) this.snapshotState = this.getSnapshotState(filepath)
@ -99,20 +105,31 @@ export class SnapshotClient {
} = options } = options
let { received } = options let { received } = options
if (!filepath) if (!filepath) {
throw new Error('Snapshot cannot be used outside of test') throw new Error('Snapshot cannot be used outside of test')
}
if (typeof properties === 'object') { if (typeof properties === 'object') {
if (typeof received !== 'object' || !received) if (typeof received !== 'object' || !received) {
throw new Error('Received value must be an object when the matcher has properties') throw new Error(
'Received value must be an object when the matcher has properties',
)
}
try { try {
const pass = this.options.isEqual?.(received, properties) ?? false const pass = this.options.isEqual?.(received, properties) ?? false
// const pass = equals(received, properties, [iterableEquality, subsetEquality]) // const pass = equals(received, properties, [iterableEquality, subsetEquality])
if (!pass) if (!pass) {
throw createMismatchError('Snapshot properties mismatched', this.snapshotState?.expand, received, properties) throw createMismatchError(
else 'Snapshot properties mismatched',
this.snapshotState?.expand,
received,
properties,
)
}
else {
received = deepMergeSnapshot(received, properties) received = deepMergeSnapshot(received, properties)
}
} }
catch (err: any) { catch (err: any) {
err.message = errorMessage || 'Snapshot mismatched' err.message = errorMessage || 'Snapshot mismatched'
@ -120,10 +137,7 @@ export class SnapshotClient {
} }
} }
const testName = [ const testName = [name, ...(message ? [message] : [])].join(' > ')
name,
...(message ? [message] : []),
].join(' > ')
const snapshotState = this.getSnapshotState(filepath) const snapshotState = this.getSnapshotState(filepath)
@ -136,38 +150,49 @@ export class SnapshotClient {
rawSnapshot, rawSnapshot,
}) })
if (!pass) if (!pass) {
throw createMismatchError(`Snapshot \`${key || 'unknown'}\` mismatched`, this.snapshotState?.expand, actual?.trim(), expected?.trim()) throw createMismatchError(
`Snapshot \`${key || 'unknown'}\` mismatched`,
this.snapshotState?.expand,
actual?.trim(),
expected?.trim(),
)
}
} }
async assertRaw(options: AssertOptions): Promise<void> { async assertRaw(options: AssertOptions): Promise<void> {
if (!options.rawSnapshot) if (!options.rawSnapshot) {
throw new Error('Raw snapshot is required') throw new Error('Raw snapshot is required')
}
const { const { filepath = this.filepath, rawSnapshot } = options
filepath = this.filepath,
rawSnapshot,
} = options
if (rawSnapshot.content == null) { if (rawSnapshot.content == null) {
if (!filepath) if (!filepath) {
throw new Error('Snapshot cannot be used outside of test') throw new Error('Snapshot cannot be used outside of test')
}
const snapshotState = this.getSnapshotState(filepath) const snapshotState = this.getSnapshotState(filepath)
// save the filepath, so it don't lose even if the await make it out-of-context // save the filepath, so it don't lose even if the await make it out-of-context
options.filepath ||= filepath options.filepath ||= filepath
// resolve and read the raw snapshot file // resolve and read the raw snapshot file
rawSnapshot.file = await snapshotState.environment.resolveRawPath(filepath, rawSnapshot.file) rawSnapshot.file = await snapshotState.environment.resolveRawPath(
rawSnapshot.content = await snapshotState.environment.readSnapshotFile(rawSnapshot.file) ?? undefined filepath,
rawSnapshot.file,
)
rawSnapshot.content
= (await snapshotState.environment.readSnapshotFile(rawSnapshot.file))
?? undefined
} }
return this.assert(options) return this.assert(options)
} }
async finishCurrentRun() { async finishCurrentRun() {
if (!this.snapshotState) if (!this.snapshotState) {
return null return null
}
const result = await this.snapshotState.pack() const result = await this.snapshotState.pack()
this.snapshotState = undefined this.snapshotState = undefined

View File

@ -14,17 +14,12 @@ export class NodeSnapshotEnvironment implements SnapshotEnvironment {
} }
async resolveRawPath(testPath: string, rawPath: string) { async resolveRawPath(testPath: string, rawPath: string) {
return isAbsolute(rawPath) return isAbsolute(rawPath) ? rawPath : resolve(dirname(testPath), rawPath)
? rawPath
: resolve(dirname(testPath), rawPath)
} }
async resolvePath(filepath: string): Promise<string> { async resolvePath(filepath: string): Promise<string> {
return join( return join(
join( join(dirname(filepath), this.options.snapshotsDirName ?? '__snapshots__'),
dirname(filepath),
this.options.snapshotsDirName ?? '__snapshots__',
),
`${basename(filepath)}.snap`, `${basename(filepath)}.snap`,
) )
} }
@ -39,13 +34,15 @@ export class NodeSnapshotEnvironment implements SnapshotEnvironment {
} }
async readSnapshotFile(filepath: string): Promise<string | null> { async readSnapshotFile(filepath: string): Promise<string | null> {
if (!existsSync(filepath)) if (!existsSync(filepath)) {
return null return null
}
return fs.readFile(filepath, 'utf-8') return fs.readFile(filepath, 'utf-8')
} }
async removeSnapshotFile(filepath: string): Promise<void> { async removeSnapshotFile(filepath: string): Promise<void> {
if (existsSync(filepath)) if (existsSync(filepath)) {
await fs.unlink(filepath) await fs.unlink(filepath)
}
} }
} }

View File

@ -1,11 +1,17 @@
import { basename, dirname, isAbsolute, join, resolve } from 'pathe' import { basename, dirname, isAbsolute, join, resolve } from 'pathe'
import type { SnapshotResult, SnapshotStateOptions, SnapshotSummary } from './types' import type {
SnapshotResult,
SnapshotStateOptions,
SnapshotSummary,
} from './types'
export class SnapshotManager { export class SnapshotManager {
summary: SnapshotSummary = undefined! summary: SnapshotSummary = undefined!
extension = '.snap' extension = '.snap'
constructor(public options: Omit<SnapshotStateOptions, 'snapshotEnvironment'>) { constructor(
public options: Omit<SnapshotStateOptions, 'snapshotEnvironment'>,
) {
this.clear() this.clear()
} }
@ -18,28 +24,27 @@ export class SnapshotManager {
} }
resolvePath(testPath: string) { resolvePath(testPath: string) {
const resolver = this.options.resolveSnapshotPath || (() => { const resolver
return join( = this.options.resolveSnapshotPath
join( || (() => {
dirname(testPath), return join(
'__snapshots__', join(dirname(testPath), '__snapshots__'),
), `${basename(testPath)}${this.extension}`,
`${basename(testPath)}${this.extension}`, )
) })
})
const path = resolver(testPath, this.extension) const path = resolver(testPath, this.extension)
return path return path
} }
resolveRawPath(testPath: string, rawPath: string) { resolveRawPath(testPath: string, rawPath: string) {
return isAbsolute(rawPath) return isAbsolute(rawPath) ? rawPath : resolve(dirname(testPath), rawPath)
? rawPath
: resolve(dirname(testPath), rawPath)
} }
} }
export function emptySummary(options: Omit<SnapshotStateOptions, 'snapshotEnvironment'>): SnapshotSummary { export function emptySummary(
options: Omit<SnapshotStateOptions, 'snapshotEnvironment'>,
): SnapshotSummary {
const summary = { const summary = {
added: 0, added: 0,
failure: false, failure: false,
@ -59,15 +64,22 @@ export function emptySummary(options: Omit<SnapshotStateOptions, 'snapshotEnviro
return summary return summary
} }
export function addSnapshotResult(summary: SnapshotSummary, result: SnapshotResult): void { export function addSnapshotResult(
if (result.added) summary: SnapshotSummary,
result: SnapshotResult,
): void {
if (result.added) {
summary.filesAdded++ summary.filesAdded++
if (result.fileDeleted) }
if (result.fileDeleted) {
summary.filesRemoved++ summary.filesRemoved++
if (result.unmatched) }
if (result.unmatched) {
summary.filesUnmatched++ summary.filesUnmatched++
if (result.updated) }
if (result.updated) {
summary.filesUpdated++ summary.filesUpdated++
}
summary.added += result.added summary.added += result.added
summary.matched += result.matched summary.matched += result.matched
@ -81,5 +93,6 @@ export function addSnapshotResult(summary: SnapshotSummary, result: SnapshotResu
summary.unmatched += result.unmatched summary.unmatched += result.unmatched
summary.updated += result.updated summary.updated += result.updated
summary.total += result.added + result.matched + result.unmatched + result.updated summary.total
+= result.added + result.matched + result.unmatched + result.updated
} }

View File

@ -1,5 +1,10 @@
import type MagicString from 'magic-string' import type MagicString from 'magic-string'
import { getCallLastIndex, lineSplitRE, offsetToLineNumber, positionToOffset } from '../../../utils/src/index' import {
getCallLastIndex,
lineSplitRE,
offsetToLineNumber,
positionToOffset,
} from '../../../utils/src/index'
import type { SnapshotEnvironment } from '../types' import type { SnapshotEnvironment } from '../types'
export interface InlineSnapshot { export interface InlineSnapshot {
@ -15,35 +20,46 @@ export async function saveInlineSnapshots(
) { ) {
const MagicString = (await import('magic-string')).default const MagicString = (await import('magic-string')).default
const files = new Set(snapshots.map(i => i.file)) const files = new Set(snapshots.map(i => i.file))
await Promise.all(Array.from(files).map(async (file) => { await Promise.all(
const snaps = snapshots.filter(i => i.file === file) Array.from(files).map(async (file) => {
const code = await environment.readSnapshotFile(file) as string const snaps = snapshots.filter(i => i.file === file)
const s = new MagicString(code) const code = (await environment.readSnapshotFile(file)) as string
const s = new MagicString(code)
for (const snap of snaps) { for (const snap of snaps) {
const index = positionToOffset(code, snap.line, snap.column) const index = positionToOffset(code, snap.line, snap.column)
replaceInlineSnap(code, s, index, snap.snapshot) replaceInlineSnap(code, s, index, snap.snapshot)
} }
const transformed = s.toString() const transformed = s.toString()
if (transformed !== code) if (transformed !== code) {
await environment.saveSnapshotFile(file, transformed) await environment.saveSnapshotFile(file, transformed)
})) }
}),
)
} }
const startObjectRegex = /(?:toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot)\s*\(\s*(?:\/\*[\s\S]*\*\/\s*|\/\/.*(?:[\n\r\u2028\u2029]\s*|[\t\v\f \xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFEFF]))*\{/ const startObjectRegex
= /(?:toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot)\s*\(\s*(?:\/\*[\s\S]*\*\/\s*|\/\/.*(?:[\n\r\u2028\u2029]\s*|[\t\v\f \xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFEFF]))*\{/
function replaceObjectSnap(code: string, s: MagicString, index: number, newSnap: string) { function replaceObjectSnap(
code: string,
s: MagicString,
index: number,
newSnap: string,
) {
let _code = code.slice(index) let _code = code.slice(index)
const startMatch = startObjectRegex.exec(_code) const startMatch = startObjectRegex.exec(_code)
if (!startMatch) if (!startMatch) {
return false return false
}
_code = _code.slice(startMatch.index) _code = _code.slice(startMatch.index)
let callEnd = getCallLastIndex(_code) let callEnd = getCallLastIndex(_code)
if (callEnd === null) if (callEnd === null) {
return false return false
}
callEnd += index + startMatch.index callEnd += index + startMatch.index
const shapeStart = index + startMatch.index + startMatch[0].length const shapeStart = index + startMatch.index + startMatch[0].length
@ -67,10 +83,12 @@ function getObjectShapeEndIndex(code: string, index: number) {
let endBraces = 0 let endBraces = 0
while (startBraces !== endBraces && index < code.length) { while (startBraces !== endBraces && index < code.length) {
const s = code[index++] const s = code[index++]
if (s === '{') if (s === '{') {
startBraces++ startBraces++
else if (s === '}') }
else if (s === '}') {
endBraces++ endBraces++
}
} }
return index return index
} }
@ -81,28 +99,43 @@ function prepareSnapString(snap: string, source: string, index: number) {
const indent = line.match(/^\s*/)![0] || '' const indent = line.match(/^\s*/)![0] || ''
const indentNext = indent.includes('\t') ? `${indent}\t` : `${indent} ` const indentNext = indent.includes('\t') ? `${indent}\t` : `${indent} `
const lines = snap const lines = snap.trim().replace(/\\/g, '\\\\').split(/\n/g)
.trim()
.replace(/\\/g, '\\\\')
.split(/\n/g)
const isOneline = lines.length <= 1 const isOneline = lines.length <= 1
const quote = '`' const quote = '`'
if (isOneline) if (isOneline) {
return `${quote}${lines.join('\n').replace(/`/g, '\\`').replace(/\$\{/g, '\\${')}${quote}` return `${quote}${lines
return `${quote}\n${lines.map(i => i ? indentNext + i : '').join('\n').replace(/`/g, '\\`').replace(/\$\{/g, '\\${')}\n${indent}${quote}` .join('\n')
.replace(/`/g, '\\`')
.replace(/\$\{/g, '\\${')}${quote}`
}
return `${quote}\n${lines
.map(i => (i ? indentNext + i : ''))
.join('\n')
.replace(/`/g, '\\`')
.replace(/\$\{/g, '\\${')}\n${indent}${quote}`
} }
const startRegex = /(?:toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot)\s*\(\s*(?:\/\*[\s\S]*\*\/\s*|\/\/.*(?:[\n\r\u2028\u2029]\s*|[\t\v\f \xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFEFF]))*[\w$]*(['"`)])/ const startRegex
export function replaceInlineSnap(code: string, s: MagicString, index: number, newSnap: string) { = /(?:toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot)\s*\(\s*(?:\/\*[\s\S]*\*\/\s*|\/\/.*(?:[\n\r\u2028\u2029]\s*|[\t\v\f \xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFEFF]))*[\w$]*(['"`)])/
export function replaceInlineSnap(
code: string,
s: MagicString,
index: number,
newSnap: string,
) {
const codeStartingAtIndex = code.slice(index) const codeStartingAtIndex = code.slice(index)
const startMatch = startRegex.exec(codeStartingAtIndex) const startMatch = startRegex.exec(codeStartingAtIndex)
const firstKeywordMatch = /toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot/.exec(codeStartingAtIndex) const firstKeywordMatch
= /toMatchInlineSnapshot|toThrowErrorMatchingInlineSnapshot/.exec(
codeStartingAtIndex,
)
if (!startMatch || startMatch.index !== firstKeywordMatch?.index) if (!startMatch || startMatch.index !== firstKeywordMatch?.index) {
return replaceObjectSnap(code, s, index, newSnap) return replaceObjectSnap(code, s, index, newSnap)
}
const quote = startMatch[1] const quote = startMatch[1]
const startIndex = index + startMatch.index + startMatch[0].length const startIndex = index + startMatch.index + startMatch[0].length
@ -115,8 +148,9 @@ export function replaceInlineSnap(code: string, s: MagicString, index: number, n
const quoteEndRE = new RegExp(`(?:^|[^\\\\])${quote}`) const quoteEndRE = new RegExp(`(?:^|[^\\\\])${quote}`)
const endMatch = quoteEndRE.exec(code.slice(startIndex)) const endMatch = quoteEndRE.exec(code.slice(startIndex))
if (!endMatch) if (!endMatch) {
return false return false
}
const endIndex = startIndex + endMatch.index! + endMatch[0].length const endIndex = startIndex + endMatch.index! + endMatch[0].length
s.overwrite(startIndex - 1, endIndex, snapString) s.overwrite(startIndex - 1, endIndex, snapString)

View File

@ -24,21 +24,21 @@ export const serialize: NewPlugin['serialize'] = (
let callsString = '' let callsString = ''
if (val.mock.calls.length !== 0) { if (val.mock.calls.length !== 0) {
const indentationNext = indentation + config.indent const indentationNext = indentation + config.indent
callsString callsString = ` {${config.spacingOuter}${indentationNext}"calls": ${printer(
= ` {${ val.mock.calls,
config.spacingOuter config,
}${indentationNext indentationNext,
}"calls": ${ depth,
printer(val.mock.calls, config, indentationNext, depth, refs) refs,
}${config.min ? ', ' : ',' )}${config.min ? ', ' : ','}${
}${config.spacingOuter config.spacingOuter
}${indentationNext }${indentationNext}"results": ${printer(
}"results": ${ val.mock.results,
printer(val.mock.results, config, indentationNext, depth, refs) config,
}${config.min ? '' : ',' indentationNext,
}${config.spacingOuter depth,
}${indentation refs,
}}` )}${config.min ? '' : ','}${config.spacingOuter}${indentation}}`
} }
return `[MockFunction${nameString}]${callsString}` return `[MockFunction${nameString}]${callsString}`

View File

@ -9,9 +9,7 @@ import type {
Plugin as PrettyFormatPlugin, Plugin as PrettyFormatPlugin,
Plugins as PrettyFormatPlugins, Plugins as PrettyFormatPlugins,
} from 'pretty-format' } from 'pretty-format'
import { import { plugins as prettyFormatPlugins } from 'pretty-format'
plugins as prettyFormatPlugins,
} from 'pretty-format'
import MockSerializer from './mockSerializer' import MockSerializer from './mockSerializer'

View File

@ -15,8 +15,11 @@ export async function saveRawSnapshots(
environment: SnapshotEnvironment, environment: SnapshotEnvironment,
snapshots: Array<RawSnapshot>, snapshots: Array<RawSnapshot>,
) { ) {
await Promise.all(snapshots.map(async (snap) => { await Promise.all(
if (!snap.readonly) snapshots.map(async (snap) => {
await environment.saveSnapshotFile(snap.file, snap.snapshot) if (!snap.readonly) {
})) await environment.saveSnapshotFile(snap.file, snap.snapshot)
}
}),
)
} }

View File

@ -8,7 +8,14 @@
import type { OptionsReceived as PrettyFormatOptions } from 'pretty-format' import type { OptionsReceived as PrettyFormatOptions } from 'pretty-format'
import type { ParsedStack } from '../../../utils/src/index' import type { ParsedStack } from '../../../utils/src/index'
import { parseErrorStacktrace } from '../../../utils/src/source-map' import { parseErrorStacktrace } from '../../../utils/src/source-map'
import type { SnapshotData, SnapshotEnvironment, SnapshotMatchOptions, SnapshotResult, SnapshotStateOptions, SnapshotUpdateState } from '../types' import type {
SnapshotData,
SnapshotEnvironment,
SnapshotMatchOptions,
SnapshotResult,
SnapshotStateOptions,
SnapshotUpdateState,
} from '../types'
import type { InlineSnapshot } from './inlineSnapshot' import type { InlineSnapshot } from './inlineSnapshot'
import { saveInlineSnapshots } from './inlineSnapshot' import { saveInlineSnapshots } from './inlineSnapshot'
import type { RawSnapshot, RawSnapshotInfo } from './rawSnapshot' import type { RawSnapshot, RawSnapshotInfo } from './rawSnapshot'
@ -64,10 +71,7 @@ export default class SnapshotState {
snapshotContent: string | null, snapshotContent: string | null,
options: SnapshotStateOptions, options: SnapshotStateOptions,
) { ) {
const { data, dirty } = getSnapshotData( const { data, dirty } = getSnapshotData(snapshotContent, options)
snapshotContent,
options,
)
this._fileExists = snapshotContent != null // TODO: update on watch? this._fileExists = snapshotContent != null // TODO: update on watch?
this._initialData = data this._initialData = data
this._snapshotData = data this._snapshotData = data
@ -90,12 +94,13 @@ export default class SnapshotState {
this._environment = options.snapshotEnvironment this._environment = options.snapshotEnvironment
} }
static async create( static async create(testFilePath: string, options: SnapshotStateOptions) {
testFilePath: string, const snapshotPath = await options.snapshotEnvironment.resolvePath(
options: SnapshotStateOptions, testFilePath,
) { )
const snapshotPath = await options.snapshotEnvironment.resolvePath(testFilePath) const content = await options.snapshotEnvironment.readSnapshotFile(
const content = await options.snapshotEnvironment.readSnapshotFile(snapshotPath) snapshotPath,
)
return new SnapshotState(testFilePath, snapshotPath, content, options) return new SnapshotState(testFilePath, snapshotPath, content, options)
} }
@ -105,20 +110,26 @@ export default class SnapshotState {
markSnapshotsAsCheckedForTest(testName: string): void { markSnapshotsAsCheckedForTest(testName: string): void {
this._uncheckedKeys.forEach((uncheckedKey) => { this._uncheckedKeys.forEach((uncheckedKey) => {
if (keyToTestName(uncheckedKey) === testName) if (keyToTestName(uncheckedKey) === testName) {
this._uncheckedKeys.delete(uncheckedKey) this._uncheckedKeys.delete(uncheckedKey)
}
}) })
} }
protected _inferInlineSnapshotStack(stacks: ParsedStack[]) { protected _inferInlineSnapshotStack(stacks: ParsedStack[]) {
// if called inside resolves/rejects, stacktrace is different // if called inside resolves/rejects, stacktrace is different
const promiseIndex = stacks.findIndex(i => i.method.match(/__VITEST_(RESOLVES|REJECTS)__/)) const promiseIndex = stacks.findIndex(i =>
if (promiseIndex !== -1) i.method.match(/__VITEST_(RESOLVES|REJECTS)__/),
)
if (promiseIndex !== -1) {
return stacks[promiseIndex + 3] return stacks[promiseIndex + 3]
}
// inline snapshot function is called __INLINE_SNAPSHOT__ // inline snapshot function is called __INLINE_SNAPSHOT__
// in integrations/snapshot/chai.ts // in integrations/snapshot/chai.ts
const stackIndex = stacks.findIndex(i => i.method.includes('__INLINE_SNAPSHOT__')) const stackIndex = stacks.findIndex(i =>
i.method.includes('__INLINE_SNAPSHOT__'),
)
return stackIndex !== -1 ? stacks[stackIndex + 2] : null return stackIndex !== -1 ? stacks[stackIndex + 2] : null
} }
@ -129,11 +140,16 @@ export default class SnapshotState {
): void { ): void {
this._dirty = true this._dirty = true
if (options.isInline) { if (options.isInline) {
const stacks = parseErrorStacktrace(options.error || new Error('snapshot'), { ignoreStackEntries: [] }) const stacks = parseErrorStacktrace(
options.error || new Error('snapshot'),
{ ignoreStackEntries: [] },
)
const stack = this._inferInlineSnapshotStack(stacks) const stack = this._inferInlineSnapshotStack(stacks)
if (!stack) { if (!stack) {
throw new Error( throw new Error(
`@vitest/snapshot: Couldn't infer stack frame for inline snapshot.\n${JSON.stringify(stacks)}`, `@vitest/snapshot: Couldn't infer stack frame for inline snapshot.\n${JSON.stringify(
stacks,
)}`,
) )
} }
// removing 1 column, because source map points to the wrong // removing 1 column, because source map points to the wrong
@ -171,7 +187,8 @@ export default class SnapshotState {
const hasExternalSnapshots = Object.keys(this._snapshotData).length const hasExternalSnapshots = Object.keys(this._snapshotData).length
const hasInlineSnapshots = this._inlineSnapshots.length const hasInlineSnapshots = this._inlineSnapshots.length
const hasRawSnapshots = this._rawSnapshots.length const hasRawSnapshots = this._rawSnapshots.length
const isEmpty = !hasExternalSnapshots && !hasInlineSnapshots && !hasRawSnapshots const isEmpty
= !hasExternalSnapshots && !hasInlineSnapshots && !hasRawSnapshots
const status: SaveStatus = { const status: SaveStatus = {
deleted: false, deleted: false,
@ -180,13 +197,19 @@ export default class SnapshotState {
if ((this._dirty || this._uncheckedKeys.size) && !isEmpty) { if ((this._dirty || this._uncheckedKeys.size) && !isEmpty) {
if (hasExternalSnapshots) { if (hasExternalSnapshots) {
await saveSnapshotFile(this._environment, this._snapshotData, this.snapshotPath) await saveSnapshotFile(
this._environment,
this._snapshotData,
this.snapshotPath,
)
this._fileExists = true this._fileExists = true
} }
if (hasInlineSnapshots) if (hasInlineSnapshots) {
await saveInlineSnapshots(this._environment, this._inlineSnapshots) await saveInlineSnapshots(this._environment, this._inlineSnapshots)
if (hasRawSnapshots) }
if (hasRawSnapshots) {
await saveRawSnapshots(this._environment, this._rawSnapshots) await saveRawSnapshots(this._environment, this._rawSnapshots)
}
status.saved = true status.saved = true
} }
@ -230,26 +253,35 @@ export default class SnapshotState {
this._counters.set(testName, (this._counters.get(testName) || 0) + 1) this._counters.set(testName, (this._counters.get(testName) || 0) + 1)
const count = Number(this._counters.get(testName)) const count = Number(this._counters.get(testName))
if (!key) if (!key) {
key = testNameToKey(testName, count) key = testNameToKey(testName, count)
}
// Do not mark the snapshot as "checked" if the snapshot is inline and // Do not mark the snapshot as "checked" if the snapshot is inline and
// there's an external snapshot. This way the external snapshot can be // there's an external snapshot. This way the external snapshot can be
// removed with `--updateSnapshot`. // removed with `--updateSnapshot`.
if (!(isInline && this._snapshotData[key] !== undefined)) if (!(isInline && this._snapshotData[key] !== undefined)) {
this._uncheckedKeys.delete(key) this._uncheckedKeys.delete(key)
}
let receivedSerialized = (rawSnapshot && typeof received === 'string') let receivedSerialized
? received as string = rawSnapshot && typeof received === 'string'
: serialize(received, undefined, this._snapshotFormat) ? (received as string)
: serialize(received, undefined, this._snapshotFormat)
if (!rawSnapshot) if (!rawSnapshot) {
receivedSerialized = addExtraLineBreaks(receivedSerialized) receivedSerialized = addExtraLineBreaks(receivedSerialized)
}
if (rawSnapshot) { if (rawSnapshot) {
// normalize EOL when snapshot contains CRLF but received is LF // normalize EOL when snapshot contains CRLF but received is LF
if (rawSnapshot.content && rawSnapshot.content.match(/\r\n/) && !receivedSerialized.match(/\r\n/)) if (
rawSnapshot.content
&& rawSnapshot.content.match(/\r\n/)
&& !receivedSerialized.match(/\r\n/)
) {
rawSnapshot.content = normalizeNewlines(rawSnapshot.content) rawSnapshot.content = normalizeNewlines(rawSnapshot.content)
}
} }
const expected = isInline const expected = isInline
@ -260,7 +292,10 @@ export default class SnapshotState {
const expectedTrimmed = prepareExpected(expected) const expectedTrimmed = prepareExpected(expected)
const pass = expectedTrimmed === prepareExpected(receivedSerialized) const pass = expectedTrimmed === prepareExpected(receivedSerialized)
const hasSnapshot = expected !== undefined const hasSnapshot = expected !== undefined
const snapshotIsPersisted = isInline || this._fileExists || (rawSnapshot && rawSnapshot.content != null) const snapshotIsPersisted
= isInline
|| this._fileExists
|| (rawSnapshot && rawSnapshot.content != null)
if (pass && !isInline && !rawSnapshot) { if (pass && !isInline && !rawSnapshot) {
// Executing a snapshot file as JavaScript and writing the strings back // Executing a snapshot file as JavaScript and writing the strings back
@ -286,19 +321,29 @@ export default class SnapshotState {
) { ) {
if (this._updateSnapshot === 'all') { if (this._updateSnapshot === 'all') {
if (!pass) { if (!pass) {
if (hasSnapshot) if (hasSnapshot) {
this.updated++ this.updated++
else }
else {
this.added++ this.added++
}
this._addSnapshot(key, receivedSerialized, { error, isInline, rawSnapshot }) this._addSnapshot(key, receivedSerialized, {
error,
isInline,
rawSnapshot,
})
} }
else { else {
this.matched++ this.matched++
} }
} }
else { else {
this._addSnapshot(key, receivedSerialized, { error, isInline, rawSnapshot }) this._addSnapshot(key, receivedSerialized, {
error,
isInline,
rawSnapshot,
})
this.added++ this.added++
} }
@ -317,9 +362,9 @@ export default class SnapshotState {
actual: removeExtraLineBreaks(receivedSerialized), actual: removeExtraLineBreaks(receivedSerialized),
count, count,
expected: expected:
expectedTrimmed !== undefined expectedTrimmed !== undefined
? removeExtraLineBreaks(expectedTrimmed) ? removeExtraLineBreaks(expectedTrimmed)
: undefined, : undefined,
key, key,
pass: false, pass: false,
} }
@ -350,8 +395,9 @@ export default class SnapshotState {
} }
const uncheckedCount = this.getUncheckedCount() const uncheckedCount = this.getUncheckedCount()
const uncheckedKeys = this.getUncheckedKeys() const uncheckedKeys = this.getUncheckedKeys()
if (uncheckedCount) if (uncheckedCount) {
this.removeUncheckedKeys() this.removeUncheckedKeys()
}
const status = await this.save() const status = await this.save()
snapshot.fileDeleted = status.deleted snapshot.fileDeleted = status.deleted

View File

@ -7,9 +7,7 @@
import naturalCompare from 'natural-compare' import naturalCompare from 'natural-compare'
import type { OptionsReceived as PrettyFormatOptions } from 'pretty-format' import type { OptionsReceived as PrettyFormatOptions } from 'pretty-format'
import { import { format as prettyFormat } from 'pretty-format'
format as prettyFormat,
} from 'pretty-format'
import { isObject } from '../../../utils/src/index' import { isObject } from '../../../utils/src/index'
import type { SnapshotData, SnapshotStateOptions } from '../types' import type { SnapshotData, SnapshotStateOptions } from '../types'
import type { SnapshotEnvironment } from '../types/environment' import type { SnapshotEnvironment } from '../types/environment'
@ -22,16 +20,20 @@ export function testNameToKey(testName: string, count: number): string {
} }
export function keyToTestName(key: string): string { export function keyToTestName(key: string): string {
if (!/ \d+$/.test(key)) if (!/ \d+$/.test(key)) {
throw new Error('Snapshot keys must end with a number.') throw new Error('Snapshot keys must end with a number.')
}
return key.replace(/ \d+$/, '') return key.replace(/ \d+$/, '')
} }
export function getSnapshotData(content: string | null, options: SnapshotStateOptions): { export function getSnapshotData(
data: SnapshotData content: string | null,
dirty: boolean options: SnapshotStateOptions,
} { ): {
data: SnapshotData
dirty: boolean
} {
const update = options.updateSnapshot const update = options.updateSnapshot
const data = Object.create(null) const data = Object.create(null)
let snapshotContents = '' let snapshotContents = ''
@ -53,8 +55,9 @@ export function getSnapshotData(content: string | null, options: SnapshotStateOp
// if (update === 'none' && isInvalid) // if (update === 'none' && isInvalid)
// throw validationResult // throw validationResult
if ((update === 'all' || update === 'new') && isInvalid) if ((update === 'all' || update === 'new') && isInvalid) {
dirty = true dirty = true
}
return { data, dirty } return { data, dirty }
} }
@ -69,7 +72,7 @@ export function addExtraLineBreaks(string: string): string {
// Instead of trim, which can remove additional newlines or spaces // Instead of trim, which can remove additional newlines or spaces
// at beginning or end of the content from a custom serializer. // at beginning or end of the content from a custom serializer.
export function removeExtraLineBreaks(string: string): string { export function removeExtraLineBreaks(string: string): string {
return (string.length > 2 && string.startsWith('\n') && string.endsWith('\n')) return string.length > 2 && string.startsWith('\n') && string.endsWith('\n')
? string.slice(1, -1) ? string.slice(1, -1)
: string : string
} }
@ -90,7 +93,11 @@ export function removeExtraLineBreaks(string: string): string {
const escapeRegex = true const escapeRegex = true
const printFunctionName = false const printFunctionName = false
export function serialize(val: unknown, indent = 2, formatOverrides: PrettyFormatOptions = {}): string { export function serialize(
val: unknown,
indent = 2,
formatOverrides: PrettyFormatOptions = {},
): string {
return normalizeNewlines( return normalizeNewlines(
prettyFormat(val, { prettyFormat(val, {
escapeRegex, escapeRegex,
@ -136,20 +143,21 @@ export async function saveSnapshotFile(
const snapshots = Object.keys(snapshotData) const snapshots = Object.keys(snapshotData)
.sort(naturalCompare) .sort(naturalCompare)
.map( .map(
key => `exports[${printBacktickString(key)}] = ${printBacktickString(normalizeNewlines(snapshotData[key]))};`, key =>
`exports[${printBacktickString(key)}] = ${printBacktickString(
normalizeNewlines(snapshotData[key]),
)};`,
) )
const content = `${environment.getHeader()}\n\n${snapshots.join('\n\n')}\n` const content = `${environment.getHeader()}\n\n${snapshots.join('\n\n')}\n`
const oldContent = await environment.readSnapshotFile(snapshotPath) const oldContent = await environment.readSnapshotFile(snapshotPath)
const skipWriting = oldContent != null && oldContent === content const skipWriting = oldContent != null && oldContent === content
if (skipWriting) if (skipWriting) {
return return
}
await environment.saveSnapshotFile( await environment.saveSnapshotFile(snapshotPath, content)
snapshotPath,
content,
)
} }
export async function saveSnapshotFileRaw( export async function saveSnapshotFileRaw(
@ -160,13 +168,11 @@ export async function saveSnapshotFileRaw(
const oldContent = await environment.readSnapshotFile(snapshotPath) const oldContent = await environment.readSnapshotFile(snapshotPath)
const skipWriting = oldContent != null && oldContent === content const skipWriting = oldContent != null && oldContent === content
if (skipWriting) if (skipWriting) {
return return
}
await environment.saveSnapshotFile( await environment.saveSnapshotFile(snapshotPath, content)
snapshotPath,
content,
)
} }
export function prepareExpected(expected?: string) { export function prepareExpected(expected?: string) {
@ -176,8 +182,9 @@ export function prepareExpected(expected?: string) {
const matchObject = /^( +)\}\s+$/m.exec(expected || '') const matchObject = /^( +)\}\s+$/m.exec(expected || '')
const objectIndent = matchObject?.[1]?.length const objectIndent = matchObject?.[1]?.length
if (objectIndent) if (objectIndent) {
return objectIndent return objectIndent
}
// Attempts to find indentation for texts. // Attempts to find indentation for texts.
// Matches the quote of first line. // Matches the quote of first line.
@ -191,7 +198,8 @@ export function prepareExpected(expected?: string) {
if (startIndent) { if (startIndent) {
expectedTrimmed = expectedTrimmed expectedTrimmed = expectedTrimmed
?.replace(new RegExp(`^${' '.repeat(startIndent)}`, 'gm'), '').replace(/ +\}$/, '}') ?.replace(new RegExp(`^${' '.repeat(startIndent)}`, 'gm'), '')
.replace(/ +\}$/, '}')
} }
return expectedTrimmed return expectedTrimmed
@ -235,9 +243,12 @@ export function deepMergeSnapshot(target: any, source: any): any {
const mergedOutput = { ...target } const mergedOutput = { ...target }
Object.keys(source).forEach((key) => { Object.keys(source).forEach((key) => {
if (isObject(source[key]) && !source[key].$$typeof) { if (isObject(source[key]) && !source[key].$$typeof) {
if (!(key in target)) if (!(key in target)) {
Object.assign(mergedOutput, { [key]: source[key] }) Object.assign(mergedOutput, { [key]: source[key] })
else mergedOutput[key] = deepMergeSnapshot(target[key], source[key]) }
else {
mergedOutput[key] = deepMergeSnapshot(target[key], source[key])
}
} }
else if (Array.isArray(source[key])) { else if (Array.isArray(source[key])) {
mergedOutput[key] = deepMergeArray(target[key], source[key]) mergedOutput[key] = deepMergeArray(target[key], source[key])

View File

@ -1,6 +1,12 @@
import type { OptionsReceived as PrettyFormatOptions, Plugin as PrettyFormatPlugin } from 'pretty-format' import type {
OptionsReceived as PrettyFormatOptions,
Plugin as PrettyFormatPlugin,
} from 'pretty-format'
import type { RawSnapshotInfo } from '../port/rawSnapshot' import type { RawSnapshotInfo } from '../port/rawSnapshot'
import type { SnapshotEnvironment, SnapshotEnvironmentOptions } from './environment' import type {
SnapshotEnvironment,
SnapshotEnvironmentOptions,
} from './environment'
export type { SnapshotEnvironment, SnapshotEnvironmentOptions } export type { SnapshotEnvironment, SnapshotEnvironmentOptions }
export type SnapshotData = Record<string, string> export type SnapshotData = Record<string, string>

View File

@ -39,15 +39,14 @@ export default defineConfig([
format: 'esm', format: 'esm',
}, },
external, external,
plugins: [ plugins: [dts({ respectExternal: true })],
dts({ respectExternal: true }),
],
onwarn, onwarn,
}, },
]) ])
function onwarn(message) { function onwarn(message) {
if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) if (['EMPTY_BUNDLE', 'CIRCULAR_DEPENDENCY'].includes(message.code)) {
return return
}
console.error(message) console.error(message)
} }

View File

@ -30,8 +30,13 @@ interface MockSettledResultRejected {
value: any value: any
} }
export type MockResult<T> = MockResultReturn<T> | MockResultThrow | MockResultIncomplete export type MockResult<T> =
export type MockSettledResult<T> = MockSettledResultFulfilled<T> | MockSettledResultRejected | MockResultReturn<T>
| MockResultThrow
| MockResultIncomplete
export type MockSettledResult<T> =
| MockSettledResultFulfilled<T>
| MockSettledResultRejected
export interface MockContext<TArgs, TReturns> { export interface MockContext<TArgs, TReturns> {
/** /**
@ -137,16 +142,19 @@ type Methods<T> = keyof {
[K in keyof T as T[K] extends Procedure ? K : never]: T[K]; [K in keyof T as T[K] extends Procedure ? K : never]: T[K];
} }
type Properties<T> = { type Properties<T> = {
[K in keyof T]: T[K] extends Procedure ? never : K [K in keyof T]: T[K] extends Procedure ? never : K;
}[keyof T] & (string | symbol) }[keyof T] &
(string | symbol)
type Classes<T> = { type Classes<T> = {
[K in keyof T]: T[K] extends new (...args: any[]) => any ? K : never [K in keyof T]: T[K] extends new (...args: any[]) => any ? K : never;
}[keyof T] & (string | symbol) }[keyof T] &
(string | symbol)
/** /**
* @deprecated Use MockInstance<A, R> instead * @deprecated Use MockInstance<A, R> instead
*/ */
export interface SpyInstance<TArgs extends any[] = any[], TReturns = any> extends MockInstance<TArgs, TReturns> {} export interface SpyInstance<TArgs extends any[] = any[], TReturns = any>
extends MockInstance<TArgs, TReturns> {}
export interface MockInstance<TArgs extends any[] = any[], TReturns = any> { export interface MockInstance<TArgs extends any[] = any[], TReturns = any> {
/** /**
@ -193,7 +201,7 @@ export interface MockInstance<TArgs extends any[] = any[], TReturns = any> {
* const increment = vi.fn().mockImplementation(count => count + 1); * const increment = vi.fn().mockImplementation(count => count + 1);
* expect(increment(3)).toBe(4); * expect(increment(3)).toBe(4);
*/ */
mockImplementation: (fn: ((...args: TArgs) => TReturns)) => this mockImplementation: (fn: (...args: TArgs) => TReturns) => this
/** /**
* Accepts a function that will be used as a mock implementation during the next call. Can be chained so that multiple function calls produce different results. * Accepts a function that will be used as a mock implementation during the next call. Can be chained so that multiple function calls produce different results.
* @example * @example
@ -201,7 +209,7 @@ export interface MockInstance<TArgs extends any[] = any[], TReturns = any> {
* expect(fn(3)).toBe(4); * expect(fn(3)).toBe(4);
* expect(fn(3)).toBe(3); * expect(fn(3)).toBe(3);
*/ */
mockImplementationOnce: (fn: ((...args: TArgs) => TReturns)) => this mockImplementationOnce: (fn: (...args: TArgs) => TReturns) => this
/** /**
* Overrides the original mock implementation temporarily while the callback is being executed. * Overrides the original mock implementation temporarily while the callback is being executed.
* @example * @example
@ -213,7 +221,10 @@ export interface MockInstance<TArgs extends any[] = any[], TReturns = any> {
* *
* myMockFn() // 'original' * myMockFn() // 'original'
*/ */
withImplementation: <T>(fn: ((...args: TArgs) => TReturns), cb: () => T) => T extends Promise<unknown> ? Promise<this> : this withImplementation: <T>(
fn: (...args: TArgs) => TReturns,
cb: () => T
) => T extends Promise<unknown> ? Promise<this> : this
/** /**
* Use this if you need to return `this` context from the method without invoking actual implementation. * Use this if you need to return `this` context from the method without invoking actual implementation.
*/ */
@ -278,11 +289,18 @@ export interface MockInstance<TArgs extends any[] = any[], TReturns = any> {
mockRejectedValueOnce: (obj: any) => this mockRejectedValueOnce: (obj: any) => this
} }
export interface Mock<TArgs extends any[] = any, TReturns = any> extends MockInstance<TArgs, TReturns> { export interface Mock<TArgs extends any[] = any, TReturns = any>
extends MockInstance<TArgs, TReturns> {
new (...args: TArgs): TReturns new (...args: TArgs): TReturns
(...args: TArgs): TReturns (...args: TArgs): TReturns
} }
export interface PartialMock<TArgs extends any[] = any, TReturns = any> extends MockInstance<TArgs, TReturns extends Promise<Awaited<TReturns>> ? Promise<Partial<Awaited<TReturns>>> : Partial<TReturns>> { export interface PartialMock<TArgs extends any[] = any, TReturns = any>
extends MockInstance<
TArgs,
TReturns extends Promise<Awaited<TReturns>>
? Promise<Partial<Awaited<TReturns>>>
: Partial<TReturns>
> {
new (...args: TArgs): TReturns new (...args: TArgs): TReturns
(...args: TArgs): TReturns (...args: TArgs): TReturns
} }
@ -292,23 +310,33 @@ export type MaybeMockedConstructor<T> = T extends new (
) => infer R ) => infer R
? Mock<ConstructorParameters<T>, R> ? Mock<ConstructorParameters<T>, R>
: T : T
export type MockedFunction<T extends Procedure> = Mock<Parameters<T>, ReturnType<T>> & { export type MockedFunction<T extends Procedure> = Mock<
Parameters<T>,
ReturnType<T>
> & {
[K in keyof T]: T[K]; [K in keyof T]: T[K];
} }
export type PartiallyMockedFunction<T extends Procedure> = PartialMock<Parameters<T>, ReturnType<T>> & { export type PartiallyMockedFunction<T extends Procedure> = PartialMock<
Parameters<T>,
ReturnType<T>
> & {
[K in keyof T]: T[K]; [K in keyof T]: T[K];
} }
export type MockedFunctionDeep<T extends Procedure> = Mock<Parameters<T>, ReturnType<T>> & MockedObjectDeep<T> export type MockedFunctionDeep<T extends Procedure> = Mock<
export type PartiallyMockedFunctionDeep<T extends Procedure> = PartialMock<Parameters<T>, ReturnType<T>> & MockedObjectDeep<T> Parameters<T>,
ReturnType<T>
> &
MockedObjectDeep<T>
export type PartiallyMockedFunctionDeep<T extends Procedure> = PartialMock<
Parameters<T>,
ReturnType<T>
> &
MockedObjectDeep<T>
export type MockedObject<T> = MaybeMockedConstructor<T> & { export type MockedObject<T> = MaybeMockedConstructor<T> & {
[K in Methods<T>]: T[K] extends Procedure [K in Methods<T>]: T[K] extends Procedure ? MockedFunction<T[K]> : T[K];
? MockedFunction<T[K]>
: T[K];
} & { [K in Properties<T>]: T[K] } } & { [K in Properties<T>]: T[K] }
export type MockedObjectDeep<T> = MaybeMockedConstructor<T> & { export type MockedObjectDeep<T> = MaybeMockedConstructor<T> & {
[K in Methods<T>]: T[K] extends Procedure [K in Methods<T>]: T[K] extends Procedure ? MockedFunctionDeep<T[K]> : T[K];
? MockedFunctionDeep<T[K]>
: T[K];
} & { [K in Properties<T>]: MaybeMockedDeep<T[K]> } } & { [K in Properties<T>]: MaybeMockedDeep<T[K]> }
export type MaybeMockedDeep<T> = T extends Procedure export type MaybeMockedDeep<T> = T extends Procedure
@ -340,8 +368,8 @@ interface Constructable {
} }
export type MockedClass<T extends Constructable> = MockInstance< export type MockedClass<T extends Constructable> = MockInstance<
T extends new (...args: infer P) => any ? P : never, T extends new (...args: infer P) => any ? P : never,
InstanceType<T> InstanceType<T>
> & { > & {
prototype: T extends { prototype: any } ? Mocked<T['prototype']> : never prototype: T extends { prototype: any } ? Mocked<T['prototype']> : never
} & T } & T
@ -351,32 +379,35 @@ export type Mocked<T> = {
? MockInstance<Args, Returns> ? MockInstance<Args, Returns>
: T[P] extends Constructable : T[P] extends Constructable
? MockedClass<T[P]> ? MockedClass<T[P]>
: T[P] : T[P];
} & } & T
T
export const mocks = new Set<MockInstance>() export const mocks = new Set<MockInstance>()
export function isMockFunction(fn: any): fn is MockInstance { export function isMockFunction(fn: any): fn is MockInstance {
return typeof fn === 'function' return (
&& '_isMockFunction' in fn typeof fn === 'function' && '_isMockFunction' in fn && fn._isMockFunction
&& fn._isMockFunction )
} }
export function spyOn<T, S extends Properties<Required<T>>>( export function spyOn<T, S extends Properties<Required<T>>>(
obj: T, obj: T,
methodName: S, methodName: S,
accessType: 'get', accessType: 'get'
): MockInstance<[], T[S]> ): MockInstance<[], T[S]>
export function spyOn<T, G extends Properties<Required<T>>>( export function spyOn<T, G extends Properties<Required<T>>>(
obj: T, obj: T,
methodName: G, methodName: G,
accessType: 'set', accessType: 'set'
): MockInstance<[T[G]], void> ): MockInstance<[T[G]], void>
export function spyOn<T, M extends (Classes<Required<T>> | Methods<Required<T>>)>( export function spyOn<T, M extends Classes<Required<T>> | Methods<Required<T>>>(
obj: T, obj: T,
methodName: M, methodName: M
): Required<T>[M] extends ({ new (...args: infer A): infer R }) | ((...args: infer A) => infer R) ? MockInstance<A, R> : never ): Required<T>[M] extends
| { new (...args: infer A): infer R }
| ((...args: infer A) => infer R)
? MockInstance<A, R>
: never
export function spyOn<T, K extends keyof T>( export function spyOn<T, K extends keyof T>(
obj: T, obj: T,
method: K, method: K,
@ -419,13 +450,15 @@ function enhanceSpy<TArgs extends any[], TReturns>(
}, },
get results() { get results() {
return state.results.map(([callType, value]) => { return state.results.map(([callType, value]) => {
const type = callType === 'error' ? 'throw' as const : 'return' as const const type
= callType === 'error' ? ('throw' as const) : ('return' as const)
return { type, value } return { type, value }
}) })
}, },
get settledResults() { get settledResults() {
return state.resolves.map(([callType, value]) => { return state.resolves.map(([callType, value]) => {
const type = callType === 'error' ? 'rejected' as const : 'fulfilled' as const const type
= callType === 'error' ? ('rejected' as const) : ('fulfilled' as const)
return { type, value } return { type, value }
}) })
}, },
@ -440,7 +473,12 @@ function enhanceSpy<TArgs extends any[], TReturns>(
function mockCall(this: unknown, ...args: any) { function mockCall(this: unknown, ...args: any) {
instances.push(this) instances.push(this)
invocations.push(++callOrder) invocations.push(++callOrder)
const impl = implementationChangedTemporarily ? implementation! : (onceImplementations.shift() || implementation || state.getOriginal() || (() => {})) const impl = implementationChangedTemporarily
? implementation!
: onceImplementations.shift()
|| implementation
|| state.getOriginal()
|| (() => {})
return impl.apply(this, args) return impl.apply(this, args)
} }
@ -485,9 +523,18 @@ function enhanceSpy<TArgs extends any[], TReturns>(
return stub return stub
} }
function withImplementation(fn: (...args: TArgs) => TReturns, cb: () => void): MockInstance<TArgs, TReturns> function withImplementation(
function withImplementation(fn: (...args: TArgs) => TReturns, cb: () => Promise<void>): Promise<MockInstance<TArgs, TReturns>> fn: (...args: TArgs) => TReturns,
function withImplementation(fn: (...args: TArgs) => TReturns, cb: () => void | Promise<void>): MockInstance<TArgs, TReturns> | Promise<MockInstance<TArgs, TReturns>> { cb: () => void
): MockInstance<TArgs, TReturns>
function withImplementation(
fn: (...args: TArgs) => TReturns,
cb: () => Promise<void>
): Promise<MockInstance<TArgs, TReturns>>
function withImplementation(
fn: (...args: TArgs) => TReturns,
cb: () => void | Promise<void>,
): MockInstance<TArgs, TReturns> | Promise<MockInstance<TArgs, TReturns>> {
const originalImplementation = implementation const originalImplementation = implementation
implementation = fn implementation = fn
@ -521,7 +568,8 @@ function enhanceSpy<TArgs extends any[], TReturns>(
}) })
stub.mockReturnValue = (val: TReturns) => stub.mockImplementation(() => val) stub.mockReturnValue = (val: TReturns) => stub.mockImplementation(() => val)
stub.mockReturnValueOnce = (val: TReturns) => stub.mockImplementationOnce(() => val) stub.mockReturnValueOnce = (val: TReturns) =>
stub.mockImplementationOnce(() => val)
stub.mockResolvedValue = (val: Awaited<TReturns>) => stub.mockResolvedValue = (val: Awaited<TReturns>) =>
stub.mockImplementation(() => Promise.resolve(val as TReturns) as any) stub.mockImplementation(() => Promise.resolve(val as TReturns) as any)
@ -553,9 +601,12 @@ export function fn<TArgs extends any[] = any[], R = any>(
export function fn<TArgs extends any[] = any[], R = any>( export function fn<TArgs extends any[] = any[], R = any>(
implementation?: (...args: TArgs) => R, implementation?: (...args: TArgs) => R,
): Mock<TArgs, R> { ): Mock<TArgs, R> {
const enhancedSpy = enhanceSpy(tinyspy.internalSpyOn({ spy: implementation || (() => {}) }, 'spy')) const enhancedSpy = enhanceSpy(
if (implementation) tinyspy.internalSpyOn({ spy: implementation || (() => {}) }, 'spy'),
)
if (implementation) {
enhancedSpy.mockImplementation(implementation) enhancedSpy.mockImplementation(implementation)
}
return enhancedSpy as Mock return enhancedSpy as Mock
} }

View File

@ -99,7 +99,7 @@ declare global {
const pausableWatch: typeof import('@vueuse/core')['pausableWatch'] const pausableWatch: typeof import('@vueuse/core')['pausableWatch']
const provide: typeof import('vue')['provide'] const provide: typeof import('vue')['provide']
const provideLocal: typeof import('@vueuse/core')['provideLocal'] const provideLocal: typeof import('@vueuse/core')['provideLocal']
const provideResizing: typeof import('./composables/browser')['provideResizing'] const provideResizing: typeof import("./composables/browser")["provideResizing"]
const reactify: typeof import('@vueuse/core')['reactify'] const reactify: typeof import('@vueuse/core')['reactify']
const reactifyObject: typeof import('@vueuse/core')['reactifyObject'] const reactifyObject: typeof import('@vueuse/core')['reactifyObject']
const reactive: typeof import('vue')['reactive'] const reactive: typeof import('vue')['reactive']
@ -107,14 +107,14 @@ declare global {
const reactiveOmit: typeof import('@vueuse/core')['reactiveOmit'] const reactiveOmit: typeof import('@vueuse/core')['reactiveOmit']
const reactivePick: typeof import('@vueuse/core')['reactivePick'] const reactivePick: typeof import('@vueuse/core')['reactivePick']
const readonly: typeof import('vue')['readonly'] const readonly: typeof import('vue')['readonly']
const recalculateDetailPanels: typeof import('./composables/browser')['recalculateDetailPanels'] const recalculateDetailPanels: typeof import("./composables/browser")["recalculateDetailPanels"]
const ref: typeof import('vue')['ref'] const ref: typeof import('vue')['ref']
const refAutoReset: typeof import('@vueuse/core')['refAutoReset'] const refAutoReset: typeof import('@vueuse/core')['refAutoReset']
const refDebounced: typeof import('@vueuse/core')['refDebounced'] const refDebounced: typeof import('@vueuse/core')['refDebounced']
const refDefault: typeof import('@vueuse/core')['refDefault'] const refDefault: typeof import('@vueuse/core')['refDefault']
const refThrottled: typeof import('@vueuse/core')['refThrottled'] const refThrottled: typeof import('@vueuse/core')['refThrottled']
const refWithControl: typeof import('@vueuse/core')['refWithControl'] const refWithControl: typeof import('@vueuse/core')['refWithControl']
const registerResizingListener: typeof import('./composables/browser')['registerResizingListener'] const registerResizingListener: typeof import("./composables/browser")["registerResizingListener"]
const resolveComponent: typeof import('vue')['resolveComponent'] const resolveComponent: typeof import('vue')['resolveComponent']
const resolveRef: typeof import('@vueuse/core')['resolveRef'] const resolveRef: typeof import('@vueuse/core')['resolveRef']
const resolveUnref: typeof import('@vueuse/core')['resolveUnref'] const resolveUnref: typeof import('@vueuse/core')['resolveUnref']
@ -150,7 +150,7 @@ declare global {
const tryOnMounted: typeof import('@vueuse/core')['tryOnMounted'] const tryOnMounted: typeof import('@vueuse/core')['tryOnMounted']
const tryOnScopeDispose: typeof import('@vueuse/core')['tryOnScopeDispose'] const tryOnScopeDispose: typeof import('@vueuse/core')['tryOnScopeDispose']
const tryOnUnmounted: typeof import('@vueuse/core')['tryOnUnmounted'] const tryOnUnmounted: typeof import('@vueuse/core')['tryOnUnmounted']
const unifiedDiff: typeof import('./composables/diff')['unifiedDiff'] const unifiedDiff: typeof import("./composables/diff")["unifiedDiff"]
const unref: typeof import('vue')['unref'] const unref: typeof import('vue')['unref']
const unrefElement: typeof import('@vueuse/core')['unrefElement'] const unrefElement: typeof import('@vueuse/core')['unrefElement']
const until: typeof import('@vueuse/core')['until'] const until: typeof import('@vueuse/core')['until']
@ -245,7 +245,7 @@ declare global {
const useMutationObserver: typeof import('@vueuse/core')['useMutationObserver'] const useMutationObserver: typeof import('@vueuse/core')['useMutationObserver']
const useNavigatorLanguage: typeof import('@vueuse/core')['useNavigatorLanguage'] const useNavigatorLanguage: typeof import('@vueuse/core')['useNavigatorLanguage']
const useNetwork: typeof import('@vueuse/core')['useNetwork'] const useNetwork: typeof import('@vueuse/core')['useNetwork']
const useNotifyResizing: typeof import('./composables/browser')['useNotifyResizing'] const useNotifyResizing: typeof import("./composables/browser")["useNotifyResizing"]
const useNow: typeof import('@vueuse/core')['useNow'] const useNow: typeof import('@vueuse/core')['useNow']
const useObjectUrl: typeof import('@vueuse/core')['useObjectUrl'] const useObjectUrl: typeof import('@vueuse/core')['useObjectUrl']
const useOffsetPagination: typeof import('@vueuse/core')['useOffsetPagination'] const useOffsetPagination: typeof import('@vueuse/core')['useOffsetPagination']
@ -267,7 +267,7 @@ declare global {
const useRafFn: typeof import('@vueuse/core')['useRafFn'] const useRafFn: typeof import('@vueuse/core')['useRafFn']
const useRefHistory: typeof import('@vueuse/core')['useRefHistory'] const useRefHistory: typeof import('@vueuse/core')['useRefHistory']
const useResizeObserver: typeof import('@vueuse/core')['useResizeObserver'] const useResizeObserver: typeof import('@vueuse/core')['useResizeObserver']
const useResizing: typeof import('./composables/browser')['useResizing'] const useResizing: typeof import("./composables/browser")["useResizing"]
const useRoute: typeof import('vue-router')['useRoute'] const useRoute: typeof import('vue-router')['useRoute']
const useRouter: typeof import('vue-router')['useRouter'] const useRouter: typeof import('vue-router')['useRouter']
const useScreenOrientation: typeof import('@vueuse/core')['useScreenOrientation'] const useScreenOrientation: typeof import('@vueuse/core')['useScreenOrientation']

View File

@ -1,65 +1,47 @@
<script setup lang="ts"> <script setup lang="ts">
import { viewport, customViewport } from '~/composables/browser' import { viewport, customViewport } from "~/composables/browser";
import type { ViewportSize } from '~/composables/browser' import type { ViewportSize } from "~/composables/browser";
import { setIframeViewport, getCurrentBrowserIframe } from '~/composables/api' import { setIframeViewport, getCurrentBrowserIframe } from "~/composables/api";
const sizes: Record<ViewportSize, [width: string, height: string] | null> = { const sizes: Record<ViewportSize, [width: string, height: string] | null> = {
'small-mobile': ['320px', '568px'], "small-mobile": ["320px", "568px"],
'large-mobile': ['414px', '896px'], "large-mobile": ["414px", "896px"],
tablet: ['834px', '1112px'], tablet: ["834px", "1112px"],
full: ['100%', '100%'], full: ["100%", "100%"],
// should not be used manually, this is just // should not be used manually, this is just
// a fallback for the case when the viewport is not set correctly // a fallback for the case when the viewport is not set correctly
custom: null, custom: null,
} };
async function changeViewport(name: ViewportSize) { async function changeViewport(name: ViewportSize) {
if (viewport.value === name) { if (viewport.value === name) {
viewport.value = customViewport.value ? 'custom' : 'full' viewport.value = customViewport.value ? "custom" : "full";
} else { } else {
viewport.value = name viewport.value = name;
} }
const iframe = getCurrentBrowserIframe() const iframe = getCurrentBrowserIframe();
if (!iframe) { if (!iframe) {
console.warn('Iframe not found') console.warn("Iframe not found");
return return;
} }
const [width, height] = sizes[viewport.value] || customViewport.value || sizes.full const [width, height] =
sizes[viewport.value] || customViewport.value || sizes.full;
await setIframeViewport(width, height) await setIframeViewport(width, height);
} }
</script> </script>
<template> <template>
<div h="full" flex="~ col"> <div h="full" flex="~ col">
<div <div p="3" h-10 flex="~ gap-2" items-center bg-header border="b base">
p="3"
h-10
flex="~ gap-2"
items-center
bg-header
border="b base"
>
<div class="i-carbon-content-delivery-network" /> <div class="i-carbon-content-delivery-network" />
<span <span pl-1 font-bold text-sm flex-auto ws-nowrap overflow-hidden truncate
pl-1 >Browser UI</span
font-bold >
text-sm
flex-auto
ws-nowrap
overflow-hidden
truncate
>Browser UI</span>
</div> </div>
<div <div p="l3 y2 r2" flex="~ gap-2" items-center bg-header border="b-2 base">
p="l3 y2 r2"
flex="~ gap-2"
items-center
bg-header
border="b-2 base"
>
<!-- TODO: these are only for preview (thank you Storybook!), we need to support more different and custom sizes (as a dropdown) --> <!-- TODO: these are only for preview (thank you Storybook!), we need to support more different and custom sizes (as a dropdown) -->
<IconButton <IconButton
v-tooltip.bottom="'Flexible'" v-tooltip.bottom="'Flexible'"
@ -91,7 +73,11 @@ async function changeViewport(name: ViewportSize) {
/> />
</div> </div>
<div flex-auto class="scrolls"> <div flex-auto class="scrolls">
<div id="tester-ui" class="flex h-full justify-center items-center font-light op70" style="overflow: auto; width: 100%; height: 100%"> <div
id="tester-ui"
class="flex h-full justify-center items-center font-light op70"
style="overflow: auto; width: 100%; height: 100%"
>
Select a test to run Select a test to run
</div> </div>
</div> </div>

View File

@ -1,66 +1,61 @@
<script setup lang="ts"> <script setup lang="ts">
import type CodeMirror from 'codemirror' import type CodeMirror from "codemirror";
const { mode, readOnly } = defineProps<{ const { mode, readOnly } = defineProps<{
mode?: string mode?: string;
readOnly?: boolean readOnly?: boolean;
}>() }>();
const emit = defineEmits<{ const emit = defineEmits<{
(event: 'save', content: string): void (event: "save", content: string): void;
}>() }>();
const modelValue = defineModel<string>() const modelValue = defineModel<string>();
const attrs = useAttrs() const attrs = useAttrs();
const modeMap: Record<string, any> = { const modeMap: Record<string, any> = {
// html: 'htmlmixed', // html: 'htmlmixed',
// vue: 'htmlmixed', // vue: 'htmlmixed',
// svelte: 'htmlmixed', // svelte: 'htmlmixed',
js: 'javascript', js: "javascript",
mjs: 'javascript', mjs: "javascript",
cjs: 'javascript', cjs: "javascript",
ts: { name: 'javascript', typescript: true }, ts: { name: "javascript", typescript: true },
mts: { name: 'javascript', typescript: true }, mts: { name: "javascript", typescript: true },
cts: { name: 'javascript', typescript: true }, cts: { name: "javascript", typescript: true },
jsx: { name: 'javascript', jsx: true }, jsx: { name: "javascript", jsx: true },
tsx: { name: 'javascript', typescript: true, jsx: true }, tsx: { name: "javascript", typescript: true, jsx: true },
} };
const el = ref<HTMLTextAreaElement>() const el = ref<HTMLTextAreaElement>();
const cm = shallowRef<CodeMirror.EditorFromTextArea>() const cm = shallowRef<CodeMirror.EditorFromTextArea>();
defineExpose({ cm }) defineExpose({ cm });
onMounted(async () => { onMounted(async () => {
cm.value = useCodeMirror(el, modelValue as unknown as Ref<string>, { cm.value = useCodeMirror(el, modelValue as unknown as Ref<string>, {
...attrs, ...attrs,
mode: modeMap[mode || ''] || mode, mode: modeMap[mode || ""] || mode,
readOnly: readOnly ? true : undefined, readOnly: readOnly ? true : undefined,
extraKeys: { extraKeys: {
'Cmd-S': function (cm) { "Cmd-S": function (cm) {
emit('save', cm.getValue()) emit("save", cm.getValue());
}, },
'Ctrl-S': function (cm) { "Ctrl-S": function (cm) {
emit('save', cm.getValue()) emit("save", cm.getValue());
}, },
}, },
}) });
cm.value.setSize('100%', '100%') cm.value.setSize("100%", "100%");
cm.value.clearHistory() cm.value.clearHistory();
setTimeout(() => cm.value!.refresh(), 100) setTimeout(() => cm.value!.refresh(), 100);
}) });
</script> </script>
<template> <template>
<div <div relative font-mono text-sm class="codemirror-scrolls">
relative
font-mono
text-sm
class="codemirror-scrolls"
>
<textarea ref="el" /> <textarea ref="el" />
</div> </div>
</template> </template>

View File

@ -1,11 +1,19 @@
<script setup lang="ts"> <script setup lang="ts">
import { client, isConnected, isConnecting, browserState } from '~/composables/client' import {
client,
isConnected,
isConnecting,
browserState,
} from "~/composables/client";
</script> </script>
<template> <template>
<template v-if="!isConnected"> <template v-if="!isConnected">
<div <div
fixed inset-0 p2 z-10 fixed
inset-0
p2
z-10
select-none select-none
text="center sm" text="center sm"
bg="overlay" bg="overlay"
@ -16,18 +24,27 @@ import { client, isConnected, isConnecting, browserState } from '~/composables/c
<div <div
h-full h-full
flex="~ col gap-2" flex="~ col gap-2"
items-center justify-center items-center
justify-center
:class="isConnecting ? 'animate-pulse' : ''" :class="isConnecting ? 'animate-pulse' : ''"
> >
<div <div
text="5xl" text="5xl"
:class="isConnecting ? 'i-carbon:renew animate-spin animate-reverse' : 'i-carbon-wifi-off'" :class="
isConnecting
? 'i-carbon:renew animate-spin animate-reverse'
: 'i-carbon-wifi-off'
"
/> />
<div text-2xl> <div text-2xl>
{{ isConnecting ? 'Connecting...' : 'Disconnected' }} {{ isConnecting ? "Connecting..." : "Disconnected" }}
</div> </div>
<div text-lg op50> <div text-lg op50>
Check your terminal or start a new server with `{{ browserState ? `vitest --browser=${browserState.config.browser.name}` : 'vitest --ui' }}` Check your terminal or start a new server with `{{
browserState
? `vitest --browser=${browserState.config.browser.name}`
: "vitest --ui"
}}`
</div> </div>
</div> </div>
</div> </div>

View File

@ -1,29 +1,16 @@
<script setup lang="ts"> <script setup lang="ts">
defineProps<{ defineProps<{
src: string src: string;
}>() }>();
</script> </script>
<template> <template>
<div h="full" flex="~ col"> <div h="full" flex="~ col">
<div <div p="3" h-10 flex="~ gap-2" items-center bg-header border="b base">
p="3"
h-10
flex="~ gap-2"
items-center
bg-header
border="b base"
>
<div class="i-carbon:folder-details-reference" /> <div class="i-carbon:folder-details-reference" />
<span <span pl-1 font-bold text-sm flex-auto ws-nowrap overflow-hidden truncate
pl-1 >Coverage</span
font-bold >
text-sm
flex-auto
ws-nowrap
overflow-hidden
truncate
>Coverage</span>
</div> </div>
<div flex-auto py-1 bg-white> <div flex-auto py-1 bg-white>
<iframe id="vitest-ui-coverage" :src="src" /> <iframe id="vitest-ui-coverage" :src="src" />

View File

@ -1,23 +1,10 @@
<template> <template>
<div h="full" flex="~ col"> <div h="full" flex="~ col">
<div <div p="3" h-10 flex="~ gap-2" items-center bg-header border="b base">
p="3"
h-10
flex="~ gap-2"
items-center
bg-header
border="b base"
>
<div class="i-carbon-dashboard" /> <div class="i-carbon-dashboard" />
<span <span pl-1 font-bold text-sm flex-auto ws-nowrap overflow-hidden truncate
pl-1 >Dashboard</span
font-bold >
text-sm
flex-auto
ws-nowrap
overflow-hidden
truncate
>Dashboard</span>
</div> </div>
<div class="scrolls" flex-auto py-1> <div class="scrolls" flex-auto py-1>
<TestsFilesContainer /> <TestsFilesContainer />

View File

@ -1,14 +1,32 @@
<script setup lang="ts"> <script setup lang="ts">
defineProps<{ defineProps<{
color?: string color?: string;
}>() }>();
const open = ref(true) const open = ref(true);
</script> </script>
<template> <template>
<div :open="open" class="details-panel" data-testid="details-panel" @toggle="open = $event.target.open"> <div
<div p="y1" text-sm bg-base items-center z-5 gap-2 :class="color" w-full flex select-none sticky top="-1"> :open="open"
class="details-panel"
data-testid="details-panel"
@toggle="open = $event.target.open"
>
<div
p="y1"
text-sm
bg-base
items-center
z-5
gap-2
:class="color"
w-full
flex
select-none
sticky
top="-1"
>
<div flex-1 h-1px border="base b" op80 /> <div flex-1 h-1px border="base b" op80 />
<slot name="summary" :open="open" /> <slot name="summary" :open="open" />
<div flex-1 h-1px border="base b" op80 /> <div flex-1 h-1px border="base b" op80 />

Some files were not shown because too many files have changed in this diff Show More