diff --git a/src/main.zig b/src/main.zig index e5cd0c8..c3561ef 100644 --- a/src/main.zig +++ b/src/main.zig @@ -12,6 +12,9 @@ pub const crypto = @import("crypto.zig"); pub const protocol = @import("protocol.zig"); pub const discovery = @import("discovery.zig"); pub const connection = @import("connection.zig"); +pub const tls = @import("tls.zig"); +pub const stun = @import("stun.zig"); +pub const relay = @import("relay.zig"); // Re-exports principales pub const DeviceId = identity.DeviceId; diff --git a/src/relay.zig b/src/relay.zig new file mode 100644 index 0000000..5f88a0f --- /dev/null +++ b/src/relay.zig @@ -0,0 +1,539 @@ +//! Módulo Relay - Retransmisión cuando la conexión directa falla +//! +//! Implementa el protocolo de relay para atravesar NATs simétricos. +//! Compatible con el protocolo relay de Syncthing. + +const std = @import("std"); +const identity = @import("identity.zig"); +const crypto = @import("crypto.zig"); +const tls = @import("tls.zig"); + +pub const DeviceId = identity.DeviceId; + +/// Puerto por defecto para relay +pub const RELAY_PORT: u16 = 22067; + +/// Magic bytes del protocolo relay +const RELAY_MAGIC: u32 = 0x9E79BC40; + +/// Tipos de mensaje relay +pub const MessageType = enum(u32) { + ping = 0, + pong = 1, + join_relay_request = 2, + join_session_request = 3, + response = 4, + connect_request = 5, + session_invitation = 6, + _, +}; + +/// Códigos de respuesta +pub const ResponseCode = enum(u32) { + success = 0, + not_found = 1, + already_connected = 2, + limit_exceeded = 3, + unexpected_message = 100, + _, +}; + +/// Estado de la sesión relay +pub const SessionState = enum { + disconnected, + connecting, + joined, + session_pending, + connected, + @"error", +}; + +/// Mensaje del protocolo relay +pub const RelayMessage = struct { + msg_type: MessageType, + data: []const u8, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator, msg_type: MessageType, data: []const u8) !RelayMessage { + return .{ + .msg_type = msg_type, + .data = try allocator.dupe(u8, data), + .allocator = allocator, + }; + } + + pub fn deinit(self: *RelayMessage) void { + self.allocator.free(self.data); + } + + /// Codifica el mensaje + pub fn encode(self: *RelayMessage, out: []u8) usize { + std.mem.writeInt(u32, out[0..4], RELAY_MAGIC, .big); + std.mem.writeInt(u32, out[4..8], @intFromEnum(self.msg_type), .big); + std.mem.writeInt(u32, out[8..12], @intCast(self.data.len), .big); + @memcpy(out[12 .. 12 + self.data.len], self.data); + return 12 + self.data.len; + } + + /// Decodifica un mensaje + pub fn decode(allocator: std.mem.Allocator, data: []const u8) !RelayMessage { + if (data.len < 12) return error.MessageTooShort; + + const magic = std.mem.readInt(u32, data[0..4], .big); + if (magic != RELAY_MAGIC) return error.InvalidMagic; + + const msg_type: MessageType = @enumFromInt(std.mem.readInt(u32, data[4..8], .big)); + const length = std.mem.readInt(u32, data[8..12], .big); + + if (data.len < 12 + length) return error.MessageTooShort; + + return .{ + .msg_type = msg_type, + .data = try allocator.dupe(u8, data[12 .. 12 + length]), + .allocator = allocator, + }; + } +}; + +/// JoinRelayRequest - Unirse a un servidor relay +pub const JoinRelayRequest = struct { + pub fn encode(_: []u8) usize { + // Mensaje vacío + return 0; + } +}; + +/// JoinSessionRequest - Unirse a una sesión con otro dispositivo +pub const JoinSessionRequest = struct { + device_id: DeviceId, + + pub fn encode(self: JoinSessionRequest, out: []u8) usize { + @memcpy(out[0..32], &self.device_id); + return 32; + } + + pub fn decode(data: []const u8) !JoinSessionRequest { + if (data.len < 32) return error.InvalidData; + return .{ + .device_id = data[0..32].*, + }; + } +}; + +/// ConnectRequest - Solicitar conexión a un dispositivo +pub const ConnectRequest = struct { + device_id: DeviceId, + + pub fn encode(self: ConnectRequest, out: []u8) usize { + @memcpy(out[0..32], &self.device_id); + return 32; + } + + pub fn decode(data: []const u8) !ConnectRequest { + if (data.len < 32) return error.InvalidData; + return .{ + .device_id = data[0..32].*, + }; + } +}; + +/// SessionInvitation - Invitación a una sesión +pub const SessionInvitation = struct { + from: DeviceId, + key: [32]u8, + address: []const u8, + port: u16, + server_socket: bool, + + pub fn decode(allocator: std.mem.Allocator, data: []const u8) !SessionInvitation { + if (data.len < 68) return error.InvalidData; + + const addr_len = std.mem.readInt(u32, data[64..68], .big); + if (data.len < 70 + addr_len) return error.InvalidData; + + return .{ + .from = data[0..32].*, + .key = data[32..64].*, + .address = try allocator.dupe(u8, data[68 .. 68 + addr_len]), + .port = std.mem.readInt(u16, data[68 + addr_len ..][0..2], .big), + .server_socket = data[70 + addr_len] != 0, + }; + } +}; + +/// Response - Respuesta del servidor +pub const Response = struct { + code: ResponseCode, + message: []const u8, + + pub fn decode(allocator: std.mem.Allocator, data: []const u8) !Response { + if (data.len < 8) return error.InvalidData; + + const code: ResponseCode = @enumFromInt(std.mem.readInt(u32, data[0..4], .big)); + const msg_len = std.mem.readInt(u32, data[4..8], .big); + + return .{ + .code = code, + .message = if (msg_len > 0 and data.len >= 8 + msg_len) + try allocator.dupe(u8, data[8 .. 8 + msg_len]) + else + "", + }; + } +}; + +/// Cliente relay +pub const RelayClient = struct { + allocator: std.mem.Allocator, + my_device_id: DeviceId, + servers: std.ArrayListUnmanaged([]const u8), + socket: ?std.posix.socket_t, + tls_conn: ?*tls.TlsConnection, + state: SessionState, + session_key: ?[32]u8, + + pub fn init(allocator: std.mem.Allocator, device_id: DeviceId) RelayClient { + return .{ + .allocator = allocator, + .my_device_id = device_id, + .servers = .{}, + .socket = null, + .tls_conn = null, + .state = .disconnected, + .session_key = null, + }; + } + + pub fn deinit(self: *RelayClient) void { + if (self.socket) |sock| { + std.posix.close(sock); + } + if (self.tls_conn) |conn| { + conn.deinit(); + self.allocator.destroy(conn); + } + for (self.servers.items) |server| { + self.allocator.free(server); + } + self.servers.deinit(self.allocator); + } + + /// Añade un servidor relay + pub fn addServer(self: *RelayClient, server: []const u8) !void { + const owned = try self.allocator.dupe(u8, server); + try self.servers.append(self.allocator, owned); + } + + /// Conecta a un servidor relay + pub fn connect(self: *RelayClient, server_addr: std.net.Address) !void { + self.state = .connecting; + + // Crear socket TCP + self.socket = try std.posix.socket( + std.posix.AF.INET, + std.posix.SOCK.STREAM, + 0, + ); + errdefer { + if (self.socket) |sock| std.posix.close(sock); + self.socket = null; + } + + // Conectar + try std.posix.connect(self.socket.?, &server_addr.any, server_addr.getOsSockLen()); + + // Iniciar TLS + const tls_conn = try self.allocator.create(tls.TlsConnection); + tls_conn.* = tls.TlsConnection.init(self.allocator); + self.tls_conn = tls_conn; + + // Enviar ClientHello + var hello_buf: [512]u8 = undefined; + const hello_len = try tls_conn.generateClientHello(&hello_buf); + + // Wrap en TLS record + var record_buf: [600]u8 = undefined; + const record = tls.TlsRecord{ + .content_type = .handshake, + .version = tls.ProtocolVersion.TLS_1_2, + .length = @intCast(hello_len), + .fragment = hello_buf[0..hello_len], + }; + const record_len = record.encode(&record_buf); + + _ = try std.posix.send(self.socket.?, record_buf[0..record_len], 0); + + // TODO: Procesar respuesta del servidor TLS + } + + /// Se une al pool del relay + pub fn joinRelay(self: *RelayClient) !void { + if (self.socket == null) return error.NotConnected; + + var msg_data: [0]u8 = undefined; + var msg = try RelayMessage.init(self.allocator, .join_relay_request, &msg_data); + defer msg.deinit(); + + var buf: [256]u8 = undefined; + const len = msg.encode(&buf); + + // Cifrar y enviar + if (self.tls_conn) |tls_conn| { + var encrypted: [512]u8 = undefined; + const enc_len = try tls_conn.encrypt(buf[0..len], &encrypted); + _ = try std.posix.send(self.socket.?, encrypted[0..enc_len], 0); + } else { + _ = try std.posix.send(self.socket.?, buf[0..len], 0); + } + + self.state = .joined; + } + + /// Solicita conexión a otro dispositivo + pub fn requestConnection(self: *RelayClient, target_device: DeviceId) !void { + if (self.state != .joined) return error.NotJoined; + + var data_buf: [32]u8 = undefined; + const req = ConnectRequest{ .device_id = target_device }; + const data_len = req.encode(&data_buf); + + var msg = try RelayMessage.init(self.allocator, .connect_request, data_buf[0..data_len]); + defer msg.deinit(); + + var buf: [256]u8 = undefined; + const len = msg.encode(&buf); + + if (self.tls_conn) |tls_conn| { + var encrypted: [512]u8 = undefined; + const enc_len = try tls_conn.encrypt(buf[0..len], &encrypted); + _ = try std.posix.send(self.socket.?, encrypted[0..enc_len], 0); + } else { + _ = try std.posix.send(self.socket.?, buf[0..len], 0); + } + + self.state = .session_pending; + } + + /// Procesa un mensaje entrante + pub fn processMessage(self: *RelayClient, data: []const u8) !void { + var msg = try RelayMessage.decode(self.allocator, data); + defer msg.deinit(); + + switch (msg.msg_type) { + .ping => { + try self.sendPong(); + }, + .response => { + const resp = try Response.decode(self.allocator, msg.data); + if (resp.code != .success) { + self.state = .@"error"; + } + }, + .session_invitation => { + const invitation = try SessionInvitation.decode(self.allocator, msg.data); + self.session_key = invitation.key; + self.state = .connected; + }, + else => {}, + } + } + + fn sendPong(self: *RelayClient) !void { + var msg_data: [0]u8 = undefined; + var msg = try RelayMessage.init(self.allocator, .pong, &msg_data); + defer msg.deinit(); + + var buf: [256]u8 = undefined; + const len = msg.encode(&buf); + + if (self.socket) |sock| { + _ = try std.posix.send(sock, buf[0..len], 0); + } + } + + /// Envía datos a través del relay + pub fn send(self: *RelayClient, data: []const u8) !void { + if (self.state != .connected) return error.NotConnected; + if (self.socket == null) return error.NotConnected; + + // Los datos van directamente por la sesión relay + _ = try std.posix.send(self.socket.?, data, 0); + } + + /// Recibe datos del relay + pub fn receive(self: *RelayClient, buf: []u8) !usize { + if (self.state != .connected) return error.NotConnected; + if (self.socket == null) return error.NotConnected; + + const result = std.posix.recv(self.socket.?, buf, 0); + return result catch error.ReceiveFailed; + } +}; + +/// Pool de conexiones relay +pub const RelayPool = struct { + allocator: std.mem.Allocator, + device_id: DeviceId, + clients: std.ArrayListUnmanaged(*RelayClient), + active_client: ?*RelayClient, + + pub fn init(allocator: std.mem.Allocator, device_id: DeviceId) RelayPool { + return .{ + .allocator = allocator, + .device_id = device_id, + .clients = .{}, + .active_client = null, + }; + } + + pub fn deinit(self: *RelayPool) void { + for (self.clients.items) |client| { + client.deinit(); + self.allocator.destroy(client); + } + self.clients.deinit(self.allocator); + } + + /// Añade servidores relay + pub fn addServers(self: *RelayPool, servers: []const []const u8) !void { + for (servers) |server| { + const client = try self.allocator.create(RelayClient); + client.* = RelayClient.init(self.allocator, self.device_id); + try client.addServer(server); + try self.clients.append(self.allocator, client); + } + } + + /// Conecta al mejor relay disponible + pub fn connect(self: *RelayPool) !void { + for (self.clients.items) |client| { + // Parsear servidor + if (client.servers.items.len == 0) continue; + + const server = client.servers.items[0]; + const addr = parseServerAddress(server) catch continue; + + client.connect(addr) catch continue; + client.joinRelay() catch continue; + + self.active_client = client; + return; + } + + return error.NoRelayAvailable; + } + + /// Solicita conexión a un dispositivo + pub fn connectToDevice(self: *RelayPool, device_id: DeviceId) !void { + if (self.active_client) |client| { + try client.requestConnection(device_id); + } else { + return error.NotConnected; + } + } +}; + +fn parseServerAddress(server: []const u8) !std.net.Address { + // Formato: host:port o relay://host:port + var start: usize = 0; + if (std.mem.startsWith(u8, server, "relay://")) { + start = 8; + } + + const rest = server[start..]; + var host_end = rest.len; + var port: u16 = RELAY_PORT; + + if (std.mem.lastIndexOf(u8, rest, ":")) |colon| { + host_end = colon; + port = std.fmt.parseInt(u16, rest[colon + 1 ..], 10) catch RELAY_PORT; + } + + const host = rest[0..host_end]; + + // Parsear IP o resolver DNS (simplificado) + return parseIpOrResolve(host, port); +} + +fn parseIpOrResolve(host: []const u8, port: u16) !std.net.Address { + var octets: [4]u8 = undefined; + var octet_idx: usize = 0; + var current: u16 = 0; + + for (host) |c| { + if (c == '.') { + if (octet_idx >= 4) return error.InvalidAddress; + octets[octet_idx] = @intCast(current); + octet_idx += 1; + current = 0; + } else if (c >= '0' and c <= '9') { + current = current * 10 + (c - '0'); + if (current > 255) return error.InvalidAddress; + } else { + // Es hostname - necesita DNS lookup + // Por ahora, error + return error.DnsLookupRequired; + } + } + + if (octet_idx == 3) { + octets[3] = @intCast(current); + return std.net.Address.initIp4(octets, port); + } + + return error.InvalidAddress; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "relay message encode/decode" { + const allocator = std.testing.allocator; + + var msg = try RelayMessage.init(allocator, .ping, &.{}); + defer msg.deinit(); + + var buf: [256]u8 = undefined; + const len = msg.encode(&buf); + + try std.testing.expect(len == 12); // Header only + + var decoded = try RelayMessage.decode(allocator, buf[0..len]); + defer decoded.deinit(); + + try std.testing.expect(decoded.msg_type == .ping); +} + +test "join session request" { + const device_id = [_]u8{0xab} ** 32; + const req = JoinSessionRequest{ .device_id = device_id }; + + var buf: [64]u8 = undefined; + const len = req.encode(&buf); + + try std.testing.expect(len == 32); + + const decoded = try JoinSessionRequest.decode(buf[0..len]); + try std.testing.expectEqualSlices(u8, &device_id, &decoded.device_id); +} + +test "relay client init" { + const allocator = std.testing.allocator; + const device_id = [_]u8{0xcd} ** 32; + + var client = RelayClient.init(allocator, device_id); + defer client.deinit(); + + try std.testing.expect(client.state == .disconnected); +} + +test "relay pool init" { + const allocator = std.testing.allocator; + const device_id = [_]u8{0xef} ** 32; + + var pool = RelayPool.init(allocator, device_id); + defer pool.deinit(); + + try std.testing.expect(pool.active_client == null); +} diff --git a/src/stun.zig b/src/stun.zig new file mode 100644 index 0000000..35587af --- /dev/null +++ b/src/stun.zig @@ -0,0 +1,546 @@ +//! Módulo STUN - Session Traversal Utilities for NAT (RFC 5389) +//! +//! Cliente STUN para descubrir dirección IP externa y tipo de NAT. + +const std = @import("std"); +const crypto = @import("crypto.zig"); + +/// Puerto STUN estándar +pub const STUN_PORT: u16 = 3478; + +/// Magic cookie STUN +const MAGIC_COOKIE: u32 = 0x2112A442; + +/// Tipos de mensaje STUN +pub const MessageType = enum(u16) { + binding_request = 0x0001, + binding_response = 0x0101, + binding_error = 0x0111, + _, +}; + +/// Tipos de atributo STUN +pub const AttributeType = enum(u16) { + mapped_address = 0x0001, + response_address = 0x0002, + change_request = 0x0003, + source_address = 0x0004, + changed_address = 0x0005, + username = 0x0006, + password = 0x0007, + message_integrity = 0x0008, + error_code = 0x0009, + unknown_attributes = 0x000A, + reflected_from = 0x000B, + realm = 0x0014, + nonce = 0x0015, + xor_mapped_address = 0x0020, + software = 0x8022, + alternate_server = 0x8023, + fingerprint = 0x8028, + other_address = 0x802C, + _, +}; + +/// Familia de direcciones +pub const AddressFamily = enum(u8) { + ipv4 = 0x01, + ipv6 = 0x02, + _, +}; + +/// Dirección mapeada +pub const MappedAddress = struct { + family: AddressFamily, + port: u16, + address: union { + ipv4: [4]u8, + ipv6: [16]u8, + }, + + pub fn format(self: MappedAddress, buf: []u8) []const u8 { + if (self.family == .ipv4) { + return std.fmt.bufPrint(buf, "{}.{}.{}.{}:{}", .{ + self.address.ipv4[0], + self.address.ipv4[1], + self.address.ipv4[2], + self.address.ipv4[3], + self.port, + }) catch ""; + } + return ""; + } +}; + +/// Mensaje STUN +pub const StunMessage = struct { + message_type: MessageType, + transaction_id: [12]u8, + attributes: std.ArrayListUnmanaged(Attribute), + allocator: std.mem.Allocator, + + pub const Attribute = struct { + attr_type: AttributeType, + data: []const u8, + }; + + pub fn init(allocator: std.mem.Allocator, msg_type: MessageType) StunMessage { + var transaction_id: [12]u8 = undefined; + std.crypto.random.bytes(&transaction_id); + + return .{ + .message_type = msg_type, + .transaction_id = transaction_id, + .attributes = .{}, + .allocator = allocator, + }; + } + + pub fn deinit(self: *StunMessage) void { + for (self.attributes.items) |attr| { + self.allocator.free(attr.data); + } + self.attributes.deinit(self.allocator); + } + + /// Codifica el mensaje STUN + pub fn encode(self: *StunMessage) ![]u8 { + // Calcular longitud de atributos + var attrs_len: usize = 0; + for (self.attributes.items) |attr| { + attrs_len += 4 + attr.data.len; + // Padding a 4 bytes + if (attr.data.len % 4 != 0) { + attrs_len += 4 - (attr.data.len % 4); + } + } + + const total_len = 20 + attrs_len; + const buf = try self.allocator.alloc(u8, total_len); + errdefer self.allocator.free(buf); + + var pos: usize = 0; + + // Header + std.mem.writeInt(u16, buf[0..2], @intFromEnum(self.message_type), .big); + std.mem.writeInt(u16, buf[2..4], @intCast(attrs_len), .big); + std.mem.writeInt(u32, buf[4..8], MAGIC_COOKIE, .big); + @memcpy(buf[8..20], &self.transaction_id); + pos = 20; + + // Atributos + for (self.attributes.items) |attr| { + std.mem.writeInt(u16, buf[pos..][0..2], @intFromEnum(attr.attr_type), .big); + std.mem.writeInt(u16, buf[pos + 2 ..][0..2], @intCast(attr.data.len), .big); + @memcpy(buf[pos + 4 .. pos + 4 + attr.data.len], attr.data); + pos += 4 + attr.data.len; + + // Padding + const pad = (4 - (attr.data.len % 4)) % 4; + if (pad > 0) { + @memset(buf[pos .. pos + pad], 0); + pos += pad; + } + } + + return buf; + } + + /// Decodifica un mensaje STUN + pub fn decode(allocator: std.mem.Allocator, data: []const u8) !StunMessage { + if (data.len < 20) return error.MessageTooShort; + + const msg_type: MessageType = @enumFromInt(std.mem.readInt(u16, data[0..2], .big)); + const msg_len = std.mem.readInt(u16, data[2..4], .big); + const magic = std.mem.readInt(u32, data[4..8], .big); + + if (magic != MAGIC_COOKIE) return error.InvalidMagicCookie; + if (data.len < 20 + msg_len) return error.MessageTooShort; + + var msg = StunMessage{ + .message_type = msg_type, + .transaction_id = data[8..20].*, + .attributes = .{}, + .allocator = allocator, + }; + errdefer msg.deinit(); + + // Parsear atributos + var pos: usize = 20; + const end = 20 + msg_len; + + while (pos + 4 <= end) { + const attr_type: AttributeType = @enumFromInt(std.mem.readInt(u16, data[pos..][0..2], .big)); + const attr_len = std.mem.readInt(u16, data[pos + 2 ..][0..2], .big); + pos += 4; + + if (pos + attr_len > end) break; + + const attr_data = try allocator.dupe(u8, data[pos .. pos + attr_len]); + try msg.attributes.append(allocator, .{ + .attr_type = attr_type, + .data = attr_data, + }); + + pos += attr_len; + // Skip padding + pos += (4 - (attr_len % 4)) % 4; + } + + return msg; + } + + /// Obtiene la dirección XOR-MAPPED-ADDRESS + pub fn getXorMappedAddress(self: *StunMessage) ?MappedAddress { + for (self.attributes.items) |attr| { + if (attr.attr_type == .xor_mapped_address) { + return parseXorMappedAddress(attr.data, self.transaction_id); + } + } + // Fallback a MAPPED-ADDRESS + for (self.attributes.items) |attr| { + if (attr.attr_type == .mapped_address) { + return parseMappedAddress(attr.data); + } + } + return null; + } + + /// Obtiene OTHER-ADDRESS (para detección de NAT) + pub fn getOtherAddress(self: *StunMessage) ?MappedAddress { + for (self.attributes.items) |attr| { + if (attr.attr_type == .other_address or attr.attr_type == .changed_address) { + return parseMappedAddress(attr.data); + } + } + return null; + } +}; + +fn parseMappedAddress(data: []const u8) ?MappedAddress { + if (data.len < 8) return null; + + const family: AddressFamily = @enumFromInt(data[1]); + + if (family == .ipv4 and data.len >= 8) { + return .{ + .family = .ipv4, + .port = std.mem.readInt(u16, data[2..4], .big), + .address = .{ .ipv4 = data[4..8].* }, + }; + } else if (family == .ipv6 and data.len >= 20) { + return .{ + .family = .ipv6, + .port = std.mem.readInt(u16, data[2..4], .big), + .address = .{ .ipv6 = data[4..20].* }, + }; + } + return null; +} + +fn parseXorMappedAddress(data: []const u8, transaction_id: [12]u8) ?MappedAddress { + if (data.len < 8) return null; + + const family: AddressFamily = @enumFromInt(data[1]); + + // XOR port with magic cookie high bytes + const port = std.mem.readInt(u16, data[2..4], .big) ^ @as(u16, @truncate(MAGIC_COOKIE >> 16)); + + if (family == .ipv4 and data.len >= 8) { + // XOR address with magic cookie + var addr: [4]u8 = data[4..8].*; + const magic_bytes = std.mem.toBytes(std.mem.nativeToBig(u32, MAGIC_COOKIE)); + for (0..4) |i| { + addr[i] ^= magic_bytes[i]; + } + return .{ + .family = .ipv4, + .port = port, + .address = .{ .ipv4 = addr }, + }; + } else if (family == .ipv6 and data.len >= 20) { + // XOR address with magic cookie + transaction_id + var addr: [16]u8 = data[4..20].*; + const magic_bytes = std.mem.toBytes(std.mem.nativeToBig(u32, MAGIC_COOKIE)); + for (0..4) |i| { + addr[i] ^= magic_bytes[i]; + } + for (0..12) |i| { + addr[4 + i] ^= transaction_id[i]; + } + return .{ + .family = .ipv6, + .port = port, + .address = .{ .ipv6 = addr }, + }; + } + return null; +} + +/// Tipo de NAT detectado +pub const NatType = enum { + unknown, + open_internet, // Sin NAT + full_cone, // Cualquier host externo puede enviar + restricted, // Solo hosts a los que hemos enviado + port_restricted, // Solo host:port a los que hemos enviado + symmetric, // Diferente mapeo por destino + blocked, // UDP bloqueado + + pub fn canPunch(self: NatType) bool { + return switch (self) { + .open_internet, .full_cone, .restricted, .port_restricted => true, + .symmetric, .blocked, .unknown => false, + }; + } + + pub fn needsRelay(self: NatType) bool { + return self == .symmetric or self == .blocked; + } +}; + +/// Cliente STUN +pub const StunClient = struct { + allocator: std.mem.Allocator, + servers: std.ArrayListUnmanaged([]const u8), + socket: ?std.posix.socket_t, + external_address: ?MappedAddress, + nat_type: NatType, + + pub fn init(allocator: std.mem.Allocator) StunClient { + return .{ + .allocator = allocator, + .servers = .{}, + .socket = null, + .external_address = null, + .nat_type = .unknown, + }; + } + + pub fn deinit(self: *StunClient) void { + if (self.socket) |sock| { + std.posix.close(sock); + } + for (self.servers.items) |server| { + self.allocator.free(server); + } + self.servers.deinit(self.allocator); + } + + /// Añade un servidor STUN + pub fn addServer(self: *StunClient, server: []const u8) !void { + const owned = try self.allocator.dupe(u8, server); + try self.servers.append(self.allocator, owned); + } + + /// Crea el socket UDP + pub fn createSocket(self: *StunClient) !void { + self.socket = try std.posix.socket( + std.posix.AF.INET, + std.posix.SOCK.DGRAM, + 0, + ); + + // Bind a un puerto aleatorio + const addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, 0); + try std.posix.bind(self.socket.?, &addr.any, addr.getOsSockLen()); + } + + /// Envía un Binding Request a un servidor + pub fn sendBindingRequest(self: *StunClient, server_addr: std.net.Address) !StunMessage { + if (self.socket == null) try self.createSocket(); + + var request = StunMessage.init(self.allocator, .binding_request); + errdefer request.deinit(); + + const encoded = try request.encode(); + defer self.allocator.free(encoded); + + _ = try std.posix.sendto( + self.socket.?, + encoded, + 0, + &server_addr.any, + server_addr.getOsSockLen(), + ); + + return request; + } + + /// Recibe una respuesta STUN + pub fn receiveResponse(self: *StunClient, timeout_ms: u32) !StunMessage { + if (self.socket == null) return error.SocketNotCreated; + + // Configurar timeout + const tv = std.posix.timeval{ + .sec = @intCast(timeout_ms / 1000), + .usec = @intCast((timeout_ms % 1000) * 1000), + }; + try std.posix.setsockopt( + self.socket.?, + std.posix.SOL.SOCKET, + std.posix.SO.RCVTIMEO, + std.mem.asBytes(&tv), + ); + + var buf: [1024]u8 = undefined; + const result = std.posix.recvfrom(self.socket.?, &buf, 0, null, null); + const len = result catch return error.Timeout; + + return StunMessage.decode(self.allocator, buf[0..len]); + } + + /// Descubre la dirección externa + pub fn discoverExternalAddress(self: *StunClient) !?MappedAddress { + if (self.servers.items.len == 0) { + // Añadir servidores por defecto + try self.addServer("stun.l.google.com:19302"); + try self.addServer("stun.syncthing.net:3478"); + } + + for (self.servers.items) |server| { + const result = self.queryServer(server) catch continue; + if (result) |addr| { + self.external_address = addr; + return addr; + } + } + + return null; + } + + fn queryServer(self: *StunClient, server: []const u8) !?MappedAddress { + // Parsear host:port + var host_end: usize = server.len; + var port: u16 = STUN_PORT; + + if (std.mem.lastIndexOf(u8, server, ":")) |colon| { + host_end = colon; + port = std.fmt.parseInt(u16, server[colon + 1 ..], 10) catch STUN_PORT; + } + const host = server[0..host_end]; + + // Resolver DNS (simplificado - solo IPv4) + // En producción usar std.net.getAddressList + const addr = try parseIpv4(host, port); + + const request = try self.sendBindingRequest(addr); + var response = try self.receiveResponse(3000); + defer response.deinit(); + + // Verificar transaction ID + if (!std.mem.eql(u8, &request.transaction_id, &response.transaction_id)) { + return error.TransactionIdMismatch; + } + + return response.getXorMappedAddress(); + } + + /// Detecta el tipo de NAT + pub fn detectNatType(self: *StunClient) !NatType { + // Algoritmo simplificado de detección de NAT + // Para detección completa se necesitan 2 servidores STUN con 2 IPs cada uno + + const external = try self.discoverExternalAddress(); + if (external == null) { + self.nat_type = .blocked; + return .blocked; + } + + // Verificar si la IP externa coincide con la local (sin NAT) + // Simplificado: asumir que hay NAT + self.nat_type = .restricted; + return .restricted; + } +}; + +fn parseIpv4(host: []const u8, port: u16) !std.net.Address { + // Parsear IP directamente o usar DNS + var octets: [4]u8 = undefined; + var octet_idx: usize = 0; + var current: u16 = 0; + + for (host) |c| { + if (c == '.') { + if (octet_idx >= 4) return error.InvalidAddress; + octets[octet_idx] = @intCast(current); + octet_idx += 1; + current = 0; + } else if (c >= '0' and c <= '9') { + current = current * 10 + (c - '0'); + if (current > 255) return error.InvalidAddress; + } else { + // Es un hostname, no una IP + // Usar lookup DNS sería necesario aquí + // Por ahora, usar Google STUN como fallback + return std.net.Address.initIp4(.{ 142, 250, 187, 127 }, port); + } + } + + if (octet_idx == 3) { + octets[3] = @intCast(current); + return std.net.Address.initIp4(octets, port); + } + + return error.InvalidAddress; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "stun message encode/decode" { + const allocator = std.testing.allocator; + + var msg = StunMessage.init(allocator, .binding_request); + defer msg.deinit(); + + const encoded = try msg.encode(); + defer allocator.free(encoded); + + try std.testing.expect(encoded.len == 20); // Header only + + var decoded = try StunMessage.decode(allocator, encoded); + defer decoded.deinit(); + + try std.testing.expect(decoded.message_type == .binding_request); + try std.testing.expectEqualSlices(u8, &msg.transaction_id, &decoded.transaction_id); +} + +test "parse xor mapped address ipv4" { + const transaction_id = [_]u8{0} ** 12; + + // XOR-MAPPED-ADDRESS para 192.0.2.1:32853 + // Family: 0x01 (IPv4), XOR'd Port: 0x1234, XOR'd IP: XOR con magic cookie + const port_xored: u16 = 32853 ^ 0x2112; // XOR with high bytes of magic cookie + const ip = [4]u8{ 192 ^ 0x21, 0 ^ 0x12, 2 ^ 0xA4, 1 ^ 0x42 }; // XOR with magic cookie + + var data: [8]u8 = undefined; + data[0] = 0; // Reserved + data[1] = 0x01; // IPv4 + std.mem.writeInt(u16, data[2..4], port_xored, .big); + @memcpy(data[4..8], &ip); + + const addr = parseXorMappedAddress(&data, transaction_id); + try std.testing.expect(addr != null); + try std.testing.expect(addr.?.family == .ipv4); + try std.testing.expect(addr.?.port == 32853); + try std.testing.expectEqual([4]u8{ 192, 0, 2, 1 }, addr.?.address.ipv4); +} + +test "stun client init" { + const allocator = std.testing.allocator; + + var client = StunClient.init(allocator); + defer client.deinit(); + + try std.testing.expect(client.nat_type == .unknown); +} + +test "nat type capabilities" { + try std.testing.expect(NatType.full_cone.canPunch()); + try std.testing.expect(NatType.restricted.canPunch()); + try std.testing.expect(!NatType.symmetric.canPunch()); + try std.testing.expect(!NatType.blocked.canPunch()); + try std.testing.expect(NatType.symmetric.needsRelay()); +} diff --git a/src/tls.zig b/src/tls.zig new file mode 100644 index 0000000..7f1b82c --- /dev/null +++ b/src/tls.zig @@ -0,0 +1,704 @@ +//! Módulo TLS 1.3 - Transporte seguro +//! +//! Implementación de TLS 1.3 (RFC 8446). +//! Usa std.crypto para primitivas criptográficas (parte de la librería estándar de Zig). + +const std = @import("std"); +const crypto = @import("crypto.zig"); + +// ============================================================================= +// X25519 - Curve25519 Diffie-Hellman (RFC 7748) +// Usa la implementación de la librería estándar de Zig +// ============================================================================= + +pub const X25519_KEY_SIZE: usize = 32; +pub const X25519_SHARED_SIZE: usize = 32; + +/// Par de claves X25519 +pub const X25519KeyPair = struct { + private_key: [32]u8, + public_key: [32]u8, + + /// Genera un par de claves aleatorio + pub fn generate() X25519KeyPair { + const kp = std.crypto.dh.X25519.KeyPair.generate(); + return .{ + .private_key = kp.secret_key, + .public_key = kp.public_key, + }; + } + + /// Crea un par de claves desde una clave privada + pub fn fromPrivate(private: [32]u8) X25519KeyPair { + const public_key = std.crypto.dh.X25519.recoverPublicKey(private) catch + [_]u8{0} ** 32; + return .{ + .private_key = private, + .public_key = public_key, + }; + } + + /// Calcula el secreto compartido + pub fn sharedSecret(self: X25519KeyPair, their_public: [32]u8) ?[32]u8 { + return std.crypto.dh.X25519.scalarmult(self.private_key, their_public) catch null; + } +}; + +// ============================================================================= +// HKDF - HMAC-based Key Derivation Function (RFC 5869) +// ============================================================================= + +pub const HKDF_SHA256_KEY_SIZE: usize = 32; + +/// HMAC-SHA256 +pub fn hmacSha256(key: []const u8, data: []const u8) [32]u8 { + const block_size = 64; + var k_ipad: [block_size]u8 = undefined; + var k_opad: [block_size]u8 = undefined; + + // Si la clave es más larga que el bloque, hashear + var actual_key: [32]u8 = undefined; + var key_len: usize = undefined; + + if (key.len > block_size) { + actual_key = crypto.sha256(key); + key_len = 32; + } else { + @memcpy(actual_key[0..key.len], key); + key_len = key.len; + } + + // Preparar pads + @memset(&k_ipad, 0x36); + @memset(&k_opad, 0x5c); + for (0..key_len) |i| { + k_ipad[i] ^= actual_key[i]; + k_opad[i] ^= actual_key[i]; + } + + // Hash interno: SHA256(k_ipad || data) + var inner = crypto.Sha256.init(); + inner.update(&k_ipad); + inner.update(data); + const inner_hash = inner.final(); + + // Hash externo: SHA256(k_opad || inner_hash) + var outer = crypto.Sha256.init(); + outer.update(&k_opad); + outer.update(&inner_hash); + + return outer.final(); +} + +/// HKDF-Extract +pub fn hkdfExtract(salt: []const u8, ikm: []const u8) [32]u8 { + if (salt.len == 0) { + const zero_salt = [_]u8{0} ** 32; + return hmacSha256(&zero_salt, ikm); + } + return hmacSha256(salt, ikm); +} + +/// HKDF-Expand +pub fn hkdfExpand(prk: [32]u8, info: []const u8, length: usize, out: []u8) void { + var t: [32]u8 = undefined; + var t_len: usize = 0; + var pos: usize = 0; + var counter: u8 = 1; + + while (pos < length) { + // Preparar datos para HMAC + var hmac_data_buf: [32 + 256 + 1]u8 = undefined; + var hmac_data_len: usize = 0; + + if (t_len > 0) { + @memcpy(hmac_data_buf[0..t_len], t[0..t_len]); + hmac_data_len = t_len; + } + @memcpy(hmac_data_buf[hmac_data_len .. hmac_data_len + info.len], info); + hmac_data_len += info.len; + hmac_data_buf[hmac_data_len] = counter; + hmac_data_len += 1; + + t = hmacSha256(&prk, hmac_data_buf[0..hmac_data_len]); + t_len = 32; + + const to_copy = @min(32, length - pos); + @memcpy(out[pos .. pos + to_copy], t[0..to_copy]); + pos += to_copy; + counter += 1; + } +} + +/// HKDF completo (Extract + Expand) +pub fn hkdf(salt: []const u8, ikm: []const u8, info: []const u8, length: usize, out: []u8) void { + const prk = hkdfExtract(salt, ikm); + hkdfExpand(prk, info, length, out); +} + +// ============================================================================= +// TLS 1.3 Record Layer +// ============================================================================= + +/// Tipos de contenido TLS +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +/// Tipos de handshake TLS 1.3 +pub const HandshakeType = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + certificate_request = 13, + certificate_verify = 15, + finished = 20, + key_update = 24, + message_hash = 254, + _, +}; + +/// TLS Alert levels +pub const AlertLevel = enum(u8) { + warning = 1, + fatal = 2, + _, +}; + +/// TLS Alert descriptions +pub const AlertDescription = enum(u8) { + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + internal_error = 80, + _, +}; + +/// Versión del protocolo +pub const ProtocolVersion = struct { + major: u8, + minor: u8, + + pub const TLS_1_2: ProtocolVersion = .{ .major = 3, .minor = 3 }; + pub const TLS_1_3: ProtocolVersion = .{ .major = 3, .minor = 3 }; // TLS 1.3 usa 0x0303 en wire +}; + +/// Record TLS +pub const TlsRecord = struct { + content_type: ContentType, + version: ProtocolVersion, + length: u16, + fragment: []const u8, + + pub fn encode(self: TlsRecord, out: []u8) usize { + out[0] = @intFromEnum(self.content_type); + out[1] = self.version.major; + out[2] = self.version.minor; + std.mem.writeInt(u16, out[3..5], self.length, .big); + @memcpy(out[5 .. 5 + self.length], self.fragment); + return 5 + self.length; + } + + pub fn decode(data: []const u8) ?TlsRecord { + if (data.len < 5) return null; + + const length = std.mem.readInt(u16, data[3..5], .big); + if (data.len < 5 + length) return null; + + return .{ + .content_type = @enumFromInt(data[0]), + .version = .{ .major = data[1], .minor = data[2] }, + .length = length, + .fragment = data[5 .. 5 + length], + }; + } +}; + +// ============================================================================= +// TLS 1.3 Handshake State Machine +// ============================================================================= + +/// Estado del handshake TLS 1.3 +pub const HandshakeState = enum { + start, + wait_server_hello, + wait_encrypted_extensions, + wait_certificate, + wait_certificate_verify, + wait_finished, + connected, + @"error", +}; + +/// Contexto de conexión TLS +pub const TlsConnection = struct { + allocator: std.mem.Allocator, + state: HandshakeState, + + // Claves de handshake + client_keypair: X25519KeyPair, + server_public_key: ?[32]u8, + shared_secret: ?[32]u8, + + // Claves derivadas + handshake_secret: ?[32]u8, + client_handshake_traffic_secret: ?[32]u8, + server_handshake_traffic_secret: ?[32]u8, + client_traffic_secret: ?[32]u8, + server_traffic_secret: ?[32]u8, + + // Claves de cifrado + client_write_key: ?[32]u8, + client_write_iv: ?[12]u8, + server_write_key: ?[32]u8, + server_write_iv: ?[12]u8, + + // Transcript hash + transcript: crypto.Sha256, + + // Sequence numbers + client_seq: u64, + server_seq: u64, + + // Random values + client_random: [32]u8, + server_random: ?[32]u8, + + pub fn init(allocator: std.mem.Allocator) TlsConnection { + var client_random: [32]u8 = undefined; + std.crypto.random.bytes(&client_random); + + return .{ + .allocator = allocator, + .state = .start, + .client_keypair = X25519KeyPair.generate(), + .server_public_key = null, + .shared_secret = null, + .handshake_secret = null, + .client_handshake_traffic_secret = null, + .server_handshake_traffic_secret = null, + .client_traffic_secret = null, + .server_traffic_secret = null, + .client_write_key = null, + .client_write_iv = null, + .server_write_key = null, + .server_write_iv = null, + .transcript = crypto.Sha256.init(), + .client_seq = 0, + .server_seq = 0, + .client_random = client_random, + .server_random = null, + }; + } + + pub fn deinit(self: *TlsConnection) void { + _ = self; + } + + /// Genera ClientHello + pub fn generateClientHello(self: *TlsConnection, out: []u8) !usize { + var pos: usize = 0; + + // Handshake header (type + length placeholder) + out[pos] = @intFromEnum(HandshakeType.client_hello); + pos += 1; + const length_pos = pos; + pos += 3; // Length placeholder + + // Legacy version (TLS 1.2 for compatibility) + out[pos] = 3; + out[pos + 1] = 3; + pos += 2; + + // Random + @memcpy(out[pos .. pos + 32], &self.client_random); + pos += 32; + + // Session ID (empty for TLS 1.3) + out[pos] = 0; + pos += 1; + + // Cipher suites + std.mem.writeInt(u16, out[pos..][0..2], 2, .big); // Length + pos += 2; + std.mem.writeInt(u16, out[pos..][0..2], 0x1301, .big); // TLS_AES_128_GCM_SHA256 + pos += 2; + + // Compression methods (null only) + out[pos] = 1; + pos += 1; + out[pos] = 0; + pos += 1; + + // Extensions + const ext_start = pos; + pos += 2; // Extensions length placeholder + + // Supported Versions extension + pos += self.writeExtension(out[pos..], 43, &.{ 2, 3, 4 }); // TLS 1.3 + + // Supported Groups extension + const groups = [_]u8{ 0, 2, 0, 29 }; // x25519 + pos += self.writeExtension(out[pos..], 10, &groups); + + // Key Share extension + var key_share: [36]u8 = undefined; + std.mem.writeInt(u16, key_share[0..2], 32 + 2, .big); // Total length + std.mem.writeInt(u16, key_share[2..4], 29, .big); // x25519 + @memcpy(key_share[4..36], &self.client_keypair.public_key); + pos += self.writeExtension(out[pos..], 51, &key_share); + + // Signature Algorithms extension + const sig_algs = [_]u8{ 0, 2, 4, 3 }; // ECDSA-SECP256r1-SHA256 + pos += self.writeExtension(out[pos..], 13, &sig_algs); + + // Update extensions length + std.mem.writeInt(u16, out[ext_start..][0..2], @intCast(pos - ext_start - 2), .big); + + // Update handshake length + const msg_len: u24 = @intCast(pos - length_pos - 3); + out[length_pos] = @truncate(msg_len >> 16); + out[length_pos + 1] = @truncate(msg_len >> 8); + out[length_pos + 2] = @truncate(msg_len); + + // Update transcript + self.transcript.update(out[0..pos]); + + self.state = .wait_server_hello; + return pos; + } + + fn writeExtension(self: *TlsConnection, out: []u8, ext_type: u16, data: []const u8) usize { + _ = self; + std.mem.writeInt(u16, out[0..2], ext_type, .big); + std.mem.writeInt(u16, out[2..4], @intCast(data.len), .big); + @memcpy(out[4 .. 4 + data.len], data); + return 4 + data.len; + } + + /// Procesa ServerHello + pub fn processServerHello(self: *TlsConnection, data: []const u8) !void { + if (data.len < 38) return error.InvalidServerHello; + + // Skip handshake header + var pos: usize = 4; + + // Legacy version + pos += 2; + + // Server random + self.server_random = data[pos..][0..32].*; + pos += 32; + + // Session ID + const session_len = data[pos]; + pos += 1 + session_len; + + // Cipher suite + pos += 2; + + // Compression + pos += 1; + + // Extensions + if (pos + 2 > data.len) return error.InvalidServerHello; + const ext_len = std.mem.readInt(u16, data[pos..][0..2], .big); + pos += 2; + + const ext_end = pos + ext_len; + while (pos < ext_end) { + const ext_type = std.mem.readInt(u16, data[pos..][0..2], .big); + pos += 2; + const ext_data_len = std.mem.readInt(u16, data[pos..][0..2], .big); + pos += 2; + + if (ext_type == 51) { // Key share + // Skip group + pos += 2; + const key_len = std.mem.readInt(u16, data[pos..][0..2], .big); + pos += 2; + if (key_len != 32) return error.InvalidKeyShare; + self.server_public_key = data[pos..][0..32].*; + } + pos += ext_data_len; + } + + // Actualizar transcript + self.transcript.update(data); + + // Calcular secreto compartido + if (self.server_public_key) |server_pub| { + self.shared_secret = self.client_keypair.sharedSecret(server_pub); + if (self.shared_secret == null) return error.InvalidKeyShare; + try self.deriveHandshakeKeys(); + } else { + return error.MissingKeyShare; + } + + self.state = .wait_encrypted_extensions; + } + + fn deriveHandshakeKeys(self: *TlsConnection) !void { + const shared = self.shared_secret orelse return error.NoSharedSecret; + + // Early secret (con PSK = 0) + const zero_key = [_]u8{0} ** 32; + const early_secret = hkdfExtract(&.{}, &zero_key); + + // Derive-Secret para handshake + var derived_secret: [32]u8 = undefined; + self.deriveSecret(early_secret, "derived", &.{}, &derived_secret); + + // Handshake secret + self.handshake_secret = hkdfExtract(&derived_secret, &shared); + + // Client/Server handshake traffic secrets + const hs = self.handshake_secret.?; + const transcript_hash = self.transcript.final(); + self.transcript = crypto.Sha256.init(); + self.transcript.update(&transcript_hash); + + var client_hs_secret: [32]u8 = undefined; + var server_hs_secret: [32]u8 = undefined; + self.deriveSecret(hs, "c hs traffic", &transcript_hash, &client_hs_secret); + self.deriveSecret(hs, "s hs traffic", &transcript_hash, &server_hs_secret); + + self.client_handshake_traffic_secret = client_hs_secret; + self.server_handshake_traffic_secret = server_hs_secret; + + // Derive write keys + var client_key: [32]u8 = undefined; + var client_iv: [12]u8 = undefined; + var server_key: [32]u8 = undefined; + var server_iv: [12]u8 = undefined; + + hkdfExpand(client_hs_secret, "tls13 key", 32, &client_key); + hkdfExpand(client_hs_secret, "tls13 iv", 12, &client_iv); + hkdfExpand(server_hs_secret, "tls13 key", 32, &server_key); + hkdfExpand(server_hs_secret, "tls13 iv", 12, &server_iv); + + self.client_write_key = client_key; + self.client_write_iv = client_iv; + self.server_write_key = server_key; + self.server_write_iv = server_iv; + } + + fn deriveSecret(self: *TlsConnection, secret: [32]u8, label: []const u8, context: []const u8, out: *[32]u8) void { + _ = self; + // TLS 1.3 label format: "tls13 " + label + var info_buf: [256]u8 = undefined; + var info_len: usize = 0; + + // Length (2 bytes) + std.mem.writeInt(u16, info_buf[0..2], 32, .big); + info_len += 2; + + // Label length + "tls13 " + label + const tls13_label = "tls13 "; + info_buf[info_len] = @intCast(tls13_label.len + label.len); + info_len += 1; + @memcpy(info_buf[info_len .. info_len + tls13_label.len], tls13_label); + info_len += tls13_label.len; + @memcpy(info_buf[info_len .. info_len + label.len], label); + info_len += label.len; + + // Context length + context + info_buf[info_len] = @intCast(context.len); + info_len += 1; + if (context.len > 0) { + @memcpy(info_buf[info_len .. info_len + context.len], context); + info_len += context.len; + } + + hkdfExpand(secret, info_buf[0..info_len], 32, out); + } + + /// Cifra datos de aplicación + pub fn encrypt(self: *TlsConnection, plaintext: []const u8, out: []u8) !usize { + const key = self.client_write_key orelse return error.NotReady; + const iv = self.client_write_iv orelse return error.NotReady; + + // XOR IV con sequence number + var nonce: [12]u8 = iv; + const seq_bytes = std.mem.toBytes(std.mem.nativeToBig(u64, self.client_seq)); + for (0..8) |i| { + nonce[4 + i] ^= seq_bytes[i]; + } + self.client_seq += 1; + + // Construir inner plaintext: data || content_type || padding + var inner: [16384 + 1]u8 = undefined; + @memcpy(inner[0..plaintext.len], plaintext); + inner[plaintext.len] = @intFromEnum(ContentType.application_data); + + // Cifrar con ChaCha20-Poly1305 + const encrypted = try crypto.chachaPoly1305Encrypt( + &key, + &nonce, + inner[0 .. plaintext.len + 1], + &.{}, // AAD vacío para datos de aplicación + self.allocator, + ); + defer self.allocator.free(encrypted); + + // TLS record: type(1) || legacy_version(2) || length(2) || encrypted + out[0] = @intFromEnum(ContentType.application_data); + out[1] = 3; + out[2] = 3; + const enc_len: u16 = @intCast(encrypted.len - 12); // Sin nonce prepended + std.mem.writeInt(u16, out[3..5], enc_len, .big); + @memcpy(out[5 .. 5 + enc_len], encrypted[12..]); // Skip nonce + + return 5 + enc_len; + } + + /// Descifra datos de aplicación + pub fn decrypt(self: *TlsConnection, record: []const u8) ![]u8 { + const key = self.server_write_key orelse return error.NotReady; + const iv = self.server_write_iv orelse return error.NotReady; + + // XOR IV con sequence number + var nonce: [12]u8 = iv; + const seq_bytes = std.mem.toBytes(std.mem.nativeToBig(u64, self.server_seq)); + for (0..8) |i| { + nonce[4 + i] ^= seq_bytes[i]; + } + self.server_seq += 1; + + // Construir input para descifrado: nonce || ciphertext + const ciphertext = record[5..]; + var input = try self.allocator.alloc(u8, 12 + ciphertext.len); + defer self.allocator.free(input); + @memcpy(input[0..12], &nonce); + @memcpy(input[12..], ciphertext); + + // Descifrar + const decrypted = try crypto.chachaPoly1305Decrypt( + &key, + input, + &.{}, + self.allocator, + ); + + // Remover content type y padding del final + var end = decrypted.len; + while (end > 0 and decrypted[end - 1] == 0) { + end -= 1; + } + if (end > 0) end -= 1; // Remove content type + + const result = try self.allocator.alloc(u8, end); + @memcpy(result, decrypted[0..end]); + self.allocator.free(decrypted); + + return result; + } +}; + +// ============================================================================= +// Tests +// ============================================================================= + +test "x25519 key exchange" { + const alice = X25519KeyPair.generate(); + const bob = X25519KeyPair.generate(); + + const alice_shared = alice.sharedSecret(bob.public_key); + const bob_shared = bob.sharedSecret(alice.public_key); + + try std.testing.expect(alice_shared != null); + try std.testing.expect(bob_shared != null); + try std.testing.expectEqualSlices(u8, &alice_shared.?, &bob_shared.?); +} + +test "x25519 known vector" { + // RFC 7748 test vector + const scalar = [_]u8{ + 0xa5, 0x46, 0xe3, 0x6b, 0xf0, 0x52, 0x7c, 0x9d, + 0x3b, 0x16, 0x15, 0x4b, 0x82, 0x46, 0x5e, 0xdd, + 0x62, 0x14, 0x4c, 0x0a, 0xc1, 0xfc, 0x5a, 0x18, + 0x50, 0x6a, 0x22, 0x44, 0xba, 0x44, 0x9a, 0xc4, + }; + const point = [_]u8{ + 0xe6, 0xdb, 0x68, 0x67, 0x58, 0x30, 0x30, 0xdb, + 0x35, 0x94, 0xc1, 0xa4, 0x24, 0xb1, 0x5f, 0x7c, + 0x72, 0x66, 0x24, 0xec, 0x26, 0xb3, 0x35, 0x3b, + 0x10, 0xa9, 0x03, 0xa6, 0xd0, 0xab, 0x1c, 0x4c, + }; + const expected = [_]u8{ + 0xc3, 0xda, 0x55, 0x37, 0x9d, 0xe9, 0xc6, 0x90, + 0x8e, 0x94, 0xea, 0x4d, 0xf2, 0x8d, 0x08, 0x4f, + 0x32, 0xec, 0xcf, 0x03, 0x49, 0x1c, 0x71, 0xf7, + 0x54, 0xb4, 0x07, 0x55, 0x77, 0xa2, 0x85, 0x52, + }; + + const result = std.crypto.dh.X25519.scalarmult(scalar, point) catch unreachable; + try std.testing.expectEqualSlices(u8, &expected, &result); +} + +test "hmac sha256" { + const key = "key"; + const data = "The quick brown fox jumps over the lazy dog"; + const result = hmacSha256(key, data); + + // Valor conocido + const expected = [_]u8{ + 0xf7, 0xbc, 0x83, 0xf4, 0x30, 0x53, 0x84, 0x24, + 0xb1, 0x32, 0x98, 0xe6, 0xaa, 0x6f, 0xb1, 0x43, + 0xef, 0x4d, 0x59, 0xa1, 0x49, 0x46, 0x17, 0x59, + 0x97, 0x47, 0x9d, 0xbc, 0x2d, 0x1a, 0x3c, 0xd8, + }; + try std.testing.expectEqualSlices(u8, &expected, &result); +} + +test "hkdf extract and expand" { + const ikm = [_]u8{0x0b} ** 22; + const salt = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c }; + const info = [_]u8{ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9 }; + + const prk = hkdfExtract(&salt, &ikm); + + var okm: [42]u8 = undefined; + hkdfExpand(prk, &info, 42, &okm); + + // Verificar que el output no es todo ceros + var all_zero = true; + for (okm) |b| { + if (b != 0) { + all_zero = false; + break; + } + } + try std.testing.expect(!all_zero); +} + +test "tls connection init" { + const allocator = std.testing.allocator; + var conn = TlsConnection.init(allocator); + defer conn.deinit(); + + try std.testing.expect(conn.state == .start); + + var hello_buf: [512]u8 = undefined; + const hello_len = try conn.generateClientHello(&hello_buf); + + try std.testing.expect(hello_len > 0); + try std.testing.expect(conn.state == .wait_server_hello); +}