package main import ( "database/sql" "fmt" "path/filepath" "sync" lua "github.com/yuin/gopher-lua" _ "modernc.org/sqlite" ) // dbRegistry tracks open database connections for cleanup var dbRegistry = struct { sync.Mutex dbs []*sql.DB }{} // setupDBModule creates the db global for SQLite3 access restricted to projectDir func setupDBModule(L *lua.LState, projectDir string) { db := L.NewTable() L.SetField(db, "open", L.NewFunction(func(ls *lua.LState) int { filename := ls.ToString(1) if filename == "" { conn, err := sql.Open("sqlite", ":memory:") if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } dbRegistry.Lock() dbRegistry.dbs = append(dbRegistry.dbs, conn) dbRegistry.Unlock() conn.Exec("PRAGMA journal_mode=WAL") conn.Exec("PRAGMA foreign_keys=ON") ls.Push(toLuaDB(ls, conn, "")) ls.Push(lua.LNil) return 2 } if !filepath.IsAbs(filename) { filename = filepath.Join(projectDir, filename) } safePath := sanitizePath(filename, projectDir) if safePath == "" { ls.Push(lua.LNil) ls.Push(lua.LString("database path access denied")) return 2 } conn, err := sql.Open("sqlite", safePath) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } conn.Exec("PRAGMA journal_mode=WAL") conn.Exec("PRAGMA foreign_keys=ON") dbRegistry.Lock() dbRegistry.dbs = append(dbRegistry.dbs, conn) dbRegistry.Unlock() ls.Push(toLuaDB(ls, conn, safePath)) ls.Push(lua.LNil) return 2 })) L.SetField(db, "openMemory", L.NewFunction(func(ls *lua.LState) int { conn, err := sql.Open("sqlite", ":memory:") if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } dbRegistry.Lock() dbRegistry.dbs = append(dbRegistry.dbs, conn) dbRegistry.Unlock() conn.Exec("PRAGMA journal_mode=WAL") conn.Exec("PRAGMA foreign_keys=ON") ls.Push(toLuaDB(ls, conn, "")) ls.Push(lua.LNil) return 2 })) L.SetGlobal("db", db) } // dbMeta is a userdata wrapper around a sql.DB connection type dbMeta struct { conn *sql.DB path string closed bool } // toLuaDB creates a Lua userdata representing an open database func toLuaDB(L *lua.LState, conn *sql.DB, path string) lua.LValue { meta := &dbMeta{conn: conn, path: path, closed: false} ud := L.NewUserData() ud.Value = meta mt := L.NewTable() L.SetField(mt, "__index", L.NewTable()) index := L.GetField(mt, "__index").(*lua.LTable) // db:exec(sql, args...) - execute without returning rows L.SetField(index, "exec", L.NewFunction(func(ls *lua.LState) int { meta := checkDB(ls, 1) if meta == nil { ls.Push(lua.LNil) ls.Push(lua.LString("invalid database handle")) return 2 } if meta.closed { ls.Push(lua.LNil) ls.Push(lua.LString("database is closed")) return 2 } sqlStr := ls.ToString(2) args := collectArgs(ls, 3) _, err := meta.conn.Exec(sqlStr, args...) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } ls.Push(lua.LTrue) ls.Push(lua.LNil) return 2 })) // db:query(sql, args...) - query and return rows as array of tables L.SetField(index, "query", L.NewFunction(func(ls *lua.LState) int { meta := checkDB(ls, 1) if meta == nil { ls.Push(lua.LNil) ls.Push(lua.LString("invalid database handle")) return 2 } if meta.closed { ls.Push(lua.LNil) ls.Push(lua.LString("database is closed")) return 2 } sqlStr := ls.ToString(2) args := collectArgs(ls, 3) rows, err := meta.conn.Query(sqlStr, args...) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } defer rows.Close() columns, err := rows.Columns() if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } result := ls.NewTable() for rows.Next() { values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } err := rows.Scan(valuePtrs...) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } rowTable := ls.NewTable() for i, col := range columns { rowTable.RawSetString(col, scanToLua(ls, values[i])) } result.Append(rowTable) } if err := rows.Err(); err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } ls.Push(result) ls.Push(lua.LNil) return 2 })) // db:queryRow(sql, args...) - query single row L.SetField(index, "queryRow", L.NewFunction(func(ls *lua.LState) int { meta := checkDB(ls, 1) if meta == nil { ls.Push(lua.LNil) ls.Push(lua.LString("invalid database handle")) return 2 } if meta.closed { ls.Push(lua.LNil) ls.Push(lua.LString("database is closed")) return 2 } sqlStr := ls.ToString(2) args := collectArgs(ls, 3) rows, err := meta.conn.Query(sqlStr, args...) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } ls.Push(lua.LNil) ls.Push(lua.LString("no rows")) return 2 } columns, err := rows.Columns() if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } err = rows.Scan(valuePtrs...) if err != nil { ls.Push(lua.LNil) ls.Push(lua.LString(err.Error())) return 2 } rowTable := ls.NewTable() for i, col := range columns { rowTable.RawSetString(col, scanToLua(ls, values[i])) } ls.Push(rowTable) ls.Push(lua.LNil) return 2 })) // db:close() - close the database connection L.SetField(index, "close", L.NewFunction(func(ls *lua.LState) int { meta := checkDB(ls, 1) if meta == nil { ls.Push(lua.LNil) ls.Push(lua.LString("invalid database handle")) return 2 } if meta.closed { ls.Push(lua.LTrue) ls.Push(lua.LNil) return 2 } err := meta.conn.Close() meta.closed = true if err != nil { ls.Push(lua.LFalse) ls.Push(lua.LString(err.Error())) return 2 } ls.Push(lua.LTrue) ls.Push(lua.LNil) return 2 })) L.SetField(mt, "__tostring", L.NewFunction(func(ls *lua.LState) int { ud := ls.CheckUserData(1) meta := ud.Value.(*dbMeta) if meta.path == "" { ls.Push(lua.LString("[db:memory]")) } else { ls.Push(lua.LString(fmt.Sprintf("[db:%s]", filepath.Base(meta.path)))) } return 1 })) L.SetField(mt, "__gc", L.NewFunction(func(ls *lua.LState) int { ud := ls.CheckUserData(1) meta := ud.Value.(*dbMeta) if !meta.closed { meta.conn.Close() meta.closed = true } return 0 })) L.SetField(mt, "__eq", L.NewFunction(func(ls *lua.LState) int { a := ls.CheckUserData(1) b := ls.CheckUserData(2) ls.Push(lua.LBool(a == b)) return 1 })) L.SetMetatable(ud, mt) return ud } // checkDB validates that the argument is a database userdata func checkDB(L *lua.LState, pos int) *dbMeta { ud, ok := L.Get(pos).(*lua.LUserData) if !ok { return nil } meta, ok := ud.Value.(*dbMeta) if !ok { return nil } return meta } // scanToLua converts a scanned SQL value to a Lua value func scanToLua(L *lua.LState, v any) lua.LValue { switch val := v.(type) { case nil: return lua.LNil case []byte: return lua.LString(string(val)) case int64: return lua.LNumber(float64(val)) case float64: return lua.LNumber(val) case bool: return lua.LBool(val) case string: return lua.LString(val) default: return lua.LString(fmt.Sprintf("%v", v)) } } // collectArgs collects variadic arguments from Lua as []any for SQL parameters func collectArgs(L *lua.LState, pos int) []any { var args []any for i := pos; ; i++ { v := L.Get(i) if v.Type() == lua.LTNil { break } switch val := v.(type) { case *lua.LString: args = append(args, string(*val)) case lua.LNumber: args = append(args, int64(val)) case lua.LBool: args = append(args, bool(val)) default: if v.Type() != lua.LTNil { args = append(args, v.String()) } } } return args } // CleanupDBs closes all tracked database connections func CleanupDBs() { dbRegistry.Lock() defer dbRegistry.Unlock() for _, conn := range dbRegistry.dbs { conn.Close() } dbRegistry.dbs = nil }