feat(napi): add on_abort for AbortSignal (#2942)

* feat(napi): add on_abort for AbortSignal

* chore: upgrade example code

* fix: fix lint error

* fix: fix lint error

* Update examples/napi/src/task.rs

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update crates/napi/src/bindgen_runtime/js_values/task.rs

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fmt

---------

Co-authored-by: LongYinan <lynweklm@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
richerfu 2025-10-03 16:14:32 +08:00 committed by GitHub
parent 70b66ee1b0
commit 9df9f890f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 93 additions and 1 deletions

View File

@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::ffi::c_void;
use std::marker::PhantomData;
use std::ptr;
@ -50,10 +51,19 @@ impl<T: for<'task> ScopedTask<'task>> AsyncTask<T> {
}
}
type AbortCallback = Rc<RefCell<Vec<Box<dyn Fn()>>>>;
/// <https://developer.mozilla.org/zh-CN/docs/Web/API/AbortController>
pub struct AbortSignal {
raw_work: Rc<Cell<sys::napi_async_work>>,
status: Rc<Cell<u8>>,
abort: AbortCallback,
}
impl AbortSignal {
pub fn on_abort<F: Fn() + 'static>(&self, cb: F) {
self.abort.borrow_mut().push(Box::new(cb));
}
}
impl UnwindSafe for AbortSignal {}
@ -74,9 +84,11 @@ impl FromNapiValue for AbortSignal {
);
let async_work_inner: Rc<Cell<sys::napi_async_work>> = Rc::new(Cell::new(ptr::null_mut()));
let task_status = Rc::new(Cell::new(0));
let abort_cbs = Rc::new(RefCell::new(vec![]));
let abort_signal = AbortSignal {
raw_work: async_work_inner.clone(),
status: task_status.clone(),
abort: abort_cbs.clone(),
};
let js_env = Env::from_raw(env);
@ -111,6 +123,7 @@ impl FromNapiValue for AbortSignal {
Ok(AbortSignal {
raw_work: async_work_inner,
status: task_status,
abort: abort_cbs,
})
}
}
@ -153,6 +166,11 @@ fn on_abort_impl(
)?;
let abort_controller_stack = Box::leak(Box::from_raw(async_task as *mut AbortSignalStack));
for abort_controller in abort_controller_stack.0.iter() {
// call abort callback
for cb in abort_controller.abort.borrow().iter() {
cb();
}
// Task Completed, return now
if abort_controller.status.get() == 1 {
return Ok(ptr::null_mut());

View File

@ -1234,6 +1234,8 @@ Generated by [AVA](https://avajs.dev).
export declare function withAbortController(a: number, b: number, signal: AbortSignal): Promise<number>
export declare function withAbortSignalHandle(signal: AbortSignal): Promise<number>
export declare function withinAsyncRuntimeIfAvailable(): void␊
export declare function withoutAbortController(a: number, b: number): Promise<number>

View File

@ -287,6 +287,7 @@ import {
indexSetToRust,
indexSetToJs,
intoUtf8,
withAbortSignalHandle,
} from '../index.cjs'
// import other stuff in `#[napi(module_exports)]`
import nativeAddon from '../index.cjs'
@ -1460,6 +1461,20 @@ test('async task with different resolved values', async (t) => {
t.deepEqual(r2, [0, 1])
})
AbortSignalTest('with abort signal handle', async (t) => {
const ctrl = new AbortController()
const promise = withAbortSignalHandle(ctrl.signal)
try {
ctrl.abort()
const ret = await promise
t.is(ret, 999)
} catch (err: unknown) {
// sometimes on CI, the scheduled task is able to abort
// so we only allow it to throw AbortError
t.is((err as Error).message, 'AbortError')
}
})
AbortSignalTest('abort resolved task', async (t) => {
const ctrl = new AbortController()
await withAbortController(1, 2, ctrl.signal).then(() => ctrl.abort())

View File

@ -406,6 +406,7 @@ export const validateTypedArraySlice = __napiModule.exports.validateTypedArraySl
export const validateUint8ClampedSlice = __napiModule.exports.validateUint8ClampedSlice
export const validateUndefined = __napiModule.exports.validateUndefined
export const withAbortController = __napiModule.exports.withAbortController
export const withAbortSignalHandle = __napiModule.exports.withAbortSignalHandle
export const withinAsyncRuntimeIfAvailable = __napiModule.exports.withinAsyncRuntimeIfAvailable
export const withoutAbortController = __napiModule.exports.withoutAbortController
export const xxh64Alias = __napiModule.exports.xxh64Alias

View File

@ -451,6 +451,7 @@ module.exports.validateTypedArraySlice = __napiModule.exports.validateTypedArray
module.exports.validateUint8ClampedSlice = __napiModule.exports.validateUint8ClampedSlice
module.exports.validateUndefined = __napiModule.exports.validateUndefined
module.exports.withAbortController = __napiModule.exports.withAbortController
module.exports.withAbortSignalHandle = __napiModule.exports.withAbortSignalHandle
module.exports.withinAsyncRuntimeIfAvailable = __napiModule.exports.withinAsyncRuntimeIfAvailable
module.exports.withoutAbortController = __napiModule.exports.withoutAbortController
module.exports.xxh64Alias = __napiModule.exports.xxh64Alias

View File

@ -108,7 +108,24 @@ function requireNative() {
}
} else if (process.platform === 'win32') {
if (process.arch === 'x64') {
if (process.report?.getReport?.()?.header?.osName?.startsWith?.('MINGW')) {
try {
return require('./example.win32-x64-gnu.node')
} catch (e) {
loadErrors.push(e)
}
try {
const binding = require('@examples/napi-win32-x64-gnu')
const bindingPackageVersion = require('@examples/napi-win32-x64-gnu/package.json').version
if (bindingPackageVersion !== '0.0.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') {
throw new Error(`Native binding package version mismatch, expected 0.0.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`)
}
return binding
} catch (e) {
loadErrors.push(e)
}
} else {
try {
return require('./example.win32-x64-msvc.node')
} catch (e) {
loadErrors.push(e)
@ -123,6 +140,7 @@ function requireNative() {
} catch (e) {
loadErrors.push(e)
}
}
} else if (process.arch === 'ia32') {
try {
return require('./example.win32-ia32-msvc.node')
@ -900,6 +918,7 @@ module.exports.validateTypedArraySlice = nativeBinding.validateTypedArraySlice
module.exports.validateUint8ClampedSlice = nativeBinding.validateUint8ClampedSlice
module.exports.validateUndefined = nativeBinding.validateUndefined
module.exports.withAbortController = nativeBinding.withAbortController
module.exports.withAbortSignalHandle = nativeBinding.withAbortSignalHandle
module.exports.withinAsyncRuntimeIfAvailable = nativeBinding.withinAsyncRuntimeIfAvailable
module.exports.withoutAbortController = nativeBinding.withoutAbortController
module.exports.xxh64Alias = nativeBinding.xxh64Alias

View File

@ -1195,6 +1195,8 @@ export type VoidNullable<T = void> =
export declare function withAbortController(a: number, b: number, signal: AbortSignal): Promise<number>
export declare function withAbortSignalHandle(signal: AbortSignal): Promise<number>
export declare function withinAsyncRuntimeIfAvailable(): void
export declare function withoutAbortController(a: number, b: number): Promise<number>

View File

@ -1,7 +1,34 @@
use std::thread::sleep;
use std::{sync::mpsc, thread::sleep};
use napi::{bindgen_prelude::*, ScopedTask};
pub struct SimpleTask {
receiver: mpsc::Receiver<i32>,
}
#[napi]
impl napi::Task for SimpleTask {
type Output = i32;
type JsValue = i32;
fn compute(&mut self) -> Result<Self::Output> {
self.receiver.recv().map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Channel receive error: {}", e),
)
})
}
fn resolve(&mut self, _env: napi::Env, output: Self::Output) -> Result<Self::JsValue> {
Ok(output)
}
fn finally(self, _env: napi::Env) -> Result<()> {
Ok(())
}
}
pub struct DelaySum(u32, u32);
#[napi]
@ -33,6 +60,13 @@ pub fn with_abort_controller(a: u32, b: u32, signal: AbortSignal) -> AsyncTask<D
AsyncTask::with_signal(DelaySum(a, b), signal)
}
#[napi]
fn with_abort_signal_handle(signal: AbortSignal) -> AsyncTask<SimpleTask> {
let (sender, receiver) = mpsc::channel::<i32>();
signal.on_abort(move || sender.send(999).unwrap());
AsyncTask::with_signal(SimpleTask { receiver }, signal)
}
struct AsyncTaskVoidReturn {}
#[napi]