//! User-defined functions and hooks for SQLite //! //! Provides support for scalar functions, aggregate functions, window functions, //! collations, and database hooks (commit, rollback, update, pre-update). const std = @import("std"); const c = @import("c.zig").c; const types = @import("types.zig"); const ColumnType = types.ColumnType; const UpdateOperation = types.UpdateOperation; const AuthAction = types.AuthAction; const AuthResult = types.AuthResult; // ============================================================================ // Function Context and Values // ============================================================================ /// Context for user-defined function results. pub const FunctionContext = struct { ctx: *c.sqlite3_context, const Self = @This(); pub fn setNull(self: Self) void { c.sqlite3_result_null(self.ctx); } pub fn setInt(self: Self, value: i64) void { c.sqlite3_result_int64(self.ctx, value); } pub fn setFloat(self: Self, value: f64) void { c.sqlite3_result_double(self.ctx, value); } pub fn setText(self: Self, value: []const u8) void { c.sqlite3_result_text(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT); } pub fn setBlob(self: Self, value: []const u8) void { c.sqlite3_result_blob(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT); } pub fn setError(self: Self, msg: []const u8) void { c.sqlite3_result_error(self.ctx, msg.ptr, @intCast(msg.len)); } }; /// Value passed to user-defined functions. pub const FunctionValue = struct { value: *c.sqlite3_value, const Self = @This(); pub fn getType(self: Self) ColumnType { const vtype = c.sqlite3_value_type(self.value); return switch (vtype) { c.SQLITE_INTEGER => .integer, c.SQLITE_FLOAT => .float, c.SQLITE_TEXT => .text, c.SQLITE_BLOB => .blob, c.SQLITE_NULL => .null_value, else => .null_value, }; } pub fn isNull(self: Self) bool { return self.getType() == .null_value; } pub fn asInt(self: Self) i64 { return c.sqlite3_value_int64(self.value); } pub fn asFloat(self: Self) f64 { return c.sqlite3_value_double(self.value); } pub fn asText(self: Self) ?[]const u8 { const len = c.sqlite3_value_bytes(self.value); const text = c.sqlite3_value_text(self.value); if (text) |t| { return t[0..@intCast(len)]; } return null; } pub fn asBlob(self: Self) ?[]const u8 { const len = c.sqlite3_value_bytes(self.value); const blob = c.sqlite3_value_blob(self.value); if (blob) |b| { const ptr: [*]const u8 = @ptrCast(b); return ptr[0..@intCast(len)]; } return null; } }; /// Context for aggregate functions with state management. pub const AggregateContext = struct { ctx: *c.sqlite3_context, const Self = @This(); pub fn getAggregateContext(self: Self, comptime T: type) ?*T { const ptr = c.sqlite3_aggregate_context(self.ctx, @sizeOf(T)); if (ptr == null) return null; return @ptrCast(@alignCast(ptr)); } pub fn setNull(self: Self) void { c.sqlite3_result_null(self.ctx); } pub fn setInt(self: Self, value: i64) void { c.sqlite3_result_int64(self.ctx, value); } pub fn setFloat(self: Self, value: f64) void { c.sqlite3_result_double(self.ctx, value); } pub fn setText(self: Self, value: []const u8) void { c.sqlite3_result_text(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT); } pub fn setBlob(self: Self, value: []const u8) void { c.sqlite3_result_blob(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT); } pub fn setError(self: Self, msg: []const u8) void { c.sqlite3_result_error(self.ctx, msg.ptr, @intCast(msg.len)); } }; /// Context for pre-update hook with access to old/new values. pub const PreUpdateContext = struct { db: *c.sqlite3, const Self = @This(); pub fn columnCount(self: Self) i32 { return c.sqlite3_preupdate_count(self.db); } pub fn depth(self: Self) i32 { return c.sqlite3_preupdate_depth(self.db); } pub fn oldValue(self: Self, col: u32) ?FunctionValue { var value: ?*c.sqlite3_value = null; const result = c.sqlite3_preupdate_old(self.db, @intCast(col), &value); if (result != c.SQLITE_OK or value == null) return null; return FunctionValue{ .value = value.? }; } pub fn newValue(self: Self, col: u32) ?FunctionValue { var value: ?*c.sqlite3_value = null; const result = c.sqlite3_preupdate_new(self.db, @intCast(col), &value); if (result != c.SQLITE_OK or value == null) return null; return FunctionValue{ .value = value.? }; } }; // ============================================================================ // Function Types // ============================================================================ pub const ScalarFn = *const fn (ctx: FunctionContext, args: []const FunctionValue) void; pub const AggregateStepFn = *const fn (ctx: AggregateContext, args: []const FunctionValue) void; pub const AggregateFinalFn = *const fn (ctx: AggregateContext) void; pub const WindowValueFn = *const fn (ctx: AggregateContext) void; pub const WindowInverseFn = *const fn (ctx: AggregateContext, args: []const FunctionValue) void; pub const CollationFn = *const fn (a: []const u8, b: []const u8) i32; // Hook function types pub const ZigCommitHookFn = *const fn () bool; pub const ZigRollbackHookFn = *const fn () void; pub const ZigUpdateHookFn = *const fn (operation: UpdateOperation, db_name: []const u8, table_name: []const u8, rowid: i64) void; pub const ZigPreUpdateHookFn = *const fn (ctx: PreUpdateContext, operation: UpdateOperation, db_name: []const u8, table_name: []const u8, old_rowid: i64, new_rowid: i64) void; pub const ZigAuthorizerFn = *const fn (action: AuthAction, arg1: ?[]const u8, arg2: ?[]const u8, arg3: ?[]const u8, arg4: ?[]const u8) AuthResult; pub const ZigProgressFn = *const fn () bool; pub const ZigBusyHandlerFn = *const fn (count: i32) bool; // ============================================================================ // Wrappers (stored in SQLite user_data) // ============================================================================ pub const ScalarFnWrapper = struct { func: ScalarFn, pub fn create(func: ScalarFn) !*ScalarFnWrapper { const wrapper = try std.heap.page_allocator.create(ScalarFnWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *ScalarFnWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const AggregateFnWrapper = struct { step_fn: AggregateStepFn, final_fn: AggregateFinalFn, pub fn create(step_fn: AggregateStepFn, final_fn: AggregateFinalFn) !*AggregateFnWrapper { const wrapper = try std.heap.page_allocator.create(AggregateFnWrapper); wrapper.step_fn = step_fn; wrapper.final_fn = final_fn; return wrapper; } pub fn destroy(self: *AggregateFnWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const WindowFnWrapper = struct { step_fn: AggregateStepFn, final_fn: AggregateFinalFn, value_fn: WindowValueFn, inverse_fn: WindowInverseFn, pub fn create(step_fn: AggregateStepFn, final_fn: AggregateFinalFn, value_fn: WindowValueFn, inverse_fn: WindowInverseFn) !*WindowFnWrapper { const wrapper = try std.heap.page_allocator.create(WindowFnWrapper); wrapper.step_fn = step_fn; wrapper.final_fn = final_fn; wrapper.value_fn = value_fn; wrapper.inverse_fn = inverse_fn; return wrapper; } pub fn destroy(self: *WindowFnWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const CollationWrapper = struct { func: CollationFn, pub fn create(func: CollationFn) !*CollationWrapper { const wrapper = try std.heap.page_allocator.create(CollationWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *CollationWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const CommitHookWrapper = struct { func: ZigCommitHookFn, pub fn create(func: ZigCommitHookFn) !*CommitHookWrapper { const wrapper = try std.heap.page_allocator.create(CommitHookWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *CommitHookWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const RollbackHookWrapper = struct { func: ZigRollbackHookFn, pub fn create(func: ZigRollbackHookFn) !*RollbackHookWrapper { const wrapper = try std.heap.page_allocator.create(RollbackHookWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *RollbackHookWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const UpdateHookWrapper = struct { func: ZigUpdateHookFn, pub fn create(func: ZigUpdateHookFn) !*UpdateHookWrapper { const wrapper = try std.heap.page_allocator.create(UpdateHookWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *UpdateHookWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const PreUpdateHookWrapper = struct { func: ZigPreUpdateHookFn, pub fn create(func: ZigPreUpdateHookFn) !*PreUpdateHookWrapper { const wrapper = try std.heap.page_allocator.create(PreUpdateHookWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *PreUpdateHookWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const AuthorizerWrapper = struct { func: ZigAuthorizerFn, pub fn create(func: ZigAuthorizerFn) !*AuthorizerWrapper { const wrapper = try std.heap.page_allocator.create(AuthorizerWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *AuthorizerWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const ProgressWrapper = struct { func: ZigProgressFn, pub fn create(func: ZigProgressFn) !*ProgressWrapper { const wrapper = try std.heap.page_allocator.create(ProgressWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *ProgressWrapper) void { std.heap.page_allocator.destroy(self); } }; pub const BusyHandlerWrapper = struct { func: ZigBusyHandlerFn, pub fn create(func: ZigBusyHandlerFn) !*BusyHandlerWrapper { const wrapper = try std.heap.page_allocator.create(BusyHandlerWrapper); wrapper.func = func; return wrapper; } pub fn destroy(self: *BusyHandlerWrapper) void { std.heap.page_allocator.destroy(self); } }; // ============================================================================ // C Callback Trampolines // ============================================================================ pub fn scalarCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *ScalarFnWrapper = @ptrCast(@alignCast(user_data)); const func_ctx = FunctionContext{ .ctx = ctx.? }; const args_count: usize = @intCast(argc); var args: [16]FunctionValue = undefined; const actual_count = @min(args_count, 16); for (0..actual_count) |i| { if (argv[i]) |v| { args[i] = FunctionValue{ .value = v }; } } wrapper.func(func_ctx, args[0..actual_count]); } pub fn scalarDestructor(ptr: ?*anyopaque) callconv(.c) void { if (ptr) |p| { const wrapper: *ScalarFnWrapper = @ptrCast(@alignCast(p)); wrapper.destroy(); } } pub fn aggregateStepCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; const args_count: usize = @intCast(argc); var args: [16]FunctionValue = undefined; const actual_count = @min(args_count, 16); for (0..actual_count) |i| { if (argv[i]) |v| { args[i] = FunctionValue{ .value = v }; } } wrapper.step_fn(agg_ctx, args[0..actual_count]); } pub fn aggregateFinalCallback(ctx: ?*c.sqlite3_context) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; wrapper.final_fn(agg_ctx); } pub fn aggregateDestructor(ptr: ?*anyopaque) callconv(.c) void { if (ptr) |p| { const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(p)); wrapper.destroy(); } } pub fn windowStepCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; const args_count: usize = @intCast(argc); var args: [16]FunctionValue = undefined; const actual_count = @min(args_count, 16); for (0..actual_count) |i| { if (argv[i]) |v| { args[i] = FunctionValue{ .value = v }; } } wrapper.step_fn(agg_ctx, args[0..actual_count]); } pub fn windowFinalCallback(ctx: ?*c.sqlite3_context) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; wrapper.final_fn(agg_ctx); } pub fn windowValueCallback(ctx: ?*c.sqlite3_context) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; wrapper.value_fn(agg_ctx); } pub fn windowInverseCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void { const user_data = c.sqlite3_user_data(ctx); if (user_data == null) return; const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data)); const agg_ctx = AggregateContext{ .ctx = ctx.? }; const args_count: usize = @intCast(argc); var args: [16]FunctionValue = undefined; const actual_count = @min(args_count, 16); for (0..actual_count) |i| { if (argv[i]) |v| { args[i] = FunctionValue{ .value = v }; } } wrapper.inverse_fn(agg_ctx, args[0..actual_count]); } pub fn windowDestructor(ptr: ?*anyopaque) callconv(.c) void { if (ptr) |p| { const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(p)); wrapper.destroy(); } } pub fn collationCallback(user_data: ?*anyopaque, len_a: c_int, data_a: ?*const anyopaque, len_b: c_int, data_b: ?*const anyopaque) callconv(.c) c_int { if (user_data == null) return 0; const wrapper: *CollationWrapper = @ptrCast(@alignCast(user_data)); const a: []const u8 = if (data_a) |ptr| @as([*]const u8, @ptrCast(ptr))[0..@intCast(len_a)] else ""; const b: []const u8 = if (data_b) |ptr| @as([*]const u8, @ptrCast(ptr))[0..@intCast(len_b)] else ""; return wrapper.func(a, b); } pub fn collationDestructor(ptr: ?*anyopaque) callconv(.c) void { if (ptr) |p| { const wrapper: *CollationWrapper = @ptrCast(@alignCast(p)); wrapper.destroy(); } } pub fn commitHookCallback(user_data: ?*anyopaque) callconv(.c) c_int { if (user_data == null) return 0; const wrapper: *CommitHookWrapper = @ptrCast(@alignCast(user_data)); const allow_commit = wrapper.func(); return if (allow_commit) 0 else 1; } pub fn rollbackHookCallback(user_data: ?*anyopaque) callconv(.c) void { if (user_data == null) return; const wrapper: *RollbackHookWrapper = @ptrCast(@alignCast(user_data)); wrapper.func(); } pub fn updateHookCallback(user_data: ?*anyopaque, operation: c_int, db_name: [*c]const u8, table_name: [*c]const u8, rowid: c.sqlite3_int64) callconv(.c) void { if (user_data == null) return; const wrapper: *UpdateHookWrapper = @ptrCast(@alignCast(user_data)); const op = UpdateOperation.fromInt(operation) orelse return; const db_str = std.mem.span(db_name); const table_str = std.mem.span(table_name); wrapper.func(op, db_str, table_str, rowid); } pub fn preUpdateHookCallback(user_data: ?*anyopaque, db: ?*c.sqlite3, operation: c_int, db_name: [*c]const u8, table_name: [*c]const u8, old_rowid: c.sqlite3_int64, new_rowid: c.sqlite3_int64) callconv(.c) void { if (user_data == null or db == null) return; const wrapper: *PreUpdateHookWrapper = @ptrCast(@alignCast(user_data)); const op = UpdateOperation.fromInt(operation) orelse return; const db_str = std.mem.span(db_name); const table_str = std.mem.span(table_name); const ctx = PreUpdateContext{ .db = db.? }; wrapper.func(ctx, op, db_str, table_str, old_rowid, new_rowid); } pub fn authorizerCallback(user_data: ?*anyopaque, action: c_int, arg1: [*c]const u8, arg2: [*c]const u8, arg3: [*c]const u8, arg4: [*c]const u8) callconv(.c) c_int { if (user_data == null) return c.SQLITE_OK; const wrapper: *AuthorizerWrapper = @ptrCast(@alignCast(user_data)); const auth_action = AuthAction.fromInt(action) orelse return c.SQLITE_OK; const a1: ?[]const u8 = if (arg1 != null) std.mem.span(arg1) else null; const a2: ?[]const u8 = if (arg2 != null) std.mem.span(arg2) else null; const a3: ?[]const u8 = if (arg3 != null) std.mem.span(arg3) else null; const a4: ?[]const u8 = if (arg4 != null) std.mem.span(arg4) else null; const result = wrapper.func(auth_action, a1, a2, a3, a4); return @intFromEnum(result); } pub fn progressCallback(user_data: ?*anyopaque) callconv(.c) c_int { if (user_data == null) return 0; const wrapper: *ProgressWrapper = @ptrCast(@alignCast(user_data)); const should_continue = wrapper.func(); return if (should_continue) 0 else 1; } pub fn busyHandlerCallback(user_data: ?*anyopaque, count: c_int) callconv(.c) c_int { if (user_data == null) return 0; const wrapper: *BusyHandlerWrapper = @ptrCast(@alignCast(user_data)); const should_retry = wrapper.func(count); return if (should_retry) 1 else 0; }