From 59da77c87f8a356d2b9d39144096365a5833de6d Mon Sep 17 00:00:00 2001 From: vit9696 Date: Sun, 6 Mar 2022 16:33:34 +0300 Subject: [PATCH] OcCryptoLib: Use caller-provided scratch in BigNumCalculateMontParams --- Library/OcCryptoLib/BigNumLib.h | 12 +++++++++++- Library/OcCryptoLib/BigNumMontgomery.c | 16 ++-------------- Library/OcCryptoLib/RsaDigitalSign.c | 8 ++++++-- Utilities/TestRsaPreprocess/RsaPreprocess.c | 11 +++++++++-- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/Library/OcCryptoLib/BigNumLib.h b/Library/OcCryptoLib/BigNumLib.h index 10a113bf..3d593a1b 100644 --- a/Library/OcCryptoLib/BigNumLib.h +++ b/Library/OcCryptoLib/BigNumLib.h @@ -76,12 +76,21 @@ BigNumSwapWord ( // Montgomery arithmetics // +/** + 1 + 2 * NumWords for RSqr, and then twice more than that for Mod. + + @param[in] NumWords The number of Words of RSqrMod and N. +**/ +#define BIG_NUM_MONT_PARAMS_SCRATCH_SIZE(NumWords) \ + ((1 + 2 * NumWords) * 3 * OC_BN_WORD_SIZE) + /** Calculates the Montgomery Inverse and R^2 mod N. @param[in,out] RSqrMod The buffer to return R^2 mod N into. @param[in] NumWords The number of Words of RSqrMod and N. @param[in] N The Montgomery Modulus. + @param[in] Scratch Scratch buffer BIG_NUM_MONT_PARAMS_SCRATCH_SIZE(NumWords). @returns The Montgomery Inverse of N. @@ -90,7 +99,8 @@ OC_BN_WORD BigNumCalculateMontParams ( IN OUT OC_BN_WORD *RSqrMod, IN OC_BN_NUM_WORDS NumWords, - IN CONST OC_BN_WORD *N + IN CONST OC_BN_WORD *N, + IN OC_BN_WORD *Scratch ); /** diff --git a/Library/OcCryptoLib/BigNumMontgomery.c b/Library/OcCryptoLib/BigNumMontgomery.c index d9bb9fb7..7dc331de 100644 --- a/Library/OcCryptoLib/BigNumMontgomery.c +++ b/Library/OcCryptoLib/BigNumMontgomery.c @@ -154,15 +154,14 @@ OC_BN_WORD BigNumCalculateMontParams ( IN OUT OC_BN_WORD *RSqrMod, IN OC_BN_NUM_WORDS NumWords, - IN CONST OC_BN_WORD *N + IN CONST OC_BN_WORD *N, + IN OC_BN_WORD *Scratch ) { OC_BN_WORD N0Inv; UINT32 NumBits; - UINTN SizeScratch; OC_BN_NUM_WORDS NumWordsRSqr; OC_BN_NUM_WORDS NumWordsMod; - OC_BN_WORD *Scratch; OC_BN_WORD *RSqr; ASSERT (RSqrMod != NULL); @@ -191,15 +190,6 @@ BigNumCalculateMontParams ( // NumWordsRSqr = (OC_BN_NUM_WORDS)(1 + 2 * NumWords); NumWordsMod = 2 * NumWordsRSqr; - SizeScratch = (NumWordsRSqr + NumWordsMod) * OC_BN_WORD_SIZE; - if (SizeScratch > OC_BN_MAX_SIZE) { - return 0; - } - - Scratch = AllocatePool (SizeScratch); - if (Scratch == NULL) { - return 0; - } RSqr = Scratch + NumWordsMod; @@ -214,8 +204,6 @@ BigNumCalculateMontParams ( BigNumMod (RSqrMod, NumWords, RSqr, NumWordsRSqr, N, Scratch); - FreePool (Scratch); - return N0Inv; } diff --git a/Library/OcCryptoLib/RsaDigitalSign.c b/Library/OcCryptoLib/RsaDigitalSign.c index 3ba4da20..6c63537d 100644 --- a/Library/OcCryptoLib/RsaDigitalSign.c +++ b/Library/OcCryptoLib/RsaDigitalSign.c @@ -502,6 +502,7 @@ RsaVerifySigDataFromData ( OC_BN_NUM_WORDS ModulusNumWords; VOID *Memory; + VOID *Scratch; OC_BN_WORD *N; OC_BN_WORD *RSqrMod; @@ -529,17 +530,20 @@ RsaVerifySigDataFromData ( "An overflow verification must be added" ); - Memory = AllocatePool (2 * ModulusSize); + Memory = AllocatePool ( + 2 * ModulusSize + BIG_NUM_MONT_PARAMS_SCRATCH_SIZE (ModulusNumWords) + ); if (Memory == NULL) { return FALSE; } N = (OC_BN_WORD *)Memory; RSqrMod = (OC_BN_WORD *)((UINTN)N + ModulusSize); + Scratch = (UINT8 *)Memory + 2 * ModulusSize; BigNumParseBuffer (N, ModulusNumWords, Modulus, ModulusSize); - N0Inv = BigNumCalculateMontParams (RSqrMod, ModulusNumWords, N); + N0Inv = BigNumCalculateMontParams (RSqrMod, ModulusNumWords, N, Scratch); if (N0Inv == 0) { FreePool (Memory); return FALSE; diff --git a/Utilities/TestRsaPreprocess/RsaPreprocess.c b/Utilities/TestRsaPreprocess/RsaPreprocess.c index f139497a..5b538a6e 100644 --- a/Utilities/TestRsaPreprocess/RsaPreprocess.c +++ b/Utilities/TestRsaPreprocess/RsaPreprocess.c @@ -32,15 +32,21 @@ int verifyRsa (CONST OC_RSA_PUBLIC_KEY *PublicKey, char *Name) UINTN ModulusSize = PublicKey->Hdr.NumQwords * sizeof (UINT64); OC_BN_WORD *RSqrMod = malloc(ModulusSize); - if (RSqrMod == NULL) { + OC_BN_WORD *Scratch = malloc( + BIG_NUM_MONT_PARAMS_SCRATCH_SIZE(ModulusSize / OC_BN_WORD_SIZE) + ); + if (RSqrMod == NULL || Scratch == NULL) { printf ("memory allocation error!\n"); + free(RSqrMod); + free(Scratch); return -1; } N0Inv = BigNumCalculateMontParams ( RSqrMod, ModulusSize / OC_BN_WORD_SIZE, - (CONST OC_BN_WORD *) PublicKey->Data + (CONST OC_BN_WORD *) PublicKey->Data, + Scratch ); printf ( @@ -56,6 +62,7 @@ int verifyRsa (CONST OC_RSA_PUBLIC_KEY *PublicKey, char *Name) (unsigned long long) PublicKey->Hdr.N0Inv ); + free(Scratch); free(RSqrMod); return 0; }