opti: cache left_shifts in multiplication

This commit is contained in:
gbrochar 2024-02-18 14:16:28 +01:00
parent fa5c3d7f96
commit c8e5e6bf67
2 changed files with 36 additions and 33 deletions

View File

@ -161,16 +161,17 @@ void bigint_destroy(bigint_t n) {
n.data = NULL;
}
void custom_bigint_add(bigint_t a, bigint_t b, bigint_t result) {
bigint_set_zeros(result);
uint32_t carriage = 0;
void custom_bigint_add(bigint_t a, bigint_t b, int index) {
uint64_t carriage = 0;
for (size_t cursor = 0; cursor < a.len; cursor++) {
uint64_t tmp = (uint64_t)a.data[cursor] + (uint64_t)b.data[cursor] + carriage;
memcpy(result.data + cursor, &tmp, sizeof(uint32_t));
uint64_t tmp = (uint64_t)a.data[cursor] + carriage;
if ((int)cursor - index >= 0) {
tmp += (uint64_t)b.data[cursor - index];
}
a.data[cursor] = (uint32_t)tmp;
carriage = tmp >> 32;
}
memcpy(a.data, result.data, a.len * sizeof(uint32_t));
}
void bigint_set_zeros(bigint_t n) {
@ -179,31 +180,37 @@ void bigint_set_zeros(bigint_t n) {
}
}
void custom_bigint_mul(bigint_t a, bigint_t b, bigint_t result, bigint_t custom) {
//bigint_t b_tool = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 4);
bigint_t b_tool = bigint_zero(a.len + b.len);
//memcpy(b_tool.data + (cursor >> 5), b.data, b.len * sizeof(uint32_t));
size_t size = sizeof(uint32_t) * 8;
int width = a.len * size;
void custom_bigint_mul(bigint_t a, bigint_t b, bigint_t result) {
int width = a.len * 32;
bigint_set_zeros(result);
bigint_t *b_tool = (bigint_t *)malloc(32 * sizeof(bigint_t));
b_tool[0] = bigint_zero(a.len + b.len);
bigint_set_zeros(b_tool[0]);
memcpy(b_tool[0].data, b.data, b.len * sizeof(uint32_t));
for (int i = 1; i < 32; i++) {
b_tool[i] = bigint_zero(a.len + b.len);
bigint_set_zeros(b_tool[i]);
memcpy(b_tool[i].data, b_tool[i - 1].data, b.len * sizeof(uint32_t));
bigint_bitwise_left_shift(b_tool[i]);
}
for (int cursor = 0; cursor < width; cursor++) {
if (a.data[cursor >> 5] >> (cursor % 32) & 1) {
bigint_set_zeros(b_tool);
memcpy(b_tool.data + (cursor >> 5), b.data, b.len * sizeof(uint32_t));
int i = cursor - cursor % 32;
while (i < cursor) {
bigint_bitwise_left_shift(b_tool);
i++;
}
custom_bigint_add(result, b_tool, custom);
int offset = cursor % 32;
int index = cursor >> 5;
if (a.data[index] >> offset & 1) {
custom_bigint_add(result, b_tool[offset], index);
}
}
bigint_destroy(b_tool);
for (int i = 0; i < 32; i++) {
bigint_destroy(b_tool[i]);
}
}
// a^e mod n
// clean memory tricks !!!
void custom_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n, bigint_t result, bigint_t custom, bigint_t custom2, bigint_t custom3) {
void custom_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n, bigint_t result, bigint_t custom, bigint_t custom2) {
bigint_set_zeros(result);
bigint_set_zeros(custom);
bigint_set_zeros(custom2);
@ -215,12 +222,12 @@ void custom_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n, bigint_t result,
}
cursor--;
while (cursor >= 0) {
custom_bigint_mul(result, result, custom, custom3);
custom_bigint_mul(result, result, custom);
custom_bigint_modulo(custom, n, custom2);
bigint_set_zeros(result);
memcpy(result.data, custom2.data, custom2.len * sizeof(uint32_t));
if (e.data[cursor / 32] & 1 << (cursor % 32)) {
custom_bigint_mul(result, a, custom, custom3);
custom_bigint_mul(result, a, custom);
custom_bigint_modulo(custom, n, custom2);
memcpy(result.data, custom2.data, custom2.len * sizeof(uint32_t));
}
@ -269,7 +276,6 @@ bigint_t bigint_prime(size_t len) {
bigint_t y = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
bigint_t custom = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
bigint_t custom2 = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
bigint_t custom3 = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
bigint_t two = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
bigint_t one = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 2);
@ -287,14 +293,13 @@ bigint_t bigint_prime(size_t len) {
while (bigint_cmp(a, two) < 0 || bigint_cmp(a, n_minus_two) > 0) {
bigint_set_random_bytes(a, len);
}
custom_bigint_pow_mod(a, d, n, x, custom, custom2, custom3);
custom_bigint_pow_mod(a, d, n, x, custom, custom2);
for (uint32_t i = 0; i < s; i++) {
custom_bigint_pow_mod(x, two, n, y, custom, custom2, custom3);
custom_bigint_pow_mod(x, two, n, y, custom, custom2);
if (!bigint_dif(y, one) && bigint_dif(x, one) && bigint_dif(x, n_minus_one)) {
bulk_destroy(x, y, n, d, two, one, n_minus_two, n_minus_one);
bigint_destroy(custom);
bigint_destroy(custom2);
bigint_destroy(custom3);
bigint_destroy(a);
return bigint_prime(len);
}
@ -305,14 +310,12 @@ bigint_t bigint_prime(size_t len) {
bulk_destroy(x, y, n, d, two, one, n_minus_two, n_minus_one);
bigint_destroy(custom);
bigint_destroy(custom2);
bigint_destroy(custom3);
bigint_destroy(a);
return bigint_prime(len);
}
}
bulk_destroy(x, y, custom, d, two, one, n_minus_two, n_minus_one);
bigint_destroy(custom2);
bigint_destroy(custom3);
bigint_destroy(a);
return n;
}

View File

@ -9,7 +9,7 @@
#include <unistd.h>
#include <string.h>
#define RSA_BLOCK_SIZE 128
#define RSA_BLOCK_SIZE 256
typedef struct bigint_s {
uint32_t *data;
@ -39,7 +39,7 @@ bigint_t bigint_new(size_t len);
bigint_t bigint_zero(size_t len);
bigint_t bigint_clone(bigint_t src);
void bigint_add(bigint_t a, bigint_t b);
void custom_bigint_add(bigint_t a, bigint_t b, bigint_t result);
void custom_bigint_add(bigint_t a, bigint_t b, int index);
bigint_t assignable_bigint_mul(bigint_t a, bigint_t b);
bigint_t assignable_bigint_modulo(bigint_t a, bigint_t b);
bigint_t assignable_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n);