tests: add socket auth smoke flow

This commit is contained in:
server
2026-04-14 06:04:08 +02:00
parent 25ec562ab0
commit 3e3f0918e9

View File

@@ -5,30 +5,87 @@
#include <iostream>
#include <stdexcept>
#include "common/packet_headers.h"
#include "game/stdafx.h"
#include "game/SecureCipher.h"
#include "libthecore/fdwatch.h"
namespace
{
#pragma pack(push, 1)
struct WirePhasePacket
{
uint16_t header;
uint16_t length;
uint8_t phase;
};
struct WireKeyChallengePacket
{
uint16_t header;
uint16_t length;
uint8_t server_pk[SecureCipher::PK_SIZE];
uint8_t challenge[SecureCipher::CHALLENGE_SIZE];
uint32_t server_time;
};
struct WireKeyResponsePacket
{
uint16_t header;
uint16_t length;
uint8_t client_pk[SecureCipher::PK_SIZE];
uint8_t challenge_response[SecureCipher::HMAC_SIZE];
};
struct WireKeyCompletePacket
{
uint16_t header;
uint16_t length;
uint8_t encrypted_token[SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE];
uint8_t nonce[SecureCipher::NONCE_SIZE];
};
#pragma pack(pop)
void Expect(bool condition, const char* message)
{
if (!condition)
throw std::runtime_error(message);
}
void WriteExact(int fd, const void* data, size_t length, const char* message)
{
const uint8_t* cursor = static_cast<const uint8_t*>(data);
size_t remaining = length;
while (remaining > 0)
{
const ssize_t written = write(fd, cursor, remaining);
Expect(written > 0, message);
cursor += written;
remaining -= static_cast<size_t>(written);
}
}
void ReadExact(int fd, void* data, size_t length, const char* message)
{
uint8_t* cursor = static_cast<uint8_t*>(data);
size_t remaining = length;
while (remaining > 0)
{
const ssize_t bytes_read = read(fd, cursor, remaining);
Expect(bytes_read > 0, message);
cursor += bytes_read;
remaining -= static_cast<size_t>(bytes_read);
}
}
void TestPacketLayouts()
{
constexpr size_t key_challenge_size =
sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::CHALLENGE_SIZE + sizeof(uint32_t);
constexpr size_t key_response_size =
sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::HMAC_SIZE;
constexpr size_t key_complete_size =
sizeof(uint16_t) * 2 + SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE + SecureCipher::NONCE_SIZE;
Expect(key_challenge_size == 72, "Unexpected key challenge wire size");
Expect(key_response_size == 68, "Unexpected key response wire size");
Expect(key_complete_size == 76, "Unexpected key complete wire size");
Expect(sizeof(WirePhasePacket) == 5, "Unexpected phase wire size");
Expect(sizeof(WireKeyChallengePacket) == 72, "Unexpected key challenge wire size");
Expect(sizeof(WireKeyResponsePacket) == 68, "Unexpected key response wire size");
Expect(sizeof(WireKeyCompletePacket) == 76, "Unexpected key complete wire size");
}
void TestSecureCipherRoundTrip()
@@ -83,6 +140,113 @@ void TestSecureCipherRoundTrip()
Expect(reverse == payload, "Client to server stream cipher round-trip failed");
}
void TestSocketAuthWireFlow()
{
SecureCipher server;
SecureCipher client;
Expect(server.Initialize(), "Server auth cipher init failed");
Expect(client.Initialize(), "Client auth cipher init failed");
int sockets[2] = { -1, -1 };
Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair for auth flow failed");
WirePhasePacket phase_packet {};
phase_packet.header = GC::PHASE;
phase_packet.length = sizeof(phase_packet);
phase_packet.phase = PHASE_HANDSHAKE;
WireKeyChallengePacket key_challenge {};
key_challenge.header = GC::KEY_CHALLENGE;
key_challenge.length = sizeof(key_challenge);
server.GetPublicKey(key_challenge.server_pk);
server.GenerateChallenge(key_challenge.challenge);
key_challenge.server_time = 0x12345678;
WriteExact(sockets[0], &phase_packet, sizeof(phase_packet), "Failed to write phase packet");
WriteExact(sockets[0], &key_challenge, sizeof(key_challenge), "Failed to write key challenge");
WirePhasePacket client_phase {};
WireKeyChallengePacket client_challenge {};
ReadExact(sockets[1], &client_phase, sizeof(client_phase), "Failed to read phase packet");
ReadExact(sockets[1], &client_challenge, sizeof(client_challenge), "Failed to read key challenge");
Expect(client_phase.header == GC::PHASE, "Unexpected phase header");
Expect(client_phase.length == sizeof(client_phase), "Unexpected phase packet length");
Expect(client_phase.phase == PHASE_HANDSHAKE, "Unexpected phase value");
Expect(client_challenge.header == GC::KEY_CHALLENGE, "Unexpected key challenge header");
Expect(client_challenge.length == sizeof(client_challenge), "Unexpected key challenge length");
Expect(std::memcmp(client_challenge.server_pk, key_challenge.server_pk, sizeof(key_challenge.server_pk)) == 0,
"Server public key changed on the wire");
Expect(std::memcmp(client_challenge.challenge, key_challenge.challenge, sizeof(key_challenge.challenge)) == 0,
"Challenge bytes changed on the wire");
Expect(client.ComputeClientKeys(client_challenge.server_pk), "Client auth key derivation failed");
WireKeyResponsePacket key_response {};
key_response.header = CG::KEY_RESPONSE;
key_response.length = sizeof(key_response);
client.GetPublicKey(key_response.client_pk);
client.ComputeChallengeResponse(client_challenge.challenge, key_response.challenge_response);
WriteExact(sockets[1], &key_response, sizeof(key_response), "Failed to write key response");
WireKeyResponsePacket server_response {};
ReadExact(sockets[0], &server_response, sizeof(server_response), "Failed to read key response");
Expect(server_response.header == CG::KEY_RESPONSE, "Unexpected key response header");
Expect(server_response.length == sizeof(server_response), "Unexpected key response length");
Expect(server.ComputeServerKeys(server_response.client_pk), "Server auth key derivation failed");
Expect(server.VerifyChallengeResponse(key_challenge.challenge, server_response.challenge_response),
"Server rejected challenge response");
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> session_token {};
for (size_t i = 0; i < session_token.size(); ++i)
session_token[i] = static_cast<uint8_t>(0x30 + i);
server.SetSessionToken(session_token.data());
WireKeyCompletePacket key_complete {};
key_complete.header = GC::KEY_COMPLETE;
key_complete.length = sizeof(key_complete);
Expect(server.EncryptToken(session_token.data(), session_token.size(), key_complete.encrypted_token, key_complete.nonce),
"Failed to encrypt key complete token");
WriteExact(sockets[0], &key_complete, sizeof(key_complete), "Failed to write key complete");
WireKeyCompletePacket client_complete {};
ReadExact(sockets[1], &client_complete, sizeof(client_complete), "Failed to read key complete");
Expect(client_complete.header == GC::KEY_COMPLETE, "Unexpected key complete header");
Expect(client_complete.length == sizeof(client_complete), "Unexpected key complete length");
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> decrypted_token {};
Expect(client.DecryptToken(client_complete.encrypted_token, sizeof(client_complete.encrypted_token),
client_complete.nonce, decrypted_token.data()),
"Failed to decrypt key complete token");
Expect(decrypted_token == session_token, "Session token changed on the wire");
server.SetActivated(true);
client.SetSessionToken(decrypted_token.data());
client.SetActivated(true);
std::array<uint8_t, 32> payload {};
for (size_t i = 0; i < payload.size(); ++i)
payload[i] = static_cast<uint8_t>(0x41 + i);
auto encrypted_payload = payload;
server.EncryptInPlace(encrypted_payload.data(), encrypted_payload.size());
WriteExact(sockets[0], encrypted_payload.data(), encrypted_payload.size(), "Failed to write encrypted payload");
std::array<uint8_t, 32> received_payload {};
ReadExact(sockets[1], received_payload.data(), received_payload.size(), "Failed to read encrypted payload");
client.DecryptInPlace(received_payload.data(), received_payload.size());
Expect(received_payload == payload, "Encrypted payload round-trip mismatch");
close(sockets[0]);
close(sockets[1]);
}
void TestFdwatchReadAndOneshotWrite()
{
int sockets[2] = { -1, -1 };
@@ -133,6 +297,7 @@ int main()
{
TestPacketLayouts();
TestSecureCipherRoundTrip();
TestSocketAuthWireFlow();
TestFdwatchReadAndOneshotWrite();
std::cout << "metin smoke tests passed\n";
return 0;