From e5f8c0fc463a31b67f4f6c675cbff6a330825c3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 5 Apr 2024 14:00:35 +0900 Subject: [PATCH] fix(file): check sum --- base1464.c | 33 --------------------- binary.h | 55 ++++++++++------------------------ file.c | 86 +++++++++++++++++++++++++++++++++--------------------- 3 files changed, 68 insertions(+), 106 deletions(-) diff --git a/base1464.c b/base1464.c index 23d3f69..1fbe6e4 100644 --- a/base1464.c +++ b/base1464.c @@ -39,9 +39,6 @@ int base16384_encode_safe(const char* data, int dlen, char* buf) { case 6: outlen += 10; break; default: break; } - #ifdef DEBUG - printf("outlen: %llu, offset: %u, malloc: %llu\n", outlen, offset, outlen + 8); - #endif uint64_t* vals = (uint64_t*)buf; uint64_t n = 0; int64_t i = 0; @@ -57,9 +54,6 @@ int base16384_encode_safe(const char* data, int dlen, char* buf) { sum |= shift & 0x0000000000003fff; sum += 0x4e004e004e004e00; vals[n++] = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } remainder valbuf; if(dlen - i == 7) { @@ -108,9 +102,6 @@ int base16384_encode_safe(const char* data, int dlen, char* buf) { valbuf.val = sum; #endif memcpy(&vals[n], valbuf.buf, outlen-2-(int)n*(int)sizeof(uint64_t)); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif buf[outlen - 2] = '='; buf[outlen - 1] = offset; } @@ -130,9 +121,6 @@ int base16384_encode(const char* data, int dlen, char* buf) { case 6: outlen += 10; break; default: break; } - #ifdef DEBUG - printf("outlen: %llu, offset: %u, malloc: %llu\n", outlen, offset, outlen + 8); - #endif uint64_t* vals = (uint64_t*)buf; uint64_t n = 0; int64_t i = 0; @@ -148,9 +136,6 @@ int base16384_encode(const char* data, int dlen, char* buf) { sum |= shift & 0x0000000000003fff; sum += 0x4e004e004e004e00; vals[n++] = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } int o = offset; if(o--) { @@ -182,9 +167,6 @@ int base16384_encode(const char* data, int dlen, char* buf) { #else vals[n] = sum; #endif - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif buf[outlen - 2] = '='; buf[outlen - 1] = offset; } @@ -204,9 +186,6 @@ int base16384_encode_unsafe(const char* data, int dlen, char* buf) { case 6: outlen += 10; break; default: break; } - #ifdef DEBUG - printf("outlen: %llu, offset: %u, malloc: %llu\n", outlen, offset, outlen + 8); - #endif uint64_t* vals = (uint64_t*)buf; uint64_t n = 0; int64_t i = 0; @@ -222,9 +201,6 @@ int base16384_encode_unsafe(const char* data, int dlen, char* buf) { sum |= shift & 0x0000000000003fff; sum += 0x4e004e004e004e00; vals[n++] = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } if(offset) { buf[outlen - 2] = '='; @@ -265,9 +241,6 @@ int base16384_decode_safe(const char* data, int dlen, char* buf) { shift <<= 2; sum |= shift & 0x00000000003fff00; *(uint64_t*)(buf+i) = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } remainder valbuf; if(outlen - i == 7) { @@ -345,9 +318,6 @@ int base16384_decode(const char* data, int dlen, char* buf) { shift <<= 2; sum |= shift & 0x00000000003fff00; *(uint64_t*)(buf+i) = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } if(offset--) { // 这里有读取越界 @@ -411,9 +381,6 @@ int base16384_decode_unsafe(const char* data, int dlen, char* buf) { shift <<= 2; sum |= shift & 0x00000000003fff00; *(uint64_t*)(buf+i) = be64toh(sum); - #ifdef DEBUG - printf("i: %llu, add sum: %016llx\n", i, sum); - #endif } register uint64_t sum = 0; register uint64_t shift = htobe64(vals[n]); diff --git a/binary.h b/binary.h index 1feaef5..43b48ca 100644 --- a/binary.h +++ b/binary.h @@ -87,33 +87,15 @@ // leftrotate function definition #define LEFTROTATE(x, c) (((x) << (c)) | ((x) >> (sizeof(x)*8 - (c)))) -static inline uint32_t calc_sum(uint32_t sum, size_t cnt, char* encbuf) { - uint32_t i; - #ifdef DEBUG - fprintf(stderr, "cnt: %zu, roundin: %08x, ", cnt, sum); - #endif - for(i = 0; i < cnt/sizeof(sum); i++) { - #ifdef DEBUG - if (!i) { - fprintf(stderr, "firstval: %08x, ", htobe32(((uint32_t*)encbuf)[i])); - } - #endif - sum += ~LEFTROTATE(htobe32(((uint32_t*)encbuf)[i]), encbuf[i*sizeof(sum)]%(8*sizeof(sum))); +static inline uint32_t calc_sum(uint32_t sum, size_t cnt, const char* encbuf) { + size_t i; + uint32_t buf; + for(i = 0; i < cnt; i++) { + buf = (uint32_t)(encbuf[i])&0xff; + buf = ((buf<<(24-6))&0x03000000) | ((buf<<(16-4))&0x00030000) | ((buf<<(8-2))&0x00000300) | (buf&0x03); + sum += buf; + sum = ~LEFTROTATE(sum, 3); } - #ifdef DEBUG - fprintf(stderr, "roundmid: %08x", sum); - #endif - size_t rem = cnt % sizeof(sum); - if(rem) { - uint32_t x = htobe32(((uint32_t*)encbuf)[i]) & (0xffffffff << (8*(sizeof(sum)-rem))); - sum += ~LEFTROTATE(x, encbuf[i*sizeof(sum)]%(8*sizeof(sum))); - #ifdef DEBUG - fprintf(stderr, ", roundrem:%08x\n", sum); - #endif - } - #ifdef DEBUG - else fprintf(stderr, "\n"); - #endif return sum; } @@ -125,19 +107,14 @@ static inline uint32_t calc_and_embed_sum(uint32_t sum, size_t cnt, char* encbuf return sum; } -static inline int calc_and_check_sum(uint32_t* s, size_t cnt, char* encbuf) { - uint32_t sum = calc_sum(*s, cnt, encbuf); - if(cnt%7) { // is last decode block - int shift = (int[]){0, 26, 20, 28, 22, 30, 24}[cnt%7]; - uint32_t sum_read = be32toh((*(uint32_t*)(&encbuf[cnt]))) >> shift; - sum >>= shift; - #ifdef DEBUG - fprintf(stderr, "cntrm: %lu, mysum: %08x, sumrd: %08x\n", cnt%7, sum, sum_read); - #endif - return sum != sum_read; - } - *s = sum; - return 0; +static inline int check_sum(uint32_t sum, uint32_t sum_read_raw, int offset) { + int shift = (int[]){0, 26, 20, 28, 22, 30, 24}[offset%7]; + uint32_t sum_read = be32toh(sum_read_raw) >> shift; + sum >>= shift; + #ifdef DEBUG + fprintf(stderr, "offset: %d, mysum: %08x, sumrd: %08x\n", offset, sum, sum_read); + #endif + return sum != sum_read; } #endif diff --git a/file.c b/file.c index 87e6a5e..e950305 100644 --- a/file.c +++ b/file.c @@ -56,9 +56,7 @@ static inline off_t get_file_size(const char* filepath) { base16384_err_t base16384_encode_file_detailed(const char* input, const char* output, char* encbuf, char* decbuf, int flag) { off_t inputsize; - FILE* fp = NULL; - FILE* fpo; - uint32_t sum = BASE16384_SIMPLE_SUM_INIT_VALUE; + FILE *fp = NULL, *fpo; int errnobak = 0, is_stdin = is_standard_io(input); base16384_err_t retval = base16384_err_ok; if(!input || !output || strlen(input) <= 0 || strlen(output) <= 0) { @@ -91,11 +89,8 @@ base16384_err_t base16384_encode_file_detailed(const char* input, const char* ou fputc(0xFE, fpo); fputc(0xFF, fpo); } - #ifdef DEBUG - inputsize = 917504; - fprintf(stderr, "inputsize: %lld\n", inputsize); - #endif size_t cnt; + uint32_t sum = BASE16384_SIMPLE_SUM_INIT_VALUE; while((cnt = fread(encbuf, sizeof(char), inputsize, fp)) > 0) { int n; while(cnt%7) { @@ -252,10 +247,8 @@ base16384_err_t base16384_decode_file_detailed(const char* input, const char* ou if(errno) { goto_base16384_file_detailed_cleanup(decode, base16384_err_read_file, {}); } - #ifdef DEBUG - fprintf(stderr, "inputsize: %lld\n", inputsize); - #endif - int cnt; + int cnt, last_encbuf_cnt = 0, last_decbuf_cnt = 0, offset = 0; + size_t total_decoded_len = 0; while((cnt = fread(decbuf, sizeof(char), inputsize, fp)) > 0) { int n; while(cnt%8) { @@ -269,16 +262,23 @@ base16384_err_t base16384_decode_file_detailed(const char* input, const char* ou decbuf[cnt++] = end; } if(errno) goto_base16384_file_detailed_cleanup(decode, base16384_err_read_file, {}); + offset = decbuf[cnt-1]; + last_decbuf_cnt = cnt; cnt = base16384_decode_unsafe(decbuf, cnt, encbuf); if(cnt && fwrite(encbuf, cnt, 1, fpo) <= 0) { goto_base16384_file_detailed_cleanup(decode, base16384_err_write_file, {}); } - if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) { - if(calc_and_check_sum(&sum, cnt, encbuf)) { - errno = EINVAL; - goto_base16384_file_detailed_cleanup(decode, base16384_err_invalid_decoding_checksum, {}); - } - } + total_decoded_len += cnt; + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) sum = calc_sum(sum, cnt, encbuf); + last_encbuf_cnt = cnt; + } + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN + && total_decoded_len >= _BASE16384_ENCBUFSZ + && last_decbuf_cnt > 2 + && decbuf[last_decbuf_cnt-2] == '=' + && check_sum(sum, *(uint32_t*)(&encbuf[last_encbuf_cnt]), offset)) { + errno = EINVAL; + goto_base16384_file_detailed_cleanup(decode, base16384_err_invalid_decoding_checksum, {}); } #if !defined _WIN32 && !defined __cosmopolitan } else { // small file, use mmap & fwrite @@ -324,7 +324,8 @@ base16384_err_t base16384_decode_fp_detailed(FILE* input, FILE* output, char* en if(errno) { return base16384_err_read_file; } - int cnt; + int cnt, last_encbuf_cnt = 0, last_decbuf_cnt = 0, offset = 0; + size_t total_decoded_len = 0; while((cnt = fread(decbuf, sizeof(char), inputsize, input)) > 0) { int n; while(cnt%8) { @@ -338,16 +339,23 @@ base16384_err_t base16384_decode_fp_detailed(FILE* input, FILE* output, char* en decbuf[cnt++] = end; } if(errno) return base16384_err_read_file; + offset = decbuf[cnt-1]; + last_decbuf_cnt = cnt; cnt = base16384_decode_unsafe(decbuf, cnt, encbuf); if(cnt && fwrite(encbuf, cnt, 1, output) <= 0) { return base16384_err_write_file; } - if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) { - if (calc_and_check_sum(&sum, cnt, encbuf)) { - errno = EINVAL; - return base16384_err_invalid_decoding_checksum; - } - } + total_decoded_len += cnt; + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) sum = calc_sum(sum, cnt, encbuf); + last_encbuf_cnt = cnt; + } + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN + && total_decoded_len >= _BASE16384_ENCBUFSZ + && last_decbuf_cnt > 2 + && decbuf[last_decbuf_cnt-2] == '=' + && check_sum(sum, *(uint32_t*)(&encbuf[last_encbuf_cnt]), offset)) { + errno = EINVAL; + return base16384_err_invalid_decoding_checksum; } return base16384_err_ok; } @@ -373,15 +381,21 @@ base16384_err_t base16384_decode_fd_detailed(int input, int output, char* encbuf errno = EINVAL; return base16384_err_fopen_output_file; } + off_t inputsize = _BASE16384_DECBUFSZ; - int p = 0, n; uint32_t sum = BASE16384_SIMPLE_SUM_INIT_VALUE; uint8_t remains[8]; + decbuf[0] = 0; if(read(input, remains, 2) != 2) { return base16384_err_read_file; } + + int p = 0; if(remains[0] != (uint8_t)(0xfe)) p = 2; + + int n, last_encbuf_cnt = 0, last_decbuf_cnt = 0, offset = 0; + size_t total_decoded_len = 0; while((n = read(input, decbuf+p, inputsize-p)) > 0) { if(p) { memcpy(decbuf, remains, p); @@ -404,19 +418,23 @@ base16384_err_t base16384_decode_fd_detailed(int input, int output, char* encbuf decbuf[n++] = (char)(next&0x00ff); } else remains[p++] = (char)(next&0x00ff); } - #ifdef DEBUG - fprintf(stderr, "decode chunk: %d, last2: %c %02x\n", cnt, decbuf[cnt-2], (uint8_t)decbuf[cnt-1]); - #endif + offset = decbuf[n-1]; + last_decbuf_cnt = n; n = base16384_decode_unsafe(decbuf, n, encbuf); if(n && write(output, encbuf, n) != n) { return base16384_err_write_file; } - if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) { - if (calc_and_check_sum(&sum, n, encbuf)) { - errno = EINVAL; - return base16384_err_invalid_decoding_checksum; - } - } + total_decoded_len += n; + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN) sum = calc_sum(sum, n, encbuf); + last_encbuf_cnt = n; + } + if(flag&BASE16384_FLAG_SUM_CHECK_ON_REMAIN + && total_decoded_len >= _BASE16384_ENCBUFSZ + && last_decbuf_cnt > 2 + && decbuf[last_decbuf_cnt-2] == '=' + && check_sum(sum, *(uint32_t*)(&encbuf[last_encbuf_cnt]), offset)) { + errno = EINVAL; + return base16384_err_invalid_decoding_checksum; } return base16384_err_ok; }