diff --git a/lib/resty/aes.lua b/lib/resty/aes.lua index 377476f..4a61aeb 100644 --- a/lib/resty/aes.lua +++ b/lib/resty/aes.lua @@ -58,6 +58,7 @@ const EVP_CIPHER *EVP_aes_256_cfb1(void); const EVP_CIPHER *EVP_aes_256_cfb8(void); const EVP_CIPHER *EVP_aes_256_cfb128(void); const EVP_CIPHER *EVP_aes_256_ofb(void); +const EVP_CIPHER *EVP_aes_256_ctr(void); const EVP_CIPHER *EVP_aes_128_gcm(void); const EVP_CIPHER *EVP_aes_192_gcm(void); const EVP_CIPHER *EVP_aes_256_gcm(void); @@ -103,11 +104,14 @@ hash = { _M.hash = hash local EVP_MAX_BLOCK_LENGTH = 32 +local shared_out_len_ptr = ffi_new("int[1]") +local shared_tmp_len_ptr = ffi_new("int[1]") +local shared_tag_buf_ptr = ffi_new("unsigned char[?]", 16) local cipher cipher = function (size, _cipher) local _size = size or 128 - local _cipher = _cipher or "cbc" + _cipher = _cipher or "cbc" local func = "EVP_aes_" .. _size .. "_" .. _cipher if C[func] then return { size=_size, cipher=_cipher, method=C[func]()} @@ -132,12 +136,12 @@ function _M.new(self, key, salt, _cipher, _hash, hash_rounds, iv_len, enable_pad ffi_gc(decrypt_ctx, C.EVP_CIPHER_CTX_free) - local _cipher = _cipher or cipher() - local _hash = _hash or hash.md5 - local hash_rounds = hash_rounds or 1 + _cipher = _cipher or cipher() + _hash = _hash or hash.md5 + hash_rounds = hash_rounds or 1 local _cipherLength = _cipher.size/8 - local gen_key = ffi_new("unsigned char[?]",_cipherLength) - local gen_iv = ffi_new("unsigned char[?]",_cipherLength) + local gen_key = ffi_new("unsigned char[?]", _cipherLength) + local gen_iv = ffi_new("unsigned char[?]", _cipherLength) iv_len = iv_len or _cipherLength -- enable padding by default local padding = (enable_padding == nil or enable_padding) and 1 or 0 @@ -223,6 +227,21 @@ function _M.new(self, key, salt, _cipher, _hash, hash_rounds, iv_len, enable_pad }, mt) end +local function alloc_buf(max_len) + local buf = ffi_new("unsigned char[?]", max_len) + return buf, max_len +end + +do + local ok, str_buf_mod = pcall(require, "string.buffer") + if ok then + local str_buf = str_buf_mod.new(4096) + function alloc_buf(max_len) + local buf, sz = str_buf:reset():reserve(max_len) + return buf, sz + end + end +end function _M.encrypt(self, s, aad) local typ = type(self) @@ -231,10 +250,9 @@ function _M.encrypt(self, s, aad) end local s_len = #s - local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH - local buf = ffi_new("unsigned char[?]", max_len) - local out_len = ffi_new("int[1]") - local tmp_len = ffi_new("int[1]") + local buf = alloc_buf(s_len + 2 * EVP_MAX_BLOCK_LENGTH) + local out_len = shared_out_len_ptr + local tmp_len = shared_tmp_len_ptr local ctx = self._encrypt_ctx if C.EVP_EncryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then @@ -254,7 +272,7 @@ function _M.encrypt(self, s, aad) if self._cipher == "gcm" then local encrypt_data = ffi_str(buf, out_len[0]) if C.EVP_EncryptFinal_ex(ctx, buf, out_len) == 0 then - return nil, "EVP_DecryptFinal_ex failed" + return nil, "EVP_EncryptFinal_ex failed" end -- FIXME: For OCB mode the taglen must either be 16 @@ -280,14 +298,13 @@ function _M.decrypt(self, s, tag, aad) end local s_len = #s - local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH - local buf = ffi_new("unsigned char[?]", max_len) - local out_len = ffi_new("int[1]") - local tmp_len = ffi_new("int[1]") + local buf = alloc_buf(s_len + 2 * EVP_MAX_BLOCK_LENGTH) + local out_len = shared_out_len_ptr + local tmp_len = shared_tmp_len_ptr local ctx = self._decrypt_ctx if C.EVP_DecryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then - return nil, "EVP_DecryptInit_ex failed" + return nil, "EVP_DecryptInit_ex failed" end if self._cipher == "gcm" and aad ~= nil then @@ -297,14 +314,14 @@ function _M.decrypt(self, s, tag, aad) end if C.EVP_DecryptUpdate(ctx, buf, out_len, s, s_len) == 0 then - return nil, "EVP_DecryptUpdate failed" + return nil, "EVP_DecryptUpdate failed" end if self._cipher == "gcm" then local plain_txt = ffi_str(buf, out_len[0]) if tag ~= nil then - local tag_buf = ffi_new("unsigned char[?]", 16) - ffi.copy(tag_buf, tag, 16) + local tag_buf = shared_tag_buf_ptr + ffi_copy(tag_buf, tag, 16) C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, 16, tag_buf); end