diff --git a/crates/napi/src/bindgen_runtime/js_values/promise.rs b/crates/napi/src/bindgen_runtime/js_values/promise.rs index 3df56b1f..874a1663 100644 --- a/crates/napi/src/bindgen_runtime/js_values/promise.rs +++ b/crates/napi/src/bindgen_runtime/js_values/promise.rs @@ -9,7 +9,7 @@ use tokio::sync::oneshot::{channel, Receiver}; use crate::{sys, Error, Result, Status}; -use super::{FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue}; +use super::{CallbackContext, FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue}; /// The JavaScript Promise object representation /// @@ -29,7 +29,7 @@ use super::{FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue}; /// /// But this `Promise` can not be pass back to `JavaScript`. /// If you want to use raw JavaScript `Promise` API, you can use the [`PromiseRaw`](./PromiseRaw) instead. -pub struct Promise { +pub struct Promise { value: Pin>>>, } @@ -63,17 +63,17 @@ impl FromNapiValue for Promise { let tx_box = Arc::new(Cell::new(Some(tx))); let tx_in_catch = tx_box.clone(); promise_object - .then(move |value| { + .then(move |ctx| { if let Some(sender) = tx_box.replace(None) { // no need to handle the send error here, the receiver has been dropped - let _ = sender.send(Ok(value)); + let _ = sender.send(Ok(ctx.value)); } Ok(()) })? - .catch(move |err: Unknown| { + .catch(move |ctx: CallbackContext| { if let Some(sender) = tx_in_catch.replace(None) { // no need to handle the send error here, the receiver has been dropped - let _ = sender.send(Err(err.into())); + let _ = sender.send(Err(ctx.value.into())); } Ok(()) })?; diff --git a/crates/napi/src/bindgen_runtime/js_values/promise_raw.rs b/crates/napi/src/bindgen_runtime/js_values/promise_raw.rs index a6c52c9f..1c9d042a 100644 --- a/crates/napi/src/bindgen_runtime/js_values/promise_raw.rs +++ b/crates/napi/src/bindgen_runtime/js_values/promise_raw.rs @@ -4,11 +4,11 @@ use std::ptr; #[cfg(all(feature = "napi4", feature = "tokio_rt"))] use crate::bindgen_runtime::Promise; -use crate::NapiRaw; use crate::{ bindgen_prelude::{FromNapiValue, Result, ToNapiValue, TypeName, ValidateNapiValue}, check_status, sys, }; +use crate::{Env, Error, NapiRaw, NapiValue, Status}; pub struct PromiseRaw { pub(crate) inner: sys::napi_value, @@ -31,12 +31,12 @@ impl PromiseRaw { pub fn then(&mut self, cb: Callback) -> Result> where U: ToNapiValue, - Callback: FnOnce(T) -> Result, + Callback: 'static + FnOnce(CallbackContext) -> Result, { let mut then_fn = ptr::null_mut(); - let then_c_string = unsafe { CStr::from_bytes_with_nul_unchecked(b"then\0") }; + const THEN: &[u8; 5] = b"then\0"; check_status!(unsafe { - sys::napi_get_named_property(self.env, self.inner, then_c_string.as_ptr(), &mut then_fn) + sys::napi_get_named_property(self.env, self.inner, THEN.as_ptr().cast(), &mut then_fn) })?; let mut then_callback = ptr::null_mut(); let rust_cb = Box::into_raw(Box::new(cb)); @@ -44,7 +44,7 @@ impl PromiseRaw { unsafe { sys::napi_create_function( self.env, - then_c_string.as_ptr(), + THEN.as_ptr().cast(), 4, Some(raw_promise_then_callback::), rust_cb.cast(), @@ -80,23 +80,19 @@ impl PromiseRaw { where E: FromNapiValue, U: ToNapiValue, - Callback: FnOnce(E) -> Result, + Callback: 'static + FnOnce(CallbackContext) -> Result, { let mut catch_fn = ptr::null_mut(); + const CATCH: &[u8; 6] = b"catch\0"; check_status!(unsafe { - sys::napi_get_named_property( - self.env, - self.inner, - "catch\0".as_ptr().cast(), - &mut catch_fn, - ) + sys::napi_get_named_property(self.env, self.inner, CATCH.as_ptr().cast(), &mut catch_fn) })?; let mut catch_callback = ptr::null_mut(); let rust_cb = Box::into_raw(Box::new(cb)); check_status!(unsafe { sys::napi_create_function( self.env, - "catch\0".as_ptr().cast(), + CATCH.as_ptr().cast(), 5, Some(raw_promise_catch_callback::), rust_cb.cast(), @@ -122,6 +118,55 @@ impl PromiseRaw { }) } + /// Promise.finally method + pub fn finally(&mut self, cb: Callback) -> Result> + where + U: ToNapiValue, + Callback: 'static + FnOnce(Env) -> Result, + { + let mut then_fn = ptr::null_mut(); + const FINALLY: &[u8; 8] = b"finally\0"; + + check_status!(unsafe { + sys::napi_get_named_property(self.env, self.inner, FINALLY.as_ptr().cast(), &mut then_fn) + })?; + let mut then_callback = ptr::null_mut(); + let rust_cb = Box::into_raw(Box::new(cb)); + check_status!( + unsafe { + sys::napi_create_function( + self.env, + FINALLY.as_ptr().cast(), + 7, + Some(raw_promise_finally_callback::), + rust_cb.cast(), + &mut then_callback, + ) + }, + "Create then function for PromiseRaw failed" + )?; + let mut new_promise = ptr::null_mut(); + check_status!( + unsafe { + sys::napi_call_function( + self.env, + self.inner, + then_fn, + 1, + [then_callback].as_ptr(), + &mut new_promise, + ) + }, + "Call then callback on PromiseRaw failed" + )?; + + Ok(Self { + env: self.env, + inner: new_promise, + _phantom: PhantomData, + }) + } + #[cfg(all(feature = "napi4", feature = "tokio_rt"))] /// Convert `PromiseRaw` to `Promise` /// @@ -156,6 +201,28 @@ impl NapiRaw for PromiseRaw { } } +impl NapiValue for PromiseRaw { + unsafe fn from_raw(env: napi_sys::napi_env, value: napi_sys::napi_value) -> Result { + let mut is_promise = false; + check_status!(unsafe { sys::napi_is_promise(env, value, &mut is_promise) })?; + is_promise + .then_some(Self { + env, + inner: value, + _phantom: PhantomData, + }) + .ok_or_else(|| Error::new(Status::InvalidArg, "JavaScript value is not Promise")) + } + + unsafe fn from_raw_unchecked(env: napi_sys::napi_env, value: napi_sys::napi_value) -> Self { + Self { + env, + inner: value, + _phantom: PhantomData, + } + } +} + pub(crate) fn validate_promise( env: napi_sys::napi_env, napi_val: napi_sys::napi_value, @@ -210,16 +277,6 @@ pub(crate) fn validate_promise( Ok(ptr::null_mut()) } -impl FromNapiValue for PromiseRaw { - unsafe fn from_napi_value(env: sys::napi_env, napi_val: sys::napi_value) -> crate::Result { - Ok(PromiseRaw { - inner: napi_val, - env, - _phantom: PhantomData, - }) - } -} - unsafe extern "C" fn raw_promise_then_callback( env: sys::napi_env, cbinfo: sys::napi_callback_info, @@ -227,19 +284,13 @@ unsafe extern "C" fn raw_promise_then_callback( where T: FromNapiValue, U: ToNapiValue, - Cb: FnOnce(T) -> Result, + Cb: FnOnce(CallbackContext) -> Result, { - match handle_then_callback::(env, cbinfo) { - Ok(v) => v, - Err(err) => { - let code = CString::new(err.status.as_ref()).unwrap(); - let msg = CString::new(err.reason).unwrap(); - unsafe { sys::napi_throw_error(env, code.as_ptr(), msg.as_ptr()) }; - ptr::null_mut() - } - } + handle_then_callback::(env, cbinfo) + .unwrap_or_else(|err| throw_error(env, err, "Error in Promise.then")) } +#[inline(always)] fn handle_then_callback( env: sys::napi_env, cbinfo: sys::napi_callback_info, @@ -247,7 +298,7 @@ fn handle_then_callback( where T: FromNapiValue, U: ToNapiValue, - Cb: FnOnce(T) -> Result, + Cb: FnOnce(CallbackContext) -> Result, { let mut callback_values = [ptr::null_mut()]; let mut rust_cb = ptr::null_mut(); @@ -267,7 +318,15 @@ where let then_value: T = unsafe { FromNapiValue::from_napi_value(env, callback_values[0]) }?; let cb: Box = unsafe { Box::from_raw(rust_cb.cast()) }; - unsafe { U::to_napi_value(env, cb(then_value)?) } + unsafe { + U::to_napi_value( + env, + cb(CallbackContext { + env: Env(env), + value: then_value, + })?, + ) + } } unsafe extern "C" fn raw_promise_catch_callback( @@ -277,19 +336,13 @@ unsafe extern "C" fn raw_promise_catch_callback( where E: FromNapiValue, U: ToNapiValue, - Cb: FnOnce(E) -> Result, + Cb: FnOnce(CallbackContext) -> Result, { - match handle_catch_callback::(env, cbinfo) { - Ok(v) => v, - Err(err) => { - let code = CString::new(err.status.as_ref()).unwrap(); - let msg = CString::new(err.reason).unwrap(); - unsafe { sys::napi_throw_error(env, code.as_ptr(), msg.as_ptr()) }; - ptr::null_mut() - } - } + handle_catch_callback::(env, cbinfo) + .unwrap_or_else(|err| throw_error(env, err, "Error in Promise.catch")) } +#[inline(always)] fn handle_catch_callback( env: sys::napi_env, cbinfo: sys::napi_callback_info, @@ -297,7 +350,7 @@ fn handle_catch_callback( where E: FromNapiValue, U: ToNapiValue, - Cb: FnOnce(E) -> Result, + Cb: FnOnce(CallbackContext) -> Result, { let mut callback_values = [ptr::null_mut(); 1]; let mut rust_cb = ptr::null_mut(); @@ -317,5 +370,84 @@ where let catch_value: E = unsafe { FromNapiValue::from_napi_value(env, callback_values[0]) }?; let cb: Box = unsafe { Box::from_raw(rust_cb.cast()) }; - unsafe { U::to_napi_value(env, cb(catch_value)?) } + unsafe { + U::to_napi_value( + env, + cb(CallbackContext { + env: Env(env), + value: catch_value, + })?, + ) + } +} + +unsafe extern "C" fn raw_promise_finally_callback( + env: sys::napi_env, + cbinfo: sys::napi_callback_info, +) -> sys::napi_value +where + U: ToNapiValue, + Cb: FnOnce(Env) -> Result, +{ + handle_finally_callback::(env, cbinfo) + .unwrap_or_else(|err| throw_error(env, err, "Error in Promise.finally")) +} + +#[inline(always)] +fn handle_finally_callback( + env: sys::napi_env, + cbinfo: sys::napi_callback_info, +) -> Result +where + U: ToNapiValue, + Cb: FnOnce(Env) -> Result, +{ + let mut rust_cb = ptr::null_mut(); + check_status!( + unsafe { + sys::napi_get_cb_info( + env, + cbinfo, + &mut 0, + ptr::null_mut(), + ptr::null_mut(), + &mut rust_cb, + ) + }, + "Get callback info from finally callback failed" + )?; + let cb: Box = unsafe { Box::from_raw(rust_cb.cast()) }; + + unsafe { U::to_napi_value(env, cb(Env(env))?) } +} + +pub struct CallbackContext { + pub env: Env, + pub value: T, +} + +impl ToNapiValue for CallbackContext { + unsafe fn to_napi_value(env: napi_sys::napi_env, val: Self) -> Result { + T::to_napi_value(env, val.value) + } +} + +#[inline(never)] +fn throw_error(env: sys::napi_env, err: Error, default_msg: &str) -> sys::napi_value { + let code = if err.status.as_ref().is_empty() { + CString::new(Status::GenericFailure.as_ref()) + } else { + CString::new(err.status.as_ref()) + } + .map(|s| s.as_ptr()) + .unwrap_or(ptr::null_mut()); + let msg = if err.reason.is_empty() { + CString::new(default_msg) + } else { + CString::new(err.reason) + } + .map(|s| s.as_ptr()) + .unwrap_or(ptr::null_mut()); + unsafe { sys::napi_throw_error(env, code, msg) }; + ptr::null_mut() } diff --git a/examples/napi/__tests__/__snapshots__/typegen.spec.ts.md b/examples/napi/__tests__/__snapshots__/typegen.spec.ts.md index 4ee610be..0670e58f 100644 --- a/examples/napi/__tests__/__snapshots__/typegen.spec.ts.md +++ b/examples/napi/__tests__/__snapshots__/typegen.spec.ts.md @@ -340,6 +340,8 @@ Generated by [AVA](https://avajs.dev). ␊ export declare function callCatchOnPromise(input: PromiseRaw): PromiseRaw␊ ␊ + export declare function callFinallyOnPromise(input: PromiseRaw, onFinally: () => void): PromiseRaw␊ + ␊ export declare function callFunction(cb: () => number): number␊ ␊ export declare function callFunctionWithArg(cb: (arg0: number, arg1: number) => number, arg0: number, arg1: number): number␊ diff --git a/examples/napi/__tests__/__snapshots__/typegen.spec.ts.snap b/examples/napi/__tests__/__snapshots__/typegen.spec.ts.snap index 469bb6c9..c1b1ff72 100644 Binary files a/examples/napi/__tests__/__snapshots__/typegen.spec.ts.snap and b/examples/napi/__tests__/__snapshots__/typegen.spec.ts.snap differ diff --git a/examples/napi/__tests__/values.spec.ts b/examples/napi/__tests__/values.spec.ts index 789fe705..69669a66 100644 --- a/examples/napi/__tests__/values.spec.ts +++ b/examples/napi/__tests__/values.spec.ts @@ -3,7 +3,7 @@ import { join } from 'node:path' import { fileURLToPath } from 'node:url' import { Subject, take } from 'rxjs' -import { spy } from 'sinon' +import Sinon, { spy } from 'sinon' import { DEFAULT_COST, @@ -185,6 +185,7 @@ import { uInit8ArrayFromString, callThenOnPromise, callCatchOnPromise, + callFinallyOnPromise, } from '../index.cjs' import { test } from './test.framework.js' @@ -498,6 +499,9 @@ test('promise', async (t) => { t.is(res, '1') const cat = await callCatchOnPromise(Promise.reject('cat')) t.is(cat, 'cat') + const spy = Sinon.spy() + await callFinallyOnPromise(Promise.resolve(1), spy) + t.true(spy.calledOnce) }) test('object', (t) => { diff --git a/examples/napi/index.cjs b/examples/napi/index.cjs index 88cf408d..0d71884b 100644 --- a/examples/napi/index.cjs +++ b/examples/napi/index.cjs @@ -435,6 +435,7 @@ module.exports.call2 = nativeBinding.call2 module.exports.callbackReturnPromise = nativeBinding.callbackReturnPromise module.exports.callbackReturnPromiseAndSpawn = nativeBinding.callbackReturnPromiseAndSpawn module.exports.callCatchOnPromise = nativeBinding.callCatchOnPromise +module.exports.callFinallyOnPromise = nativeBinding.callFinallyOnPromise module.exports.callFunction = nativeBinding.callFunction module.exports.callFunctionWithArg = nativeBinding.callFunctionWithArg module.exports.callFunctionWithArgAndCtx = nativeBinding.callFunctionWithArgAndCtx diff --git a/examples/napi/index.d.cts b/examples/napi/index.d.cts index 637d9c2e..51f21488 100644 --- a/examples/napi/index.d.cts +++ b/examples/napi/index.d.cts @@ -330,6 +330,8 @@ export declare function callbackReturnPromiseAndSpawn(jsFunc: (arg0: string) => export declare function callCatchOnPromise(input: PromiseRaw): PromiseRaw +export declare function callFinallyOnPromise(input: PromiseRaw, onFinally: () => void): PromiseRaw + export declare function callFunction(cb: () => number): number export declare function callFunctionWithArg(cb: (arg0: number, arg1: number) => number, arg0: number, arg1: number): number diff --git a/examples/napi/src/promise.rs b/examples/napi/src/promise.rs index 69c9ac82..47e1e6af 100644 --- a/examples/napi/src/promise.rs +++ b/examples/napi/src/promise.rs @@ -8,10 +8,21 @@ pub async fn async_plus_100(p: Promise) -> Result { #[napi] pub fn call_then_on_promise(mut input: PromiseRaw) -> Result> { - input.then(|v| Ok(format!("{}", v))) + input.then(|v| Ok(format!("{}", v.value))) } #[napi] pub fn call_catch_on_promise(mut input: PromiseRaw) -> Result> { - input.catch(|e: String| Ok(e)) + input.catch(|e: CallbackContext| Ok(e.value)) +} + +#[napi] +pub fn call_finally_on_promise( + mut input: PromiseRaw, + on_finally: FunctionRef<(), ()>, +) -> Result> { + input.finally(move |env| { + on_finally.borrow_back(&env)?.call(())?; + Ok(()) + }) }