vikunja-api/vendor/xorm.io/xorm/session_insert.go
konrad d28f005552 Update xorm to v1 (#323)
Fix limit for databases other than sqlite

go mod tidy && go mod vendor

Remove unneeded break statements

Make everything work with the new xorm version

Fix xorm logging

Fix lint

Fix redis init

Fix using id field

Fix database init for testing

Change default database log level

Add xorm logger

Use const for postgres

go mod tidy

Merge branch 'master' into update/xorm

# Conflicts:
#	go.mod
#	go.sum
#	vendor/modules.txt

go mod vendor

Fix loading fixtures for postgres

Go mod vendor1

Update xorm to version 1

Co-authored-by: kolaente <k@knt.li>
Reviewed-on: https://kolaente.dev/vikunja/api/pulls/323
2020-04-12 17:29:24 +00:00

707 lines
18 KiB
Go

// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// ErrNoElementsOnSlice represents an error there is no element when insert
var ErrNoElementsOnSlice = errors.New("No element on slice when insert")
// Insert insert one or more beans
func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
switch bean.(type) {
case map[string]interface{}:
cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
if err != nil {
return affected, err
}
affected += cnt
case []map[string]interface{}:
s := bean.([]map[string]interface{})
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapInterface(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
case map[string]string:
cnt, err := session.insertMapString(bean.(map[string]string))
if err != nil {
return affected, err
}
affected += cnt
case []map[string]string:
s := bean.([]map[string]string)
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapString(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len()
if size <= 0 {
return 0, ErrNoElementsOnSlice
}
cnt, err := session.innerInsertMulti(bean)
if err != nil {
return affected, err
}
affected += cnt
} else {
cnt, err := session.innerInsert(bean)
if err != nil {
return affected, err
}
affected += cnt
}
}
}
return affected, err
}
func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, errors.New("needs a pointer to a slice")
}
if sliceValue.Len() <= 0 {
return 0, errors.New("could not insert a empty slice")
}
if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
table := session.statement.RefTable
size := sliceValue.Len()
var colNames []string
var colMultiPlaces []string
var args []interface{}
var cols []*schemas.Column
for i := 0; i < size; i++ {
v := sliceValue.Index(i)
var vv reflect.Value
switch v.Kind() {
case reflect.Interface:
vv = reflect.Indirect(v.Elem())
default:
vv = reflect.Indirect(v)
}
elemValue := v.Interface()
var colPlaces []string
// handle BeforeInsertProcessor
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.beforeClosures {
closure(elemValue)
}
if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
// --
for _, col := range table.Columns() {
ptrFieldValue, err := col.ValueOfV(&vv)
if err != nil {
return 0, err
}
fieldValue := *ptrFieldValue
if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) {
continue
}
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnInt(bean, col, 1)
})
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
args = append(args, arg)
}
if i == 0 {
colNames = append(colNames, col.Name)
cols = append(cols, col)
}
colPlaces = append(colPlaces, "?")
}
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
}
cleanupProcessorsClosures(&session.beforeClosures)
quoter := session.engine.dialect.Quoter()
var sql string
colStr := quoter.Join(colNames, ",")
if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName),
colStr)
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
session.cacheInsert(tableName)
lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ {
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor
if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures {
closure(elemValue)
}
if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[elemValue] = &afterClosures
}
} else {
if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
session.afterInsertBeans[elemValue] = nil
}
}
}
}
cleanupProcessorsClosures(&session.afterClosures)
return res.RowsAffected()
}
// InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, ErrParamsType
}
if sliceValue.Len() <= 0 {
return 0, nil
}
return session.innerInsertMulti(rowsSlicePtr)
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
closure(bean)
}
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
lenAfterClosures := len(session.afterClosures)
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[bean]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[bean] = &afterClosures
}
} else {
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
session.afterInsertBeans[bean] = nil
}
}
}
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if len(res) < 1 {
return 0, errors.New("insert no error but not returned id")
}
idByte := res[0][table.AutoIncrement]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if len(res) < 1 {
return 0, errors.New("insert successfully but not returned id")
}
idByte := res[0][table.AutoIncrement]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if table.AutoIncrement == "" {
return res.RowsAffected()
}
var id int64
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return res.RowsAffected()
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return res.RowsAffected()
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return res.RowsAffected()
}
// InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
func (session *Session) InsertOne(bean interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
return session.innerInsert(bean)
}
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.cacherMgr.GetCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debugf("[cache] clear sql: %v", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
if session.statement.IncrColumns.IsColExist(col.Name) {
continue
} else if session.statement.DecrColumns.IsColExist(col.Name) {
continue
} else if session.statement.ExprColumns.IsColExist(col.Name) {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
continue
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
exprs := session.statement.ExprColumns
w := builder.NewWriter()
// if insert where
if session.statement.Conds().IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err
}
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}
if err := session.statement.WriteArgs(w, args); err != nil {
return 0, err
}
if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.WriteArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.Conds().WriteTo(w); err != nil {
return 0, err
}
} else {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err
}
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
return 0, err
}
w.Append(args...)
if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.WriteArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(")"); err != nil {
return 0, err
}
}
sql := w.String()
args = w.Args()
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}