diff --git a/crates/napi/src/bindgen_runtime/js_values/task.rs b/crates/napi/src/bindgen_runtime/js_values/task.rs index 9b7df92c..3c37ae7d 100644 --- a/crates/napi/src/bindgen_runtime/js_values/task.rs +++ b/crates/napi/src/bindgen_runtime/js_values/task.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::ffi::c_void; use std::marker::PhantomData; use std::ptr; @@ -50,10 +51,19 @@ impl ScopedTask<'task>> AsyncTask { } } +type AbortCallback = Rc>>>; + /// pub struct AbortSignal { raw_work: Rc>, status: Rc>, + abort: AbortCallback, +} + +impl AbortSignal { + pub fn on_abort(&self, cb: F) { + self.abort.borrow_mut().push(Box::new(cb)); + } } impl UnwindSafe for AbortSignal {} @@ -74,9 +84,11 @@ impl FromNapiValue for AbortSignal { ); let async_work_inner: Rc> = Rc::new(Cell::new(ptr::null_mut())); let task_status = Rc::new(Cell::new(0)); + let abort_cbs = Rc::new(RefCell::new(vec![])); let abort_signal = AbortSignal { raw_work: async_work_inner.clone(), status: task_status.clone(), + abort: abort_cbs.clone(), }; let js_env = Env::from_raw(env); @@ -111,6 +123,7 @@ impl FromNapiValue for AbortSignal { Ok(AbortSignal { raw_work: async_work_inner, status: task_status, + abort: abort_cbs, }) } } @@ -153,6 +166,11 @@ fn on_abort_impl( )?; let abort_controller_stack = Box::leak(Box::from_raw(async_task as *mut AbortSignalStack)); for abort_controller in abort_controller_stack.0.iter() { + // call abort callback + for cb in abort_controller.abort.borrow().iter() { + cb(); + } + // Task Completed, return now if abort_controller.status.get() == 1 { return Ok(ptr::null_mut()); diff --git a/examples/napi/__tests__/__snapshots__/values.spec.ts.md b/examples/napi/__tests__/__snapshots__/values.spec.ts.md index e31773b3..444d3160 100644 --- a/examples/napi/__tests__/__snapshots__/values.spec.ts.md +++ b/examples/napi/__tests__/__snapshots__/values.spec.ts.md @@ -1234,6 +1234,8 @@ Generated by [AVA](https://avajs.dev). ␊ export declare function withAbortController(a: number, b: number, signal: AbortSignal): Promise␊ ␊ + export declare function withAbortSignalHandle(signal: AbortSignal): Promise␊ + ␊ export declare function withinAsyncRuntimeIfAvailable(): void␊ ␊ export declare function withoutAbortController(a: number, b: number): Promise␊ diff --git a/examples/napi/__tests__/__snapshots__/values.spec.ts.snap b/examples/napi/__tests__/__snapshots__/values.spec.ts.snap index 3d124f6d..3b4e8373 100644 Binary files a/examples/napi/__tests__/__snapshots__/values.spec.ts.snap and b/examples/napi/__tests__/__snapshots__/values.spec.ts.snap differ diff --git a/examples/napi/__tests__/values.spec.ts b/examples/napi/__tests__/values.spec.ts index 04a28314..404570e0 100644 --- a/examples/napi/__tests__/values.spec.ts +++ b/examples/napi/__tests__/values.spec.ts @@ -287,6 +287,7 @@ import { indexSetToRust, indexSetToJs, intoUtf8, + withAbortSignalHandle, } from '../index.cjs' // import other stuff in `#[napi(module_exports)]` import nativeAddon from '../index.cjs' @@ -1460,6 +1461,20 @@ test('async task with different resolved values', async (t) => { t.deepEqual(r2, [0, 1]) }) +AbortSignalTest('with abort signal handle', async (t) => { + const ctrl = new AbortController() + const promise = withAbortSignalHandle(ctrl.signal) + try { + ctrl.abort() + const ret = await promise + t.is(ret, 999) + } catch (err: unknown) { + // sometimes on CI, the scheduled task is able to abort + // so we only allow it to throw AbortError + t.is((err as Error).message, 'AbortError') + } +}) + AbortSignalTest('abort resolved task', async (t) => { const ctrl = new AbortController() await withAbortController(1, 2, ctrl.signal).then(() => ctrl.abort()) diff --git a/examples/napi/example.wasi-browser.js b/examples/napi/example.wasi-browser.js index 86c43280..2583e86b 100644 --- a/examples/napi/example.wasi-browser.js +++ b/examples/napi/example.wasi-browser.js @@ -406,6 +406,7 @@ export const validateTypedArraySlice = __napiModule.exports.validateTypedArraySl export const validateUint8ClampedSlice = __napiModule.exports.validateUint8ClampedSlice export const validateUndefined = __napiModule.exports.validateUndefined export const withAbortController = __napiModule.exports.withAbortController +export const withAbortSignalHandle = __napiModule.exports.withAbortSignalHandle export const withinAsyncRuntimeIfAvailable = __napiModule.exports.withinAsyncRuntimeIfAvailable export const withoutAbortController = __napiModule.exports.withoutAbortController export const xxh64Alias = __napiModule.exports.xxh64Alias diff --git a/examples/napi/example.wasi.cjs b/examples/napi/example.wasi.cjs index 56dfdaf1..12ab879a 100644 --- a/examples/napi/example.wasi.cjs +++ b/examples/napi/example.wasi.cjs @@ -451,6 +451,7 @@ module.exports.validateTypedArraySlice = __napiModule.exports.validateTypedArray module.exports.validateUint8ClampedSlice = __napiModule.exports.validateUint8ClampedSlice module.exports.validateUndefined = __napiModule.exports.validateUndefined module.exports.withAbortController = __napiModule.exports.withAbortController +module.exports.withAbortSignalHandle = __napiModule.exports.withAbortSignalHandle module.exports.withinAsyncRuntimeIfAvailable = __napiModule.exports.withinAsyncRuntimeIfAvailable module.exports.withoutAbortController = __napiModule.exports.withoutAbortController module.exports.xxh64Alias = __napiModule.exports.xxh64Alias diff --git a/examples/napi/index.cjs b/examples/napi/index.cjs index 1fb0aec9..681a2fbc 100644 --- a/examples/napi/index.cjs +++ b/examples/napi/index.cjs @@ -108,7 +108,24 @@ function requireNative() { } } else if (process.platform === 'win32') { if (process.arch === 'x64') { + if (process.report?.getReport?.()?.header?.osName?.startsWith?.('MINGW')) { + try { + return require('./example.win32-x64-gnu.node') + } catch (e) { + loadErrors.push(e) + } try { + const binding = require('@examples/napi-win32-x64-gnu') + const bindingPackageVersion = require('@examples/napi-win32-x64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { return require('./example.win32-x64-msvc.node') } catch (e) { loadErrors.push(e) @@ -123,6 +140,7 @@ function requireNative() { } catch (e) { loadErrors.push(e) } + } } else if (process.arch === 'ia32') { try { return require('./example.win32-ia32-msvc.node') @@ -900,6 +918,7 @@ module.exports.validateTypedArraySlice = nativeBinding.validateTypedArraySlice module.exports.validateUint8ClampedSlice = nativeBinding.validateUint8ClampedSlice module.exports.validateUndefined = nativeBinding.validateUndefined module.exports.withAbortController = nativeBinding.withAbortController +module.exports.withAbortSignalHandle = nativeBinding.withAbortSignalHandle module.exports.withinAsyncRuntimeIfAvailable = nativeBinding.withinAsyncRuntimeIfAvailable module.exports.withoutAbortController = nativeBinding.withoutAbortController module.exports.xxh64Alias = nativeBinding.xxh64Alias diff --git a/examples/napi/index.d.cts b/examples/napi/index.d.cts index b87142d1..27628d39 100644 --- a/examples/napi/index.d.cts +++ b/examples/napi/index.d.cts @@ -1195,6 +1195,8 @@ export type VoidNullable = export declare function withAbortController(a: number, b: number, signal: AbortSignal): Promise +export declare function withAbortSignalHandle(signal: AbortSignal): Promise + export declare function withinAsyncRuntimeIfAvailable(): void export declare function withoutAbortController(a: number, b: number): Promise diff --git a/examples/napi/src/task.rs b/examples/napi/src/task.rs index be8a7352..0cac1902 100644 --- a/examples/napi/src/task.rs +++ b/examples/napi/src/task.rs @@ -1,7 +1,34 @@ -use std::thread::sleep; +use std::{sync::mpsc, thread::sleep}; use napi::{bindgen_prelude::*, ScopedTask}; +pub struct SimpleTask { + receiver: mpsc::Receiver, +} + +#[napi] +impl napi::Task for SimpleTask { + type Output = i32; + type JsValue = i32; + + fn compute(&mut self) -> Result { + self.receiver.recv().map_err(|e| { + Error::new( + Status::GenericFailure, + format!("Channel receive error: {}", e), + ) + }) + } + + fn resolve(&mut self, _env: napi::Env, output: Self::Output) -> Result { + Ok(output) + } + + fn finally(self, _env: napi::Env) -> Result<()> { + Ok(()) + } +} + pub struct DelaySum(u32, u32); #[napi] @@ -33,6 +60,13 @@ pub fn with_abort_controller(a: u32, b: u32, signal: AbortSignal) -> AsyncTask AsyncTask { + let (sender, receiver) = mpsc::channel::(); + signal.on_abort(move || sender.send(999).unwrap()); + AsyncTask::with_signal(SimpleTask { receiver }, signal) +} + struct AsyncTaskVoidReturn {} #[napi]