package nstore import ( "database/sql" "fmt" "log" "reflect" "strings" ) func (d *Database) createStructTable(tx *sql.Tx, rType reflect.Type) error { if rType.Kind() != reflect.Struct { return errInvalidStruct } tableName := structTypeTableName(rType) colNames, colTypes, err := structTypeToColumns(rType) if err != nil { return err } var colSql strings.Builder for i := range colNames { fmt.Fprintf(&colSql, `"%s" %s,`, colNames[i], colTypes[i]) } query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS "%s" ( _id INTEGER PRIMARY KEY AUTOINCREMENT, %s ) `, tableName, strings.TrimSuffix(colSql.String(), ",")) d.debugQuery(query) if _, err := tx.Exec(query); err != nil { return err } return nil } func (d *Database) createFieldRelationTable(tx *sql.Tx, rType reflect.Type, rField reflect.StructField) error { if !rField.IsExported() { return errFieldNotExported } if rType.Kind() != reflect.Struct { return errInvalidStruct } databaseType := fieldRelationTypeToDatabaseType(rField.Type) if databaseType == "" { return errInvalidRelationType } tableName := fmt.Sprintf("_%s_%s", structTypeTableName(rType), databaseSlug(rField.Name)) query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS "%s" ( _id INTEGER PRIMARY KEY AUTOINCREMENT, src_id INTEGER NOT NULL, value %s NOT NULL, UNIQUE(src_id, value) ) `, tableName, databaseType) d.debugQuery(query) if _, err := tx.Exec(query); err != nil { return err } return nil } func (d *Database) insertStructPlaceholder(tx *sql.Tx, rType reflect.Type) (int64, error) { if err := d.createStructTable(tx, rType); err != nil { return 0, err } tableName := structTypeTableName(rType) columnNames, err := structTypeToColumnNames(rType) if err != nil { return 0, err } values := make([]any, 0) for range columnNames { values = append(values, nil) } query := fmt.Sprintf( `INSERT INTO "%s" ("%s") VALUES (%s)`, tableName, strings.Join(columnNames, `","`), generateColumnPlaceholders(len(columnNames))) d.debugQuery(query, values...) res, err := tx.Exec(query, values...) if err != nil { return 0, err } return res.LastInsertId() } func (d *Database) saveStructDepth(tx *sql.Tx, rValue reflect.Value, depth int) (int64, error) { skipRelations := depth >= maxRecursionDepth rType := rValue.Type() ID, _ := d.ID(rValue.Addr().Interface()) if ID == 0 { var err error ID, err = d.insertStructPlaceholder(tx, rType) if err != nil { return 0, err } if ID == 0 { return 0, errNotExists } d.recordIDs[rValue.Addr().Interface()] = ID } tableName := structTypeTableName(rType) colNames, err := structTypeToColumnNames(rType) if err != nil { return 0, err } colValues := make([]any, 0) for rField, rFieldValue := range rValue.Fields() { if !rField.IsExported() { if d.debug { log.Printf("skipping non exported field %s", rField.Name) } continue } rFieldType := rFieldValue.Type() if rFieldType.Kind() == reflect.Pointer { rFieldValue = rFieldValue.Elem() rFieldType = rFieldValue.Type() } colType := fieldTypeToDatabaseType(rField.Type) if colType == "" { relType := fieldRelationTypeToDatabaseType(rField.Type) if relType == "" { log.Printf("unable save field %s of type %s, type not implemented", rField.Name, rField.Type.Name()) continue } if skipRelations { continue } if err := d.saveFieldRelation(tx, ID, rType, rField, rFieldValue, depth); err != nil { return 0, err } continue } colValues = append(colValues, sanitizeValue(rFieldValue.Interface())) } if len(colNames) == 0 || len(colValues) == 0 { return 0, errReadStruct } updateClauses := strings.TrimSuffix(strings.Join(colNames, `" = ?, "`), `?, "`) args := append(colValues, ID) query := fmt.Sprintf(`UPDATE "%s" SET "%s" = ? WHERE _id = ?`, tableName, updateClauses) d.debugQuery(query, args...) stmt, err := tx.Prepare(query) if err != nil { return 0, err } defer stmt.Close() _, err = stmt.Exec(args...) if err != nil { return 0, err } return ID, nil } func (d *Database) saveFieldRelation(tx *sql.Tx, srcID int64, rType reflect.Type, rField reflect.StructField, rFieldValue reflect.Value, depth int) error { if err := d.createFieldRelationTable(tx, rType, rField); err != nil { return err } saveValues := make([]any, 0) rFieldType := rField.Type if rFieldType.Kind() == reflect.Pointer { rFieldType = rFieldType.Elem() } switch rFieldType.Kind() { case reflect.Struct: dstID, err := d.saveStructDepth(tx, rFieldValue, depth+1) if err != nil { return err } d.recordIDs[rFieldValue.Addr().Interface()] = dstID saveValues = append(saveValues, dstID) case reflect.Slice: for _, item := range rFieldValue.Seq2() { log.Println(rField.Name, item.Type().Kind()) if item.Type().Kind() == reflect.Pointer { item = item.Elem() } if item.Type().Kind() == reflect.Struct { dstID, err := d.saveStructDepth(tx, item, depth+1) if err != nil { return err } d.recordIDs[item.Addr().Interface()] = dstID saveValues = append(saveValues, dstID) continue } saveValues = append(saveValues, item.Interface()) } default: return errInvalidRelationType } if err := d.deleteFieldRelationsByValues(tx, srcID, rType, rField, saveValues); err != nil { return err } tableName := fmt.Sprintf("_%s_%s", structTypeTableName(rType), databaseSlug(rField.Name)) for _, saveValue := range saveValues { query := fmt.Sprintf(`INSERT OR IGNORE INTO "%s" (src_id, value) VALUES (?, ?);`, tableName) d.debugQuery(query, srcID, saveValue) _, err := tx.Exec(query, srcID, saveValue) if err != nil { return err } } return nil }