summaryrefslogtreecommitdiff
path: root/hkdf.h
blob: 711f250cd8c8c1540253c91c232ee741328af2fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// hkdf.h - written and placed in public domain by Jeffrey Walton. Copyright assigned to Crypto++ project

#ifndef CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H
#define CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H

#include "cryptlib.h"
#include "hmac.h"
#include "hrtimer.h"
#include "secblock.h"

#include <cstring>

NAMESPACE_BEGIN(CryptoPP)

//! abstract base class for key derivation function
class KeyDerivationFunction
{
public:
	virtual size_t MaxDerivedKeyLength() const =0;
	virtual bool UsesContext() const =0;
	//! derive key from secret
	virtual unsigned int DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* context=NULL, size_t contextLen=0) const =0;
	
	// If salt is missing, then use the NULL vector. The length depends on the Hash function.
	static const byte s_NullVector[64];
};

const byte KeyDerivationFunction::s_NullVector[64] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};

//! General, multipurpose KDF from RFC 5869. T should be a HashTransformation class
//! https://eprint.iacr.org/2010/264 and https://tools.ietf.org/html/rfc5869
template <class T>
class CRYPTOPP_DLL HKDF : public KeyDerivationFunction
{
public:
	size_t MaxDerivedKeyLength() const {return static_cast<size_t>(T::DIGESTSIZE) * 255;}
	bool UsesContext() const {return true;}
	unsigned int DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* context, size_t contextLen) const;
};

template <class T>
unsigned int HKDF<T>::DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* context, size_t contextLen) const
{
	static const size_t DIGEST_SIZE = static_cast<size_t>(T::DIGESTSIZE);
	CRYPTOPP_COMPILE_ASSERT(DIGEST_SIZE <= COUNTOF(s_NullVector));
	const unsigned int req = static_cast<unsigned int>(derivedLen);
	
	assert(secret && secretLen);
	assert(derived && derivedLen);
	assert(derivedLen <= MaxDerivedKeyLength());

	if(derivedLen > MaxDerivedKeyLength())
		throw InvalidArgument("HKDF: derivedLen must be less than or equal to MaxDerivedKeyLength");

	HMAC<T> hmac;
	FixedSizeSecBlock<byte, DIGEST_SIZE> prk, buffer;

	// Extract
	const byte* key = (salt ? salt : s_NullVector);
	const size_t klen = (salt ? saltLen : DIGEST_SIZE);

	hmac.SetKey(key, klen);
	hmac.CalculateDigest(prk, secret, secretLen);

	// Expand
	hmac.SetKey(prk.data(), prk.size());
	byte block = 0;

	while (derivedLen > 0)
	{
		if(block++) {hmac.Update(buffer, buffer.size());}
		if(context && contextLen) {hmac.Update(context, contextLen);}
		hmac.CalculateDigest(buffer, &block, 1);

		size_t segmentLen = STDMIN(derivedLen, DIGEST_SIZE);
		std::memcpy(derived, buffer, segmentLen);

		derived += segmentLen;
		derivedLen -= segmentLen;
	}

	return req;
}

NAMESPACE_END

#endif // CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H