feat: square and multiply, a^e mod n

This commit is contained in:
gbrochar 2024-02-16 13:32:06 +01:00
parent 5028b0dd9f
commit 3883287e8d
3 changed files with 159 additions and 12 deletions

View File

@ -55,7 +55,6 @@ int bigint_cmp(bigint_t a, bigint_t b) {
uint32_t size = sizeof(uint32_t) * 8;
uint32_t acursor = size * a.len - 1;
uint32_t bcursor = size * b.len - 1;
printf("cursors a and b %d %d\n", acursor, bcursor);
while (acursor > bcursor) {
if (a.data[acursor / size] & (1 << acursor % size)) {
return 1;
@ -84,6 +83,7 @@ int bigint_cmp(bigint_t a, bigint_t b) {
return 0;
}
// TODO check opti
bigint_t assignable_bigint_substraction(bigint_t a, bigint_t b) {
if (a.len != b.len) {
printf("error: attempting to substract numbers of different length\n");
@ -108,7 +108,13 @@ bigint_t assignable_bigint_substraction(bigint_t a, bigint_t b) {
return result;
}
void bigint_substraction(bigint_t a, bigint_t b) {
// TODO check opti
void bigint_substraction(bigint_t a, bigint_t bb) {
bigint_t b = bigint_clone(bb);
if (a.len > bb.len) {
b = bigint_zero(a.len);
memcpy(b.data, bb.data, b.len * sizeof(uint32_t));
}
if (a.len != b.len) {
printf("error: attempting to substract numbers of different length\n");
exit(1);
@ -130,35 +136,154 @@ void bigint_substraction(bigint_t a, bigint_t b) {
bigint_destroy(zero);
}
// TODO check opti
bigint_t assignable_bigint_modulo(bigint_t a, bigint_t b) {
bigint_t result = bigint_clone(a);
bigint_t mod = bigint_clone(b);
printf("a = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
if (a.len > b.len) {
mod = bigint_zero(a.len);
memcpy(mod.data, b.data, b.len * sizeof(uint32_t));
}
if (bigint_cmp(result, b) == -1) {
bigint_destroy(mod);
return result;
}
bigint_bitwise_left_shift(mod);
printf("after bitwise_shift\na = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
while (bigint_cmp(b, mod) == -1) {
while (bigint_cmp(result, mod) == 1) {
bigint_bitwise_left_shift(mod);
printf("DOUBLE after bitwise_shift\na = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
}
bigint_bitwise_right_shift(mod);
printf("before sub \na = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
if (bigint_cmp(result, mod) == 1) {
bigint_substraction(result, mod);
printf("subbed\na = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
bigint_substraction(result, mod);
}
bigint_bitwise_right_shift(mod);
}
while (bigint_cmp(result, b) == 1) {
while (bigint_cmp(result, b) == 1) {
bigint_substraction(result, b);
}
bigint_destroy(mod);
return result;
}
printf("subbed\na = %ud\nb = %ud\nresult = %ud\nmod = %ud\n", a.data[0], b.data[0], result.data[0], mod.data[0]);
void bigint_add(bigint_t a, bigint_t b) {
bigint_t result = bigint_zero(a.len);
size_t size = sizeof(uint32_t) * 8;
size_t width = a.len * size;
uint32_t carriage = 0;
// printf("hello add\n");
for (size_t cursor = 0; cursor < width; cursor++) {
// printf("hahaha %ld %ld\n", cursor, width);
uint32_t a_bit = a.data[cursor / size] >> (cursor % size) & 1;
uint32_t b_bit = b.data[cursor / size] >> (cursor % size) & 1;
result.data[cursor / size] |= (a_bit ^ b_bit ^ carriage) << (cursor % size);
carriage = (a_bit & b_bit) | ((a_bit ^ b_bit) & carriage);
}
// printf("im out\n");
bigint_destroy(a);
a = bigint_clone(result);
bigint_destroy(result);
}
void bigint_set_zeros(bigint_t n) {
// printf("hello set zeros\n");
for (size_t i = 0; i < n.len; i++) {
n.data[i] = 0;
}
// printf("goodbye set zeros\n");
}
bigint_t assignable_bigint_mul(bigint_t a, bigint_t b) {
bigint_t result = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 4);
bigint_t b_tool = bigint_zero(RSA_BLOCK_SIZE / 8 / sizeof(uint32_t) * 4);
/*if (a.len > b.len) {
result = bigint_zero(a.len);
b_tool = bigint_zero(a.len);
} else {
result = bigint_zero(a.len + b.len);
b_tool = bigint_zero(a.len + b.len);
}*/
size_t size = sizeof(uint32_t) * 8;
size_t width = a.len * size;
printf("multiplying %d and %d\n", a.data[0], b.data[0]);
// printf("hello mul\n");
for (size_t cursor = 0; cursor < width; cursor++) {
// printf("hello BIG LOOP ls %ld %ld\n", cursor, width);
if (a.data[cursor / 32] >> (cursor % 32) & 1) {
bigint_set_zeros(b_tool);
printf("bef %d\n", b_tool.data[0]);
// printf("hello memcpy\n");
memcpy(b_tool.data, b.data, b.len * sizeof(uint32_t));
printf("aft %d\n", b_tool.data[0]);
// printf("goodbye memcpy\n");
for (size_t i = 0; i < cursor; i++) {
// printf("hello bitwise ls %ld %ld\n", i, cursor);
bigint_bitwise_left_shift(b_tool);
// printf("goodbye bitwise ls\n");
}
// printf("before hello add\n");
bigint_add(result, b_tool);
}
}
// printf("GOODBYE BIG LOOP ls \n");
bigint_destroy(b_tool);
return result;
}
// a^e mod n
// clean memory tricks !!!
bigint_t assignable_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n) {
printf("print a\n");
bigint_print(a);
printf("print e\n");
bigint_print(e);
printf("print n\n");
bigint_print(n);
bigint_t result = bigint_clone(a);
size_t size = sizeof(uint32_t) * 8;
int cursor = e.len * size - 1;
while (!(e.data[cursor / 32] & 1 << (cursor % 32))) {
cursor--;
}
cursor--;
/*
printf("SQUARE\n");
bigint_t tmp_result2 = assignable_bigint_mul(result, result);
bigint_destroy(result);
result = bigint_clone(tmp_result2);
bigint_destroy(tmp_result2);
tmp_result2 = assignable_bigint_modulo(result, n);
bigint_destroy(result);
result = bigint_clone(tmp_result2);
bigint_destroy(tmp_result2);
*/
printf("cursor %d\n", cursor);
while (cursor >= 0) {
printf("SQUARE\n");
bigint_t tmp_result2 = assignable_bigint_mul(result, result);
bigint_destroy(result);
result = bigint_clone(tmp_result2);
bigint_destroy(tmp_result2);
tmp_result2 = assignable_bigint_modulo(result, n);
bigint_destroy(result);
result = bigint_clone(tmp_result2);
bigint_destroy(tmp_result2);
if (e.data[cursor / 32] & 1 << (cursor % 32)) {
printf("MULTIPLY\n");
bigint_t tmp_result = assignable_bigint_mul(result, a);
bigint_destroy(result);
result = bigint_clone(tmp_result);
bigint_destroy(tmp_result);
tmp_result = assignable_bigint_modulo(result, n);
bigint_destroy(result);
result = bigint_clone(tmp_result);
bigint_destroy(tmp_result);
}
cursor -= 1;
}
printf("this time its over\n");
return result;
}

View File

@ -80,5 +80,24 @@ rsa_t rsa_generate_keys(size_t block_size) {
printf("result is %ud\n", result.data[0]);
a = bigint_clone(result);
b.data[0] = 5764;
printf("length\na: %lu e: %lu n: %lu\n", result.len, a.len, b.len);
bigint_t result2 = assignable_bigint_pow_mod(result, a, b);
printf("bigpowmod is %u \n", result2.data[0]);
/* result.data[0] = 8;
a.data[0] = 4;
result2 = assignable_bigint_mul(result, a);
printf("result2 is %u \n", result2.data[0]);
result.data[0] = 84;
a.data[0] = 463;
result2 = assignable_bigint_mul(result, a);
printf("result2 is %u \n", result2.data[0]);
bigint_add(result, a);
printf("result2 is %u \n", result.data[0]);
*/
return rsa;
}

View File

@ -38,7 +38,10 @@ void bigint_print(bigint_t n);
bigint_t bigint_new(size_t len);
bigint_t bigint_zero(size_t len);
bigint_t bigint_clone(bigint_t src);
void bigint_add(bigint_t a, bigint_t b);
bigint_t assignable_bigint_mul(bigint_t a, bigint_t b);
bigint_t assignable_bigint_modulo(bigint_t a, bigint_t b);
bigint_t assignable_bigint_pow_mod(bigint_t a, bigint_t e, bigint_t n);
void bigint_destroy(bigint_t n);