diff --git a/src/net/grunt/ClientLink.cpp b/src/net/grunt/ClientLink.cpp index 0769ae9..3b317aa 100644 --- a/src/net/grunt/ClientLink.cpp +++ b/src/net/grunt/ClientLink.cpp @@ -8,6 +8,11 @@ #include #include +#define SERVER_PUBLIC_KEY_LEN 32 +#define SALT_LEN 32 +#define VERSION_CHALLENGE_LEN 16 +#define PIN_SALT_LEN 16 + Grunt::Command Grunt::s_clientCommands[] = { { Grunt::ClientLink::CMD_AUTH_LOGON_CHALLENGE, "ClientLink::CMD_AUTH_LOGON_CHALLENGE", &Grunt::ClientLink::CmdAuthLogonChallenge, 0 }, { Grunt::ClientLink::CMD_AUTH_LOGON_PROOF, "ClientLink::CMD_AUTH_LOGON_PROOF", &Grunt::ClientLink::CmdAuthLogonProof, 0 }, @@ -47,24 +52,24 @@ void Grunt::ClientLink::Call() { } int32_t Grunt::ClientLink::CmdAuthLogonChallenge(CDataStore& msg) { - // sizeof(protocol) + sizeof(result) - if (!CanRead(msg, 2)) { + uint8_t protocol; + uint8_t result; + + if (!CanRead(msg, sizeof(protocol) + sizeof(result))) { return 0; } - uint8_t protocol; msg.Get(protocol); if (protocol != 0) { return 1; } - uint8_t result; msg.Get(result); // Auth failure (success == 0) if (result != 0) { - if (msg.Tell() > msg.Size()) { + if (!msg.IsValid()) { return 1; } @@ -79,65 +84,57 @@ int32_t Grunt::ClientLink::CmdAuthLogonChallenge(CDataStore& msg) { return 2; } - // sizeof(serverPublicKey) + sizeof(generatorLen) - if (!CanRead(msg, 33)) { + uint8_t* serverPublicKey; + uint8_t generatorLen; + + if (!CanRead(msg, SERVER_PUBLIC_KEY_LEN + sizeof(generatorLen))) { return 0; } - uint8_t* serverPublicKey; - msg.GetDataInSitu(reinterpret_cast(serverPublicKey), 32); - - uint8_t generatorLen; + msg.GetDataInSitu(reinterpret_cast(serverPublicKey), SERVER_PUBLIC_KEY_LEN); msg.Get(generatorLen); - // generatorLen + sizeof(largeSafePrimeLen) - if (!CanRead(msg, generatorLen + 1)) { + uint8_t* generator; + uint8_t largeSafePrimeLen; + + if (!CanRead(msg, generatorLen + sizeof(largeSafePrimeLen))) { return 0; } - uint8_t* generator; msg.GetDataInSitu(reinterpret_cast(generator), generatorLen); - - uint8_t largeSafePrimeLen; msg.Get(largeSafePrimeLen); - // largeSafePrimeLen + sizeof(salt) + sizeof(versionChallenge) - if (!CanRead(msg, largeSafePrimeLen + 48)) { - return 0; - } - uint8_t* largeSafePrime; - msg.GetDataInSitu(reinterpret_cast(largeSafePrime), largeSafePrimeLen); - uint8_t* salt; - msg.GetDataInSitu(reinterpret_cast(salt), 32); - uint8_t* versionChallenge; - msg.GetDataInSitu(reinterpret_cast(versionChallenge), 16); - // sizeof(logonFlags) - if (!CanRead(msg, 1)) { + if (!CanRead(msg, largeSafePrimeLen + SALT_LEN + VERSION_CHALLENGE_LEN)) { return 0; } + msg.GetDataInSitu(reinterpret_cast(largeSafePrime), largeSafePrimeLen); + msg.GetDataInSitu(reinterpret_cast(salt), SALT_LEN); + msg.GetDataInSitu(reinterpret_cast(versionChallenge), VERSION_CHALLENGE_LEN); + uint8_t logonFlags; + + if (!CanRead(msg, sizeof(logonFlags))) { + return 0; + } + msg.Get(logonFlags); + bool pinEnabled = logonFlags & 0x1; + bool matrixEnabled = logonFlags & 0x2; + bool tokenEnabled = logonFlags & 0x4; + + // PIN (0x1) + uint32_t pinGridSeed = 0; uint8_t* pinSalt = nullptr; - uint8_t matrixWidth = 0; - uint8_t matrixHeight = 0; - uint8_t matrixDigitCount = 0; - uint8_t matrixChallengeCount = 0; - uint64_t matrixSeed = 0; - - uint8_t tokenRequired = 0; - - // PIN - if (logonFlags & 0x1) { - // sizeof(pinGridSeed) + sizeof(pinSalt) - if (!CanRead(msg, 20)) { + if (pinEnabled) { + if (!CanRead(msg, sizeof(pinGridSeed) + PIN_SALT_LEN)) { return 0; } @@ -145,37 +142,43 @@ int32_t Grunt::ClientLink::CmdAuthLogonChallenge(CDataStore& msg) { msg.GetDataInSitu(reinterpret_cast(pinSalt), 16); } - // MATRIX - if (logonFlags & 0x2) { - // TODO - /* - if (CanRead(msg, 12)) { - msg.Get(matrixWidth); - msg.Get(matrixHeight); - msg.Get(matrixDigitCount); - msg.Get(matrixChallengeCount); - msg.Get(matrixSeed); + // MATRIX (0x2) - if ((logonFlags & 0x2) && matrixChallengeCount == 0) { - return 1; - } - } else { + uint8_t matrixWidth = 0; + uint8_t matrixHeight = 0; + uint8_t matrixDigitCount = 0; + uint8_t matrixChallengeCount = 0; + uint64_t matrixSeed = 0; + + if (matrixEnabled) { + if (!CanRead(msg, sizeof(matrixWidth) + sizeof(matrixHeight) + sizeof(matrixDigitCount) + sizeof(matrixChallengeCount) + sizeof(matrixSeed))) { return 0; } - */ + + msg.Get(matrixWidth); + msg.Get(matrixHeight); + msg.Get(matrixDigitCount); + msg.Get(matrixChallengeCount); + msg.Get(matrixSeed); + + if (matrixChallengeCount == 0) { + return 1; + } } - // TOKEN (authenticator) - if (logonFlags & 0x4) { - // sizeof(tokenRequired) - if (!CanRead(msg, 1)) { + // TOKEN (aka authenticator) (0x4) + + uint8_t tokenRequired = 0; + + if (tokenEnabled) { + if (!CanRead(msg, sizeof(tokenRequired))) { return 0; } msg.Get(tokenRequired); } - if (msg.Tell() > msg.Size()) { + if (!msg.IsValid()) { return 1; } @@ -188,13 +191,28 @@ int32_t Grunt::ClientLink::CmdAuthLogonChallenge(CDataStore& msg) { if (this->m_srpClient.CalculateProof(largeSafePrime, largeSafePrimeLen, generator, generatorLen, salt, 32, serverPublicKey, 32, srpRandom)) { this->SetState(2); + this->m_clientResponse->LogonResult(GRUNT_RESULT_5, nullptr, 0, 0); } else { this->SetState(4); - this->m_clientResponse->SetPinInfo(logonFlags & 0x1, pinGridSeed, pinSalt); - // TODO - // this->m_clientResponse->SetMatrixInfo(logonFlags & 0x2, matrixWidth, matrixHeight, matrixDigitCount, matrixDigitCount, 0, matrixChallengeCount, matrixSeed, this->m_srpClient.buf20, 40); - this->m_clientResponse->SetTokenInfo(logonFlags & 0x4, tokenRequired); + + this->m_clientResponse->SetPinInfo(pinEnabled, pinGridSeed, pinSalt); + + this->m_clientResponse->SetMatrixInfo( + matrixEnabled, + matrixWidth, + matrixHeight, + matrixDigitCount, + matrixDigitCount, + false, + matrixChallengeCount, + matrixSeed, + this->m_srpClient.sessionKey, + 40 + ); + + this->m_clientResponse->SetTokenInfo(tokenEnabled, tokenRequired); + this->m_clientResponse->GetVersionProof(versionChallenge); } diff --git a/src/net/grunt/ClientResponse.hpp b/src/net/grunt/ClientResponse.hpp index d5e604a..ff06736 100644 --- a/src/net/grunt/ClientResponse.hpp +++ b/src/net/grunt/ClientResponse.hpp @@ -13,8 +13,8 @@ class Grunt::ClientResponse { virtual bool OnlineIdle() = 0; virtual void GetLogonMethod() = 0; virtual void GetVersionProof(const uint8_t* versionChallenge) = 0; - virtual void SetPinInfo(bool enabled, uint32_t a3, const uint8_t* a4) = 0; - virtual void SetMatrixInfo(bool enabled, uint8_t a3, uint8_t a4, uint8_t a5, uint8_t a6, bool a7, uint8_t a8, uint64_t a9, const uint8_t* a10, uint32_t a11) = 0; + virtual void SetPinInfo(bool enabled, uint32_t gridSeed, const uint8_t* salt) = 0; + virtual void SetMatrixInfo(bool enabled, uint8_t width, uint8_t height, uint8_t a5, uint8_t a6, bool a7, uint8_t challengeCount, uint64_t seed, const uint8_t* sessionKey, uint32_t a11) = 0; virtual void SetTokenInfo(bool enabled, uint8_t required) = 0; virtual void LogonResult(Result result, const uint8_t* sessionKey, uint32_t sessionKeyLen, uint16_t flags) = 0; virtual void RealmListResult(CDataStore* msg) = 0; diff --git a/src/net/login/GruntLogin.cpp b/src/net/login/GruntLogin.cpp index cdbb24c..e5d84c8 100644 --- a/src/net/login/GruntLogin.cpp +++ b/src/net/login/GruntLogin.cpp @@ -282,14 +282,14 @@ void GruntLogin::ProveVersion(const uint8_t* versionChecksum) { ); } -void GruntLogin::SetMatrixInfo(bool enabled, uint8_t a3, uint8_t a4, uint8_t a5, uint8_t a6, bool a7, uint8_t a8, uint64_t a9, const uint8_t* a10, uint32_t a11) { +void GruntLogin::SetMatrixInfo(bool enabled, uint8_t width, uint8_t height, uint8_t a5, uint8_t a6, bool a7, uint8_t challengeCount, uint64_t seed, const uint8_t* sessionKey, uint32_t a11) { // TODO } -void GruntLogin::SetPinInfo(bool enabled, uint32_t a3, const uint8_t* a4) { +void GruntLogin::SetPinInfo(bool enabled, uint32_t gridSeed, const uint8_t* salt) { // TODO } -void GruntLogin::SetTokenInfo(bool enabled, uint8_t tokenRequired) { +void GruntLogin::SetTokenInfo(bool enabled, uint8_t required) { // TODO } diff --git a/src/net/login/GruntLogin.hpp b/src/net/login/GruntLogin.hpp index d8c33c5..5fab315 100644 --- a/src/net/login/GruntLogin.hpp +++ b/src/net/login/GruntLogin.hpp @@ -16,9 +16,9 @@ class GruntLogin : public Login { virtual bool Connected(const NETADDR& addr); virtual void GetLogonMethod(); virtual void GetVersionProof(const uint8_t* versionChallenge); - virtual void SetPinInfo(bool enabled, uint32_t a3, const uint8_t* a4); - virtual void SetMatrixInfo(bool enabled, uint8_t a3, uint8_t a4, uint8_t a5, uint8_t a6, bool a7, uint8_t a8, uint64_t a9, const uint8_t* a10, uint32_t a11); - virtual void SetTokenInfo(bool enabled, uint8_t tokenRequired); + virtual void SetPinInfo(bool enabled, uint32_t gridSeed, const uint8_t* salt); + virtual void SetMatrixInfo(bool enabled, uint8_t width, uint8_t height, uint8_t a5, uint8_t a6, bool a7, uint8_t challengeCount, uint64_t seed, const uint8_t* sessionKey, uint32_t a11); + virtual void SetTokenInfo(bool enabled, uint8_t required); virtual void LogonResult(Grunt::Result result, const uint8_t* sessionKey, uint32_t sessionKeyLen, uint16_t flags); virtual LOGIN_STATE NextSecurityState(LOGIN_STATE state); virtual int32_t GetServerId();