zcatsql/src/functions.zig
reugenio 5e28cbe4bf refactor: modularize root.zig into specialized modules
Split monolithic root.zig (4200 lines) into 9 focused modules:
- c.zig: centralized @cImport for SQLite
- errors.zig: Error enum and resultToError
- types.zig: OpenFlags, ColumnType, Limit, enums
- database.zig: Database struct with all methods
- statement.zig: Statement struct with bindings/columns
- functions.zig: UDFs, hooks, and C callbacks
- backup.zig: Backup and Blob I/O
- pool.zig: ConnectionPool (thread-safe)
- root.zig: re-exports + tests (~1100 lines)

Total: ~3600 lines (74% reduction in root.zig)
All 47 tests passing.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-08 19:54:19 +01:00

567 lines
18 KiB
Zig

//! User-defined functions and hooks for SQLite
//!
//! Provides support for scalar functions, aggregate functions, window functions,
//! collations, and database hooks (commit, rollback, update, pre-update).
const std = @import("std");
const c = @import("c.zig").c;
const types = @import("types.zig");
const ColumnType = types.ColumnType;
const UpdateOperation = types.UpdateOperation;
const AuthAction = types.AuthAction;
const AuthResult = types.AuthResult;
// ============================================================================
// Function Context and Values
// ============================================================================
/// Context for user-defined function results.
pub const FunctionContext = struct {
ctx: *c.sqlite3_context,
const Self = @This();
pub fn setNull(self: Self) void {
c.sqlite3_result_null(self.ctx);
}
pub fn setInt(self: Self, value: i64) void {
c.sqlite3_result_int64(self.ctx, value);
}
pub fn setFloat(self: Self, value: f64) void {
c.sqlite3_result_double(self.ctx, value);
}
pub fn setText(self: Self, value: []const u8) void {
c.sqlite3_result_text(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT);
}
pub fn setBlob(self: Self, value: []const u8) void {
c.sqlite3_result_blob(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT);
}
pub fn setError(self: Self, msg: []const u8) void {
c.sqlite3_result_error(self.ctx, msg.ptr, @intCast(msg.len));
}
};
/// Value passed to user-defined functions.
pub const FunctionValue = struct {
value: *c.sqlite3_value,
const Self = @This();
pub fn getType(self: Self) ColumnType {
const vtype = c.sqlite3_value_type(self.value);
return switch (vtype) {
c.SQLITE_INTEGER => .integer,
c.SQLITE_FLOAT => .float,
c.SQLITE_TEXT => .text,
c.SQLITE_BLOB => .blob,
c.SQLITE_NULL => .null_value,
else => .null_value,
};
}
pub fn isNull(self: Self) bool {
return self.getType() == .null_value;
}
pub fn asInt(self: Self) i64 {
return c.sqlite3_value_int64(self.value);
}
pub fn asFloat(self: Self) f64 {
return c.sqlite3_value_double(self.value);
}
pub fn asText(self: Self) ?[]const u8 {
const len = c.sqlite3_value_bytes(self.value);
const text = c.sqlite3_value_text(self.value);
if (text) |t| {
return t[0..@intCast(len)];
}
return null;
}
pub fn asBlob(self: Self) ?[]const u8 {
const len = c.sqlite3_value_bytes(self.value);
const blob = c.sqlite3_value_blob(self.value);
if (blob) |b| {
const ptr: [*]const u8 = @ptrCast(b);
return ptr[0..@intCast(len)];
}
return null;
}
};
/// Context for aggregate functions with state management.
pub const AggregateContext = struct {
ctx: *c.sqlite3_context,
const Self = @This();
pub fn getAggregateContext(self: Self, comptime T: type) ?*T {
const ptr = c.sqlite3_aggregate_context(self.ctx, @sizeOf(T));
if (ptr == null) return null;
return @ptrCast(@alignCast(ptr));
}
pub fn setNull(self: Self) void {
c.sqlite3_result_null(self.ctx);
}
pub fn setInt(self: Self, value: i64) void {
c.sqlite3_result_int64(self.ctx, value);
}
pub fn setFloat(self: Self, value: f64) void {
c.sqlite3_result_double(self.ctx, value);
}
pub fn setText(self: Self, value: []const u8) void {
c.sqlite3_result_text(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT);
}
pub fn setBlob(self: Self, value: []const u8) void {
c.sqlite3_result_blob(self.ctx, value.ptr, @intCast(value.len), c.SQLITE_TRANSIENT);
}
pub fn setError(self: Self, msg: []const u8) void {
c.sqlite3_result_error(self.ctx, msg.ptr, @intCast(msg.len));
}
};
/// Context for pre-update hook with access to old/new values.
pub const PreUpdateContext = struct {
db: *c.sqlite3,
const Self = @This();
pub fn columnCount(self: Self) i32 {
return c.sqlite3_preupdate_count(self.db);
}
pub fn depth(self: Self) i32 {
return c.sqlite3_preupdate_depth(self.db);
}
pub fn oldValue(self: Self, col: u32) ?FunctionValue {
var value: ?*c.sqlite3_value = null;
const result = c.sqlite3_preupdate_old(self.db, @intCast(col), &value);
if (result != c.SQLITE_OK or value == null) return null;
return FunctionValue{ .value = value.? };
}
pub fn newValue(self: Self, col: u32) ?FunctionValue {
var value: ?*c.sqlite3_value = null;
const result = c.sqlite3_preupdate_new(self.db, @intCast(col), &value);
if (result != c.SQLITE_OK or value == null) return null;
return FunctionValue{ .value = value.? };
}
};
// ============================================================================
// Function Types
// ============================================================================
pub const ScalarFn = *const fn (ctx: FunctionContext, args: []const FunctionValue) void;
pub const AggregateStepFn = *const fn (ctx: AggregateContext, args: []const FunctionValue) void;
pub const AggregateFinalFn = *const fn (ctx: AggregateContext) void;
pub const WindowValueFn = *const fn (ctx: AggregateContext) void;
pub const WindowInverseFn = *const fn (ctx: AggregateContext, args: []const FunctionValue) void;
pub const CollationFn = *const fn (a: []const u8, b: []const u8) i32;
// Hook function types
pub const ZigCommitHookFn = *const fn () bool;
pub const ZigRollbackHookFn = *const fn () void;
pub const ZigUpdateHookFn = *const fn (operation: UpdateOperation, db_name: []const u8, table_name: []const u8, rowid: i64) void;
pub const ZigPreUpdateHookFn = *const fn (ctx: PreUpdateContext, operation: UpdateOperation, db_name: []const u8, table_name: []const u8, old_rowid: i64, new_rowid: i64) void;
pub const ZigAuthorizerFn = *const fn (action: AuthAction, arg1: ?[]const u8, arg2: ?[]const u8, arg3: ?[]const u8, arg4: ?[]const u8) AuthResult;
pub const ZigProgressFn = *const fn () bool;
pub const ZigBusyHandlerFn = *const fn (count: i32) bool;
// ============================================================================
// Wrappers (stored in SQLite user_data)
// ============================================================================
pub const ScalarFnWrapper = struct {
func: ScalarFn,
pub fn create(func: ScalarFn) !*ScalarFnWrapper {
const wrapper = try std.heap.page_allocator.create(ScalarFnWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *ScalarFnWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const AggregateFnWrapper = struct {
step_fn: AggregateStepFn,
final_fn: AggregateFinalFn,
pub fn create(step_fn: AggregateStepFn, final_fn: AggregateFinalFn) !*AggregateFnWrapper {
const wrapper = try std.heap.page_allocator.create(AggregateFnWrapper);
wrapper.step_fn = step_fn;
wrapper.final_fn = final_fn;
return wrapper;
}
pub fn destroy(self: *AggregateFnWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const WindowFnWrapper = struct {
step_fn: AggregateStepFn,
final_fn: AggregateFinalFn,
value_fn: WindowValueFn,
inverse_fn: WindowInverseFn,
pub fn create(step_fn: AggregateStepFn, final_fn: AggregateFinalFn, value_fn: WindowValueFn, inverse_fn: WindowInverseFn) !*WindowFnWrapper {
const wrapper = try std.heap.page_allocator.create(WindowFnWrapper);
wrapper.step_fn = step_fn;
wrapper.final_fn = final_fn;
wrapper.value_fn = value_fn;
wrapper.inverse_fn = inverse_fn;
return wrapper;
}
pub fn destroy(self: *WindowFnWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const CollationWrapper = struct {
func: CollationFn,
pub fn create(func: CollationFn) !*CollationWrapper {
const wrapper = try std.heap.page_allocator.create(CollationWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *CollationWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const CommitHookWrapper = struct {
func: ZigCommitHookFn,
pub fn create(func: ZigCommitHookFn) !*CommitHookWrapper {
const wrapper = try std.heap.page_allocator.create(CommitHookWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *CommitHookWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const RollbackHookWrapper = struct {
func: ZigRollbackHookFn,
pub fn create(func: ZigRollbackHookFn) !*RollbackHookWrapper {
const wrapper = try std.heap.page_allocator.create(RollbackHookWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *RollbackHookWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const UpdateHookWrapper = struct {
func: ZigUpdateHookFn,
pub fn create(func: ZigUpdateHookFn) !*UpdateHookWrapper {
const wrapper = try std.heap.page_allocator.create(UpdateHookWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *UpdateHookWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const PreUpdateHookWrapper = struct {
func: ZigPreUpdateHookFn,
pub fn create(func: ZigPreUpdateHookFn) !*PreUpdateHookWrapper {
const wrapper = try std.heap.page_allocator.create(PreUpdateHookWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *PreUpdateHookWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const AuthorizerWrapper = struct {
func: ZigAuthorizerFn,
pub fn create(func: ZigAuthorizerFn) !*AuthorizerWrapper {
const wrapper = try std.heap.page_allocator.create(AuthorizerWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *AuthorizerWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const ProgressWrapper = struct {
func: ZigProgressFn,
pub fn create(func: ZigProgressFn) !*ProgressWrapper {
const wrapper = try std.heap.page_allocator.create(ProgressWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *ProgressWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
pub const BusyHandlerWrapper = struct {
func: ZigBusyHandlerFn,
pub fn create(func: ZigBusyHandlerFn) !*BusyHandlerWrapper {
const wrapper = try std.heap.page_allocator.create(BusyHandlerWrapper);
wrapper.func = func;
return wrapper;
}
pub fn destroy(self: *BusyHandlerWrapper) void {
std.heap.page_allocator.destroy(self);
}
};
// ============================================================================
// C Callback Trampolines
// ============================================================================
pub fn scalarCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *ScalarFnWrapper = @ptrCast(@alignCast(user_data));
const func_ctx = FunctionContext{ .ctx = ctx.? };
const args_count: usize = @intCast(argc);
var args: [16]FunctionValue = undefined;
const actual_count = @min(args_count, 16);
for (0..actual_count) |i| {
if (argv[i]) |v| {
args[i] = FunctionValue{ .value = v };
}
}
wrapper.func(func_ctx, args[0..actual_count]);
}
pub fn scalarDestructor(ptr: ?*anyopaque) callconv(.c) void {
if (ptr) |p| {
const wrapper: *ScalarFnWrapper = @ptrCast(@alignCast(p));
wrapper.destroy();
}
}
pub fn aggregateStepCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
const args_count: usize = @intCast(argc);
var args: [16]FunctionValue = undefined;
const actual_count = @min(args_count, 16);
for (0..actual_count) |i| {
if (argv[i]) |v| {
args[i] = FunctionValue{ .value = v };
}
}
wrapper.step_fn(agg_ctx, args[0..actual_count]);
}
pub fn aggregateFinalCallback(ctx: ?*c.sqlite3_context) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
wrapper.final_fn(agg_ctx);
}
pub fn aggregateDestructor(ptr: ?*anyopaque) callconv(.c) void {
if (ptr) |p| {
const wrapper: *AggregateFnWrapper = @ptrCast(@alignCast(p));
wrapper.destroy();
}
}
pub fn windowStepCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
const args_count: usize = @intCast(argc);
var args: [16]FunctionValue = undefined;
const actual_count = @min(args_count, 16);
for (0..actual_count) |i| {
if (argv[i]) |v| {
args[i] = FunctionValue{ .value = v };
}
}
wrapper.step_fn(agg_ctx, args[0..actual_count]);
}
pub fn windowFinalCallback(ctx: ?*c.sqlite3_context) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
wrapper.final_fn(agg_ctx);
}
pub fn windowValueCallback(ctx: ?*c.sqlite3_context) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
wrapper.value_fn(agg_ctx);
}
pub fn windowInverseCallback(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.c) void {
const user_data = c.sqlite3_user_data(ctx);
if (user_data == null) return;
const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(user_data));
const agg_ctx = AggregateContext{ .ctx = ctx.? };
const args_count: usize = @intCast(argc);
var args: [16]FunctionValue = undefined;
const actual_count = @min(args_count, 16);
for (0..actual_count) |i| {
if (argv[i]) |v| {
args[i] = FunctionValue{ .value = v };
}
}
wrapper.inverse_fn(agg_ctx, args[0..actual_count]);
}
pub fn windowDestructor(ptr: ?*anyopaque) callconv(.c) void {
if (ptr) |p| {
const wrapper: *WindowFnWrapper = @ptrCast(@alignCast(p));
wrapper.destroy();
}
}
pub fn collationCallback(user_data: ?*anyopaque, len_a: c_int, data_a: ?*const anyopaque, len_b: c_int, data_b: ?*const anyopaque) callconv(.c) c_int {
if (user_data == null) return 0;
const wrapper: *CollationWrapper = @ptrCast(@alignCast(user_data));
const a: []const u8 = if (data_a) |ptr|
@as([*]const u8, @ptrCast(ptr))[0..@intCast(len_a)]
else
"";
const b: []const u8 = if (data_b) |ptr|
@as([*]const u8, @ptrCast(ptr))[0..@intCast(len_b)]
else
"";
return wrapper.func(a, b);
}
pub fn collationDestructor(ptr: ?*anyopaque) callconv(.c) void {
if (ptr) |p| {
const wrapper: *CollationWrapper = @ptrCast(@alignCast(p));
wrapper.destroy();
}
}
pub fn commitHookCallback(user_data: ?*anyopaque) callconv(.c) c_int {
if (user_data == null) return 0;
const wrapper: *CommitHookWrapper = @ptrCast(@alignCast(user_data));
const allow_commit = wrapper.func();
return if (allow_commit) 0 else 1;
}
pub fn rollbackHookCallback(user_data: ?*anyopaque) callconv(.c) void {
if (user_data == null) return;
const wrapper: *RollbackHookWrapper = @ptrCast(@alignCast(user_data));
wrapper.func();
}
pub fn updateHookCallback(user_data: ?*anyopaque, operation: c_int, db_name: [*c]const u8, table_name: [*c]const u8, rowid: c.sqlite3_int64) callconv(.c) void {
if (user_data == null) return;
const wrapper: *UpdateHookWrapper = @ptrCast(@alignCast(user_data));
const op = UpdateOperation.fromInt(operation) orelse return;
const db_str = std.mem.span(db_name);
const table_str = std.mem.span(table_name);
wrapper.func(op, db_str, table_str, rowid);
}
pub fn preUpdateHookCallback(user_data: ?*anyopaque, db: ?*c.sqlite3, operation: c_int, db_name: [*c]const u8, table_name: [*c]const u8, old_rowid: c.sqlite3_int64, new_rowid: c.sqlite3_int64) callconv(.c) void {
if (user_data == null or db == null) return;
const wrapper: *PreUpdateHookWrapper = @ptrCast(@alignCast(user_data));
const op = UpdateOperation.fromInt(operation) orelse return;
const db_str = std.mem.span(db_name);
const table_str = std.mem.span(table_name);
const ctx = PreUpdateContext{ .db = db.? };
wrapper.func(ctx, op, db_str, table_str, old_rowid, new_rowid);
}
pub fn authorizerCallback(user_data: ?*anyopaque, action: c_int, arg1: [*c]const u8, arg2: [*c]const u8, arg3: [*c]const u8, arg4: [*c]const u8) callconv(.c) c_int {
if (user_data == null) return c.SQLITE_OK;
const wrapper: *AuthorizerWrapper = @ptrCast(@alignCast(user_data));
const auth_action = AuthAction.fromInt(action) orelse return c.SQLITE_OK;
const a1: ?[]const u8 = if (arg1 != null) std.mem.span(arg1) else null;
const a2: ?[]const u8 = if (arg2 != null) std.mem.span(arg2) else null;
const a3: ?[]const u8 = if (arg3 != null) std.mem.span(arg3) else null;
const a4: ?[]const u8 = if (arg4 != null) std.mem.span(arg4) else null;
const result = wrapper.func(auth_action, a1, a2, a3, a4);
return @intFromEnum(result);
}
pub fn progressCallback(user_data: ?*anyopaque) callconv(.c) c_int {
if (user_data == null) return 0;
const wrapper: *ProgressWrapper = @ptrCast(@alignCast(user_data));
const should_continue = wrapper.func();
return if (should_continue) 0 else 1;
}
pub fn busyHandlerCallback(user_data: ?*anyopaque, count: c_int) callconv(.c) c_int {
if (user_data == null) return 0;
const wrapper: *BusyHandlerWrapper = @ptrCast(@alignCast(user_data));
const should_retry = wrapper.func(count);
return if (should_retry) 1 else 0;
}