重新学习c++一周了,写个MD5练练手。

参考资料:

  1. 在线MD5
  2. 伪代码
  3. RFC中的C实现
  4. RFC 1321

基础实现

由于C++的流操作还不熟悉,没有实现读取任意长度输入的功能,目前是把输入整个读进来,存在一个字符串里。

md5.hpp

实现和定义都在头文件里,放在namespace MD5中。

类型别名

计算MD5的过程中需要操作字节和无符号整数,用到两种类型:

  1. 表示单个字节的类型,用于位运算
  2. 表示四个字节的类型,用作无符号整数计算和位运算

using起了两个别名:

using uint1 = unsigned char;
using uint = unsigned int;

给用户提供的函数

md5的输入是一个字节流,目前没有实现字节流,将输入读入一个字符串中,然后构建字节数组,作为算法的输入。 md5算法描述中说的是每次处理512个bit,也就是64byte的数据,代码中用了一个循环,每次处理输入的字节数组中64个元素。 md5的输出是128bit的数据,记录在state变量中,RFC中的初始值为

          word A: 01 23 45 67
          word B: 89 ab cd ef
          word C: fe dc ba 98
          word D: 76 54 32 10

但是需要小端序,所以初始化时应该这么写:

    std::vector<uint> state{
        0x67452301,
        0xefcdab89,
        0x98badcfe,
        0x10325476,
    };

完整的hex_digest函数:

std::string hex_digest(const std::string &input_str) {
    // 输入
    std::vector<uint1> msg{input_str.begin(), input_str.end()};
    // 补全长度
    pad_msg(msg);
    
    // 初始化
    std::vector<uint> state{
        0x67452301,
        0xefcdab89,
        0x98badcfe,
        0x10325476,
    };

    // 每次处理 64 byte = 512 bit = 16 word(in RFC, 4 byte = 1 word)
    for (auto i = 0; i < msg.size() - 63; i += 64) {
        // TODO do not copy
        // O(64)
        std::vector<uint1> temp_block{msg.begin() + i, msg.begin() + i + 64};
        process_block(temp_block, state);
    }

    // 用字符串表示结果
    return assemble_result(state);
}

实现细节

pad_msg

如果输入的字节数不是64的整数倍,补全到64的整数倍:

  1. 在消息后面加上一个bit,1,然后后面都是0,(换成字节表示是:先加上0x80,然后一直加0x00)直到按字节算的长度对64取余结果为56
  2. 将原本的消息长度表示为64bit(8byte)的无符号整数,用小端序字节序加到消息后面,注意,长度是按bit算的。

代码:

void pad_msg(std::vector<uint1> &msg) {
    // add 0x80 to vector, then add 0x00 until msg.length() mod 64 = 56
    if (msg.size() % 64 == 0 && !msg.empty()) { // no need to pad
        return;
    }
    // assuming full byte data
    auto original_size = msg.size() * 8;
    auto byte_num = sizeof original_size;
    // assert byte_num = 8;
    std::vector<uint1> size_vec(byte_num, 0);
    for (size_t i = 0; i < byte_num; i++) {
        size_vec[i] = (original_size >> (i * 8)) & 0xFF;
    }

    msg.push_back(0x80);
    while (msg.size() % 64 != 56) {
        msg.push_back(0);
    }
    //    std::cout << "original_size = " << original_size << '\n';
    msg.insert(msg.end(), size_vec.begin(), size_vec.end());
}

process_block

对一个64byte的块进行操作,更新state。实现时需要注意,算法中此处的单位是word,也就是四个字节为单位,作为无符号整数计算。 但是为了方便,我直接传了字节数组进去。所以在循环内部需要转换。一开始没注意到,debug两个小时才看出来。

void process_block(const std::vector<uint1> &block, std::vector<uint> &state) {
    // process 64 bytes (512 bits) chunk, update _state

    uint a = state[0];
    uint b = state[1];
    uint c = state[2];
    uint d = state[3];

    for (uint i = 0; i < 64; i++) {
        uint F, g;

        if (i < 16) {
            F = (b & c) | ((~b) & d);
            g = i;
        } else if (i < 32) {
            F = (d & b) | ((~d) & c);
            g = (5u * i + 1) % 16u;
        } else if (i < 48) {
            F = b ^ c ^ d;
            g = (3u * i + 5) % 16;
        } else {
            F = c ^ (b | (~d));
            g = (7 * i) % 16;
        }
        // 拼出uint
        uint block_g{};
        for (size_t j = 0; j < 4; j++) {
            block_g += block[g * 4 + j] << (j * 8);
        }

        F = F + a + K[i] + block_g;
        a = d;
        d = c;
        c = b;
        b = b + left_rotate(F, s[i]);
    }

    state[0] += a;
    state[1] += b;
    state[2] += c;
    state[3] += d;
}

assemble_result

MD5的输出是一个128bit的指纹,通常表示为16进制字符串,如,空输入的输出为d41d8cd98f00b204e9800998ecf8427e。 RFC中规定了,输出从A的low-order byte开始,以D的high-order byte结束。所以,windows和linux系统直接按内存中的 顺序翻一下再转换成16进制字符串表示就可以了,注意需要补上0

std::string assemble_result(const std::vector<uint> &result) {
    std::stringstream ss;
    for (auto u : result) {
        while (u != 0) {
            uint b = u & 0xFF;
            ss << std::setfill('0') << std::setw(2) << std::hex << b;
            u = u >> 8;
        }
    }
    return ss.str();
}