From 3e3f0918e9646ae5c5f61a9a50ccdcf79eb54894 Mon Sep 17 00:00:00 2001 From: server Date: Tue, 14 Apr 2026 06:04:08 +0200 Subject: [PATCH] tests: add socket auth smoke flow --- tests/smoke_auth.cpp | 185 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 175 insertions(+), 10 deletions(-) diff --git a/tests/smoke_auth.cpp b/tests/smoke_auth.cpp index ec84900..040795f 100644 --- a/tests/smoke_auth.cpp +++ b/tests/smoke_auth.cpp @@ -5,30 +5,87 @@ #include #include +#include "common/packet_headers.h" #include "game/stdafx.h" #include "game/SecureCipher.h" #include "libthecore/fdwatch.h" namespace { +#pragma pack(push, 1) +struct WirePhasePacket +{ + uint16_t header; + uint16_t length; + uint8_t phase; +}; + +struct WireKeyChallengePacket +{ + uint16_t header; + uint16_t length; + uint8_t server_pk[SecureCipher::PK_SIZE]; + uint8_t challenge[SecureCipher::CHALLENGE_SIZE]; + uint32_t server_time; +}; + +struct WireKeyResponsePacket +{ + uint16_t header; + uint16_t length; + uint8_t client_pk[SecureCipher::PK_SIZE]; + uint8_t challenge_response[SecureCipher::HMAC_SIZE]; +}; + +struct WireKeyCompletePacket +{ + uint16_t header; + uint16_t length; + uint8_t encrypted_token[SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE]; + uint8_t nonce[SecureCipher::NONCE_SIZE]; +}; +#pragma pack(pop) + void Expect(bool condition, const char* message) { if (!condition) throw std::runtime_error(message); } +void WriteExact(int fd, const void* data, size_t length, const char* message) +{ + const uint8_t* cursor = static_cast(data); + size_t remaining = length; + + while (remaining > 0) + { + const ssize_t written = write(fd, cursor, remaining); + Expect(written > 0, message); + cursor += written; + remaining -= static_cast(written); + } +} + +void ReadExact(int fd, void* data, size_t length, const char* message) +{ + uint8_t* cursor = static_cast(data); + size_t remaining = length; + + while (remaining > 0) + { + const ssize_t bytes_read = read(fd, cursor, remaining); + Expect(bytes_read > 0, message); + cursor += bytes_read; + remaining -= static_cast(bytes_read); + } +} + void TestPacketLayouts() { - constexpr size_t key_challenge_size = - sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::CHALLENGE_SIZE + sizeof(uint32_t); - constexpr size_t key_response_size = - sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::HMAC_SIZE; - constexpr size_t key_complete_size = - sizeof(uint16_t) * 2 + SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE + SecureCipher::NONCE_SIZE; - - Expect(key_challenge_size == 72, "Unexpected key challenge wire size"); - Expect(key_response_size == 68, "Unexpected key response wire size"); - Expect(key_complete_size == 76, "Unexpected key complete wire size"); + Expect(sizeof(WirePhasePacket) == 5, "Unexpected phase wire size"); + Expect(sizeof(WireKeyChallengePacket) == 72, "Unexpected key challenge wire size"); + Expect(sizeof(WireKeyResponsePacket) == 68, "Unexpected key response wire size"); + Expect(sizeof(WireKeyCompletePacket) == 76, "Unexpected key complete wire size"); } void TestSecureCipherRoundTrip() @@ -83,6 +140,113 @@ void TestSecureCipherRoundTrip() Expect(reverse == payload, "Client to server stream cipher round-trip failed"); } +void TestSocketAuthWireFlow() +{ + SecureCipher server; + SecureCipher client; + + Expect(server.Initialize(), "Server auth cipher init failed"); + Expect(client.Initialize(), "Client auth cipher init failed"); + + int sockets[2] = { -1, -1 }; + Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair for auth flow failed"); + + WirePhasePacket phase_packet {}; + phase_packet.header = GC::PHASE; + phase_packet.length = sizeof(phase_packet); + phase_packet.phase = PHASE_HANDSHAKE; + + WireKeyChallengePacket key_challenge {}; + key_challenge.header = GC::KEY_CHALLENGE; + key_challenge.length = sizeof(key_challenge); + server.GetPublicKey(key_challenge.server_pk); + server.GenerateChallenge(key_challenge.challenge); + key_challenge.server_time = 0x12345678; + + WriteExact(sockets[0], &phase_packet, sizeof(phase_packet), "Failed to write phase packet"); + WriteExact(sockets[0], &key_challenge, sizeof(key_challenge), "Failed to write key challenge"); + + WirePhasePacket client_phase {}; + WireKeyChallengePacket client_challenge {}; + ReadExact(sockets[1], &client_phase, sizeof(client_phase), "Failed to read phase packet"); + ReadExact(sockets[1], &client_challenge, sizeof(client_challenge), "Failed to read key challenge"); + + Expect(client_phase.header == GC::PHASE, "Unexpected phase header"); + Expect(client_phase.length == sizeof(client_phase), "Unexpected phase packet length"); + Expect(client_phase.phase == PHASE_HANDSHAKE, "Unexpected phase value"); + Expect(client_challenge.header == GC::KEY_CHALLENGE, "Unexpected key challenge header"); + Expect(client_challenge.length == sizeof(client_challenge), "Unexpected key challenge length"); + Expect(std::memcmp(client_challenge.server_pk, key_challenge.server_pk, sizeof(key_challenge.server_pk)) == 0, + "Server public key changed on the wire"); + Expect(std::memcmp(client_challenge.challenge, key_challenge.challenge, sizeof(key_challenge.challenge)) == 0, + "Challenge bytes changed on the wire"); + + Expect(client.ComputeClientKeys(client_challenge.server_pk), "Client auth key derivation failed"); + + WireKeyResponsePacket key_response {}; + key_response.header = CG::KEY_RESPONSE; + key_response.length = sizeof(key_response); + client.GetPublicKey(key_response.client_pk); + client.ComputeChallengeResponse(client_challenge.challenge, key_response.challenge_response); + + WriteExact(sockets[1], &key_response, sizeof(key_response), "Failed to write key response"); + + WireKeyResponsePacket server_response {}; + ReadExact(sockets[0], &server_response, sizeof(server_response), "Failed to read key response"); + + Expect(server_response.header == CG::KEY_RESPONSE, "Unexpected key response header"); + Expect(server_response.length == sizeof(server_response), "Unexpected key response length"); + Expect(server.ComputeServerKeys(server_response.client_pk), "Server auth key derivation failed"); + Expect(server.VerifyChallengeResponse(key_challenge.challenge, server_response.challenge_response), + "Server rejected challenge response"); + + std::array session_token {}; + for (size_t i = 0; i < session_token.size(); ++i) + session_token[i] = static_cast(0x30 + i); + + server.SetSessionToken(session_token.data()); + + WireKeyCompletePacket key_complete {}; + key_complete.header = GC::KEY_COMPLETE; + key_complete.length = sizeof(key_complete); + Expect(server.EncryptToken(session_token.data(), session_token.size(), key_complete.encrypted_token, key_complete.nonce), + "Failed to encrypt key complete token"); + + WriteExact(sockets[0], &key_complete, sizeof(key_complete), "Failed to write key complete"); + + WireKeyCompletePacket client_complete {}; + ReadExact(sockets[1], &client_complete, sizeof(client_complete), "Failed to read key complete"); + + Expect(client_complete.header == GC::KEY_COMPLETE, "Unexpected key complete header"); + Expect(client_complete.length == sizeof(client_complete), "Unexpected key complete length"); + + std::array decrypted_token {}; + Expect(client.DecryptToken(client_complete.encrypted_token, sizeof(client_complete.encrypted_token), + client_complete.nonce, decrypted_token.data()), + "Failed to decrypt key complete token"); + Expect(decrypted_token == session_token, "Session token changed on the wire"); + + server.SetActivated(true); + client.SetSessionToken(decrypted_token.data()); + client.SetActivated(true); + + std::array payload {}; + for (size_t i = 0; i < payload.size(); ++i) + payload[i] = static_cast(0x41 + i); + + auto encrypted_payload = payload; + server.EncryptInPlace(encrypted_payload.data(), encrypted_payload.size()); + WriteExact(sockets[0], encrypted_payload.data(), encrypted_payload.size(), "Failed to write encrypted payload"); + + std::array received_payload {}; + ReadExact(sockets[1], received_payload.data(), received_payload.size(), "Failed to read encrypted payload"); + client.DecryptInPlace(received_payload.data(), received_payload.size()); + Expect(received_payload == payload, "Encrypted payload round-trip mismatch"); + + close(sockets[0]); + close(sockets[1]); +} + void TestFdwatchReadAndOneshotWrite() { int sockets[2] = { -1, -1 }; @@ -133,6 +297,7 @@ int main() { TestPacketLayouts(); TestSecureCipherRoundTrip(); + TestSocketAuthWireFlow(); TestFdwatchReadAndOneshotWrite(); std::cout << "metin smoke tests passed\n"; return 0;