146 lines
4.9 KiB
C++
146 lines
4.9 KiB
C++
#include <array>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <exception>
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
|
|
#include "game/stdafx.h"
|
|
#include "game/SecureCipher.h"
|
|
#include "libthecore/fdwatch.h"
|
|
|
|
namespace
|
|
{
|
|
void Expect(bool condition, const char* message)
|
|
{
|
|
if (!condition)
|
|
throw std::runtime_error(message);
|
|
}
|
|
|
|
void TestPacketLayouts()
|
|
{
|
|
constexpr size_t key_challenge_size =
|
|
sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::CHALLENGE_SIZE + sizeof(uint32_t);
|
|
constexpr size_t key_response_size =
|
|
sizeof(uint16_t) * 2 + SecureCipher::PK_SIZE + SecureCipher::HMAC_SIZE;
|
|
constexpr size_t key_complete_size =
|
|
sizeof(uint16_t) * 2 + SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE + SecureCipher::NONCE_SIZE;
|
|
|
|
Expect(key_challenge_size == 72, "Unexpected key challenge wire size");
|
|
Expect(key_response_size == 68, "Unexpected key response wire size");
|
|
Expect(key_complete_size == 76, "Unexpected key complete wire size");
|
|
}
|
|
|
|
void TestSecureCipherRoundTrip()
|
|
{
|
|
SecureCipher server;
|
|
SecureCipher client;
|
|
|
|
Expect(server.Initialize(), "Server SecureCipher init failed");
|
|
Expect(client.Initialize(), "Client SecureCipher init failed");
|
|
|
|
std::array<uint8_t, SecureCipher::PK_SIZE> server_pk {};
|
|
std::array<uint8_t, SecureCipher::PK_SIZE> client_pk {};
|
|
server.GetPublicKey(server_pk.data());
|
|
client.GetPublicKey(client_pk.data());
|
|
|
|
Expect(client.ComputeClientKeys(server_pk.data()), "Client session key derivation failed");
|
|
Expect(server.ComputeServerKeys(client_pk.data()), "Server session key derivation failed");
|
|
|
|
std::array<uint8_t, SecureCipher::CHALLENGE_SIZE> challenge {};
|
|
std::array<uint8_t, SecureCipher::HMAC_SIZE> response {};
|
|
server.GenerateChallenge(challenge.data());
|
|
client.ComputeChallengeResponse(challenge.data(), response.data());
|
|
Expect(server.VerifyChallengeResponse(challenge.data(), response.data()), "Challenge verification failed");
|
|
|
|
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> token {};
|
|
for (size_t i = 0; i < token.size(); ++i)
|
|
token[i] = static_cast<uint8_t>(i);
|
|
|
|
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE + SecureCipher::TAG_SIZE> ciphertext {};
|
|
std::array<uint8_t, SecureCipher::NONCE_SIZE> nonce {};
|
|
std::array<uint8_t, SecureCipher::SESSION_TOKEN_SIZE> plaintext {};
|
|
|
|
Expect(server.EncryptToken(token.data(), token.size(), ciphertext.data(), nonce.data()), "Token encryption failed");
|
|
Expect(client.DecryptToken(ciphertext.data(), ciphertext.size(), nonce.data(), plaintext.data()), "Token decryption failed");
|
|
Expect(std::memcmp(token.data(), plaintext.data(), token.size()) == 0, "Token round-trip mismatch");
|
|
|
|
server.SetActivated(true);
|
|
client.SetActivated(true);
|
|
|
|
std::array<uint8_t, 96> payload {};
|
|
for (size_t i = 0; i < payload.size(); ++i)
|
|
payload[i] = static_cast<uint8_t>(0xA0 + (i % 31));
|
|
|
|
auto encrypted = payload;
|
|
server.EncryptInPlace(encrypted.data(), encrypted.size());
|
|
client.DecryptInPlace(encrypted.data(), encrypted.size());
|
|
Expect(encrypted == payload, "Server to client stream cipher round-trip failed");
|
|
|
|
auto reverse = payload;
|
|
client.EncryptInPlace(reverse.data(), reverse.size());
|
|
server.DecryptInPlace(reverse.data(), reverse.size());
|
|
Expect(reverse == payload, "Client to server stream cipher round-trip failed");
|
|
}
|
|
|
|
void TestFdwatchReadAndOneshotWrite()
|
|
{
|
|
int sockets[2] = { -1, -1 };
|
|
Expect(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0, "socketpair failed");
|
|
|
|
LPFDWATCH fdw = fdwatch_new(64);
|
|
Expect(fdw != nullptr, "fdwatch_new failed");
|
|
|
|
int marker = 42;
|
|
fdwatch_add_fd(fdw, sockets[1], &marker, FDW_READ, false);
|
|
|
|
const uint8_t byte = 0x7F;
|
|
Expect(write(sockets[0], &byte, sizeof(byte)) == sizeof(byte), "socketpair write failed");
|
|
|
|
timeval timeout {};
|
|
timeout.tv_sec = 0;
|
|
timeout.tv_usec = 200000;
|
|
|
|
int num_events = fdwatch(fdw, &timeout);
|
|
Expect(num_events == 1, "Expected one read event");
|
|
Expect(fdwatch_get_client_data(fdw, 0) == &marker, "Unexpected client data");
|
|
Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_READ, "Expected FDW_READ event");
|
|
|
|
uint8_t read_back = 0;
|
|
Expect(read(sockets[1], &read_back, sizeof(read_back)) == sizeof(read_back), "socketpair read failed");
|
|
Expect(read_back == byte, "Read payload mismatch");
|
|
|
|
fdwatch_add_fd(fdw, sockets[1], &marker, FDW_WRITE, true);
|
|
num_events = fdwatch(fdw, &timeout);
|
|
Expect(num_events >= 1, "Expected at least one write event");
|
|
Expect(fdwatch_check_event(fdw, sockets[1], 0) == FDW_WRITE, "Expected FDW_WRITE event");
|
|
|
|
timeout.tv_sec = 0;
|
|
timeout.tv_usec = 0;
|
|
num_events = fdwatch(fdw, &timeout);
|
|
Expect(num_events == 0, "FDW_WRITE oneshot was not cleared");
|
|
|
|
fdwatch_del_fd(fdw, sockets[1]);
|
|
fdwatch_delete(fdw);
|
|
close(sockets[0]);
|
|
close(sockets[1]);
|
|
}
|
|
}
|
|
|
|
int main()
|
|
{
|
|
try
|
|
{
|
|
TestPacketLayouts();
|
|
TestSecureCipherRoundTrip();
|
|
TestFdwatchReadAndOneshotWrite();
|
|
std::cout << "metin smoke tests passed\n";
|
|
return 0;
|
|
}
|
|
catch (const std::exception& e)
|
|
{
|
|
std::cerr << "metin smoke tests failed: " << e.what() << '\n';
|
|
return 1;
|
|
}
|
|
}
|