fix(napi): Promise callbacks should require static lifetime (#2172)

This commit is contained in:
LongYinan 2024-07-07 20:42:17 +08:00 committed by GitHub
parent 27030c8dae
commit 3a511bacee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 208 additions and 56 deletions

View File

@ -9,7 +9,7 @@ use tokio::sync::oneshot::{channel, Receiver};
use crate::{sys, Error, Result, Status};
use super::{FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue};
use super::{CallbackContext, FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue};
/// The JavaScript Promise object representation
///
@ -29,7 +29,7 @@ use super::{FromNapiValue, PromiseRaw, TypeName, Unknown, ValidateNapiValue};
///
/// But this `Promise<T>` can not be pass back to `JavaScript`.
/// If you want to use raw JavaScript `Promise` API, you can use the [`PromiseRaw`](./PromiseRaw) instead.
pub struct Promise<T: FromNapiValue> {
pub struct Promise<T: 'static + FromNapiValue> {
value: Pin<Box<Receiver<Result<T>>>>,
}
@ -63,17 +63,17 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
let tx_box = Arc::new(Cell::new(Some(tx)));
let tx_in_catch = tx_box.clone();
promise_object
.then(move |value| {
.then(move |ctx| {
if let Some(sender) = tx_box.replace(None) {
// no need to handle the send error here, the receiver has been dropped
let _ = sender.send(Ok(value));
let _ = sender.send(Ok(ctx.value));
}
Ok(())
})?
.catch(move |err: Unknown| {
.catch(move |ctx: CallbackContext<Unknown>| {
if let Some(sender) = tx_in_catch.replace(None) {
// no need to handle the send error here, the receiver has been dropped
let _ = sender.send(Err(err.into()));
let _ = sender.send(Err(ctx.value.into()));
}
Ok(())
})?;

View File

@ -4,11 +4,11 @@ use std::ptr;
#[cfg(all(feature = "napi4", feature = "tokio_rt"))]
use crate::bindgen_runtime::Promise;
use crate::NapiRaw;
use crate::{
bindgen_prelude::{FromNapiValue, Result, ToNapiValue, TypeName, ValidateNapiValue},
check_status, sys,
};
use crate::{Env, Error, NapiRaw, NapiValue, Status};
pub struct PromiseRaw<T> {
pub(crate) inner: sys::napi_value,
@ -31,12 +31,12 @@ impl<T: FromNapiValue> PromiseRaw<T> {
pub fn then<Callback, U>(&mut self, cb: Callback) -> Result<PromiseRaw<U>>
where
U: ToNapiValue,
Callback: FnOnce(T) -> Result<U>,
Callback: 'static + FnOnce(CallbackContext<T>) -> Result<U>,
{
let mut then_fn = ptr::null_mut();
let then_c_string = unsafe { CStr::from_bytes_with_nul_unchecked(b"then\0") };
const THEN: &[u8; 5] = b"then\0";
check_status!(unsafe {
sys::napi_get_named_property(self.env, self.inner, then_c_string.as_ptr(), &mut then_fn)
sys::napi_get_named_property(self.env, self.inner, THEN.as_ptr().cast(), &mut then_fn)
})?;
let mut then_callback = ptr::null_mut();
let rust_cb = Box::into_raw(Box::new(cb));
@ -44,7 +44,7 @@ impl<T: FromNapiValue> PromiseRaw<T> {
unsafe {
sys::napi_create_function(
self.env,
then_c_string.as_ptr(),
THEN.as_ptr().cast(),
4,
Some(raw_promise_then_callback::<T, U, Callback>),
rust_cb.cast(),
@ -80,23 +80,19 @@ impl<T: FromNapiValue> PromiseRaw<T> {
where
E: FromNapiValue,
U: ToNapiValue,
Callback: FnOnce(E) -> Result<U>,
Callback: 'static + FnOnce(CallbackContext<E>) -> Result<U>,
{
let mut catch_fn = ptr::null_mut();
const CATCH: &[u8; 6] = b"catch\0";
check_status!(unsafe {
sys::napi_get_named_property(
self.env,
self.inner,
"catch\0".as_ptr().cast(),
&mut catch_fn,
)
sys::napi_get_named_property(self.env, self.inner, CATCH.as_ptr().cast(), &mut catch_fn)
})?;
let mut catch_callback = ptr::null_mut();
let rust_cb = Box::into_raw(Box::new(cb));
check_status!(unsafe {
sys::napi_create_function(
self.env,
"catch\0".as_ptr().cast(),
CATCH.as_ptr().cast(),
5,
Some(raw_promise_catch_callback::<E, U, Callback>),
rust_cb.cast(),
@ -122,6 +118,55 @@ impl<T: FromNapiValue> PromiseRaw<T> {
})
}
/// Promise.finally method
pub fn finally<U, Callback>(&mut self, cb: Callback) -> Result<PromiseRaw<T>>
where
U: ToNapiValue,
Callback: 'static + FnOnce(Env) -> Result<U>,
{
let mut then_fn = ptr::null_mut();
const FINALLY: &[u8; 8] = b"finally\0";
check_status!(unsafe {
sys::napi_get_named_property(self.env, self.inner, FINALLY.as_ptr().cast(), &mut then_fn)
})?;
let mut then_callback = ptr::null_mut();
let rust_cb = Box::into_raw(Box::new(cb));
check_status!(
unsafe {
sys::napi_create_function(
self.env,
FINALLY.as_ptr().cast(),
7,
Some(raw_promise_finally_callback::<U, Callback>),
rust_cb.cast(),
&mut then_callback,
)
},
"Create then function for PromiseRaw failed"
)?;
let mut new_promise = ptr::null_mut();
check_status!(
unsafe {
sys::napi_call_function(
self.env,
self.inner,
then_fn,
1,
[then_callback].as_ptr(),
&mut new_promise,
)
},
"Call then callback on PromiseRaw failed"
)?;
Ok(Self {
env: self.env,
inner: new_promise,
_phantom: PhantomData,
})
}
#[cfg(all(feature = "napi4", feature = "tokio_rt"))]
/// Convert `PromiseRaw<T>` to `Promise<T>`
///
@ -156,6 +201,28 @@ impl<T> NapiRaw for PromiseRaw<T> {
}
}
impl<T> NapiValue for PromiseRaw<T> {
unsafe fn from_raw(env: napi_sys::napi_env, value: napi_sys::napi_value) -> Result<Self> {
let mut is_promise = false;
check_status!(unsafe { sys::napi_is_promise(env, value, &mut is_promise) })?;
is_promise
.then_some(Self {
env,
inner: value,
_phantom: PhantomData,
})
.ok_or_else(|| Error::new(Status::InvalidArg, "JavaScript value is not Promise"))
}
unsafe fn from_raw_unchecked(env: napi_sys::napi_env, value: napi_sys::napi_value) -> Self {
Self {
env,
inner: value,
_phantom: PhantomData,
}
}
}
pub(crate) fn validate_promise(
env: napi_sys::napi_env,
napi_val: napi_sys::napi_value,
@ -210,16 +277,6 @@ pub(crate) fn validate_promise(
Ok(ptr::null_mut())
}
impl<T> FromNapiValue for PromiseRaw<T> {
unsafe fn from_napi_value(env: sys::napi_env, napi_val: sys::napi_value) -> crate::Result<Self> {
Ok(PromiseRaw {
inner: napi_val,
env,
_phantom: PhantomData,
})
}
}
unsafe extern "C" fn raw_promise_then_callback<T, U, Cb>(
env: sys::napi_env,
cbinfo: sys::napi_callback_info,
@ -227,19 +284,13 @@ unsafe extern "C" fn raw_promise_then_callback<T, U, Cb>(
where
T: FromNapiValue,
U: ToNapiValue,
Cb: FnOnce(T) -> Result<U>,
Cb: FnOnce(CallbackContext<T>) -> Result<U>,
{
match handle_then_callback::<T, U, Cb>(env, cbinfo) {
Ok(v) => v,
Err(err) => {
let code = CString::new(err.status.as_ref()).unwrap();
let msg = CString::new(err.reason).unwrap();
unsafe { sys::napi_throw_error(env, code.as_ptr(), msg.as_ptr()) };
ptr::null_mut()
}
}
handle_then_callback::<T, U, Cb>(env, cbinfo)
.unwrap_or_else(|err| throw_error(env, err, "Error in Promise.then"))
}
#[inline(always)]
fn handle_then_callback<T, U, Cb>(
env: sys::napi_env,
cbinfo: sys::napi_callback_info,
@ -247,7 +298,7 @@ fn handle_then_callback<T, U, Cb>(
where
T: FromNapiValue,
U: ToNapiValue,
Cb: FnOnce(T) -> Result<U>,
Cb: FnOnce(CallbackContext<T>) -> Result<U>,
{
let mut callback_values = [ptr::null_mut()];
let mut rust_cb = ptr::null_mut();
@ -267,7 +318,15 @@ where
let then_value: T = unsafe { FromNapiValue::from_napi_value(env, callback_values[0]) }?;
let cb: Box<Cb> = unsafe { Box::from_raw(rust_cb.cast()) };
unsafe { U::to_napi_value(env, cb(then_value)?) }
unsafe {
U::to_napi_value(
env,
cb(CallbackContext {
env: Env(env),
value: then_value,
})?,
)
}
}
unsafe extern "C" fn raw_promise_catch_callback<E, U, Cb>(
@ -277,19 +336,13 @@ unsafe extern "C" fn raw_promise_catch_callback<E, U, Cb>(
where
E: FromNapiValue,
U: ToNapiValue,
Cb: FnOnce(E) -> Result<U>,
Cb: FnOnce(CallbackContext<E>) -> Result<U>,
{
match handle_catch_callback::<E, U, Cb>(env, cbinfo) {
Ok(v) => v,
Err(err) => {
let code = CString::new(err.status.as_ref()).unwrap();
let msg = CString::new(err.reason).unwrap();
unsafe { sys::napi_throw_error(env, code.as_ptr(), msg.as_ptr()) };
ptr::null_mut()
}
}
handle_catch_callback::<E, U, Cb>(env, cbinfo)
.unwrap_or_else(|err| throw_error(env, err, "Error in Promise.catch"))
}
#[inline(always)]
fn handle_catch_callback<E, U, Cb>(
env: sys::napi_env,
cbinfo: sys::napi_callback_info,
@ -297,7 +350,7 @@ fn handle_catch_callback<E, U, Cb>(
where
E: FromNapiValue,
U: ToNapiValue,
Cb: FnOnce(E) -> Result<U>,
Cb: FnOnce(CallbackContext<E>) -> Result<U>,
{
let mut callback_values = [ptr::null_mut(); 1];
let mut rust_cb = ptr::null_mut();
@ -317,5 +370,84 @@ where
let catch_value: E = unsafe { FromNapiValue::from_napi_value(env, callback_values[0]) }?;
let cb: Box<Cb> = unsafe { Box::from_raw(rust_cb.cast()) };
unsafe { U::to_napi_value(env, cb(catch_value)?) }
unsafe {
U::to_napi_value(
env,
cb(CallbackContext {
env: Env(env),
value: catch_value,
})?,
)
}
}
unsafe extern "C" fn raw_promise_finally_callback<U, Cb>(
env: sys::napi_env,
cbinfo: sys::napi_callback_info,
) -> sys::napi_value
where
U: ToNapiValue,
Cb: FnOnce(Env) -> Result<U>,
{
handle_finally_callback::<U, Cb>(env, cbinfo)
.unwrap_or_else(|err| throw_error(env, err, "Error in Promise.finally"))
}
#[inline(always)]
fn handle_finally_callback<U, Cb>(
env: sys::napi_env,
cbinfo: sys::napi_callback_info,
) -> Result<sys::napi_value>
where
U: ToNapiValue,
Cb: FnOnce(Env) -> Result<U>,
{
let mut rust_cb = ptr::null_mut();
check_status!(
unsafe {
sys::napi_get_cb_info(
env,
cbinfo,
&mut 0,
ptr::null_mut(),
ptr::null_mut(),
&mut rust_cb,
)
},
"Get callback info from finally callback failed"
)?;
let cb: Box<Cb> = unsafe { Box::from_raw(rust_cb.cast()) };
unsafe { U::to_napi_value(env, cb(Env(env))?) }
}
pub struct CallbackContext<T> {
pub env: Env,
pub value: T,
}
impl<T: ToNapiValue> ToNapiValue for CallbackContext<T> {
unsafe fn to_napi_value(env: napi_sys::napi_env, val: Self) -> Result<napi_sys::napi_value> {
T::to_napi_value(env, val.value)
}
}
#[inline(never)]
fn throw_error(env: sys::napi_env, err: Error, default_msg: &str) -> sys::napi_value {
let code = if err.status.as_ref().is_empty() {
CString::new(Status::GenericFailure.as_ref())
} else {
CString::new(err.status.as_ref())
}
.map(|s| s.as_ptr())
.unwrap_or(ptr::null_mut());
let msg = if err.reason.is_empty() {
CString::new(default_msg)
} else {
CString::new(err.reason)
}
.map(|s| s.as_ptr())
.unwrap_or(ptr::null_mut());
unsafe { sys::napi_throw_error(env, code, msg) };
ptr::null_mut()
}

View File

@ -340,6 +340,8 @@ Generated by [AVA](https://avajs.dev).
export declare function callCatchOnPromise(input: PromiseRaw<number>): PromiseRaw<string>
export declare function callFinallyOnPromise(input: PromiseRaw<number>, onFinally: () => void): PromiseRaw<number>
export declare function callFunction(cb: () => number): number␊
export declare function callFunctionWithArg(cb: (arg0: number, arg1: number) => number, arg0: number, arg1: number): number␊

View File

@ -3,7 +3,7 @@ import { join } from 'node:path'
import { fileURLToPath } from 'node:url'
import { Subject, take } from 'rxjs'
import { spy } from 'sinon'
import Sinon, { spy } from 'sinon'
import {
DEFAULT_COST,
@ -185,6 +185,7 @@ import {
uInit8ArrayFromString,
callThenOnPromise,
callCatchOnPromise,
callFinallyOnPromise,
} from '../index.cjs'
import { test } from './test.framework.js'
@ -498,6 +499,9 @@ test('promise', async (t) => {
t.is(res, '1')
const cat = await callCatchOnPromise(Promise.reject('cat'))
t.is(cat, 'cat')
const spy = Sinon.spy()
await callFinallyOnPromise(Promise.resolve(1), spy)
t.true(spy.calledOnce)
})
test('object', (t) => {

View File

@ -435,6 +435,7 @@ module.exports.call2 = nativeBinding.call2
module.exports.callbackReturnPromise = nativeBinding.callbackReturnPromise
module.exports.callbackReturnPromiseAndSpawn = nativeBinding.callbackReturnPromiseAndSpawn
module.exports.callCatchOnPromise = nativeBinding.callCatchOnPromise
module.exports.callFinallyOnPromise = nativeBinding.callFinallyOnPromise
module.exports.callFunction = nativeBinding.callFunction
module.exports.callFunctionWithArg = nativeBinding.callFunctionWithArg
module.exports.callFunctionWithArgAndCtx = nativeBinding.callFunctionWithArgAndCtx

View File

@ -330,6 +330,8 @@ export declare function callbackReturnPromiseAndSpawn(jsFunc: (arg0: string) =>
export declare function callCatchOnPromise(input: PromiseRaw<number>): PromiseRaw<string>
export declare function callFinallyOnPromise(input: PromiseRaw<number>, onFinally: () => void): PromiseRaw<number>
export declare function callFunction(cb: () => number): number
export declare function callFunctionWithArg(cb: (arg0: number, arg1: number) => number, arg0: number, arg1: number): number

View File

@ -8,10 +8,21 @@ pub async fn async_plus_100(p: Promise<u32>) -> Result<u32> {
#[napi]
pub fn call_then_on_promise(mut input: PromiseRaw<u32>) -> Result<PromiseRaw<String>> {
input.then(|v| Ok(format!("{}", v)))
input.then(|v| Ok(format!("{}", v.value)))
}
#[napi]
pub fn call_catch_on_promise(mut input: PromiseRaw<u32>) -> Result<PromiseRaw<String>> {
input.catch(|e: String| Ok(e))
input.catch(|e: CallbackContext<String>| Ok(e.value))
}
#[napi]
pub fn call_finally_on_promise(
mut input: PromiseRaw<u32>,
on_finally: FunctionRef<(), ()>,
) -> Result<PromiseRaw<u32>> {
input.finally(move |env| {
on_finally.borrow_back(&env)?.call(())?;
Ok(())
})
}