Files
m2dev-server-src/tests/smoke_auth.cpp
2026-04-14 06:04:08 +02:00

311 lines
11 KiB
C++

#include <array>
#include <cstdint>
#include <cstring>
#include <exception>
#include <iostream>
#include <stdexcept>
#include "common/packet_headers.h"
#include "game/stdafx.h"
#include "game/SecureCipher.h"
#include "libthecore/fdwatch.h"
namespace
{
#pragma pack(push, 1)
struct WirePhasePacket
{
uint16_t header;
uint16_t length;
uint8_t phase;
};
struct WireKeyChallengePacket
{
uint16_t header;
uint16_t length;
uint8_t server_pk[SecureCipher::PK_SIZE];
uint8_t challenge[SecureCipher::CHALLENGE_SIZE];
uint32_t server_time;
};
struct WireKeyResponsePacket
{
uint16_t header;
uint16_t length;
uint8_t client_pk[SecureCipher::PK_SIZE];
uint8_t challenge_response[SecureCipher::HMAC_SIZE];
};
struct WireKeyCompletePacket
{
uint16_t header;
uint16_t length;
uint8_t encrypted_token[SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE];
uint8_t nonce[SecureCipher::NONCE_SIZE];
};
#pragma pack(pop)
void Expect(bool condition, const char* message)
{
if (!condition)
throw std::runtime_error(message);
}
void WriteExact(int fd, const void* data, size_t length, const char* message)
{
const uint8_t* cursor = static_cast<const uint8_t*>(data);
size_t remaining = length;
while (remaining > 0)
{
const ssize_t written = write(fd, cursor, remaining);
Expect(written > 0, message);
cursor += written;
remaining -= static_cast<size_t>(written);
}
}
void ReadExact(int fd, void* data, size_t length, const char* message)
{
uint8_t* cursor = static_cast<uint8_t*>(data);
size_t remaining = length;
while (remaining > 0)
{
const ssize_t bytes_read = read(fd, cursor, remaining);
Expect(bytes_read > 0, message);
cursor += bytes_read;
remaining -= static_cast<size_t>(bytes_read);
}
}
void TestPacketLayouts()
{
Expect(sizeof(WirePhasePacket) == 5, "Unexpected phase wire size");
Expect(sizeof(WireKeyChallengePacket) == 72, "Unexpected key challenge wire size");
Expect(sizeof(WireKeyResponsePacket) == 68, "Unexpected key response wire size");
Expect(sizeof(WireKeyCompletePacket) == 76, "Unexpected key complete wire size");
}
void TestSecureCipherRoundTrip()
{
SecureCipher server;
SecureCipher client;
Expect(server.Initialize(), "Server SecureCipher init failed");
Expect(client.Initialize(), "Client SecureCipher init failed");
std::array<uint8_t, SecureCipher::PK_SIZE> server_pk {};
std::array<uint8_t, SecureCipher::PK_SIZE> client_pk {};
server.GetPublicKey(server_pk.data());
client.GetPublicKey(client_pk.data());
Expect(client.ComputeClientKeys(server_pk.data()), "Client session key derivation failed");
Expect(server.ComputeServerKeys(client_pk.data()), "Server session key derivation failed");
std::array<uint8_t, SecureCipher::CHALLENGE_SIZE> challenge {};
std::array<uint8_t, SecureCipher::HMAC_SIZE> response {};
server.GenerateChallenge(challenge.data());
client.ComputeChallengeResponse(challenge.data(), response.data());
Expect(server.VerifyChallengeResponse(challenge.data(), response.data()), "Challenge verification failed");
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> token {};
for (size_t i = 0; i < token.size(); ++i)
token[i] = static_cast<uint8_t>(i);
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE> ciphertext {};
std::array<uint8_t, SecureCipher::NONCE_SIZE> nonce {};
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> plaintext {};
Expect(server.EncryptToken(token.data(), token.size(), ciphertext.data(), nonce.data()), "Token encryption failed");
Expect(client.DecryptToken(ciphertext.data(), ciphertext.size(), nonce.data(), plaintext.data()), "Token decryption failed");
Expect(std::memcmp(token.data(), plaintext.data(), token.size()) == 0, "Token round-trip mismatch");
server.SetActivated(true);
client.SetActivated(true);
std::array<uint8_t, 96> payload {};
for (size_t i = 0; i < payload.size(); ++i)
payload[i] = static_cast<uint8_t>(0xA0 + (i % 31));
auto encrypted = payload;
server.EncryptInPlace(encrypted.data(), encrypted.size());
client.DecryptInPlace(encrypted.data(), encrypted.size());
Expect(encrypted == payload, "Server to client stream cipher round-trip failed");
auto reverse = payload;
client.EncryptInPlace(reverse.data(), reverse.size());
server.DecryptInPlace(reverse.data(), reverse.size());
Expect(reverse == payload, "Client to server stream cipher round-trip failed");
}
void TestSocketAuthWireFlow()
{
SecureCipher server;
SecureCipher client;
Expect(server.Initialize(), "Server auth cipher init failed");
Expect(client.Initialize(), "Client auth cipher init failed");
int sockets[2] = { -1, -1 };
Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair for auth flow failed");
WirePhasePacket phase_packet {};
phase_packet.header = GC::PHASE;
phase_packet.length = sizeof(phase_packet);
phase_packet.phase = PHASE_HANDSHAKE;
WireKeyChallengePacket key_challenge {};
key_challenge.header = GC::KEY_CHALLENGE;
key_challenge.length = sizeof(key_challenge);
server.GetPublicKey(key_challenge.server_pk);
server.GenerateChallenge(key_challenge.challenge);
key_challenge.server_time = 0x12345678;
WriteExact(sockets[0], &phase_packet, sizeof(phase_packet), "Failed to write phase packet");
WriteExact(sockets[0], &key_challenge, sizeof(key_challenge), "Failed to write key challenge");
WirePhasePacket client_phase {};
WireKeyChallengePacket client_challenge {};
ReadExact(sockets[1], &client_phase, sizeof(client_phase), "Failed to read phase packet");
ReadExact(sockets[1], &client_challenge, sizeof(client_challenge), "Failed to read key challenge");
Expect(client_phase.header == GC::PHASE, "Unexpected phase header");
Expect(client_phase.length == sizeof(client_phase), "Unexpected phase packet length");
Expect(client_phase.phase == PHASE_HANDSHAKE, "Unexpected phase value");
Expect(client_challenge.header == GC::KEY_CHALLENGE, "Unexpected key challenge header");
Expect(client_challenge.length == sizeof(client_challenge), "Unexpected key challenge length");
Expect(std::memcmp(client_challenge.server_pk, key_challenge.server_pk, sizeof(key_challenge.server_pk)) == 0,
"Server public key changed on the wire");
Expect(std::memcmp(client_challenge.challenge, key_challenge.challenge, sizeof(key_challenge.challenge)) == 0,
"Challenge bytes changed on the wire");
Expect(client.ComputeClientKeys(client_challenge.server_pk), "Client auth key derivation failed");
WireKeyResponsePacket key_response {};
key_response.header = CG::KEY_RESPONSE;
key_response.length = sizeof(key_response);
client.GetPublicKey(key_response.client_pk);
client.ComputeChallengeResponse(client_challenge.challenge, key_response.challenge_response);
WriteExact(sockets[1], &key_response, sizeof(key_response), "Failed to write key response");
WireKeyResponsePacket server_response {};
ReadExact(sockets[0], &server_response, sizeof(server_response), "Failed to read key response");
Expect(server_response.header == CG::KEY_RESPONSE, "Unexpected key response header");
Expect(server_response.length == sizeof(server_response), "Unexpected key response length");
Expect(server.ComputeServerKeys(server_response.client_pk), "Server auth key derivation failed");
Expect(server.VerifyChallengeResponse(key_challenge.challenge, server_response.challenge_response),
"Server rejected challenge response");
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> session_token {};
for (size_t i = 0; i < session_token.size(); ++i)
session_token[i] = static_cast<uint8_t>(0x30 + i);
server.SetSessionToken(session_token.data());
WireKeyCompletePacket key_complete {};
key_complete.header = GC::KEY_COMPLETE;
key_complete.length = sizeof(key_complete);
Expect(server.EncryptToken(session_token.data(), session_token.size(), key_complete.encrypted_token, key_complete.nonce),
"Failed to encrypt key complete token");
WriteExact(sockets[0], &key_complete, sizeof(key_complete), "Failed to write key complete");
WireKeyCompletePacket client_complete {};
ReadExact(sockets[1], &client_complete, sizeof(client_complete), "Failed to read key complete");
Expect(client_complete.header == GC::KEY_COMPLETE, "Unexpected key complete header");
Expect(client_complete.length == sizeof(client_complete), "Unexpected key complete length");
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> decrypted_token {};
Expect(client.DecryptToken(client_complete.encrypted_token, sizeof(client_complete.encrypted_token),
client_complete.nonce, decrypted_token.data()),
"Failed to decrypt key complete token");
Expect(decrypted_token == session_token, "Session token changed on the wire");
server.SetActivated(true);
client.SetSessionToken(decrypted_token.data());
client.SetActivated(true);
std::array<uint8_t, 32> payload {};
for (size_t i = 0; i < payload.size(); ++i)
payload[i] = static_cast<uint8_t>(0x41 + i);
auto encrypted_payload = payload;
server.EncryptInPlace(encrypted_payload.data(), encrypted_payload.size());
WriteExact(sockets[0], encrypted_payload.data(), encrypted_payload.size(), "Failed to write encrypted payload");
std::array<uint8_t, 32> received_payload {};
ReadExact(sockets[1], received_payload.data(), received_payload.size(), "Failed to read encrypted payload");
client.DecryptInPlace(received_payload.data(), received_payload.size());
Expect(received_payload == payload, "Encrypted payload round-trip mismatch");
close(sockets[0]);
close(sockets[1]);
}
void TestFdwatchReadAndOneshotWrite()
{
int sockets[2] = { -1, -1 };
Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair failed");
LPFDWATCH fdw = fdwatch_new(64);
Expect(fdw != nullptr, "fdwatch_new failed");
int marker = 42;
fdwatch_add_fd(fdw, sockets[1], &marker, FDW_READ, false);
const uint8_t byte = 0x7F;
Expect(write(sockets[0], &byte, sizeof(byte)) == sizeof(byte), "socketpair write failed");
timeval timeout {};
timeout.tv_sec = 0;
timeout.tv_usec = 200000;
int num_events = fdwatch(fdw, &timeout);
Expect(num_events == 1, "Expected one read event");
Expect(fdwatch_get_client_data(fdw, 0) == &marker, "Unexpected client data");
Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_READ, "Expected FDW_READ event");
uint8_t read_back = 0;
Expect(read(sockets[1], &read_back, sizeof(read_back)) == sizeof(read_back), "socketpair read failed");
Expect(read_back == byte, "Read payload mismatch");
fdwatch_add_fd(fdw, sockets[1], &marker, FDW_WRITE, true);
num_events = fdwatch(fdw, &timeout);
Expect(num_events >= 1, "Expected at least one write event");
Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_WRITE, "Expected FDW_WRITE event");
timeout.tv_sec = 0;
timeout.tv_usec = 0;
num_events = fdwatch(fdw, &timeout);
Expect(num_events == 0, "FDW_WRITE oneshot was not cleared");
fdwatch_del_fd(fdw, sockets[1]);
fdwatch_delete(fdw);
close(sockets[0]);
close(sockets[1]);
}
}
int main()
{
try
{
TestPacketLayouts();
TestSecureCipherRoundTrip();
TestSocketAuthWireFlow();
TestFdwatchReadAndOneshotWrite();
std::cout << "metin smoke tests passed\n";
return 0;
}
catch (const std::exception& e)
{
std::cerr << "metin smoke tests failed: " << e.what() << '\n';
return 1;
}
}