d28f005552
Fix limit for databases other than sqlite go mod tidy && go mod vendor Remove unneeded break statements Make everything work with the new xorm version Fix xorm logging Fix lint Fix redis init Fix using id field Fix database init for testing Change default database log level Add xorm logger Use const for postgres go mod tidy Merge branch 'master' into update/xorm # Conflicts: # go.mod # go.sum # vendor/modules.txt go mod vendor Fix loading fixtures for postgres Go mod vendor1 Update xorm to version 1 Co-authored-by: kolaente <k@knt.li> Reviewed-on: https://kolaente.dev/vikunja/api/pulls/323
490 lines
13 KiB
Go
490 lines
13 KiB
Go
// Copyright 2016 The Xorm Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package xorm
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"xorm.io/xorm/internal/utils"
|
|
"xorm.io/xorm/schemas"
|
|
)
|
|
|
|
// Ping test if database is ok
|
|
func (session *Session) Ping() error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
|
|
return session.DB().PingContext(session.ctx)
|
|
}
|
|
|
|
// CreateTable create a table according a bean
|
|
func (session *Session) CreateTable(bean interface{}) error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
return session.createTable(bean)
|
|
}
|
|
|
|
func (session *Session) createTable(bean interface{}) error {
|
|
if err := session.statement.SetRefBean(bean); err != nil {
|
|
return err
|
|
}
|
|
|
|
sqlStrs := session.statement.GenCreateTableSQL()
|
|
for _, s := range sqlStrs {
|
|
_, err := session.exec(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateIndexes create indexes
|
|
func (session *Session) CreateIndexes(bean interface{}) error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
return session.createIndexes(bean)
|
|
}
|
|
|
|
func (session *Session) createIndexes(bean interface{}) error {
|
|
if err := session.statement.SetRefBean(bean); err != nil {
|
|
return err
|
|
}
|
|
|
|
sqls := session.statement.GenIndexSQL()
|
|
for _, sqlStr := range sqls {
|
|
_, err := session.exec(sqlStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateUniques create uniques
|
|
func (session *Session) CreateUniques(bean interface{}) error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
return session.createUniques(bean)
|
|
}
|
|
|
|
func (session *Session) createUniques(bean interface{}) error {
|
|
if err := session.statement.SetRefBean(bean); err != nil {
|
|
return err
|
|
}
|
|
|
|
sqls := session.statement.GenUniqueSQL()
|
|
for _, sqlStr := range sqls {
|
|
_, err := session.exec(sqlStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DropIndexes drop indexes
|
|
func (session *Session) DropIndexes(bean interface{}) error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
return session.dropIndexes(bean)
|
|
}
|
|
|
|
func (session *Session) dropIndexes(bean interface{}) error {
|
|
if err := session.statement.SetRefBean(bean); err != nil {
|
|
return err
|
|
}
|
|
|
|
sqls := session.statement.GenDelIndexSQL()
|
|
for _, sqlStr := range sqls {
|
|
_, err := session.exec(sqlStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DropTable drop table will drop table if exist, if drop failed, it will return error
|
|
func (session *Session) DropTable(beanOrTableName interface{}) error {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
return session.dropTable(beanOrTableName)
|
|
}
|
|
|
|
func (session *Session) dropTable(beanOrTableName interface{}) error {
|
|
tableName := session.engine.TableName(beanOrTableName)
|
|
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
|
|
if !checkIfExist {
|
|
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
checkIfExist = exist
|
|
}
|
|
|
|
if checkIfExist {
|
|
_, err := session.exec(sqlStr)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// IsTableExist if a table is exist
|
|
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
|
|
tableName := session.engine.TableName(beanOrTableName)
|
|
|
|
return session.isTableExist(tableName)
|
|
}
|
|
|
|
func (session *Session) isTableExist(tableName string) (bool, error) {
|
|
return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
|
}
|
|
|
|
// IsTableEmpty if table have any records
|
|
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
|
|
if session.isAutoClose {
|
|
defer session.Close()
|
|
}
|
|
return session.isTableEmpty(session.engine.TableName(bean))
|
|
}
|
|
|
|
func (session *Session) isTableEmpty(tableName string) (bool, error) {
|
|
var total int64
|
|
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
|
|
err := session.queryRow(sqlStr).Scan(&total)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
err = nil
|
|
}
|
|
return true, err
|
|
}
|
|
|
|
return total == 0, nil
|
|
}
|
|
|
|
// find if index is exist according cols
|
|
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
|
|
indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
for _, index := range indexes {
|
|
if utils.SliceEq(index.Cols, cols) {
|
|
if unique {
|
|
return index.Type == schemas.UniqueType, nil
|
|
}
|
|
return index.Type == schemas.IndexType, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (session *Session) addColumn(colName string) error {
|
|
col := session.statement.RefTable.GetColumn(colName)
|
|
sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col)
|
|
_, err := session.exec(sql)
|
|
return err
|
|
}
|
|
|
|
func (session *Session) addIndex(tableName, idxName string) error {
|
|
index := session.statement.RefTable.Indexes[idxName]
|
|
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
|
|
_, err := session.exec(sqlStr)
|
|
return err
|
|
}
|
|
|
|
func (session *Session) addUnique(tableName, uqeName string) error {
|
|
index := session.statement.RefTable.Indexes[uqeName]
|
|
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
|
|
_, err := session.exec(sqlStr)
|
|
return err
|
|
}
|
|
|
|
// Sync2 synchronize structs to database tables
|
|
func (session *Session) Sync2(beans ...interface{}) error {
|
|
engine := session.engine
|
|
|
|
if session.isAutoClose {
|
|
session.isAutoClose = false
|
|
defer session.Close()
|
|
}
|
|
|
|
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session.autoResetStatement = false
|
|
defer func() {
|
|
session.autoResetStatement = true
|
|
session.resetStatement()
|
|
}()
|
|
|
|
for _, bean := range beans {
|
|
v := utils.ReflectValue(bean)
|
|
table, err := engine.tagParser.ParseWithCache(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var tbName string
|
|
if len(session.statement.AltTableName) > 0 {
|
|
tbName = session.statement.AltTableName
|
|
} else {
|
|
tbName = engine.TableName(bean)
|
|
}
|
|
tbNameWithSchema := engine.tbNameWithSchema(tbName)
|
|
|
|
var oriTable *schemas.Table
|
|
for _, tb := range tables {
|
|
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
|
|
oriTable = tb
|
|
break
|
|
}
|
|
}
|
|
|
|
// this is a new table
|
|
if oriTable == nil {
|
|
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = session.createUniques(bean)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = session.createIndexes(bean)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
// this will modify an old table
|
|
if err = engine.loadTableInfo(oriTable); err != nil {
|
|
return err
|
|
}
|
|
|
|
// check columns
|
|
for _, col := range table.Columns() {
|
|
var oriCol *schemas.Column
|
|
for _, col2 := range oriTable.Columns() {
|
|
if strings.EqualFold(col.Name, col2.Name) {
|
|
oriCol = col2
|
|
break
|
|
}
|
|
}
|
|
|
|
// column is not exist on table
|
|
if oriCol == nil {
|
|
session.statement.RefTable = table
|
|
session.statement.SetTableName(tbNameWithSchema)
|
|
if err = session.addColumn(col.Name); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
err = nil
|
|
expectedType := engine.dialect.SQLType(col)
|
|
curType := engine.dialect.SQLType(oriCol)
|
|
if expectedType != curType {
|
|
if expectedType == schemas.Text &&
|
|
strings.HasPrefix(curType, schemas.Varchar) {
|
|
// currently only support mysql & postgres
|
|
if engine.dialect.URI().DBType == schemas.MYSQL ||
|
|
engine.dialect.URI().DBType == schemas.POSTGRES {
|
|
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
|
|
tbNameWithSchema, col.Name, curType, expectedType)
|
|
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
|
|
} else {
|
|
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
|
|
tbNameWithSchema, col.Name, curType, expectedType)
|
|
}
|
|
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
|
|
if engine.dialect.URI().DBType == schemas.MYSQL {
|
|
if oriCol.Length < col.Length {
|
|
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
|
|
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
|
|
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
|
|
}
|
|
}
|
|
} else {
|
|
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
|
|
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
|
|
tbNameWithSchema, col.Name, curType, expectedType)
|
|
}
|
|
}
|
|
} else if expectedType == schemas.Varchar {
|
|
if engine.dialect.URI().DBType == schemas.MYSQL {
|
|
if oriCol.Length < col.Length {
|
|
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
|
|
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
|
|
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
|
|
}
|
|
}
|
|
}
|
|
|
|
if col.Default != oriCol.Default {
|
|
switch {
|
|
case col.IsAutoIncrement: // For autoincrement column, don't check default
|
|
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
|
|
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
|
|
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
|
|
default:
|
|
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
|
|
tbName, col.Name, oriCol.Default, col.Default)
|
|
}
|
|
}
|
|
if col.Nullable != oriCol.Nullable {
|
|
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
|
|
tbName, col.Name, oriCol.Nullable, col.Nullable)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var foundIndexNames = make(map[string]bool)
|
|
var addedNames = make(map[string]*schemas.Index)
|
|
|
|
for name, index := range table.Indexes {
|
|
var oriIndex *schemas.Index
|
|
for name2, index2 := range oriTable.Indexes {
|
|
if index.Equal(index2) {
|
|
oriIndex = index2
|
|
foundIndexNames[name2] = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if oriIndex != nil {
|
|
if oriIndex.Type != index.Type {
|
|
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
|
|
_, err = session.exec(sql)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
oriIndex = nil
|
|
}
|
|
}
|
|
|
|
if oriIndex == nil {
|
|
addedNames[name] = index
|
|
}
|
|
}
|
|
|
|
for name2, index2 := range oriTable.Indexes {
|
|
if _, ok := foundIndexNames[name2]; !ok {
|
|
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
|
|
_, err = session.exec(sql)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
for name, index := range addedNames {
|
|
if index.Type == schemas.UniqueType {
|
|
session.statement.RefTable = table
|
|
session.statement.SetTableName(tbNameWithSchema)
|
|
err = session.addUnique(tbNameWithSchema, name)
|
|
} else if index.Type == schemas.IndexType {
|
|
session.statement.RefTable = table
|
|
session.statement.SetTableName(tbNameWithSchema)
|
|
err = session.addIndex(tbNameWithSchema, name)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// check all the columns which removed from struct fields but left on database tables.
|
|
for _, colName := range oriTable.ColumnsSeq() {
|
|
if table.GetColumn(colName) == nil {
|
|
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ImportFile SQL DDL file
|
|
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
|
|
file, err := os.Open(ddlPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer file.Close()
|
|
return session.Import(file)
|
|
}
|
|
|
|
// Import SQL DDL from io.Reader
|
|
func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
|
|
var results []sql.Result
|
|
var lastError error
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
var inSingleQuote bool
|
|
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
for i, b := range data {
|
|
if b == '\'' {
|
|
inSingleQuote = !inSingleQuote
|
|
}
|
|
if !inSingleQuote && b == ';' {
|
|
return i + 1, data[0:i], nil
|
|
}
|
|
}
|
|
// If we're at EOF, we have a final, non-terminated line. Return it.
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
// Request more data.
|
|
return 0, nil, nil
|
|
}
|
|
|
|
scanner.Split(semiColSpliter)
|
|
|
|
for scanner.Scan() {
|
|
query := strings.Trim(scanner.Text(), " \t\n\r")
|
|
if len(query) > 0 {
|
|
result, err := session.Exec(query)
|
|
results = append(results, result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return results, lastError
|
|
}
|