#include #include #include #include #include #include #include "common/packet_headers.h" #include "game/stdafx.h" #include "game/SecureCipher.h" #include "libthecore/fdwatch.h" namespace { #pragma pack(push, 1) struct WirePhasePacket { uint16_t header; uint16_t length; uint8_t phase; }; struct WireKeyChallengePacket { uint16_t header; uint16_t length; uint8_t server_pk[SecureCipher::PK_SIZE]; uint8_t challenge[SecureCipher::CHALLENGE_SIZE]; uint32_t server_time; }; struct WireKeyResponsePacket { uint16_t header; uint16_t length; uint8_t client_pk[SecureCipher::PK_SIZE]; uint8_t challenge_response[SecureCipher::HMAC_SIZE]; }; struct WireKeyCompletePacket { uint16_t header; uint16_t length; uint8_t encrypted_token[SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE]; uint8_t nonce[SecureCipher::NONCE_SIZE]; }; #pragma pack(pop) void Expect(bool condition, const char* message) { if (!condition) throw std::runtime_error(message); } void WriteExact(int fd, const void* data, size_t length, const char* message) { const uint8_t* cursor = static_cast(data); size_t remaining = length; while (remaining > 0) { const ssize_t written = write(fd, cursor, remaining); Expect(written > 0, message); cursor += written; remaining -= static_cast(written); } } void ReadExact(int fd, void* data, size_t length, const char* message) { uint8_t* cursor = static_cast(data); size_t remaining = length; while (remaining > 0) { const ssize_t bytes_read = read(fd, cursor, remaining); Expect(bytes_read > 0, message); cursor += bytes_read; remaining -= static_cast(bytes_read); } } void TestPacketLayouts() { Expect(sizeof(WirePhasePacket) == 5, "Unexpected phase wire size"); Expect(sizeof(WireKeyChallengePacket) == 72, "Unexpected key challenge wire size"); Expect(sizeof(WireKeyResponsePacket) == 68, "Unexpected key response wire size"); Expect(sizeof(WireKeyCompletePacket) == 76, "Unexpected key complete wire size"); } void TestSecureCipherRoundTrip() { SecureCipher server; SecureCipher client; Expect(server.Initialize(), "Server SecureCipher init failed"); Expect(client.Initialize(), "Client SecureCipher init failed"); std::array server_pk {}; std::array client_pk {}; server.GetPublicKey(server_pk.data()); client.GetPublicKey(client_pk.data()); Expect(client.ComputeClientKeys(server_pk.data()), "Client session key derivation failed"); Expect(server.ComputeServerKeys(client_pk.data()), "Server session key derivation failed"); std::array challenge {}; std::array response {}; server.GenerateChallenge(challenge.data()); client.ComputeChallengeResponse(challenge.data(), response.data()); Expect(server.VerifyChallengeResponse(challenge.data(), response.data()), "Challenge verification failed"); std::array token {}; for (size_t i = 0; i < token.size(); ++i) token[i] = static_cast(i); std::array ciphertext {}; std::array nonce {}; std::array plaintext {}; Expect(server.EncryptToken(token.data(), token.size(), ciphertext.data(), nonce.data()), "Token encryption failed"); Expect(client.DecryptToken(ciphertext.data(), ciphertext.size(), nonce.data(), plaintext.data()), "Token decryption failed"); Expect(std::memcmp(token.data(), plaintext.data(), token.size()) == 0, "Token round-trip mismatch"); server.SetActivated(true); client.SetActivated(true); std::array payload {}; for (size_t i = 0; i < payload.size(); ++i) payload[i] = static_cast(0xA0 + (i % 31)); auto encrypted = payload; server.EncryptInPlace(encrypted.data(), encrypted.size()); client.DecryptInPlace(encrypted.data(), encrypted.size()); Expect(encrypted == payload, "Server to client stream cipher round-trip failed"); auto reverse = payload; client.EncryptInPlace(reverse.data(), reverse.size()); server.DecryptInPlace(reverse.data(), reverse.size()); Expect(reverse == payload, "Client to server stream cipher round-trip failed"); } void TestSocketAuthWireFlow() { SecureCipher server; SecureCipher client; Expect(server.Initialize(), "Server auth cipher init failed"); Expect(client.Initialize(), "Client auth cipher init failed"); int sockets[2] = { -1, -1 }; Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair for auth flow failed"); WirePhasePacket phase_packet {}; phase_packet.header = GC::PHASE; phase_packet.length = sizeof(phase_packet); phase_packet.phase = PHASE_HANDSHAKE; WireKeyChallengePacket key_challenge {}; key_challenge.header = GC::KEY_CHALLENGE; key_challenge.length = sizeof(key_challenge); server.GetPublicKey(key_challenge.server_pk); server.GenerateChallenge(key_challenge.challenge); key_challenge.server_time = 0x12345678; WriteExact(sockets[0], &phase_packet, sizeof(phase_packet), "Failed to write phase packet"); WriteExact(sockets[0], &key_challenge, sizeof(key_challenge), "Failed to write key challenge"); WirePhasePacket client_phase {}; WireKeyChallengePacket client_challenge {}; ReadExact(sockets[1], &client_phase, sizeof(client_phase), "Failed to read phase packet"); ReadExact(sockets[1], &client_challenge, sizeof(client_challenge), "Failed to read key challenge"); Expect(client_phase.header == GC::PHASE, "Unexpected phase header"); Expect(client_phase.length == sizeof(client_phase), "Unexpected phase packet length"); Expect(client_phase.phase == PHASE_HANDSHAKE, "Unexpected phase value"); Expect(client_challenge.header == GC::KEY_CHALLENGE, "Unexpected key challenge header"); Expect(client_challenge.length == sizeof(client_challenge), "Unexpected key challenge length"); Expect(std::memcmp(client_challenge.server_pk, key_challenge.server_pk, sizeof(key_challenge.server_pk)) == 0, "Server public key changed on the wire"); Expect(std::memcmp(client_challenge.challenge, key_challenge.challenge, sizeof(key_challenge.challenge)) == 0, "Challenge bytes changed on the wire"); Expect(client.ComputeClientKeys(client_challenge.server_pk), "Client auth key derivation failed"); WireKeyResponsePacket key_response {}; key_response.header = CG::KEY_RESPONSE; key_response.length = sizeof(key_response); client.GetPublicKey(key_response.client_pk); client.ComputeChallengeResponse(client_challenge.challenge, key_response.challenge_response); WriteExact(sockets[1], &key_response, sizeof(key_response), "Failed to write key response"); WireKeyResponsePacket server_response {}; ReadExact(sockets[0], &server_response, sizeof(server_response), "Failed to read key response"); Expect(server_response.header == CG::KEY_RESPONSE, "Unexpected key response header"); Expect(server_response.length == sizeof(server_response), "Unexpected key response length"); Expect(server.ComputeServerKeys(server_response.client_pk), "Server auth key derivation failed"); Expect(server.VerifyChallengeResponse(key_challenge.challenge, server_response.challenge_response), "Server rejected challenge response"); std::array session_token {}; for (size_t i = 0; i < session_token.size(); ++i) session_token[i] = static_cast(0x30 + i); server.SetSessionToken(session_token.data()); WireKeyCompletePacket key_complete {}; key_complete.header = GC::KEY_COMPLETE; key_complete.length = sizeof(key_complete); Expect(server.EncryptToken(session_token.data(), session_token.size(), key_complete.encrypted_token, key_complete.nonce), "Failed to encrypt key complete token"); WriteExact(sockets[0], &key_complete, sizeof(key_complete), "Failed to write key complete"); WireKeyCompletePacket client_complete {}; ReadExact(sockets[1], &client_complete, sizeof(client_complete), "Failed to read key complete"); Expect(client_complete.header == GC::KEY_COMPLETE, "Unexpected key complete header"); Expect(client_complete.length == sizeof(client_complete), "Unexpected key complete length"); std::array decrypted_token {}; Expect(client.DecryptToken(client_complete.encrypted_token, sizeof(client_complete.encrypted_token), client_complete.nonce, decrypted_token.data()), "Failed to decrypt key complete token"); Expect(decrypted_token == session_token, "Session token changed on the wire"); server.SetActivated(true); client.SetSessionToken(decrypted_token.data()); client.SetActivated(true); std::array payload {}; for (size_t i = 0; i < payload.size(); ++i) payload[i] = static_cast(0x41 + i); auto encrypted_payload = payload; server.EncryptInPlace(encrypted_payload.data(), encrypted_payload.size()); WriteExact(sockets[0], encrypted_payload.data(), encrypted_payload.size(), "Failed to write encrypted payload"); std::array received_payload {}; ReadExact(sockets[1], received_payload.data(), received_payload.size(), "Failed to read encrypted payload"); client.DecryptInPlace(received_payload.data(), received_payload.size()); Expect(received_payload == payload, "Encrypted payload round-trip mismatch"); close(sockets[0]); close(sockets[1]); } void TestFdwatchReadAndOneshotWrite() { int sockets[2] = { -1, -1 }; Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair failed"); LPFDWATCH fdw = fdwatch_new(64); Expect(fdw != nullptr, "fdwatch_new failed"); int marker = 42; fdwatch_add_fd(fdw, sockets[1], &marker, FDW_READ, false); const uint8_t byte = 0x7F; Expect(write(sockets[0], &byte, sizeof(byte)) == sizeof(byte), "socketpair write failed"); timeval timeout {}; timeout.tv_sec = 0; timeout.tv_usec = 200000; int num_events = fdwatch(fdw, &timeout); Expect(num_events == 1, "Expected one read event"); Expect(fdwatch_get_client_data(fdw, 0) == &marker, "Unexpected client data"); Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_READ, "Expected FDW_READ event"); uint8_t read_back = 0; Expect(read(sockets[1], &read_back, sizeof(read_back)) == sizeof(read_back), "socketpair read failed"); Expect(read_back == byte, "Read payload mismatch"); fdwatch_add_fd(fdw, sockets[1], &marker, FDW_WRITE, true); num_events = fdwatch(fdw, &timeout); Expect(num_events >= 1, "Expected at least one write event"); Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_WRITE, "Expected FDW_WRITE event"); timeout.tv_sec = 0; timeout.tv_usec = 0; num_events = fdwatch(fdw, &timeout); Expect(num_events == 0, "FDW_WRITE oneshot was not cleared"); fdwatch_del_fd(fdw, sockets[1]); fdwatch_delete(fdw); close(sockets[0]); close(sockets[1]); } } int main() { try { TestPacketLayouts(); TestSecureCipherRoundTrip(); TestSocketAuthWireFlow(); TestFdwatchReadAndOneshotWrite(); std::cout << "metin smoke tests passed\n"; return 0; } catch (const std::exception& e) { std::cerr << "metin smoke tests failed: " << e.what() << '\n'; return 1; } }