DB Migrations (#67)

This commit is contained in:
konrad 2019-03-29 17:54:35 +00:00 committed by Gitea
parent e21471a193
commit be5a17e993
127 changed files with 7917 additions and 1456 deletions

View file

@ -0,0 +1,71 @@
---
date: "2019-03-29:00:00+02:00"
title: "Database migrations"
draft: false
type: "doc"
menu:
sidebar:
parent: "development"
---
# Database Migrations
Vikunja runs all database migrations automatically on each start if needed.
Additionally, they can also be run directly by using the `migrate` command.
We use [xormigrate](https://github.com/techknowlogick/xormigrate) to handle migrations,
which is based on gormigrate.
## Add a new migration
All migrations are stored in `pkg/migrations` and files should have the same name as their id.
Each migration should have a function to apply and roll it back, as well as a numeric id (the datetime)
and a more in-depth description of what the migration actually does.
To easily get a new id, run the following on any unix system:
```bash
date +%Y%m%d%H%M%S
```
New migrations should be added via the `init()` function to the `migrations` variable.
All migrations are sorted before being executed, since `init()` does not guarantee the order.
When you're adding a new struct, you also need to add it to the `models.GetTables()` function
to ensure it will be created on new installations.
### Example
```go
package migration
import (
"github.com/go-xorm/xorm"
"src.techknowlogick.com/xormigrate"
)
// Used for rollback
type teamMembersMigration20190328074430 struct {
Updated int64 `xorm:"updated"`
}
func (teamMembersMigration20190328074430) TableName() string {
return "team_members"
}
func init() {
migrations = append(migrations, &xormigrate.Migration{
ID: "20190328074430",
Description: "Remove updated from team_members",
Migrate: func(tx *xorm.Engine) error {
return dropTableColum(tx, "team_members", "updated")
},
Rollback: func(tx *xorm.Engine) error {
return tx.Sync2(teamMembersMigration20190328074430{})
},
})
}
```
You should always copy the changed parts of the struct you're changing when adding migraitons.

View file

@ -20,6 +20,7 @@ In general, this api repo has the following structure:
* `log`
* `mail`
* `metrics`
* `migration`
* `models`
* `red`
* `routes`
@ -71,6 +72,13 @@ This package handles all mail sending. To learn how to send a mail, see [sending
This package handles all metrics which are exposed to the prometheus endpoint.
To learn how it works and how to add new metrics, take a look at [how metrics work]({{< ref "../practical-instructions/metrics.md">}}).
### migration
This package handles all migrations.
All migrations are stored and executed here.
To learn more, take a look at the [migrations docs]({{< ref "../development/migrations.md">}}).
### models
This is where most of the magic happens.

13
go.mod
View file

@ -31,10 +31,10 @@ require (
github.com/go-openapi/swag v0.17.2 // indirect
github.com/go-redis/redis v6.14.2+incompatible
github.com/go-sql-driver/mysql v1.4.1
github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25
github.com/go-xorm/core v0.5.8
github.com/go-xorm/builder v0.3.2
github.com/go-xorm/core v0.6.0
github.com/go-xorm/tests v0.5.6 // indirect
github.com/go-xorm/xorm v0.0.0-20170930012613-29d4a0330a00
github.com/go-xorm/xorm v0.7.1
github.com/go-xorm/xorm-redis-cache v0.0.0-20180727005610-859b313566b2
github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc
github.com/imdario/mergo v0.3.6
@ -42,19 +42,19 @@ require (
github.com/jgautheron/goconst v0.0.0-20170703170152-9740945f5dcb
github.com/karalabe/xgo v0.0.0-20181007145344-72da7d1d3970
github.com/kisielk/gotool v1.0.0 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/labstack/echo v3.3.10+incompatible
github.com/labstack/gommon v0.2.8
github.com/mattn/go-colorable v0.1.1 // indirect
github.com/mattn/go-isatty v0.0.7 // indirect
github.com/mattn/go-oci8 v0.0.0-20181130072307-052f5d97b9b6 // indirect
github.com/mattn/go-runewidth v0.0.4 // indirect
github.com/mattn/go-sqlite3 v1.10.0
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/olekukonko/tablewriter v0.0.1
github.com/onsi/ginkgo v1.7.0 // indirect
github.com/onsi/gomega v1.4.3 // indirect
github.com/op/go-logging v0.0.0-20160315200505-970db520ece7
github.com/pkg/errors v0.8.0 // indirect
github.com/prometheus/client_golang v0.9.2
github.com/spf13/cobra v0.0.3
github.com/spf13/viper v1.2.0
@ -63,16 +63,15 @@ require (
github.com/swaggo/swag v1.4.1-0.20181210033626-0e12fd5eb026
github.com/urfave/cli v1.20.0 // indirect
github.com/valyala/fasttemplate v1.0.1 // indirect
github.com/ziutek/mymysql v1.5.4 // indirect
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3
golang.org/x/net v0.0.0-20181217023233-e147a9138326 // indirect
golang.org/x/sys v0.0.0-20190329044733-9eb1bfa1ce65 // indirect
golang.org/x/tools v0.0.0-20181026183834-f60e5f99f081 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
gopkg.in/testfixtures.v2 v2.5.3
gopkg.in/yaml.v2 v2.2.2 // indirect
honnef.co/go/tools v0.0.0-20190215041234-466a0476246c
src.techknowlogick.com/xormigrate v0.0.0-20190321151057-24497c23c09c
)

21
go.sum
View file

@ -29,6 +29,7 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLM
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cweill/gotests v1.5.2 h1:kKqmKmS2wCV3tuLnfpbiuN8OlkosQZTpCfiqmiuNAsA=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -56,16 +57,24 @@ github.com/go-openapi/swag v0.17.2 h1:K/ycE/XTUDFltNHSO32cGRUhrVGJD64o8WgAIZNyc3
github.com/go-openapi/swag v0.17.2/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg=
github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0=
github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25 h1:jUX9yw6+iKrs/WuysV2M6ap/ObK/07SE/a7I2uxitwM=
github.com/go-xorm/builder v0.0.0-20170519032130-c8871c857d25/go.mod h1:M+P3wv0K2C+ynucGDEqJCeOTc+6DcAtiiqU8GrCksXY=
github.com/go-xorm/builder v0.3.2 h1:pSsZQRRzJNapKEAEhigw3xLmiLPeAYv5GFlpYZ8+a5I=
github.com/go-xorm/builder v0.3.2/go.mod h1:v8mE3MFBgtL+RGFNfUnAMUqqfk/Y4W5KuwCFQIEpQLk=
github.com/go-xorm/core v0.5.8 h1:vQ0ghlVGnlnFmm4SpHY+xNnPlH810paMcw+Hwz9BCqE=
github.com/go-xorm/core v0.5.8/go.mod h1:d8FJ9Br8OGyQl12MCclmYBuBqqxsyeedpXciV5Myih8=
github.com/go-xorm/core v0.6.0 h1:tp6hX+ku4OD9khFZS8VGBDRY3kfVCtelPfmkgCyHxL0=
github.com/go-xorm/core v0.6.0/go.mod h1:d8FJ9Br8OGyQl12MCclmYBuBqqxsyeedpXciV5Myih8=
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM=
github.com/go-xorm/tests v0.5.6 h1:E4nmVkKfHQAm+i2/pmOJ5JUej6sORVcvwl6/LQybif4=
github.com/go-xorm/tests v0.5.6/go.mod h1:s8J/EnVBcXQR93dN7Jy6Dwlo92HUP5nTgKWF1wGeCDg=
github.com/go-xorm/xorm v0.0.0-20170930012613-29d4a0330a00 h1:jlA1XEj8QHl6my6FUkHwRCGu/J5hQ1zkW7RqULZ2XGc=
github.com/go-xorm/xorm v0.0.0-20170930012613-29d4a0330a00/go.mod h1:i7qRPD38xj/v75UV+a9pEzr5tfRaH2ndJfwt/fGbQhs=
github.com/go-xorm/xorm v0.7.1 h1:Kj7mfuqctPdX60zuxP6EoEut0f3E6K66H6hcoxiHUMc=
github.com/go-xorm/xorm v0.7.1/go.mod h1:EHS1htMQFptzMaIHKyzqpHGw6C9Rtug75nsq6DA9unI=
github.com/go-xorm/xorm-redis-cache v0.0.0-20180727005610-859b313566b2 h1:57QbyUkFcFjipHJQstYR5owRxsQzgD8/OAO/hr4yl/E=
github.com/go-xorm/xorm-redis-cache v0.0.0-20180727005610-859b313566b2/go.mod h1:xxK9FGkFXrau9/vGdDYSOyQfSgKXBV7iHXpQfNuv6B0=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
@ -82,6 +91,8 @@ github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28=
github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ=
github.com/jackc/pgx v3.2.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
github.com/jgautheron/goconst v0.0.0-20170703170152-9740945f5dcb h1:D5s1HIu80AcMGcqmk7fNIVptmAubVHHaj3v5Upex6Zs=
github.com/jgautheron/goconst v0.0.0-20170703170152-9740945f5dcb/go.mod h1:82TxjOpWQiPmywlbIaB2ZkqJoSYJdLGPgAJDvM3PbKc=
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
@ -119,6 +130,9 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-oci8 v0.0.0-20181115070430-6eefff3c767c/go.mod h1:/M9VLO+lUPmxvoOK2PfWRZ8mTtB4q1Hy9lEGijv9Nr8=
github.com/mattn/go-oci8 v0.0.0-20181130072307-052f5d97b9b6 h1:gheNi9lnffYyVyqQzJqY7lo+M3bCDVw5fLU/jSuCMhc=
github.com/mattn/go-oci8 v0.0.0-20181130072307-052f5d97b9b6/go.mod h1:/M9VLO+lUPmxvoOK2PfWRZ8mTtB4q1Hy9lEGijv9Nr8=
github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y=
github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
@ -129,6 +143,8 @@ github.com/mitchellh/mapstructure v1.0.0 h1:vVpGvMXJPqSDh2VYHF7gsfQj8Ncx+Xw5Y1KH
github.com/mitchellh/mapstructure v1.0.0/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/olekukonko/tablewriter v0.0.1 h1:b3iUnf1v+ppJiOfNX4yxxqfWKMQPZR5yoh8urCTFX88=
github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@ -150,6 +166,8 @@ github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 h1:PnBWHBf+6L0jO
github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nLJdBg+pBmGgkJlSaKC2KaQmTCk1XDtE=
github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.2.0 h1:HHl1DSRbEQN2i8tJmtS6ViPyHx35+p51amrdsiTCrkg=
@ -231,6 +249,7 @@ gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
gopkg.in/stretchr/testify.v1 v1.2.2/go.mod h1:QI5V/q6UbPmuhtm10CaFZxED9NreB8PnFYN9JcR6TxU=
gopkg.in/testfixtures.v2 v2.5.3 h1:P8gDACSLJGxutzBqbzvfiXYgmQ2s00LIr4uAvWBCPAg=
gopkg.in/testfixtures.v2 v2.5.3/go.mod h1:rGPtsOtPcZhs7AsHYf1WmufW1hEsM6DXdLrYz60nrQQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
@ -245,3 +264,5 @@ honnef.co/go/tools v0.0.0-20190128043916-71123fcbb8fe h1:/GZ/onp6W295MEgrIwtlbnx
honnef.co/go/tools v0.0.0-20190128043916-71123fcbb8fe/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190215041234-466a0476246c h1:z+UFwlQ7KVwdlQTE5JjvDvfZmyyAVrEiiwau20b7X8k=
honnef.co/go/tools v0.0.0-20190215041234-466a0476246c/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
src.techknowlogick.com/xormigrate v0.0.0-20190321151057-24497c23c09c h1:fTwL7EZ3ouk3xeiPiRBYEjSPWTREb9T57bjzpRBNOpQ=
src.techknowlogick.com/xormigrate v0.0.0-20190321151057-24497c23c09c/go.mod h1:B2NutmcRaDDw4EGe7DoCwyWCELA8W+KxXPhLtgqFUaU=

59
pkg/cmd/migrate.go Normal file
View file

@ -0,0 +1,59 @@
// Vikunja is a todo-list application to facilitate your life.
// Copyright 2019 Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package cmd
import (
"code.vikunja.io/api/pkg/migration"
"github.com/spf13/cobra"
)
func init() {
migrateCmd.AddCommand(migrateListCmd)
migrationRollbackCmd.Flags().StringVarP(&rollbackUntilFlag, "name", "n", "", "The id of the migration you want to roll back until.")
migrationRollbackCmd.MarkFlagRequired("name")
migrateCmd.AddCommand(migrationRollbackCmd)
rootCmd.AddCommand(migrateCmd)
}
// TODO: add args to run migrations up or down, until a certain point etc
// Rollback until
// list -> Essentially just show the table, maybe with an extra column if the migration did run or not
var migrateCmd = &cobra.Command{
Use: "migrate",
Short: "Run all database migrations which didn't already run.",
Run: func(cmd *cobra.Command, args []string) {
migration.Migrate(nil)
},
}
var migrateListCmd = &cobra.Command{
Use: "list",
Short: "Show a list with all database migrations.",
Run: func(cmd *cobra.Command, args []string) {
migration.ListMigrations()
},
}
var rollbackUntilFlag string
var migrationRollbackCmd = &cobra.Command{
Use: "rollback",
Short: "Roll migrations back until a certain point.",
Run: func(cmd *cobra.Command, args []string) {
migration.Rollback(rollbackUntilFlag)
},
}

View file

@ -19,6 +19,7 @@ package cmd
import (
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/mail"
"code.vikunja.io/api/pkg/migration"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/routes"
"code.vikunja.io/api/pkg/swagger"
@ -43,6 +44,9 @@ var webCmd = &cobra.Command{
// Set logger
log.InitLogger()
// Run the migrations
migration.Migrate(nil)
// Set Engine
err := models.SetEngine()
if err != nil {

78
pkg/db/db.go Normal file
View file

@ -0,0 +1,78 @@
// Vikunja is a todo-list application to facilitate your life.
// Copyright 2019 Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package db
import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/log"
"fmt"
"github.com/go-xorm/core"
"github.com/go-xorm/xorm"
"github.com/spf13/viper"
_ "github.com/go-sql-driver/mysql" // Because.
_ "github.com/mattn/go-sqlite3" // Because.
)
// CreateDBEngine initializes a db engine from the config
func CreateDBEngine() (engine *xorm.Engine, err error) {
// If the database type is not set, this likely means we need to initialize the config first
if viper.GetString("database.type") == "" {
config.InitConfig()
}
// Use Mysql if set
if viper.GetString("database.type") == "mysql" {
engine, err = initMysqlEngine()
if err != nil {
return
}
} else {
// Otherwise use sqlite
engine, err = initSqliteEngine()
if err != nil {
return
}
}
engine.SetMapper(core.GonicMapper{})
engine.ShowSQL(viper.GetString("log.database") != "off")
engine.SetLogger(xorm.NewSimpleLogger(log.GetLogWriter("database")))
return
}
func initMysqlEngine() (engine *xorm.Engine, err error) {
connStr := fmt.Sprintf(
"%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true",
viper.GetString("database.user"),
viper.GetString("database.password"),
viper.GetString("database.host"),
viper.GetString("database.database"))
engine, err = xorm.NewEngine("mysql", connStr)
engine.SetMaxOpenConns(viper.GetInt("database.openconnections"))
return
}
func initSqliteEngine() (engine *xorm.Engine, err error) {
path := viper.GetString("database.path")
if path == "" {
path = "./db.db"
}
return xorm.NewEngine("sqlite3", path)
}

View file

@ -0,0 +1,44 @@
// Vikunja is a todo-list application to facilitate your life.
// Copyright 2019 Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package migration
import (
"github.com/go-xorm/xorm"
"src.techknowlogick.com/xormigrate"
)
// Used for rollback
type tasksReminderDateMigration20190324205606 struct {
ReminderUnix int64 `xorm:"int(11) INDEX"`
}
func (tasksReminderDateMigration20190324205606) TableName() string {
return "tasks"
}
func init() {
migrations = append(migrations, &xormigrate.Migration{
ID: "20190324205606",
Description: "Remove reminders_unix from tasks",
Migrate: func(tx *xorm.Engine) error {
return dropTableColum(tx, "tasks", "reminders_unix")
},
Rollback: func(tx *xorm.Engine) error {
return tx.Sync2(tasksReminderDateMigration20190324205606{})
},
})
}

View file

@ -0,0 +1,44 @@
// Vikunja is a todo-list application to facilitate your life.
// Copyright 2019 Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package migration
import (
"github.com/go-xorm/xorm"
"src.techknowlogick.com/xormigrate"
)
// Used for rollback
type teamMembersMigration20190328074430 struct {
Updated int64 `xorm:"updated"`
}
func (teamMembersMigration20190328074430) TableName() string {
return "team_members"
}
func init() {
migrations = append(migrations, &xormigrate.Migration{
ID: "20190328074430",
Description: "Remove updated from team_members",
Migrate: func(tx *xorm.Engine) error {
return dropTableColum(tx, "team_members", "updated")
},
Rollback: func(tx *xorm.Engine) error {
return tx.Sync2(teamMembersMigration20190328074430{})
},
})
}

124
pkg/migration/migration.go Normal file
View file

@ -0,0 +1,124 @@
// Vikunja is a todo-list application to facilitate your life.
// Copyright 2019 Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package migration
import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"github.com/go-xorm/xorm"
"github.com/olekukonko/tablewriter"
"github.com/spf13/viper"
"os"
"sort"
"src.techknowlogick.com/xormigrate"
)
// You can get the id string for new migrations by running `date +%Y%m%d%H%M%S` on a unix system.
var migrations []*xormigrate.Migration
// A helper function because we need a migration in various places which we can't really solve with an init() function.
func initMigration(x *xorm.Engine) *xormigrate.Xormigrate {
// Get our own xorm engine if we don't have one
if x == nil {
var err error
x, err = db.CreateDBEngine()
if err != nil {
log.Log.Criticalf("Could not connect to db: %v", err.Error())
return nil
}
}
// Because init() does not guarantee the order in which these are added to the slice,
// we need to sort them to ensure that they are in order
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].ID < migrations[j].ID
})
m := xormigrate.New(x, migrations)
m.NewLogger(log.GetLogWriter("database"))
m.InitSchema(initSchema)
return m
}
// Migrate runs all migrations
func Migrate(x *xorm.Engine) {
m := initMigration(x)
err := m.Migrate()
if err != nil {
log.Log.Fatalf("Migration failed: %v", err)
}
log.Log.Info("Ran all migrations successfully.")
}
// ListMigrations pretty-prints a list with all migrations.
func ListMigrations() {
x, err := db.CreateDBEngine()
if err != nil {
log.Log.Fatalf("Could not connect to db: %v", err.Error())
}
ms := []*xormigrate.Migration{}
err = x.Find(&ms)
if err != nil {
log.Log.Fatalf("Error getting migration table: %v", err.Error())
}
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"ID", "Description"})
table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderColor(tablewriter.Colors{tablewriter.Bold, tablewriter.BgGreenColor},
tablewriter.Colors{tablewriter.Bold, tablewriter.BgGreenColor})
for _, m := range ms {
table.Append([]string{m.ID, m.Description})
}
table.Render()
}
// Rollback rolls back all migrations until a certain point.
func Rollback(migrationID string) {
m := initMigration(nil)
err := m.RollbackTo(migrationID)
if err != nil {
log.Log.Fatalf("Could not rollback: %v", err)
}
log.Log.Info("Rolled back successfully.")
}
// Deletes a column from a table. All arguments are strings, to let them be standalone and not depending on any struct.
func dropTableColum(x *xorm.Engine, tableName, col string) error {
switch viper.GetString("database.type") {
case "sqlite":
log.Log.Warning("Unable to drop columns in SQLite")
case "mysql":
_, err := x.Exec("ALTER TABLE " + tableName + " DROP COLUMN " + col)
if err != nil {
return err
}
default:
log.Log.Fatal("Unknown db.")
}
return nil
}
func initSchema(tx *xorm.Engine) error {
return tx.Sync2(
models.GetTables()...,
)
}

View file

@ -1,3 +1,4 @@
- id: 1
task_id: 1
label_id: 4
label_id: 4
created: 0

View file

@ -1,12 +1,20 @@
- id: 1
title: 'Label #1'
created_by_id: 1
updated: 0
created: 0
- id: 2
title: 'Label #2'
created_by_id: 1
updated: 0
created: 0
- id: 3
title: 'Label #3 - other user'
created_by_id: 2
updated: 0
created: 0
- id: 4
title: 'Label #4 - visible via other task'
created_by_id: 2
created_by_id: 2
updated: 0
created: 0

View file

@ -4,27 +4,37 @@
description: Lorem Ipsum
owner_id: 1
namespace_id: 1
updated: 0
created: 0
-
id: 2
title: Test2
description: Lorem Ipsum
owner_id: 3
namespace_id: 1
updated: 0
created: 0
-
id: 3
title: Test3
description: Lorem Ipsum
owner_id: 3
namespace_id: 2
updated: 0
created: 0
-
id: 4
title: Test4
description: Lorem Ipsum
owner_id: 3
namespace_id: 3
updated: 0
created: 0
-
id: 5
title: Test5
description: Lorem Ipsum
owner_id: 5
namespace_id: 5
namespace_id: 5
updated: 0
created: 0

View file

@ -3,13 +3,19 @@
name: testnamespace
description: Lorem Ipsum
owner_id: 1
updated: 0
created: 0
-
id: 2
name: testnamespace2
description: Lorem Ipsum
owner_id: 2
updated: 0
created: 0
-
id: 3
name: testnamespace3
description: Lorem Ipsum
owner_id: 3
owner_id: 3
updated: 0
created: 0

View file

@ -1,6 +1,10 @@
- id: 1
team_id: 1
list_id: 3
updated: 0
created: 0
- id: 2
team_id: 2
list_id: 3
updated: 0
created: 0

View file

@ -2,6 +2,8 @@
team_id: 1
user_id: 1
admin: true
created: 0
-
team_id: 1
user_id: 2
created: 0

View file

@ -1,6 +1,10 @@
- id: 1
team_id: 1
namespace_id: 3
updated: 0
created: 0
- id: 2
team_id: 2
namespace_id: 3
updated: 0
created: 0

View file

@ -3,26 +3,36 @@
username: 'user1'
password: '1234'
email: 'user1@example.com'
updated: 0
created: 0
-
id: 2
username: 'user2'
password: '1234'
email: 'user2@example.com'
updated: 0
created: 0
-
id: 3
username: 'user3'
password: '1234'
email: 'user3@example.com'
updated: 0
created: 0
-
id: 4
username: 'user4'
password: '1234'
email: 'user4@example.com'
email_confirm_token: tiepiQueed8ahc7zeeFe1eveiy4Ein8osooxegiephauph2Ael
updated: 0
created: 0
-
id: 5
username: 'user5'
password: '1234'
email: 'user4@example.com'
email_confirm_token: tiepiQueed8ahc7zeeFe1eveiy4Ein8osooxegiephauph2Ael
is_active: false
is_active: false
updated: 0
created: 0

View file

@ -1,6 +1,10 @@
- id: 1
user_id: 1
list_id: 3
updated: 0
created: 0
- id: 2
user_id: 2
list_id: 3
updated: 0
created: 0

View file

@ -1,6 +1,10 @@
- id: 1
user_id: 1
namespace_id: 3
updated: 0
created: 0
- id: 2
user_id: 2
namespace_id: 3
updated: 0
created: 0

View file

@ -27,18 +27,18 @@ type Label struct {
// The title of the lable. You'll see this one on tasks associated with it.
Title string `xorm:"varchar(250) not null" json:"title" valid:"runelength(3|250)" minLength:"3" maxLength:"250"`
// The label description.
Description string `xorm:"varchar(250)" json:"description" valid:"runelength(0|250)" maxLength:"250"`
Description string `xorm:"varchar(250) null" json:"description" valid:"runelength(0|250)" maxLength:"250"`
// The color this label has
HexColor string `xorm:"varchar(6)" json:"hex_color" valid:"runelength(0|6)" maxLength:"6"`
HexColor string `xorm:"varchar(6) null" json:"hex_color" valid:"runelength(0|6)" maxLength:"6"`
CreatedByID int64 `xorm:"int(11) not null" json:"-"`
// The user who created this label
CreatedBy *User `xorm:"-" json:"created_by"`
// A unix timestamp when this label was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this label was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -29,7 +29,7 @@ type LabelTask struct {
// The label id you want to associate with a task.
LabelID int64 `xorm:"int(11) INDEX not null" json:"label_id" param:"label"`
// A unix timestamp when this task was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -25,11 +25,11 @@ type List struct {
// The unique, numeric id of this list.
ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"list"`
// The title of the list. You'll see this in the namespace overview.
Title string `xorm:"varchar(250)" json:"title" valid:"required,runelength(3|250)" minLength:"3" maxLength:"250"`
Title string `xorm:"varchar(250) not null" json:"title" valid:"required,runelength(3|250)" minLength:"3" maxLength:"250"`
// The description of the list.
Description string `xorm:"varchar(1000)" json:"description" valid:"runelength(0|1000)" maxLength:"1000"`
OwnerID int64 `xorm:"int(11) INDEX" json:"-"`
NamespaceID int64 `xorm:"int(11) INDEX" json:"-" param:"namespace"`
Description string `xorm:"varchar(1000) null" json:"description" valid:"runelength(0|1000)" maxLength:"1000"`
OwnerID int64 `xorm:"int(11) INDEX not null" json:"-"`
NamespaceID int64 `xorm:"int(11) INDEX not null" json:"-" param:"namespace"`
// The user who created this list.
Owner User `xorm:"-" json:"owner" valid:"-"`
@ -37,9 +37,9 @@ type List struct {
Tasks []*ListTask `xorm:"-" json:"tasks"`
// A unix timestamp when this list was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this list was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -25,7 +25,7 @@ type ListTaskAssginee struct {
ID int64 `xorm:"int(11) autoincr not null unique pk" json:"-"`
TaskID int64 `xorm:"int(11) INDEX not null" json:"-" param:"listtask"`
UserID int64 `xorm:"int(11) INDEX not null" json:"user_id" param:"user"`
Created int64 `xorm:"created"`
Created int64 `xorm:"created not null"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -26,28 +26,28 @@ type ListTask struct {
// The unique, numeric id of this task.
ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"listtask"`
// The task text. This is what you'll see in the list.
Text string `xorm:"varchar(250)" json:"text" valid:"runelength(3|250)" minLength:"3" maxLength:"250"`
Text string `xorm:"varchar(250) not null" json:"text" valid:"runelength(3|250)" minLength:"3" maxLength:"250"`
// The task description.
Description string `xorm:"varchar(250)" json:"description" valid:"runelength(0|250)" maxLength:"250"`
// Whether a task is done or not.
Done bool `xorm:"INDEX" json:"done"`
Done bool `xorm:"INDEX null" json:"done"`
// A unix timestamp when the task is due.
DueDateUnix int64 `xorm:"int(11) INDEX" json:"dueDate"`
DueDateUnix int64 `xorm:"int(11) INDEX null" json:"dueDate"`
// An array of unix timestamps when the user wants to be reminded of the task.
RemindersUnix []int64 `xorm:"JSON TEXT" json:"reminderDates"`
CreatedByID int64 `xorm:"int(11)" json:"-"` // ID of the user who put that task on the list
RemindersUnix []int64 `xorm:"JSON TEXT null" json:"reminderDates"`
CreatedByID int64 `xorm:"int(11) not null" json:"-"` // ID of the user who put that task on the list
// The list this task belongs to.
ListID int64 `xorm:"int(11) INDEX" json:"-" param:"list"`
ListID int64 `xorm:"int(11) INDEX not null" json:"-" param:"list"`
// An amount in seconds this task repeats itself. If this is set, when marking the task as done, it will mark itself as "undone" and then increase all remindes and the due date by its amount.
RepeatAfter int64 `xorm:"int(11) INDEX" json:"repeatAfter"`
RepeatAfter int64 `xorm:"int(11) INDEX null" json:"repeatAfter"`
// If the task is a subtask, this is the id of its parent.
ParentTaskID int64 `xorm:"int(11) INDEX" json:"parentTaskID"`
ParentTaskID int64 `xorm:"int(11) INDEX null" json:"parentTaskID"`
// The task priority. Can be anything you want, it is possible to sort by this later.
Priority int64 `xorm:"int(11)" json:"priority"`
Priority int64 `xorm:"int(11) null" json:"priority"`
// When this task starts.
StartDateUnix int64 `xorm:"int(11) INDEX" json:"startDate" query:"-"`
StartDateUnix int64 `xorm:"int(11) INDEX null" json:"startDate" query:"-"`
// When this task ends.
EndDateUnix int64 `xorm:"int(11) INDEX" json:"endDate" query:"-"`
EndDateUnix int64 `xorm:"int(11) INDEX null" json:"endDate" query:"-"`
// An array of users who are assigned to this task
Assignees []*User `xorm:"-" json:"assignees"`
// An array of labels which are associated with this task.
@ -61,9 +61,9 @@ type ListTask struct {
Subtasks []*ListTask `xorm:"-" json:"subtasks"`
// A unix timestamp when this task was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this task was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
// The user who initially created the task.
CreatedBy User `xorm:"-" json:"createdBy" valid:"-"`

View file

@ -27,12 +27,12 @@ type ListUser struct {
// The list id.
ListID int64 `xorm:"int(11) not null INDEX" json:"-" param:"list"`
// The right this user has. 0 = Read only, 1 = Read & Write, 2 = Admin. See the docs for more details.
Right Right `xorm:"int(11) INDEX" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
Right Right `xorm:"int(11) INDEX null" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
// A unix timestamp when this relation was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this relation was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -17,11 +17,10 @@
package models
import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"encoding/gob"
"fmt"
_ "github.com/go-sql-driver/mysql" // Because.
"github.com/go-xorm/core"
"github.com/go-xorm/xorm"
xrc "github.com/go-xorm/xorm-redis-cache"
_ "github.com/mattn/go-sqlite3" // Because.
@ -30,51 +29,11 @@ import (
var (
x *xorm.Engine
tables []interface{}
tablesWithPointer []interface{}
)
func getEngine() (*xorm.Engine, error) {
// Use Mysql if set
if viper.GetString("database.type") == "mysql" {
connStr := fmt.Sprintf(
"%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true",
viper.GetString("database.user"),
viper.GetString("database.password"),
viper.GetString("database.host"),
viper.GetString("database.database"))
e, err := xorm.NewEngine("mysql", connStr)
e.SetMaxOpenConns(viper.GetInt("database.openconnections"))
return e, err
}
// Otherwise use sqlite
path := viper.GetString("database.path")
if path == "" {
path = "./db.db"
}
return xorm.NewEngine("sqlite3", path)
}
func init() {
tables = append(tables,
new(User),
new(List),
new(ListTask),
new(Team),
new(TeamMember),
new(TeamList),
new(TeamNamespace),
new(Namespace),
new(ListUser),
new(NamespaceUser),
new(ListTaskAssginee),
new(Label),
new(LabelTask),
)
tablesWithPointer = append(tables,
// GetTables returns all structs which are also a table.
func GetTables() []interface{} {
return []interface{}{
&User{},
&List{},
&ListTask{},
@ -88,17 +47,19 @@ func init() {
&ListTaskAssginee{},
&Label{},
&LabelTask{},
)
}
}
// SetEngine sets the xorm.Engine
func SetEngine() (err error) {
x, err = getEngine()
x, err = db.CreateDBEngine()
if err != nil {
return fmt.Errorf("failed to connect to database: %v", err)
log.Log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache
// We have to initialize the cache here to avoid import cycles
if viper.GetBool("cache.enabled") {
switch viper.GetString("cache.type") {
case "memory":
@ -107,23 +68,12 @@ func SetEngine() (err error) {
case "redis":
cacher := xrc.NewRedisCacher(viper.GetString("redis.host"), viper.GetString("redis.password"), xrc.DEFAULT_EXPIRATION, x.Logger())
x.SetDefaultCacher(cacher)
gob.Register(tables)
gob.Register(tablesWithPointer) // Need to register tables with pointer as well...
gob.Register(GetTables())
default:
log.Log.Info("Did not find a valid cache type. Caching disabled. Please refer to the docs for poosible cache types.")
}
}
x.SetMapper(core.GonicMapper{})
// Sync dat shit
if err = x.StoreEngine("InnoDB").Sync2(tables...); err != nil {
return fmt.Errorf("sync database struct error: %v", err)
}
x.ShowSQL(viper.GetString("log.database") != "off")
x.SetLogger(xorm.NewSimpleLogger(log.GetLogWriter("database")))
return nil
}

View file

@ -26,18 +26,18 @@ type Namespace struct {
// The unique, numeric id of this namespace.
ID int64 `xorm:"int(11) autoincr not null unique pk" json:"id" param:"namespace"`
// The name of this namespace.
Name string `xorm:"varchar(250)" json:"name" valid:"required,runelength(5|250)" minLength:"5" maxLength:"250"`
Name string `xorm:"varchar(250) not null" json:"name" valid:"required,runelength(5|250)" minLength:"5" maxLength:"250"`
// The description of the namespace
Description string `xorm:"varchar(1000)" json:"description" valid:"runelength(0|250)" maxLength:"250"`
Description string `xorm:"varchar(1000) null" json:"description" valid:"runelength(0|250)" maxLength:"250"`
OwnerID int64 `xorm:"int(11) not null INDEX" json:"-"`
// The user who owns this namespace
Owner User `xorm:"-" json:"owner" valid:"-"`
// A unix timestamp when this namespace was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this namespace was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`
@ -189,7 +189,7 @@ func (n *Namespace) ReadAll(search string, a web.Auth, page int) (interface{}, e
// Get all lists
lists := []*List{}
err = x.Table(&lists).
err = x.
In("namespace_id", namespaceids).
Find(&lists)
if err != nil {

View file

@ -27,12 +27,12 @@ type NamespaceUser struct {
// The namespace id
NamespaceID int64 `xorm:"int(11) not null INDEX" json:"-" param:"namespace"`
// The right this user has. 0 = Read only, 1 = Read & Write, 2 = Admin. See the docs for more details.
Right Right `xorm:"int(11) INDEX" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
Right Right `xorm:"int(11) INDEX null" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
// A unix timestamp when this relation was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this relation was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -27,12 +27,12 @@ type TeamList struct {
// The list id.
ListID int64 `xorm:"int(11) not null INDEX" json:"-" param:"list"`
// The right this team has. 0 = Read only, 1 = Read & Write, 2 = Admin. See the docs for more details.
Right Right `xorm:"int(11) INDEX" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
Right Right `xorm:"int(11) INDEX null" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
// A unix timestamp when this relation was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this relation was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -27,12 +27,12 @@ type TeamNamespace struct {
// The namespace id.
NamespaceID int64 `xorm:"int(11) not null INDEX" json:"-" param:"namespace"`
// The right this team has. 0 = Read only, 1 = Read & Write, 2 = Admin. See the docs for more details.
Right Right `xorm:"int(11) INDEX" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
Right Right `xorm:"int(11) INDEX null" json:"right" valid:"length(0|2)" maximum:"2" default:"0"`
// A unix timestamp when this relation was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this relation was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -69,10 +69,10 @@ type TeamMember struct {
// The id of the member.
UserID int64 `xorm:"int(11) not null INDEX" json:"userID" param:"user"`
// Whether or not the member is an admin of the team. See the docs for more about what a team admin can do
Admin bool `xorm:"tinyint(1) INDEX" json:"admin"`
Admin bool `xorm:"tinyint(1) INDEX null" json:"admin"`
// A unix timestamp when this relation was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
web.CRUDable `xorm:"-" json:"-"`
web.Rights `xorm:"-" json:"-"`

View file

@ -17,6 +17,7 @@
package models
import (
"code.vikunja.io/api/pkg/config"
_ "code.vikunja.io/api/pkg/config" // To trigger its init() which initializes the config
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/mail"
@ -51,12 +52,23 @@ func MainTest(m *testing.M, pathToRoot string) {
func createTestEngine(fixturesDir string) error {
var err error
var fixturesHelper testfixtures.Helper = &testfixtures.SQLite{}
// If set, use the config we provided instead of normal
if os.Getenv("VIKUNJA_TESTS_USE_CONFIG") == "1" {
config.InitConfig()
err = SetEngine()
if err != nil {
return err
}
err = x.Sync2(GetTables()...)
if err != nil {
return err
}
if viper.GetString("database.type") == "mysql" {
fixturesHelper = &testfixtures.MySQL{}
}
} else {
x, err = xorm.NewEngine("sqlite3", "file::memory:?cache=shared")
if err != nil {
@ -66,7 +78,7 @@ func createTestEngine(fixturesDir string) error {
x.SetMapper(core.GonicMapper{})
// Sync dat shit
if err := x.StoreEngine("InnoDB").Sync2(tables...); err != nil {
if err := x.Sync2(GetTables()...); err != nil {
return fmt.Errorf("sync database struct error: %v", err)
}
@ -76,13 +88,6 @@ func createTestEngine(fixturesDir string) error {
}
}
var fixturesHelper testfixtures.Helper
if viper.GetString("database.type") == "mysql" {
fixturesHelper = &testfixtures.MySQL{}
} else {
fixturesHelper = &testfixtures.SQLite{}
}
return InitFixtures(fixturesHelper, fixturesDir)
}

View file

@ -44,16 +44,16 @@ type User struct {
Username string `xorm:"varchar(250) not null unique" json:"username" valid:"length(3|250)" minLength:"3" maxLength:"250"`
Password string `xorm:"varchar(250) not null" json:"-"`
// The user's email address.
Email string `xorm:"varchar(250)" json:"email" valid:"email,length(0|250)" maxLength:"250"`
IsActive bool `json:"-"`
Email string `xorm:"varchar(250) null" json:"email" valid:"email,length(0|250)" maxLength:"250"`
IsActive bool `xorm:"null" json:"-"`
PasswordResetToken string `xorm:"varchar(450)" json:"-"`
EmailConfirmToken string `xorm:"varchar(450)" json:"-"`
PasswordResetToken string `xorm:"varchar(450) null" json:"-"`
EmailConfirmToken string `xorm:"varchar(450) null" json:"-"`
// A unix timestamp when this task was created. You cannot change this value.
Created int64 `xorm:"created" json:"created"`
Created int64 `xorm:"created not null" json:"created"`
// A unix timestamp when this task was last updated. You cannot change this value.
Updated int64 `xorm:"updated" json:"updated"`
Updated int64 `xorm:"updated not null" json:"updated"`
web.Auth `xorm:"-" json:"-"`
}

37
vendor/github.com/go-xorm/builder/.drone.yml generated vendored Normal file
View file

@ -0,0 +1,37 @@
workspace:
base: /go
path: src/github.com/go-xorm/builder
clone:
git:
image: plugins/git:next
depth: 50
tags: true
matrix:
GO_VERSION:
- 1.8
- 1.9
- 1.10
- 1.11
pipeline:
test:
image: golang:${GO_VERSION}
commands:
- go get -u github.com/golang/lint/golint
- go get -u github.com/stretchr/testify/assert
- go get -u github.com/go-xorm/sqlfiddle
- golint ./...
- go test -v -race -coverprofile=coverage.txt -covermode=atomic
when:
event: [ push, tag, pull_request ]
codecov:
image: robertstettner/drone-codecov
group: build
secrets: [ codecov_token ]
files:
- coverage.txt
when:
event: [ push, pull_request ]

View file

@ -1,26 +1,47 @@
# SQL builder
[![CircleCI](https://circleci.com/gh/go-xorm/builder/tree/master.svg?style=svg)](https://circleci.com/gh/go-xorm/builder/tree/master)
[![GitCI.cn](https://gitci.cn/api/badges/go-xorm/builder/status.svg)](https://gitci.cn/go-xorm/builder) [![codecov](https://codecov.io/gh/go-xorm/builder/branch/master/graph/badge.svg)](https://codecov.io/gh/go-xorm/builder)
[![](https://goreportcard.com/badge/github.com/go-xorm/builder)](https://goreportcard.com/report/github.com/go-xorm/builder)
Package builder is a lightweight and fast SQL builder for Go and XORM.
Make sure you have installed Go 1.1+ and then:
Make sure you have installed Go 1.8+ and then:
go get github.com/go-xorm/builder
# Insert
```Go
sql, args, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
// INSERT INTO table1 SELECT * FROM table2
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()
// INSERT INTO table1 (a, b) SELECT b, c FROM table2
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
```
# Select
```Go
// Simple Query
sql, args, err := Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
// With join
sql, args, err = Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})).
RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL()
// From sub query
sql, args, err := Select("sub.id").From(Select("c").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
// From union query
sql, args, err = Select("sub.id").From(
Select("id").From("table1").Where(Eq{"a": 1}).Union("all", Select("id").From("table1").Where(Eq{"a": 2})),"sub").
Where(Eq{"b": 1}).ToSQL()
// With order by
sql, args, err = Select("a", "b", "c").From("table1").Where(Eq{"f1": "v1", "f2": "v2"}).
OrderBy("a ASC").ToSQL()
// With limit.
// Be careful! You should set up specific dialect for builder before performing a query with LIMIT
sql, args, err = Dialect(MYSQL).Select("a", "b", "c").From("table1").OrderBy("a ASC").
Limit(5, 10).ToSQL()
```
# Update
@ -35,6 +56,16 @@ sql, args, err := Update(Eq{"a": 2}).From("table1").Where(Eq{"a": 1}).ToSQL()
sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL()
```
# Union
```Go
sql, args, err := Select("*").From("a").Where(Eq{"status": "1"}).
Union("all", Select("*").From("a").Where(Eq{"status": "2"})).
Union("distinct", Select("*").From("a").Where(Eq{"status": "3"})).
Union("", Select("*").From("a").Where(Eq{"status": "4"})).
ToSQL()
```
# Conditions
* `Eq` is a redefine of a map, you can give one or more conditions to `Eq`

View file

@ -4,6 +4,12 @@
package builder
import (
sql2 "database/sql"
"fmt"
"sort"
)
type optype byte
const (
@ -12,6 +18,15 @@ const (
insertType // insert
updateType // update
deleteType // delete
unionType // union
)
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
)
type join struct {
@ -20,60 +35,115 @@ type join struct {
joinCond Cond
}
type union struct {
unionType string
builder *Builder
}
type limit struct {
limitN int
offset int
}
// Builder describes a SQL statement
type Builder struct {
optype
tableName string
cond Cond
selects []string
joins []join
inserts Eq
updates []Eq
dialect string
isNested bool
into string
from string
subQuery *Builder
cond Cond
selects []string
joins []join
unions []union
limitation *limit
insertCols []string
insertVals []interface{}
updates []Eq
orderBy string
groupBy string
having string
}
// Select creates a select Builder
func Select(cols ...string) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Select(cols...)
// Dialect sets the db dialect of Builder.
func Dialect(dialect string) *Builder {
builder := &Builder{cond: NewCond(), dialect: dialect}
return builder
}
// Insert creates an insert Builder
func Insert(eq Eq) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Insert(eq)
// MySQL is shortcut of Dialect(MySQL)
func MySQL() *Builder {
return Dialect(MYSQL)
}
// Update creates an update Builder
func Update(updates ...Eq) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Update(updates...)
// MsSQL is shortcut of Dialect(MsSQL)
func MsSQL() *Builder {
return Dialect(MSSQL)
}
// Delete creates a delete Builder
func Delete(conds ...Cond) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Delete(conds...)
// Oracle is shortcut of Dialect(Oracle)
func Oracle() *Builder {
return Dialect(ORACLE)
}
// Postgres is shortcut of Dialect(Postgres)
func Postgres() *Builder {
return Dialect(POSTGRES)
}
// SQLite is shortcut of Dialect(SQLITE)
func SQLite() *Builder {
return Dialect(SQLITE)
}
// Where sets where SQL
func (b *Builder) Where(cond Cond) *Builder {
b.cond = b.cond.And(cond)
if b.cond.IsValid() {
b.cond = b.cond.And(cond)
} else {
b.cond = cond
}
return b
}
// From sets the table name
func (b *Builder) From(tableName string) *Builder {
b.tableName = tableName
// From sets from subject(can be a table name in string or a builder pointer) and its alias
func (b *Builder) From(subject interface{}, alias ...string) *Builder {
switch subject.(type) {
case *Builder:
b.subQuery = subject.(*Builder)
if len(alias) > 0 {
b.from = alias[0]
} else {
b.isNested = true
}
case string:
b.from = subject.(string)
if len(alias) > 0 {
b.from = b.from + " " + alias[0]
}
}
return b
}
// TableName returns the table name
func (b *Builder) TableName() string {
if b.optype == insertType {
return b.into
}
return b.from
}
// Into sets insert table name
func (b *Builder) Into(tableName string) *Builder {
b.tableName = tableName
b.into = tableName
return b
}
// Join sets join table and contions
// Join sets join table and conditions
func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder {
switch joinCond.(type) {
case Cond:
@ -85,6 +155,50 @@ func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builde
return b
}
// Union sets union conditions
func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder {
var builder *Builder
if b.optype != unionType {
builder = &Builder{cond: NewCond()}
builder.optype = unionType
builder.dialect = b.dialect
builder.selects = b.selects
currentUnions := b.unions
// erase sub unions (actually append to new Builder.unions)
b.unions = nil
for e := range currentUnions {
currentUnions[e].builder.dialect = b.dialect
}
builder.unions = append(append(builder.unions, union{"", b}), currentUnions...)
} else {
builder = b
}
if unionCond != nil {
if unionCond.dialect == "" && builder.dialect != "" {
unionCond.dialect = builder.dialect
}
builder.unions = append(builder.unions, union{unionTp, unionCond})
}
return builder
}
// Limit sets limitN condition
func (b *Builder) Limit(limitN int, offset ...int) *Builder {
b.limitation = &limit{limitN: limitN}
if len(offset) > 0 {
b.limitation.offset = offset[0]
}
return b
}
// InnerJoin sets inner join
func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder {
return b.Join("INNER", joinTable, joinCond)
@ -113,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
// Select sets select SQL
func (b *Builder) Select(cols ...string) *Builder {
b.selects = cols
b.optype = selectType
if b.optype == condType {
b.optype = selectType
}
return b
}
@ -130,15 +246,52 @@ func (b *Builder) Or(cond Cond) *Builder {
}
// Insert sets insert SQL
func (b *Builder) Insert(eq Eq) *Builder {
b.inserts = eq
func (b *Builder) Insert(eq ...interface{}) *Builder {
if len(eq) > 0 {
var paramType = -1
for _, e := range eq {
switch t := e.(type) {
case Eq:
if paramType == -1 {
paramType = 0
}
if paramType != 0 {
break
}
for k, v := range t {
b.insertCols = append(b.insertCols, k)
b.insertVals = append(b.insertVals, v)
}
case string:
if paramType == -1 {
paramType = 1
}
if paramType != 1 {
break
}
b.insertCols = append(b.insertCols, t)
}
}
}
if len(b.insertCols) == len(b.insertVals) {
sort.Slice(b.insertVals, func(i, j int) bool {
return b.insertCols[i] < b.insertCols[j]
})
sort.Strings(b.insertCols)
}
b.optype = insertType
return b
}
// Update sets update SQL
func (b *Builder) Update(updates ...Eq) *Builder {
b.updates = updates
b.updates = make([]Eq, 0, len(updates))
for _, update := range updates {
if update.IsValid() {
b.updates = append(b.updates, update)
}
}
b.optype = updateType
return b
}
@ -153,8 +306,8 @@ func (b *Builder) Delete(conds ...Cond) *Builder {
// WriteTo implements Writer interface
func (b *Builder) WriteTo(w Writer) error {
switch b.optype {
case condType:
return b.cond.WriteTo(w)
/*case condType:
return b.cond.WriteTo(w)*/
case selectType:
return b.selectWriteTo(w)
case insertType:
@ -163,6 +316,8 @@ func (b *Builder) WriteTo(w Writer) error {
return b.updateWriteTo(w)
case deleteType:
return b.deleteWriteTo(w)
case unionType:
return b.unionWriteTo(w)
}
return ErrNotSupportType
@ -175,16 +330,48 @@ func (b *Builder) ToSQL() (string, []interface{}, error) {
return "", nil, err
}
return w.writer.String(), w.args, nil
// in case of sql.NamedArg in args
for e := range w.args {
if namedArg, ok := w.args[e].(sql2.NamedArg); ok {
w.args[e] = namedArg.Value
}
}
var sql = w.writer.String()
var err error
switch b.dialect {
case ORACLE, MSSQL:
// This is for compatibility with different sql drivers
for e := range w.args {
w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e])
}
var prefix string
if b.dialect == ORACLE {
prefix = ":p"
} else {
prefix = "@p"
}
if sql, err = ConvertPlaceholder(sql, prefix); err != nil {
return "", nil, err
}
case POSTGRES:
if sql, err = ConvertPlaceholder(sql, "$"); err != nil {
return "", nil, err
}
}
return sql, w.args, nil
}
// ToSQL convert a builder or condtions to SQL and args
func ToSQL(cond interface{}) (string, []interface{}, error) {
switch cond.(type) {
case Cond:
return condToSQL(cond.(Cond))
case *Builder:
return cond.(*Builder).ToSQL()
// ToBoundSQL
func (b *Builder) ToBoundSQL() (string, error) {
w := NewWriter()
if err := b.WriteTo(w); err != nil {
return "", err
}
return "", nil, ErrNotSupportType
return ConvertToBoundSQL(w.writer.String(), w.args)
}

View file

@ -5,16 +5,21 @@
package builder
import (
"errors"
"fmt"
)
// Delete creates a delete Builder
func Delete(conds ...Cond) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Delete(conds...)
}
func (b *Builder) deleteWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
return errors.New("no table indicated")
if len(b.from) <= 0 {
return ErrNoTableName
}
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
return err
}

View file

@ -6,37 +6,63 @@ package builder
import (
"bytes"
"errors"
"fmt"
)
func (b *Builder) insertWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
return errors.New("no table indicated")
}
if len(b.inserts) <= 0 {
return errors.New("no column to be update")
// Insert creates an insert Builder
func Insert(eq ...interface{}) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Insert(eq...)
}
func (b *Builder) insertSelectWriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
if len(b.insertCols) > 0 {
fmt.Fprintf(w, "(")
for _, col := range b.insertCols {
fmt.Fprintf(w, col)
}
fmt.Fprintf(w, ") ")
}
return b.selectWriteTo(w)
}
func (b *Builder) insertWriteTo(w Writer) error {
if len(b.into) <= 0 {
return ErrNoTableName
}
if len(b.insertCols) <= 0 && b.from == "" {
return ErrNoColumnToInsert
}
if b.into != "" && b.from != "" {
return b.insertSelectWriteTo(w)
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
return err
}
var args = make([]interface{}, 0)
var bs []byte
var valBuffer = bytes.NewBuffer(bs)
var i = 0
for col, value := range b.inserts {
for i, col := range b.insertCols {
value := b.insertVals[i]
fmt.Fprint(w, col)
if e, ok := value.(expr); ok {
fmt.Fprint(valBuffer, e.sql)
fmt.Fprintf(valBuffer, "(%s)", e.sql)
args = append(args, e.args...)
} else {
fmt.Fprint(valBuffer, "?")
args = append(args, value)
}
if i != len(b.inserts)-1 {
if i != len(b.insertCols)-1 {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
@ -44,7 +70,6 @@ func (b *Builder) insertWriteTo(w Writer) error {
return err
}
}
i = i + 1
}
if _, err := fmt.Fprint(w, ") Values ("); err != nil {

100
vendor/github.com/go-xorm/builder/builder_limit.go generated vendored Normal file
View file

@ -0,0 +1,100 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"fmt"
"strings"
)
func (b *Builder) limitWriteTo(w Writer) error {
if strings.TrimSpace(b.dialect) == "" {
return ErrDialectNotSetUp
}
if b.limitation != nil {
limit := b.limitation
if limit.offset < 0 || limit.limitN <= 0 {
return ErrInvalidLimitation
}
// erase limit condition
b.limitation = nil
ow := w.(*BytesWriter)
switch strings.ToLower(strings.TrimSpace(b.dialect)) {
case ORACLE:
if len(b.selects) == 0 {
b.selects = append(b.selects, "*")
}
var final *Builder
selects := b.selects
b.selects = append(selects, "ROWNUM RN")
var wb *Builder
if b.optype == unionType {
wb = Dialect(b.dialect).Select("at.*", "ROWNUM RN").
From(b, "at")
} else {
wb = b
}
if limit.offset == 0 {
final = Dialect(b.dialect).Select(selects...).From(wb, "at").
Where(Lte{"at.RN": limit.limitN})
} else {
sub := Dialect(b.dialect).Select("*").
From(b, "at").Where(Lte{"at.RN": limit.offset + limit.limitN})
final = Dialect(b.dialect).Select(selects...).From(sub, "att").
Where(Gt{"att.RN": limit.offset})
}
return final.WriteTo(ow)
case SQLITE, MYSQL, POSTGRES:
// if type UNION, we need to write previous content back to current writer
if b.optype == unionType {
if err := b.WriteTo(ow); err != nil {
return err
}
}
if limit.offset == 0 {
fmt.Fprint(ow, " LIMIT ", limit.limitN)
} else {
fmt.Fprintf(ow, " LIMIT %v OFFSET %v", limit.limitN, limit.offset)
}
case MSSQL:
if len(b.selects) == 0 {
b.selects = append(b.selects, "*")
}
var final *Builder
selects := b.selects
b.selects = append(append([]string{fmt.Sprintf("TOP %d %v", limit.limitN+limit.offset, b.selects[0])},
b.selects[1:]...), "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN")
var wb *Builder
if b.optype == unionType {
wb = Dialect(b.dialect).Select("*", "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN").
From(b, "at")
} else {
wb = b
}
if limit.offset == 0 {
final = Dialect(b.dialect).Select(selects...).From(wb, "at")
} else {
final = Dialect(b.dialect).Select(selects...).From(wb, "at").Where(Gt{"at.RN": limit.offset})
}
return final.WriteTo(ow)
default:
return ErrNotSupportType
}
}
return nil
}

View file

@ -5,13 +5,24 @@
package builder
import (
"errors"
"fmt"
)
// Select creates a select Builder
func Select(cols ...string) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Select(cols...)
}
func (b *Builder) selectWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
return errors.New("no table indicated")
if len(b.from) <= 0 && !b.isNested {
return ErrNoTableName
}
// perform limit before writing to writer when b.dialect between ORACLE and MSSQL
// this avoid a duplicate writing problem in simple limit query
if b.limitation != nil && (b.dialect == ORACLE || b.dialect == MSSQL) {
return b.limitWriteTo(w)
}
if _, err := fmt.Fprint(w, "SELECT "); err != nil {
@ -34,20 +45,101 @@ func (b *Builder) selectWriteTo(w Writer) error {
}
}
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil {
return err
if b.subQuery == nil {
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
return err
}
} else {
if b.cond.IsValid() && len(b.from) <= 0 {
return ErrUnnamedDerivedTable
}
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
return ErrInconsistentDialect
}
// dialect of sub-query will inherit from the main one (if not set up)
if b.dialect != "" && b.subQuery.dialect == "" {
b.subQuery.dialect = b.dialect
}
switch b.subQuery.optype {
case selectType, unionType:
fmt.Fprint(w, " FROM (")
if err := b.subQuery.WriteTo(w); err != nil {
return err
}
if len(b.from) == 0 {
fmt.Fprintf(w, ")")
} else {
fmt.Fprintf(w, ") %v", b.from)
}
default:
return ErrUnexpectedSubQuery
}
}
for _, v := range b.joins {
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable)
if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil {
return err
}
if err := v.joinCond.WriteTo(w); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
if b.cond.IsValid() {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := b.cond.WriteTo(w); err != nil {
return err
}
}
return b.cond.WriteTo(w)
if len(b.groupBy) > 0 {
if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil {
return err
}
}
if len(b.having) > 0 {
if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil {
return err
}
}
if len(b.orderBy) > 0 {
if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil {
return err
}
}
if b.limitation != nil {
if err := b.limitWriteTo(w); err != nil {
return err
}
}
return nil
}
// OrderBy orderBy SQL
func (b *Builder) OrderBy(orderBy string) *Builder {
b.orderBy = orderBy
return b
}
// GroupBy groupby SQL
func (b *Builder) GroupBy(groupby string) *Builder {
b.groupBy = groupby
return b
}
// Having having SQL
func (b *Builder) Having(having string) *Builder {
b.having = having
return b
}

47
vendor/github.com/go-xorm/builder/builder_union.go generated vendored Normal file
View file

@ -0,0 +1,47 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"fmt"
"strings"
)
func (b *Builder) unionWriteTo(w Writer) error {
if b.limitation != nil || b.cond.IsValid() ||
b.orderBy != "" || b.having != "" || b.groupBy != "" {
return ErrNotUnexpectedUnionConditions
}
for idx, u := range b.unions {
current := u.builder
if current.optype != selectType {
return ErrUnsupportedUnionMembers
}
if len(b.unions) == 1 {
if err := current.selectWriteTo(w); err != nil {
return err
}
} else {
if b.dialect != "" && b.dialect != current.dialect {
return ErrInconsistentDialect
}
if idx != 0 {
fmt.Fprint(w, fmt.Sprintf(" UNION %v ", strings.ToUpper(u.unionType)))
}
fmt.Fprint(w, "(")
if err := current.selectWriteTo(w); err != nil {
return err
}
fmt.Fprint(w, ")")
}
}
return nil
}

View file

@ -5,19 +5,24 @@
package builder
import (
"errors"
"fmt"
)
// Update creates an update Builder
func Update(updates ...Eq) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Update(updates...)
}
func (b *Builder) updateWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
return errors.New("no table indicated")
if len(b.from) <= 0 {
return ErrNoTableName
}
if len(b.updates) <= 0 {
return errors.New("no column to be update")
return ErrNoColumnToUpdate
}
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil {
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil {
return err
}

View file

@ -1,12 +0,0 @@
dependencies:
override:
# './...' is a relative pattern which means all subdirectories
- go get -t -d -v ./...
- go build -v
- go get -u github.com/golang/lint/golint
test:
override:
# './...' is a relative pattern which means all subdirectories
- golint ./...
- go test -v -race

View file

@ -5,7 +5,6 @@
package builder
import (
"bytes"
"io"
)
@ -19,15 +18,15 @@ var _ Writer = NewWriter()
// BytesWriter implments Writer and save SQL in bytes.Buffer
type BytesWriter struct {
writer *bytes.Buffer
buffer []byte
writer *StringBuilder
args []interface{}
}
// NewWriter creates a new string writer
func NewWriter() *BytesWriter {
w := &BytesWriter{}
w.writer = bytes.NewBuffer(w.buffer)
w := &BytesWriter{
writer: &StringBuilder{},
}
return w
}
@ -73,15 +72,3 @@ func (condEmpty) Or(conds ...Cond) Cond {
func (condEmpty) IsValid() bool {
return false
}
func condToSQL(cond Cond) (string, []interface{}, error) {
if cond == nil || !cond.IsValid() {
return "", nil, nil
}
w := NewWriter()
if err := cond.WriteTo(w); err != nil {
return "", nil, err
}
return w.writer.String(), w.args, nil
}

View file

@ -25,7 +25,9 @@ func And(conds ...Cond) Cond {
func (and condAnd) WriteTo(w Writer) error {
for i, cond := range and {
_, isOr := cond.(condOr)
if isOr {
_, isExpr := cond.(expr)
wrap := isOr || isExpr
if wrap {
fmt.Fprint(w, "(")
}
@ -34,7 +36,7 @@ func (and condAnd) WriteTo(w Writer) error {
return err
}
if isOr {
if wrap {
fmt.Fprint(w, ")")
}

View file

@ -17,10 +17,35 @@ var _ Cond = Between{}
// WriteTo write data to Writer
func (between Between) WriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "%s BETWEEN ? AND ?", between.Col); err != nil {
if _, err := fmt.Fprintf(w, "%s BETWEEN ", between.Col); err != nil {
return err
}
w.Append(between.LessVal, between.MoreVal)
if lv, ok := between.LessVal.(expr); ok {
if err := lv.WriteTo(w); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, "?"); err != nil {
return err
}
w.Append(between.LessVal)
}
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
if mv, ok := between.MoreVal.(expr); ok {
if err := mv.WriteTo(w); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, "?"); err != nil {
return err
}
w.Append(between.MoreVal)
}
return nil
}

View file

@ -10,7 +10,13 @@ import "fmt"
func WriteMap(w Writer, data map[string]interface{}, op string) error {
var args = make([]interface{}, 0, len(data))
var i = 0
for k, v := range data {
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
for _, k := range keys {
v := data[k]
switch v.(type) {
case expr:
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil {

View file

@ -4,7 +4,10 @@
package builder
import "fmt"
import (
"fmt"
"sort"
)
// Incr implements a type used by Eq
type Incr int
@ -19,7 +22,8 @@ var _ Cond = Eq{}
func (eq Eq) opWriteTo(op string, w Writer) error {
var i = 0
for k, v := range eq {
for _, k := range eq.sortedKeys() {
v := eq[k]
switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}:
if err := In(k, v).WriteTo(w); err != nil {
@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond {
func (eq Eq) IsValid() bool {
return len(eq) > 0
}
// sortedKeys returns all keys of this Eq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (eq Eq) sortedKeys() []string {
keys := make([]string, 0, len(eq))
for key := range eq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View file

@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil {
return err
}
// FIXME: if use other regular express, this will be failed. but for compitable, keep this
// FIXME: if use other regular express, this will be failed. but for compatible, keep this
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' {
w.Append(like[1])
} else {

View file

@ -4,7 +4,10 @@
package builder
import "fmt"
import (
"fmt"
"sort"
)
// Neq defines not equal conditions
type Neq map[string]interface{}
@ -15,7 +18,8 @@ var _ Cond = Neq{}
func (neq Neq) WriteTo(w Writer) error {
var args = make([]interface{}, 0, len(neq))
var i = 0
for k, v := range neq {
for _, k := range neq.sortedKeys() {
v := neq[k]
switch v.(type) {
case []int, []int64, []string, []int32, []int16, []int8:
if err := NotIn(k, v).WriteTo(w); err != nil {
@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond {
func (neq Neq) IsValid() bool {
return len(neq) > 0
}
// sortedKeys returns all keys of this Neq sorted with sort.Strings.
// It is used internally for consistent ordering when generating
// SQL, see https://github.com/go-xorm/builder/issues/10
func (neq Neq) sortedKeys() []string {
keys := make([]string, 0, len(neq))
for key := range neq {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View file

@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, "("); err != nil {
return err
}
}
}
if err := not[0].WriteTo(w); err != nil {
@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
case Eq:
if len(not[0].(Eq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
case Neq:
if len(not[0].(Neq)) > 1 {
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
}
}
return nil

View file

@ -27,10 +27,12 @@ func (o condOr) WriteTo(w Writer) error {
for i, cond := range o {
var needQuote bool
switch cond.(type) {
case condAnd:
case condAnd, expr:
needQuote = true
case Eq:
needQuote = (len(cond.(Eq)) > 1)
case Neq:
needQuote = (len(cond.(Neq)) > 1)
}
if needQuote {

View file

@ -8,9 +8,33 @@ import "errors"
var (
// ErrNotSupportType not supported SQL type error
ErrNotSupportType = errors.New("not supported SQL type")
ErrNotSupportType = errors.New("Not supported SQL type")
// ErrNoNotInConditions no NOT IN params error
ErrNoNotInConditions = errors.New("No NOT IN conditions")
// ErrNoInConditions no IN params error
ErrNoInConditions = errors.New("No IN conditions")
// ErrNeedMoreArguments need more arguments
ErrNeedMoreArguments = errors.New("Need more sql arguments")
// ErrNoTableName no table name
ErrNoTableName = errors.New("No table indicated")
// ErrNoColumnToInsert no column to update
ErrNoColumnToUpdate = errors.New("No column(s) to update")
// ErrNoColumnToInsert no column to update
ErrNoColumnToInsert = errors.New("No column(s) to insert")
// ErrNotSupportDialectType not supported dialect type error
ErrNotSupportDialectType = errors.New("Not supported dialect type")
// ErrNotUnexpectedUnionConditions using union in a wrong way
ErrNotUnexpectedUnionConditions = errors.New("Unexpected conditional fields in UNION query")
// ErrUnsupportedUnionMembers unexpected members in UNION query
ErrUnsupportedUnionMembers = errors.New("Unexpected members in UNION query")
// ErrUnexpectedSubQuery Unexpected sub-query in SELECT query
ErrUnexpectedSubQuery = errors.New("Unexpected sub-query in SELECT query")
// ErrDialectNotSetUp dialect is not setup yet
ErrDialectNotSetUp = errors.New("Dialect is not setup yet, try to use `Dialect(dbType)` at first")
// ErrInvalidLimitation offset or limit is not correct
ErrInvalidLimitation = errors.New("Offset or limit is not correct")
// ErrUnnamedDerivedTable Every derived table must have its own alias
ErrUnnamedDerivedTable = errors.New("Every derived table must have its own alias")
// ErrInconsistentDialect Inconsistent dialect in same builder
ErrInconsistentDialect = errors.New("Inconsistent dialect in same builder")
)

1
vendor/github.com/go-xorm/builder/go.mod generated vendored Normal file
View file

@ -0,0 +1 @@
module "github.com/go-xorm/builder"

156
vendor/github.com/go-xorm/builder/sql.go generated vendored Normal file
View file

@ -0,0 +1,156 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
sql2 "database/sql"
"fmt"
"reflect"
"time"
)
func condToSQL(cond Cond) (string, []interface{}, error) {
if cond == nil || !cond.IsValid() {
return "", nil, nil
}
w := NewWriter()
if err := cond.WriteTo(w); err != nil {
return "", nil, err
}
return w.writer.String(), w.args, nil
}
func condToBoundSQL(cond Cond) (string, error) {
if cond == nil || !cond.IsValid() {
return "", nil
}
w := NewWriter()
if err := cond.WriteTo(w); err != nil {
return "", err
}
return ConvertToBoundSQL(w.writer.String(), w.args)
}
// ToSQL convert a builder or conditions to SQL and args
func ToSQL(cond interface{}) (string, []interface{}, error) {
switch cond.(type) {
case Cond:
return condToSQL(cond.(Cond))
case *Builder:
return cond.(*Builder).ToSQL()
}
return "", nil, ErrNotSupportType
}
// ToBoundSQL convert a builder or conditions to parameters bound SQL
func ToBoundSQL(cond interface{}) (string, error) {
switch cond.(type) {
case Cond:
return condToBoundSQL(cond.(Cond))
case *Builder:
return cond.(*Builder).ToBoundSQL()
}
return "", ErrNotSupportType
}
func noSQLQuoteNeeded(a interface{}) bool {
switch a.(type) {
case int, int8, int16, int32, int64:
return true
case uint, uint8, uint16, uint32, uint64:
return true
case float32, float64:
return true
case bool:
return true
case string:
return false
case time.Time, *time.Time:
return false
}
t := reflect.TypeOf(a)
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.Bool:
return true
case reflect.String:
return false
}
return false
}
// ConvertToBoundSQL will convert SQL and args to a bound SQL
func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
buf := StringBuilder{}
var i, j, start int
for ; i < len(sql); i++ {
if sql[i] == '?' {
_, err := buf.WriteString(sql[start:i])
if err != nil {
return "", err
}
start = i + 1
if len(args) == j {
return "", ErrNeedMoreArguments
}
arg := args[j]
if namedArg, ok := arg.(sql2.NamedArg); ok {
arg = namedArg.Value
}
if noSQLQuoteNeeded(arg) {
_, err = fmt.Fprint(&buf, arg)
} else {
_, err = fmt.Fprintf(&buf, "'%v'", arg)
}
if err != nil {
return "", err
}
j = j + 1
}
}
_, err := buf.WriteString(sql[start:])
if err != nil {
return "", err
}
return buf.String(), nil
}
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix
func ConvertPlaceholder(sql, prefix string) (string, error) {
buf := StringBuilder{}
var i, j, start int
for ; i < len(sql); i++ {
if sql[i] == '?' {
if _, err := buf.WriteString(sql[start:i]); err != nil {
return "", err
}
start = i + 1
j = j + 1
if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil {
return "", err
}
}
}
if _, err := buf.WriteString(sql[start:]); err != nil {
return "", err
}
return buf.String(), nil
}

119
vendor/github.com/go-xorm/builder/string_builder.go generated vendored Normal file
View file

@ -0,0 +1,119 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"unicode/utf8"
"unsafe"
)
// A StringBuilder is used to efficiently build a string using Write methods.
// It minimizes memory copying. The zero value is ready to use.
// Do not copy a non-zero Builder.
type StringBuilder struct {
addr *StringBuilder // of receiver, to detect copies by value
buf []byte
}
// noescape hides a pointer from escape analysis. noescape is
// the identity function but escape analysis doesn't think the
// output depends on the input. noescape is inlined and currently
// compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime; see issues 23382 and 7921.
//go:nosplit
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
func (b *StringBuilder) copyCheck() {
if b.addr == nil {
// This hack works around a failing of Go's escape analysis
// that was causing b to escape and be heap allocated.
// See issue 23382.
// TODO: once issue 7921 is fixed, this should be reverted to
// just "b.addr = b".
b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b)))
} else if b.addr != b {
panic("strings: illegal use of non-zero Builder copied by value")
}
}
// String returns the accumulated string.
func (b *StringBuilder) String() string {
return *(*string)(unsafe.Pointer(&b.buf))
}
// Len returns the number of accumulated bytes; b.Len() == len(b.String()).
func (b *StringBuilder) Len() int { return len(b.buf) }
// Reset resets the Builder to be empty.
func (b *StringBuilder) Reset() {
b.addr = nil
b.buf = nil
}
// grow copies the buffer to a new, larger buffer so that there are at least n
// bytes of capacity beyond len(b.buf).
func (b *StringBuilder) grow(n int) {
buf := make([]byte, len(b.buf), 2*cap(b.buf)+n)
copy(buf, b.buf)
b.buf = buf
}
// Grow grows b's capacity, if necessary, to guarantee space for
// another n bytes. After Grow(n), at least n bytes can be written to b
// without another allocation. If n is negative, Grow panics.
func (b *StringBuilder) Grow(n int) {
b.copyCheck()
if n < 0 {
panic("strings.Builder.Grow: negative count")
}
if cap(b.buf)-len(b.buf) < n {
b.grow(n)
}
}
// Write appends the contents of p to b's buffer.
// Write always returns len(p), nil.
func (b *StringBuilder) Write(p []byte) (int, error) {
b.copyCheck()
b.buf = append(b.buf, p...)
return len(p), nil
}
// WriteByte appends the byte c to b's buffer.
// The returned error is always nil.
func (b *StringBuilder) WriteByte(c byte) error {
b.copyCheck()
b.buf = append(b.buf, c)
return nil
}
// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer.
// It returns the length of r and a nil error.
func (b *StringBuilder) WriteRune(r rune) (int, error) {
b.copyCheck()
if r < utf8.RuneSelf {
b.buf = append(b.buf, byte(r))
return 1, nil
}
l := len(b.buf)
if cap(b.buf)-l < utf8.UTFMax {
b.grow(utf8.UTFMax)
}
n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r)
b.buf = b.buf[:l+n]
return n, nil
}
// WriteString appends the contents of s to b's buffer.
// It returns the length of s and a nil error.
func (b *StringBuilder) WriteString(s string) (int, error) {
b.copyCheck()
b.buf = append(b.buf, s...)
return len(s), nil
}

View file

@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
}
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
}
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil

View file

@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table {
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column {
}
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {

View file

@ -74,6 +74,7 @@ var (
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
NText = "NTEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
@ -130,6 +131,7 @@ var (
NVarchar: TEXT_TYPE,
TinyText: TEXT_TYPE,
Text: TEXT_TYPE,
NText: TEXT_TYPE,
MediumText: TEXT_TYPE,
LongText: TEXT_TYPE,
Uuid: TEXT_TYPE,
@ -293,7 +295,7 @@ func SQLType2Type(st SQLType) reflect.Type {
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{})

125
vendor/github.com/go-xorm/xorm/.drone.yml generated vendored Normal file
View file

@ -0,0 +1,125 @@
workspace:
base: /go
path: src/github.com/go-xorm/xorm
clone:
git:
image: plugins/git:next
depth: 50
tags: true
services:
mysql:
image: mysql:5.7
environment:
- MYSQL_DATABASE=xorm_test
- MYSQL_ALLOW_EMPTY_PASSWORD=yes
when:
event: [ push, tag, pull_request ]
pgsql:
image: postgres:9.5
environment:
- POSTGRES_USER=postgres
- POSTGRES_DB=xorm_test
when:
event: [ push, tag, pull_request ]
#mssql:
# image: microsoft/mssql-server-linux:2017-CU11
# environment:
# - ACCEPT_EULA=Y
# - SA_PASSWORD=yourStrong(!)Password
# - MSSQL_PID=Developer
# commands:
# - echo 'CREATE DATABASE xorm_test' > create.sql
# - /opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P yourStrong(!)Password -i "create.sql"
matrix:
GO_VERSION:
- 1.8
- 1.9
- 1.10
- 1.11
pipeline:
init_postgres:
image: postgres:9.5
commands:
# wait for postgres service to become available
- |
until psql -U postgres -d xorm_test -h pgsql \
-c "SELECT 1;" >/dev/null 2>&1; do sleep 1; done
# query the database
- |
psql -U postgres -d xorm_test -h pgsql \
-c "create schema xorm;"
build:
image: golang:${GO_VERSION}
commands:
- go get -t -d -v ./...
- go get -u github.com/go-xorm/core
- go get -u github.com/go-xorm/builder
- go build -v
when:
event: [ push, pull_request ]
test-sqlite:
image: golang:${GO_VERSION}
commands:
- go get -u github.com/wadey/gocovmerge
- go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic
- go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic
when:
event: [ push, pull_request ]
test-mysql:
image: golang:${GO_VERSION}
commands:
- go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic
when:
event: [ push, pull_request ]
test-mysql-utf8mb4:
image: golang:${GO_VERSION}
commands:
- go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test?charset=utf8mb4" -coverprofile=coverage2.1-1.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test?charset=utf8mb4" -cache=true -coverprofile=coverage2.1-2.txt -covermode=atomic
when:
event: [ push, pull_request ]
test-mymysql:
image: golang:${GO_VERSION}
commands:
- go test -v -race -db="mymysql" -conn_str="tcp:mysql:3306*xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic
- go test -v -race -db="mymysql" -conn_str="tcp:mysql:3306*xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
when:
event: [ push, pull_request ]
test-postgres:
image: golang:${GO_VERSION}
commands:
- go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
when:
event: [ push, pull_request ]
test-postgres-schema:
image: golang:${GO_VERSION}
commands:
- go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage2.1-1.txt coverage2.1-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt
when:
event: [ push, pull_request ]
#coverage:
# image: robertstettner/drone-codecov
# secrets: [ codecov_token ]
# files:
# - coverage.txt
# when:
# event: [ push, pull_request ]
# branch: [ master ]

View file

@ -28,3 +28,6 @@ temp_test.go
.vscode
xorm.test
*.sqlite3
test.db.sql
.idea/

View file

@ -32,13 +32,10 @@ proposed functionality.
We appreciate any bug reports, but especially ones with self-contained
(doesn't depend on code outside of xorm), minimal (can't be simplified
further) test cases. It's especially helpful if you can submit a pull
request with just the failing test case (you'll probably want to
pattern it after the tests in
[base.go](https://github.com/go-xorm/tests/blob/master/base.go) AND
[benchmark.go](https://github.com/go-xorm/tests/blob/master/benchmark.go).
request with just the failing test case(you can find some example test file like [session_get_test.go](https://github.com/go-xorm/xorm/blob/master/session_get_test.go)).
If you implements a new database interface, you maybe need to add a <databasename>_test.go file.
For example, [mysql_test.go](https://github.com/go-xorm/tests/blob/master/mysql/mysql_test.go)
If you implements a new database interface, you maybe need to add a test_<databasename>.sh file.
For example, [mysql_test.go](https://github.com/go-xorm/xorm/blob/master/test_mysql.sh)
### New functionality

View file

@ -1,3 +1,5 @@
# xorm
[中文](https://github.com/go-xorm/xorm/blob/master/README_CN.md)
Xorm is a simple and powerful ORM for Go.
@ -6,7 +8,7 @@ Xorm is a simple and powerful ORM for Go.
[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm)
[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3)
# Features
## Features
* Struct <-> Table Mapping Support
@ -28,7 +30,13 @@ Xorm is a simple and powerful ORM for Go.
* SQL Builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder)
# Drivers Support
* Automatical Read/Write seperatelly
* Postgres schema support
* Context Cache support
## Drivers Support
Drivers for Go's sql package which currently support database/sql includes:
@ -46,43 +54,17 @@ Drivers for Go's sql package which currently support database/sql includes:
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment)
# Changelog
* **v0.6.3**
* merge tests to main project
* add `Exist` function
* add `SumInt` function
* Mysql now support read and create column comment.
* fix time related bugs.
* fix some other bugs.
* **v0.6.2**
* refactor tag parse methods
* add Scan features to Get
* add QueryString method
* **v0.6.0**
* remove support for ql
* add query condition builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder), so `Where`, `And`, `Or`
methods can use `builder.Cond` as parameter
* add Sum, SumInt, SumInt64 and NotIn methods
* some bugs fixed
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
# Installation
## Installation
go get github.com/go-xorm/xorm
# Documents
## Documents
* [Manual](http://xorm.io/docs)
* [GoDoc](http://godoc.org/github.com/go-xorm/xorm)
* [GoWalker](http://gowalker.org/github.com/go-xorm/xorm)
# Quick Start
## Quick Start
* Create Engine
@ -106,15 +88,36 @@ type User struct {
err := engine.Sync2(new(User))
```
* `Query` runs a SQL string, the returned results is `[]map[string][]byte`, `QueryString` returns `[]map[string]string`.
* Create Engine Group
```Go
dataSourceNameSlice := []string{masterDataSourceName, slave1DataSourceName, slave2DataSourceName}
engineGroup, err := xorm.NewEngineGroup(driverName, dataSourceNameSlice)
```
```Go
masterEngine, err := xorm.NewEngine(driverName, masterDataSourceName)
slave1Engine, err := xorm.NewEngine(driverName, slave1DataSourceName)
slave2Engine, err := xorm.NewEngine(driverName, slave2DataSourceName)
engineGroup, err := xorm.NewEngineGroup(masterEngine, []*Engine{slave1Engine, slave2Engine})
```
Then all place where `engine` you can just use `engineGroup`.
* `Query` runs a SQL string, the returned results is `[]map[string][]byte`, `QueryString` returns `[]map[string]string`, `QueryInterface` returns `[]map[string]interface{}`.
```Go
results, err := engine.Query("select * from user")
results, err := engine.Where("a = 1").Query()
results, err := engine.QueryString("select * from user")
results, err := engine.Where("a = 1").QueryString()
results, err := engine.QueryInterface("select * from user")
results, err := engine.Where("a = 1").QueryInterface()
```
* `Execute` runs a SQL string, it returns `affected` and `error`
* `Exec` runs a SQL string, it returns `affected` and `error`
```Go
affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
@ -125,62 +128,76 @@ affected, err := engine.Exec("update user set age = ? where name = ?", age, name
```Go
affected, err := engine.Insert(&user)
// INSERT INTO struct () values ()
affected, err := engine.Insert(&user1, &user2)
// INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values ()
affected, err := engine.Insert(&users)
// INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&user1, &users)
// INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),()
```
* Query one record from database
* `Get` query one record from database
```Go
has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ?
var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
```
* Check if one record exist on table
* `Exist` check if one record exist on table
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* Query multiple records from database, also you can use join and extends
* `Find` query multiple records from database, also you can use join and extends
```Go
var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10
// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct {
Id int64
@ -193,14 +210,14 @@ type UserDetail struct {
}
var users []UserDetail
err := engine.Table("user").Select("user.*, detail.*")
err := engine.Table("user").Select("user.*, detail.*").
Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0).
Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
```
* Query multiple records and record by record handle, there are two methods Iterate and Rows
* `Iterate` and `Rows` query multiple records and record by record handle, there are two methods Iterate and Rows
```Go
err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
@ -209,6 +226,13 @@ err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
})
// SELECT * FROM user
err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
user := bean.(*User)
return nil
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
@ -218,10 +242,10 @@ for rows.Next() {
}
```
* Update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on.
* `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on.
```Go
affected, err := engine.Id(1).Update(&user)
affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name})
@ -232,32 +256,50 @@ affected, err := engine.In("id", ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user)
affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
```
* Delete one or more records, Delete MUST have condition
* `Delete` delete one or more records, Delete MUST have condition
```Go
affected, err := engine.Where(...).Delete(&user)
// DELETE FROM user Where ...
affected, err := engine.Id(2).Delete(&user)
affected, err := engine.ID(2).Delete(&user)
// DELETE FROM user Where id = ?
```
* Count records
* `Count` count records
```Go
counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user
```
* `Sum` sum functions
```Go
agesFloat64, err := engine.Sum(&user, "age")
// SELECT sum(age) AS total FROM user
agesInt64, err := engine.SumInt(&user, "age")
// SELECT sum(age) AS total FROM user
sumFloat64Slice, err := engine.Sums(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
sumInt64Slice, err := engine.SumsInt(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
```
* Query conditions builder
```Go
@ -265,7 +307,155 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
// SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?)
```
# Cases
* Multiple operations in one go routine, no transation here but resue session memory
```Go
session := engine.NewSession()
defer session.Close()
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
return nil
```
* Transation should on one go routine. There is transaction and resue session memory
```Go
session := engine.NewSession()
defer session.Close()
// add Begin() before any action
if err := session.Begin(); err != nil {
// if returned then will rollback automatically
return err
}
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
// add Commit() after all actions
return session.Commit()
```
* Or you can use `Transaction` to replace above codes.
```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return nil, err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return nil, err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return nil, err
}
return nil, nil
})
```
* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session.
```Go
sess := engine.NewSession()
defer sess.Close()
var context = xorm.NewMemoryContextCache()
var c2 ContextGetStruct
has, err := sess.ID(1).ContextCache(context).Get(&c2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c2.Id)
assert.EqualValues(t, "1", c2.Name)
sql, args := sess.LastSQL()
assert.True(t, len(sql) > 0)
assert.True(t, len(args) > 0)
var c3 ContextGetStruct
has, err = sess.ID(1).ContextCache(context).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name)
sql, args = sess.LastSQL()
assert.True(t, len(sql) == 0)
assert.True(t, len(args) == 0)
```
## Contributing
If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md). And we also provide [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) to discuss.
## Credits
### Contributors
This project exists thanks to all the people who contribute. [[Contribute](CONTRIBUTING.md)].
<a href="graphs/contributors"><img src="https://opencollective.com/xorm/contributors.svg?width=890&button=false" /></a>
### Backers
Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/xorm#backer)]
<a href="https://opencollective.com/xorm#backers" target="_blank"><img src="https://opencollective.com/xorm/backers.svg?width=890"></a>
### Sponsors
Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/xorm#sponsor)]
## Changelog
* **v0.7.0**
* Some bugs fixed
* **v0.6.6**
* Some bugs fixed
* **v0.6.5**
* Postgres schema support
* vgo support
* Add FindAndCount
* Database special params support via NewEngineWithParams
* Some bugs fixed
* **v0.6.4**
* Automatical Read/Write seperatelly
* Query/QueryString/QueryInterface and action with Where/And
* Get support non-struct variables
* BufferSize on Iterate
* fix some other bugs.
[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16)
## Cases
* [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
@ -301,15 +491,6 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog)
# Discuss
## LICENSE
Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm)
# Contributing
If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)
# LICENSE
BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)
BSD License [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

View file

@ -22,6 +22,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 支持级联加载Struct
* Schema支持仅Postgres
* 支持缓存
* 支持根据数据库自动生成xorm的结构体
@ -30,6 +32,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 内置SQL Builder支持
* 上下文缓存支持
## 驱动支持
目前支持的Go数据库驱动和对应的数据库如下
@ -50,35 +54,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持)
## 更新日志
* **v0.6.3**
* 合并单元测试到主工程
* 新增`Exist`方法
* 新增`SumInt`方法
* Mysql新增读取和创建字段注释支持
* 新增`SetConnMaxLifetime`方法
* 修正了时间相关的Bug
* 修复了一些其它Bug
* **v0.6.2**
* 重构Tag解析方式
* Get方法新增类似Scan的特性
* 新增 QueryString 方法
* **v0.6.0**
* 去除对 ql 的支持
* 新增条件查询分析器 [github.com/go-xorm/builder](https://github.com/go-xorm/builder), 从因此 `Where, And, Or` 函数
将可以用 `builder.Cond` 作为条件组合
* 新增 Sum, SumInt, SumInt64 和 NotIn 函数
* Bug修正
* **v0.5.0**
* logging接口进行不兼容改变
* Bug修正
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## 安装
go get github.com/go-xorm/xorm
@ -115,12 +90,33 @@ type User struct {
err := engine.Sync2(new(User))
```
* `Query` 最原始的也支持SQL语句查询返回的结果类型为 []map[string][]byte。`QueryString` 返回 []map[string]string
* 创建Engine组
```Go
dataSourceNameSlice := []string{masterDataSourceName, slave1DataSourceName, slave2DataSourceName}
engineGroup, err := xorm.NewEngineGroup(driverName, dataSourceNameSlice)
```
```Go
masterEngine, err := xorm.NewEngine(driverName, masterDataSourceName)
slave1Engine, err := xorm.NewEngine(driverName, slave1DataSourceName)
slave2Engine, err := xorm.NewEngine(driverName, slave2DataSourceName)
engineGroup, err := xorm.NewEngineGroup(masterEngine, []*Engine{slave1Engine, slave2Engine})
```
所有使用 `engine` 都可以简单的用 `engineGroup` 来替换。
* `Query` 最原始的也支持SQL语句查询返回的结果类型为 []map[string][]byte。`QueryString` 返回 []map[string]string, `QueryInterface` 返回 `[]map[string]interface{}`.
```Go
results, err := engine.Query("select * from user")
results, err := engine.Where("a = 1").Query()
results, err := engine.QueryString("select * from user")
results, err := engine.Where("a = 1").QueryString()
results, err := engine.QueryInterface("select * from user")
results, err := engine.Where("a = 1").QueryInterface()
```
* `Exec` 执行一个SQL语句
@ -129,67 +125,81 @@ results, err := engine.QueryString("select * from user")
affected, err := engine.Exec("update user set age = ? where name = ?", age, name)
```
* 插入一条或者多条记录
* `Insert` 插入一条或者多条记录
```Go
affected, err := engine.Insert(&user)
// INSERT INTO struct () values ()
affected, err := engine.Insert(&user1, &user2)
// INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values ()
affected, err := engine.Insert(&users)
// INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&user1, &users)
// INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),()
```
* 查询单条记录
* `Get` 查询单条记录
```Go
has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ?
var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
```
* 检测记录是否存在
* `Exist` 检测记录是否存在
```Go
has, err := testEngine.Exist(new(RecordExist))
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Exist(&RecordExist{
Name: "test1",
})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{})
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
// select * from record_exist where name = ?
has, err = testEngine.Table("record_exist").Exist()
// SELECT * FROM record_exist LIMIT 1
has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
// SELECT * FROM record_exist WHERE name = ? LIMIT 1
```
* 查询多条记录当然可以使用Join和extends来组合使用
* `Find` 查询多条记录当然可以使用Join和extends来组合使用
```Go
var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10
// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct {
Id int64
@ -206,10 +216,10 @@ err := engine.Table("user").Select("user.*, detail.*")
Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0).
Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
```
* 根据条件遍历数据库,可以有两种方式: Iterate and Rows
* `Iterate``Rows` 根据条件遍历数据库,可以有两种方式: Iterate and Rows
```Go
err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
@ -218,6 +228,13 @@ err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
})
// SELECT * FROM user
err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean interface{}) error {
user := bean.(*User)
return nil
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
@ -227,10 +244,10 @@ for rows.Next() {
}
```
* 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段
* `Update` 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段
```Go
affected, err := engine.Id(1).Update(&user)
affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name})
@ -241,31 +258,50 @@ affected, err := engine.In(ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user)
affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
```
* 删除记录需要注意删除必须至少有一个条件否则会报错。要清空数据库可以用EmptyTable
* `Delete` 删除记录需要注意删除必须至少有一个条件否则会报错。要清空数据库可以用EmptyTable
```Go
affected, err := engine.Where(...).Delete(&user)
// DELETE FROM user Where ...
affected, err := engine.ID(2).Delete(&user)
// DELETE FROM user Where id = ?
```
* 获取记录条数
* `Count` 获取记录条数
```Go
counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user
```
* `Sum` 求和函数
```Go
agesFloat64, err := engine.Sum(&user, "age")
// SELECT sum(age) AS total FROM user
agesInt64, err := engine.SumInt(&user, "age")
// SELECT sum(age) AS total FROM user
sumFloat64Slice, err := engine.Sums(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
sumInt64Slice, err := engine.SumsInt(&user, "age", "score")
// SELECT sum(age), sum(score) FROM user
```
* 条件编辑器
```Go
@ -273,6 +309,132 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
// SELECT id, name ... FROM user WHERE a NOT IN (?, ?) AND b IN (?, ?, ?)
```
* 在一个Go程中多次操作数据库但没有事务
```Go
session := engine.NewSession()
defer session.Close()
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
return nil
```
* 在一个Go程中有事务
```Go
session := engine.NewSession()
defer session.Close()
// add Begin() before any action
if err := session.Begin(); err != nil {
// if returned then will rollback automatically
return err
}
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return err
}
// add Commit() after all actions
return session.Commit()
```
* 事物的简写方法
```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil {
return nil, err
}
user2 := Userinfo{Username: "yyy"}
if _, err := session.Where("id = ?", 2).Update(&user2); err != nil {
return nil, err
}
if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil {
return nil, err
}
return nil, nil
})
```
* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session.
```Go
sess := engine.NewSession()
defer sess.Close()
var context = xorm.NewMemoryContextCache()
var c2 ContextGetStruct
has, err := sess.ID(1).ContextCache(context).Get(&c2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c2.Id)
assert.EqualValues(t, "1", c2.Name)
sql, args := sess.LastSQL()
assert.True(t, len(sql) > 0)
assert.True(t, len(args) > 0)
var c3 ContextGetStruct
has, err = sess.ID(1).ContextCache(context).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name)
sql, args = sess.LastSQL()
assert.True(t, len(sql) == 0)
assert.True(t, len(args) == 0)
```
## 贡献
如果您也想为Xorm贡献您的力量请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)。您也可以加入QQ群 技术帮助和讨论。
群一280360085 (已满)
群二795010183
## Credits
### Contributors
感谢所有的贡献者. [[Contribute](CONTRIBUTING.md)].
<a href="graphs/contributors"><img src="https://opencollective.com/xorm/contributors.svg?width=890&button=false" /></a>
### Backers
感谢我们所有的 backers! 🙏 [[成为 backer](https://opencollective.com/xorm#backer)]
<a href="https://opencollective.com/xorm#backers" target="_blank"><img src="https://opencollective.com/xorm/backers.svg?width=890"></a>
### Sponsors
成为 sponsor 来支持 xorm。您的 logo 将会被显示并被链接到您的网站。 [[成为 sponsor](https://opencollective.com/xorm#sponsor)]
# 案例
* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang)
@ -307,13 +469,30 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e"))
* [go-blog](http://wangcheng.me) - [github.com/easykoo/go-blog](https://github.com/easykoo/go-blog)
## 讨论
请加入QQ群280360085 进行讨论。
## 更新日志
## 贡献
* **v0.7.0**
* 修正部分Bug
如果您也想为Xorm贡献您的力量请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)
* **v0.6.6**
* 修正部分Bug
* **v0.6.5**
* 通过 engine.SetSchema 来支持 schema当前仅支持Postgres
* vgo 支持
* 新增 `FindAndCount` 函数
* 通过 `NewEngineWithParams` 支持数据库特别参数
* 修正部分Bug
* **v0.6.4**
* 自动读写分离支持
* Query/QueryString/QueryInterface 支持与 Where/And 合用
* `Get` 支持获取非结构体变量
* `Iterate` 支持 `BufferSize`
* 修正部分Bug
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## LICENSE

View file

@ -17,11 +17,12 @@ database:
- createdb -p 5432 -e -U postgres xorm_test1
- createdb -p 5432 -e -U postgres xorm_test2
- createdb -p 5432 -e -U postgres xorm_test3
- psql xorm_test postgres -c "create schema xorm"
test:
override:
# './...' is a relative pattern which means all subdirectories
- go get -u github.com/wadey/gocovmerge;
- go get -u github.com/wadey/gocovmerge
- go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic
- go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic
@ -30,7 +31,9 @@ test:
- go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt > coverage.txt
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh

26
vendor/github.com/go-xorm/xorm/context.go generated vendored Normal file
View file

@ -0,0 +1,26 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package xorm
import "context"
// PingContext tests if database is alive
func (engine *Engine) PingContext(ctx context.Context) error {
session := engine.NewSession()
defer session.Close()
return session.PingContext(ctx)
}
// PingContext test if database is ok
func (session *Session) PingContext(ctx context.Context) error {
if session.isAutoClose {
defer session.Close()
}
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().PingContext(ctx)
}

30
vendor/github.com/go-xorm/xorm/context_cache.go generated vendored Normal file
View file

@ -0,0 +1,30 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
// ContextCache is the interface that operates the cache data.
type ContextCache interface {
// Put puts value into cache with key.
Put(key string, val interface{})
// Get gets cached value by given key.
Get(key string) interface{}
}
type memoryContextCache map[string]interface{}
// NewMemoryContextCache return memoryContextCache
func NewMemoryContextCache() memoryContextCache {
return make(map[string]interface{})
}
// Put puts value into cache with key.
func (m memoryContextCache) Put(key string, val interface{}) {
m[key] = val
}
// Get gets cached value by given key.
func (m memoryContextCache) Get(key string) interface{} {
return m[key]
}

View file

@ -209,10 +209,10 @@ func convertAssign(dest, src interface{}) error {
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
} else {
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())

View file

@ -172,12 +172,33 @@ type mysql struct {
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
rowFormat string
}
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
func (db *mysql) SetParams(params map[string]string) {
rowFormat, ok := params["rowFormat"]
if ok {
var t = strings.ToUpper(rowFormat)
switch t {
case "COMPACT":
fallthrough
case "REDUNDANT":
fallthrough
case "DYNAMIC":
fallthrough
case "COMPRESSED":
db.rowFormat = t
break
default:
break
}
}
}
func (db *mysql) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
@ -487,6 +508,62 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil
}
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += db.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
} else {
sql += col.StringNoPk(db)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if len(charset) == 0 {
charset = db.URI().Charset
}
if len(charset) != 0 {
sql += " DEFAULT CHARSET " + charset
}
if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat
}
return sql
}
func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
}

View file

@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net/url"
"sort"
"strconv"
"strings"
@ -765,14 +764,26 @@ var (
"YES": true,
"ZONE": true,
}
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
)
const postgresPublicSchema = "public"
type postgres struct {
core.Base
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
}
return nil
}
func (db *postgres) SqlType(c *core.Column) string {
@ -869,32 +880,42 @@ func (db *postgres) IndexOnTable() bool {
}
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
if len(db.Schema) == 0 {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}*/
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
if len(db.Schema) == 0 {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
quote := db.Quote
idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
@ -903,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
args := []interface{}{db.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
if len(db.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
@ -922,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
// FIXME: the schema should be replaced by user custom's
args := []interface{}{tableName, "public"}
args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -934,7 +962,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -1024,9 +1060,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
}
func (db *postgres) GetTables() ([]*core.Table, error) {
// FIXME: replace public to user customrize schema
args := []interface{}{"public"}
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
args := []interface{}{}
s := "SELECT tablename FROM pg_tables"
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " WHERE schemaname = $1"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -1050,9 +1090,12 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
// FIXME: replace the public schema to user specify schema
args := []interface{}{"public", tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -1117,10 +1160,6 @@ func (vs values) Get(k string) (v string) {
return vs[k]
}
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@ -1131,46 +1170,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
return escaper.Replace(u.Path[1:]), nil
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
return "", nil
}
func parseOpts(name string, o values) {
func parseOpts(name string, o values) error {
if len(name) == 0 {
return
return fmt.Errorf("invalid options: %s", name)
}
name = strings.TrimSpace(name)
@ -1179,31 +1190,48 @@ func parseOpts(name string, o values) {
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
return fmt.Errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
return nil
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName)
db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil {
return nil, err
}
}
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname")
db.DbName = o.Get("dbname")
}
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/
return db, nil
}
type pqDriverPgx struct {
pqDriver
}
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) {
// Remove the leading characters for driver to work
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
dataSourceName = dataSourceName[9:]
}
return pgx.pqDriver.Parse(driverName, dataSourceName)
}

View file

@ -233,7 +233,7 @@ func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
}
func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
// var unique string
quote := db.Quote
idxName := index.Name
@ -452,5 +452,9 @@ type sqlite3Driver struct {
}
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
}
return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil
}

View file

@ -47,6 +47,52 @@ type Engine struct {
disableGlobalCache bool
tagHandlers map[string]tagHandler
engineGroup *EngineGroup
cachers map[string]core.Cacher
cacherLock sync.RWMutex
}
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
engine.cacherLock.Lock()
engine.cachers[tableName] = cacher
engine.cacherLock.Unlock()
}
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
engine.setCacher(tableName, cacher)
}
func (engine *Engine) getCacher(tableName string) core.Cacher {
var cacher core.Cacher
var ok bool
engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName]
engine.cacherLock.RUnlock()
if !ok && !engine.disableGlobalCache {
cacher = engine.Cacher
}
return cacher
}
func (engine *Engine) GetCacher(tableName string) core.Cacher {
return engine.getCacher(tableName)
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.BufferSize(size)
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName}
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO
@ -79,6 +125,11 @@ func (engine *Engine) SetLogger(logger core.ILogger) {
engine.dialect.SetLogger(logger)
}
// SetLogLevel sets the logger level
func (engine *Engine) SetLogLevel(level core.LogLevel) {
engine.logger.SetLevel(level)
}
// SetDisableGlobalCache disable global cache or not
func (engine *Engine) SetDisableGlobalCache(disable bool) {
if engine.disableGlobalCache != disable {
@ -126,6 +177,14 @@ func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr()
}
func (engine *Engine) quoteColumns(columnStr string) string {
columns := strings.Split(columnStr, ",")
for i := 0; i < len(columns); i++ {
columns[i] = engine.Quote(strings.TrimSpace(columns[i]))
}
return strings.Join(columns, ",")
}
// Quote Use QuoteStr quote the string sql
func (engine *Engine) Quote(value string) string {
value = strings.TrimSpace(value)
@ -143,7 +202,7 @@ func (engine *Engine) Quote(value string) string {
}
// QuoteTo quotes string and writes into the buffer
func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) {
func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
if buf == nil {
return
}
@ -186,6 +245,11 @@ func (engine *Engine) AutoIncrStr() string {
return engine.dialect.AutoIncrStr()
}
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.db.SetConnMaxLifetime(d)
}
// SetMaxOpenConns is only available for go 1.2+
func (engine *Engine) SetMaxOpenConns(conns int) {
engine.db.SetMaxOpenConns(conns)
@ -201,6 +265,11 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
engine.Cacher = cacher
}
// GetDefaultCacher returns the default cacher
func (engine *Engine) GetDefaultCacher() core.Cacher {
return engine.Cacher
}
// NoCache If you has set default cacher, and you want temporilly stop use cache,
// you can use NoCache()
func (engine *Engine) NoCache() *Session {
@ -218,13 +287,7 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
v := rValue(bean)
tb, err := engine.autoMapType(v)
if err != nil {
return err
}
tb.Cacher = cacher
engine.setCacher(engine.TableName(bean, true), cacher)
return nil
}
@ -509,33 +572,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil
}
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
}
if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return tb.TableName()
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name())
}
// Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
@ -736,6 +772,13 @@ func (engine *Engine) OrderBy(order string) *Session {
return session.OrderBy(order)
}
// Prepare enables prepare statement
func (engine *Engine) Prepare() *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Prepare()
}
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session {
session := engine.NewSession()
@ -757,7 +800,8 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions)
}
func (engine *Engine) unMapType(t reflect.Type) {
// UnMapType removes the datbase mapper of a type
func (engine *Engine) UnMapType(t reflect.Type) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
delete(engine.Tables, t)
@ -811,7 +855,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
if err != nil {
engine.logger.Error(err)
}
return &Table{tb, engine.tbName(v)}
return &Table{tb, engine.TableName(bean)}
}
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -826,15 +870,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i
}
}
func (engine *Engine) newTable() *core.Table {
table := core.NewEmptyTable()
if !engine.disableGlobalCache {
table.Cacher = engine.Cacher
}
return table
}
// TableName table name interface to define customerize table name
type TableName interface {
TableName() string
@ -846,21 +881,9 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := engine.newTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}
table := core.NewEmptyTable()
table.Type = t
table.Name = engine.tbNameForMap(v)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
@ -914,7 +937,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
}
if pStart > -1 {
if !strings.HasSuffix(k, ")") {
return nil, errors.New("cannot match ) charactor")
return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", col.FieldName, key)
}
ctx.tagName = k[:pStart]
@ -1014,15 +1037,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
if hasCacheTag {
if engine.Cacher != nil { // !nash! use engine's cacher if provided
engine.logger.Info("enable cache on table:", table.Name)
table.Cacher = engine.Cacher
engine.setCacher(table.Name, engine.Cacher)
} else {
engine.logger.Info("enable LRU cache on table:", table.Name)
table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000))
}
}
if hasNoCacheTag {
engine.logger.Info("no cache on table:", table.Name)
table.Cacher = nil
engine.logger.Info("disable cache on table:", table.Name)
engine.setCacher(table.Name, nil)
}
return table, nil
@ -1081,7 +1104,25 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
var err error
pkField := v.FieldByName(col.FieldName)
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, ErrUnSupportedType
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() {
case reflect.String:
pk[i], err = engine.idTypeAssertion(col, pkField.String())
@ -1127,26 +1168,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean)
}
func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher
}
// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
tableName := engine.TableName(bean)
cacher := engine.getCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.DelBean(tableName, id)
@ -1157,21 +1182,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
tableName := engine.TableName(bean)
cacher := engine.getCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
@ -1189,13 +1201,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
tableName := engine.tbName(v)
tableNameNoSchema := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
isExist, err := session.Table(bean).isTableExist(tableName)
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
if err != nil {
return err
}
@ -1221,12 +1233,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
} else {
for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
err = session.addColumn(col.Name)
@ -1237,35 +1249,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
for name, index := range table.Indexes {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableName, index.Cols, true)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
err = session.addUnique(tableName, name)
err = session.addUnique(tableNameNoSchema, name)
if err != nil {
return err
}
}
} else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableName, index.Cols, false)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
err = session.addIndex(tableName, name)
err = session.addIndex(tableNameNoSchema, name)
if err != nil {
return err
}
@ -1334,31 +1346,31 @@ func (engine *Engine) DropIndexes(bean interface{}) error {
}
// Exec raw sql
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
session := engine.NewSession()
defer session.Close()
return session.Exec(sql, args...)
return session.Exec(sqlorArgs...)
}
// Query a raw sql and return records as []map[string][]byte
func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) {
session := engine.NewSession()
defer session.Close()
return session.Query(sql, paramStr...)
return session.Query(sqlorArgs...)
}
// QueryString runs a raw sql and return records as []map[string]string
func (engine *Engine) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
func (engine *Engine) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryString(sqlStr, args...)
return session.QueryString(sqlorArgs...)
}
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
func (engine *Engine) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession()
defer session.Close()
return session.QueryInterface(sqlStr, args...)
return session.QueryInterface(sqlorArgs...)
}
// Insert one or more records
@ -1418,6 +1430,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
return session.Find(beans, condiBeans...)
}
// FindAndCount find the results and also return the counts
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.FindAndCount(rowsSlicePtr, condiBean...)
}
// Iterate record by record handle records from table, bean's non-empty fields
// are conditions.
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {
@ -1564,24 +1583,44 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
return
}
// GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() core.IMapper {
return engine.ColumnMapper
}
// GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() core.IMapper {
return engine.TableMapper
}
// GetTZLocation returns time zone of the application
func (engine *Engine) GetTZLocation() *time.Location {
return engine.TZLocation
}
// SetTZLocation sets time zone of the application
func (engine *Engine) SetTZLocation(tz *time.Location) {
engine.TZLocation = tz
}
// GetTZDatabase returns time zone of the database
func (engine *Engine) GetTZDatabase() *time.Location {
return engine.DatabaseTZ
}
// SetTZDatabase sets time zone of the database
func (engine *Engine) SetTZDatabase(tz *time.Location) {
engine.DatabaseTZ = tz
}
// SetSchema sets the schema of database
func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().Schema = schema
}
// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Unscoped()
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName}
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.BufferSize(size)
}

View file

@ -9,6 +9,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"github.com/go-xorm/builder"
@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
engine.logger.Error(err)
if !strings.Contains(err.Error(), "is not valid") {
engine.logger.Warn(err)
}
continue
}

204
vendor/github.com/go-xorm/xorm/engine_group.go generated vendored Normal file
View file

@ -0,0 +1,204 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"time"
"github.com/go-xorm/core"
)
// EngineGroup defines an engine group
type EngineGroup struct {
*Engine
slaves []*Engine
policy GroupPolicy
}
// NewEngineGroup creates a new engine group
func NewEngineGroup(args1 interface{}, args2 interface{}, policies ...GroupPolicy) (*EngineGroup, error) {
var eg EngineGroup
if len(policies) > 0 {
eg.policy = policies[0]
} else {
eg.policy = RoundRobinPolicy()
}
driverName, ok1 := args1.(string)
conns, ok2 := args2.([]string)
if ok1 && ok2 {
engines := make([]*Engine, len(conns))
for i, conn := range conns {
engine, err := NewEngine(driverName, conn)
if err != nil {
return nil, err
}
engine.engineGroup = &eg
engines[i] = engine
}
eg.Engine = engines[0]
eg.slaves = engines[1:]
return &eg, nil
}
master, ok3 := args1.(*Engine)
slaves, ok4 := args2.([]*Engine)
if ok3 && ok4 {
master.engineGroup = &eg
for i := 0; i < len(slaves); i++ {
slaves[i].engineGroup = &eg
}
eg.Engine = master
eg.slaves = slaves
return &eg, nil
}
return nil, ErrParamsType
}
// Close the engine
func (eg *EngineGroup) Close() error {
err := eg.Engine.Close()
if err != nil {
return err
}
for i := 0; i < len(eg.slaves); i++ {
err := eg.slaves[i].Close()
if err != nil {
return err
}
}
return nil
}
// Master returns the master engine
func (eg *EngineGroup) Master() *Engine {
return eg.Engine
}
// Ping tests if database is alive
func (eg *EngineGroup) Ping() error {
if err := eg.Engine.Ping(); err != nil {
return err
}
for _, slave := range eg.slaves {
if err := slave.Ping(); err != nil {
return err
}
}
return nil
}
// SetColumnMapper set the column name mapping rule
func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) {
eg.Engine.ColumnMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ColumnMapper = mapper
}
}
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) {
eg.Engine.SetConnMaxLifetime(d)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetConnMaxLifetime(d)
}
}
// SetDefaultCacher set the default cacher
func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) {
eg.Engine.SetDefaultCacher(cacher)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetDefaultCacher(cacher)
}
}
// SetLogger set the new logger
func (eg *EngineGroup) SetLogger(logger core.ILogger) {
eg.Engine.SetLogger(logger)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogger(logger)
}
}
// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level core.LogLevel) {
eg.Engine.SetLogLevel(level)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogLevel(level)
}
}
// SetMapper set the name mapping rules
func (eg *EngineGroup) SetMapper(mapper core.IMapper) {
eg.Engine.SetMapper(mapper)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetMapper(mapper)
}
}
// SetMaxIdleConns set the max idle connections on pool, default is 2
func (eg *EngineGroup) SetMaxIdleConns(conns int) {
eg.Engine.db.SetMaxIdleConns(conns)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].db.SetMaxIdleConns(conns)
}
}
// SetMaxOpenConns is only available for go 1.2+
func (eg *EngineGroup) SetMaxOpenConns(conns int) {
eg.Engine.db.SetMaxOpenConns(conns)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].db.SetMaxOpenConns(conns)
}
}
// SetPolicy set the group policy
func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
eg.policy = policy
return eg
}
// SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) {
eg.Engine.TableMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].TableMapper = mapper
}
}
// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowExecTime(show ...bool) {
eg.Engine.ShowExecTime(show...)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ShowExecTime(show...)
}
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowSQL(show ...bool) {
eg.Engine.ShowSQL(show...)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ShowSQL(show...)
}
}
// Slave returns one of the physical databases which is a slave according the policy
func (eg *EngineGroup) Slave() *Engine {
switch len(eg.slaves) {
case 0:
return eg.Engine
case 1:
return eg.slaves[0]
}
return eg.policy.Slave(eg)
}
// Slaves returns all the slaves
func (eg *EngineGroup) Slaves() []*Engine {
return eg.slaves
}

116
vendor/github.com/go-xorm/xorm/engine_group_policy.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"math/rand"
"sync"
"time"
)
// GroupPolicy is be used by chosing the current slave from slaves
type GroupPolicy interface {
Slave(*EngineGroup) *Engine
}
// GroupPolicyHandler should be used when a function is a GroupPolicy
type GroupPolicyHandler func(*EngineGroup) *Engine
// Slave implements the chosen of slaves
func (h GroupPolicyHandler) Slave(eg *EngineGroup) *Engine {
return h(eg)
}
// RandomPolicy implmentes randomly chose the slave of slaves
func RandomPolicy() GroupPolicyHandler {
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
return func(g *EngineGroup) *Engine {
return g.Slaves()[r.Intn(len(g.Slaves()))]
}
}
// WeightRandomPolicy implmentes randomly chose the slave of slaves
func WeightRandomPolicy(weights []int) GroupPolicyHandler {
var rands = make([]int, 0, len(weights))
for i := 0; i < len(weights); i++ {
for n := 0; n < weights[i]; n++ {
rands = append(rands, i)
}
}
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
idx := rands[r.Intn(len(rands))]
if idx >= len(slaves) {
idx = len(slaves) - 1
}
return slaves[idx]
}
}
func RoundRobinPolicy() GroupPolicyHandler {
var pos = -1
var lock sync.Mutex
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
lock.Lock()
defer lock.Unlock()
pos++
if pos >= len(slaves) {
pos = 0
}
return slaves[pos]
}
}
func WeightRoundRobinPolicy(weights []int) GroupPolicyHandler {
var rands = make([]int, 0, len(weights))
for i := 0; i < len(weights); i++ {
for n := 0; n < weights[i]; n++ {
rands = append(rands, i)
}
}
var pos = -1
var lock sync.Mutex
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
lock.Lock()
defer lock.Unlock()
pos++
if pos >= len(rands) {
pos = 0
}
idx := rands[pos]
if idx >= len(slaves) {
idx = len(slaves) - 1
}
return slaves[idx]
}
}
// LeastConnPolicy implements GroupPolicy, every time will get the least connections slave
func LeastConnPolicy() GroupPolicyHandler {
return func(g *EngineGroup) *Engine {
var slaves = g.Slaves()
connections := 0
idx := 0
for i := 0; i < len(slaves); i++ {
openConnections := slaves[i].DB().Stats().OpenConnections
if i == 0 {
connections = openConnections
idx = i
} else if openConnections <= connections {
connections = openConnections
idx = i
}
}
return slaves[idx]
}
}

View file

@ -1,14 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.6
package xorm
import "time"
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.db.SetConnMaxLifetime(d)
}

113
vendor/github.com/go-xorm/xorm/engine_table.go generated vendored Normal file
View file

@ -0,0 +1,113 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"reflect"
"strings"
"github.com/go-xorm/core"
)
// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
}
return engine.TableMapper.Obj2Table(v.Type().Name())
}
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = engine.tbNameForMap(v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case TableName:
return tablename.(TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return engine.tbNameForMap(v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return engine.tbNameForMap(v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

View file

@ -6,21 +6,44 @@ package xorm
import (
"errors"
"fmt"
)
var (
// ErrParamsType params error
ErrParamsType = errors.New("Params type error")
// ErrTableNotFound table not found error
ErrTableNotFound = errors.New("Not found table")
ErrTableNotFound = errors.New("Table not found")
// ErrUnSupportedType unsupported error
ErrUnSupportedType = errors.New("Unsupported type error")
// ErrNotExist record is not exist error
ErrNotExist = errors.New("Not exist error")
// ErrNotExist record does not exist error
ErrNotExist = errors.New("Record does not exist")
// ErrCacheFailed cache failed error
ErrCacheFailed = errors.New("Cache failed")
// ErrNeedDeletedCond delete needs less one condition error
ErrNeedDeletedCond = errors.New("Delete need at least one condition")
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type")
)
// ErrFieldIsNotExist columns does not exist
type ErrFieldIsNotExist struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotExist) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}
// ErrFieldIsNotValid is not valid
type ErrFieldIsNotValid struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotValid) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}

24
vendor/github.com/go-xorm/xorm/go.mod generated vendored Normal file
View file

@ -0,0 +1,24 @@
module github.com/go-xorm/xorm
require (
github.com/cockroachdb/apd v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/denisenkom/go-mssqldb v0.0.0-20181014144952-4e0d7dc8888f
github.com/go-sql-driver/mysql v1.4.0
github.com/go-xorm/builder v0.3.2
github.com/go-xorm/core v0.6.0
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a // indirect
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect
github.com/jackc/pgx v3.2.0+incompatible
github.com/kr/pretty v0.1.0 // indirect
github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.9.0
github.com/pkg/errors v0.8.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect
github.com/stretchr/testify v1.2.2
github.com/ziutek/mymysql v1.5.4
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/stretchr/testify.v1 v1.2.2
)

43
vendor/github.com/go-xorm/xorm/go.sum generated vendored Normal file
View file

@ -0,0 +1,43 @@
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20181014144952-4e0d7dc8888f h1:WH0w/R4Yoey+04HhFxqZ6VX6I0d7RMyw5aXQ9UTvQPs=
github.com/denisenkom/go-mssqldb v0.0.0-20181014144952-4e0d7dc8888f/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc=
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-xorm/builder v0.3.2 h1:pSsZQRRzJNapKEAEhigw3xLmiLPeAYv5GFlpYZ8+a5I=
github.com/go-xorm/builder v0.3.2/go.mod h1:v8mE3MFBgtL+RGFNfUnAMUqqfk/Y4W5KuwCFQIEpQLk=
github.com/go-xorm/core v0.6.0 h1:tp6hX+ku4OD9khFZS8VGBDRY3kfVCtelPfmkgCyHxL0=
github.com/go-xorm/core v0.6.0/go.mod h1:d8FJ9Br8OGyQl12MCclmYBuBqqxsyeedpXciV5Myih8=
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y=
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ=
github.com/jackc/pgx v3.2.0+incompatible h1:0Vihzu20St42/UDsvZGdNE6jak7oi/UOeMzwMPHkgFY=
github.com/jackc/pgx v3.2.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs=
github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/stretchr/testify.v1 v1.2.2 h1:yhQC6Uy5CqibAIlk1wlusa/MJ3iAN49/BsR/dCCKz3M=
gopkg.in/stretchr/testify.v1 v1.2.2/go.mod h1:QI5V/q6UbPmuhtm10CaFZxED9NreB8PnFYN9JcR6TxU=

View file

@ -11,7 +11,6 @@ import (
"sort"
"strconv"
"strings"
"time"
"github.com/go-xorm/core"
)
@ -293,19 +292,6 @@ func structName(v reflect.Type) string {
return v.Name()
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
func sliceEq(left, right []string) bool {
if len(left) != len(right) {
return false
@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool {
return true
}
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if col.IsDeleted {
continue
}
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}

116
vendor/github.com/go-xorm/xorm/interface.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"reflect"
"time"
"github.com/go-xorm/core"
)
// Interface defines the interface which Engine, EngineGroup and Session will implementate.
type Interface interface {
AllCols() *Session
Alias(alias string) *Session
Asc(colNames ...string) *Session
BufferSize(size int) *Session
Cols(columns ...string) *Session
Count(...interface{}) (int64, error)
CreateIndexes(bean interface{}) error
CreateUniques(bean interface{}) error
Decr(column string, arg ...interface{}) *Session
Desc(...string) *Session
Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error
Exec(sqlOrAgrs ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error)
Get(interface{}) (bool, error)
GroupBy(keys string) *Session
ID(interface{}) *Session
In(string, ...interface{}) *Session
Incr(column string, arg ...interface{}) *Session
Insert(...interface{}) (int64, error)
InsertOne(interface{}) (int64, error)
IsTableEmpty(bean interface{}) (bool, error)
IsTableExist(beanOrTableName interface{}) (bool, error)
Iterate(interface{}, IterFunc) error
Limit(int, ...int) *Session
MustCols(columns ...string) *Session
NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
Omit(columns ...string) *Session
OrderBy(order string) *Session
Ping() error
Query(sqlOrAgrs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error)
QueryString(sqlorArgs ...interface{}) ([]map[string]string, error)
Rows(bean interface{}) (*Rows, error)
SetExpr(string, string) *Session
SQL(interface{}, ...interface{}) *Session
Sum(bean interface{}, colName string) (float64, error)
SumInt(bean interface{}, colName string) (int64, error)
Sums(bean interface{}, colNames ...string) ([]float64, error)
SumsInt(bean interface{}, colNames ...string) ([]int64, error)
Table(tableNameOrBean interface{}) *Session
Unscoped() *Session
Update(bean interface{}, condiBeans ...interface{}) (int64, error)
UseBool(...string) *Session
Where(interface{}, ...interface{}) *Session
}
// EngineInterface defines the interface which Engine, EngineGroup will implementate.
type EngineInterface interface {
Interface
Before(func(interface{})) *Session
Charset(charset string) *Session
ClearCache(...interface{}) error
CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error)
Dialect() core.Dialect
DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error
GetCacher(string) core.Cacher
GetColumnMapper() core.IMapper
GetDefaultCacher() core.Cacher
GetTableMapper() core.IMapper
GetTZDatabase() *time.Location
GetTZLocation() *time.Location
MapCacher(interface{}, core.Cacher) error
NewSession() *Session
NoAutoTime() *Session
Quote(string) string
SetCacher(string, core.Cacher)
SetConnMaxLifetime(time.Duration)
SetDefaultCacher(core.Cacher)
SetLogger(logger core.ILogger)
SetLogLevel(core.LogLevel)
SetMapper(core.IMapper)
SetMaxOpenConns(int)
SetMaxIdleConns(int)
SetSchema(string)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowExecTime(...bool)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}
var (
_ Interface = &Session{}
_ EngineInterface = &Engine{}
_ EngineInterface = &EngineGroup{}
)

View file

@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var args []interface{}
var err error
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
if err = rows.session.statement.setRefBean(bean); err != nil {
return nil, err
}
@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
}
dataStruct := rValue(bean)
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
if err := rows.session.statement.setRefBean(bean); err != nil {
return err
}
@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}
dataStruct := rValue(bean)
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil {
return err

View file

@ -76,6 +76,7 @@ func (session *Session) Init() {
session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0)
session.beforeClosures = make([]func(interface{}), 0)
session.afterClosures = make([]func(interface{}), 0)
session.stmtCache = make(map[uint32]*core.Stmt)
session.afterProcessors = make([]executedProcessor, 0)
@ -101,6 +102,12 @@ func (session *Session) Close() {
}
}
// ContextCache enable context cache or not
func (session *Session) ContextCache(context ContextCache) *Session {
session.statement.context = context
return session
}
// IsClosed returns if session is closed
func (session *Session) IsClosed() bool {
return session.db == nil
@ -262,13 +269,13 @@ func (session *Session) canCache() bool {
return true
}
func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, err error) {
crc := crc32.ChecksumIEEE([]byte(sqlStr))
// TODO try hash(sqlStr+len(sqlStr))
var has bool
stmt, has = session.stmtCache[crc]
if !has {
stmt, err = session.DB().Prepare(sqlStr)
stmt, err = db.Prepare(sqlStr)
if err != nil {
return nil, err
}
@ -277,24 +284,22 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) {
return
}
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) {
var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil {
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
return nil
return nil, ErrFieldIsNotExist{key, table.Name}
}
fieldValue, err := col.ValueOfV(dataStruct)
if err != nil {
session.engine.logger.Error(err)
return nil
return nil, err
}
if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
return nil
return nil, ErrFieldIsNotValid{key, table.Name}
}
return fieldValue
return fieldValue, nil
}
// Cell cell is a result of one column field
@ -406,405 +411,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
tempMap[lKey] = idx
if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil {
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
// if row is null then ignore
if rawValue.Interface() == nil {
continue
fieldValue, err := session.getField(dataStruct, key, table, idx)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
session.engine.logger.Warn(err)
}
continue
}
if fieldValue == nil {
continue
}
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if err := structConvert.FromDB(data); err != nil {
return nil, err
}
} else {
// if row is null then ignore
if rawValue.Interface() == nil {
continue
}
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if err := structConvert.FromDB(data); err != nil {
return nil, err
}
continue
}
}
if _, ok := fieldValue.Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue.Interface().(core.Conversion).FromDB(data)
} else {
return nil, err
}
continue
}
}
rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
col := table.GetColumnIdx(key, idx)
if col.IsPrimaryKey {
pk = append(pk, rawValue.Interface())
if _, ok := fieldValue.Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue.Interface().(core.Conversion).FromDB(data)
} else {
return nil, err
}
fieldType := fieldValue.Type()
hasAssigned := false
continue
}
if col.SQLType.IsJson() {
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
col := table.GetColumnIdx(key, idx)
if col.IsPrimaryKey {
pk = append(pk, rawValue.Interface())
}
fieldType := fieldValue.Type()
hasAssigned := false
if col.SQLType.IsJson() {
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
} else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
}
hasAssigned = true
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
continue
}
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
}
hasAssigned = true
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
continue
}
switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128:
// TODO: reimplement this
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
}
continue
}
hasAssigned = true
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128:
// TODO: reimplement this
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
bs = vv.Bytes()
}
hasAssigned = true
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}
case reflect.Slice, reflect.Array:
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array:
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array:
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
hasAssigned = true
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
}
}
}
}
case reflect.String:
if rawValueType.Kind() == reflect.String {
hasAssigned = true
fieldValue.SetString(vv.String())
}
case reflect.Bool:
if rawValueType.Kind() == reflect.Bool {
hasAssigned = true
fieldValue.SetBool(vv.Bool())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetInt(vv.Int())
}
case reflect.Float32, reflect.Float64:
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
hasAssigned = true
fieldValue.SetFloat(vv.Float())
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true
fieldValue.SetUint(vv.Uint())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetUint(uint64(vv.Int()))
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
if rawValueType == core.TimeType {
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
hasAssigned = true
t := vv.Convert(core.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return nil, err
}
hasAssigned = true
if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err
}
if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return nil, errors.New("cascade obj is not exist")
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
}
}
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
} // switch fieldType.Kind()
}
case reflect.String:
if rawValueType.Kind() == reflect.String {
hasAssigned = true
fieldValue.SetString(vv.String())
}
case reflect.Bool:
if rawValueType.Kind() == reflect.Bool {
hasAssigned = true
fieldValue.SetBool(vv.Bool())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetInt(vv.Int())
}
case reflect.Float32, reflect.Float64:
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
hasAssigned = true
fieldValue.SetFloat(vv.Float())
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true
fieldValue.SetUint(vv.Uint())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetUint(uint64(vv.Int()))
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value
if !hasAssigned {
data, err := value2Bytes(&rawValue)
if rawValueType == core.TimeType {
hasAssigned = true
t := vv.Convert(core.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
if err != nil {
return nil, err
}
if err = session.bytes2Value(col, fieldValue, data); err != nil {
hasAssigned = true
if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err
}
if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return nil, errors.New("cascade obj is not exist")
}
}
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
} // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value
if !hasAssigned {
data, err := value2Bytes(&rawValue)
if err != nil {
return nil, err
}
if err = session.bytes2Value(col, fieldValue, data); err != nil {
return nil, err
}
}
}
@ -823,15 +840,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
// Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session {
session.statement.Unscoped()

View file

@ -4,6 +4,121 @@
package xorm
import (
"reflect"
"strings"
"time"
"github.com/go-xorm/core"
)
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
type columnMap []string
func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}
n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}
return false
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.statement.Incr(column, arg...)

View file

@ -34,27 +34,27 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil {
x = time.Unix(sd, 0)
session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
@ -68,7 +68,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
outErr = fmt.Errorf("unsupported time format %v", sdata)
return

View file

@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed
}
cacher := session.engine.getCacher2(table)
cacher := session.engine.getCacher(tableName)
pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close()
}
if err := session.statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}
@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
})
}
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
}

View file

@ -10,6 +10,7 @@ import (
"reflect"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
// Exist returns true if the record exist otherwise return false
@ -35,10 +36,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, err
}
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
}
args = condArgs
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
}
args = []interface{}{}
}
} else {
@ -48,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
if err := session.statement.setRefBean(bean[0]); err != nil {
return false, err
}
}

View file

@ -29,6 +29,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.find(rowsSlicePtr, condiBean...)
}
// FindAndCount find the results and also return the counts
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
err := session.find(rowsSlicePtr, condiBean...)
if err != nil {
return 0, err
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return 0, errors.New("needs a pointer to a slice or a map")
}
sliceElementType := sliceValue.Type().Elem()
if sliceElementType.Kind() == reflect.Ptr {
sliceElementType = sliceElementType.Elem()
}
session.autoResetStatement = true
if session.statement.selectStr != "" {
session.statement.selectStr = ""
}
if session.statement.OrderStr != "" {
session.statement.OrderStr = ""
}
return session.Count(reflect.New(sliceElementType).Interface())
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@ -42,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv); err != nil {
return err
}
} else {
@ -50,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
} else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv.Elem()); err != nil {
if err := session.statement.setRefValue(pv); err != nil {
return err
}
} else {
@ -102,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
@ -110,7 +143,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = "*"
}
@ -128,7 +161,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return err
}
@ -143,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
if session.canCache() {
if cacher := session.engine.getCacher2(table); cacher != nil &&
if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.IsDistinct &&
!session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
@ -288,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
if cacher == nil {
return nil
}
for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
@ -297,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed
}
tableName := session.statement.TableName()
table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
rows, err := session.queryRows(newsql, args...)

View file

@ -5,7 +5,9 @@
package xorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
@ -30,7 +32,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return false, err
}
}
@ -56,7 +58,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher2(table); cacher != nil &&
if cacher := session.engine.getCacher(table.Name); cacher != nil &&
!session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed {
@ -65,7 +67,28 @@ func (session *Session) get(bean interface{}) (bool, error) {
}
}
return session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
context := session.statement.context
if context != nil {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil {
structValue := reflect.Indirect(reflect.ValueOf(bean))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
session.lastSQL = ""
session.lastSQLArgs = nil
return true, nil
}
}
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
if err != nil || !has {
return has, err
}
if context != nil {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean)
}
return true, nil
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
@ -76,9 +99,19 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, nil
}
switch bean.(type) {
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString:
return true, rows.Scan(&bean)
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
return true, rows.Scan(bean)
}
switch beanKind {
case reflect.Struct:
fields, err := rows.Columns()
@ -126,8 +159,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed
}
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)

View file

@ -66,11 +66,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice")
}
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
@ -115,15 +116,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col)
@ -170,15 +167,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted {
continue
}
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col)
@ -211,38 +204,33 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
}
cleanupProcessorsClosures(&session.beforeClosures)
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string
var tableName = session.statement.TableName()
var sql string
if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
statement = fmt.Sprintf(sql,
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp))
} else {
statement = fmt.Sprintf(sql,
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(statement, args...)
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)
lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ {
@ -298,7 +286,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
@ -316,8 +304,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
// --
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
@ -400,11 +388,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
@ -445,11 +431,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc(bean)
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
@ -490,9 +474,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean)
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
@ -539,16 +521,104 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return session.innerInsert(bean)
}
func (session *Session) cacheInsert(table *core.Table, tables ...string) error {
if table == nil {
return ErrCacheFailed
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.getCacher2(table)
for _, t := range tables {
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
cacher := session.engine.getCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debug("[cache] clear sql:", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == core.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}
if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}

View file

@ -8,17 +8,86 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 {
return convertSQLOrArgs(sqlorArgs...)
}
if session.statement.RawSQL != "" {
return session.statement.RawSQL, session.statement.RawParams, nil
}
if len(session.statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = session.statement.ColumnStr
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := session.statement.processIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return "", nil, err
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
func (session *Session) Query(sqlorArgs ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
return session.queryBytes(sqlStr, args...)
}
@ -97,6 +166,34 @@ func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string,
return result, nil
}
func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) {
result := make([]string, 0, len(fields))
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for i := 0; i < len(fields); i++ {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i]))
// if row is null then as empty string
if rawValue.Interface() == nil {
result = append(result, "")
continue
}
if data, err := value2String(&rawValue); err == nil {
result = append(result, data)
} else {
return nil, err
}
}
return result, nil
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns()
if err != nil {
@ -113,12 +210,33 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
return resultsSlice, nil
}
func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
record, err := row2sliceStr(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, record)
}
return resultsSlice, nil
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
@ -128,6 +246,26 @@ func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[s
return rows2Strings(rows)
}
// QuerySliceString runs a raw sql and return records as [][]string
func (session *Session) QuerySliceString(sqlorArgs ...interface{}) ([][]string, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2SliceString(rows)
}
func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) {
resultsMap = make(map[string]interface{}, len(fields))
scanResultContainers := make([]interface{}, len(fields))
@ -162,11 +300,16 @@ func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, er
}
// QueryInterface runs a raw sql and return records as []map[string]interface{}
func (session *Session) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) {
func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) {
if session.isAutoClose {
defer session.Close()
}
sqlStr, args, err := session.genQuerySQL(sqlorArgs...)
if err != nil {
return nil, err
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err

View file

@ -9,6 +9,7 @@ import (
"reflect"
"time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
@ -47,9 +48,16 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
}
if session.isAutoCommit {
var db *core.DB
if session.engine.engineGroup != nil {
db = session.engine.engineGroup.Slave().DB()
} else {
db = session.DB()
}
if session.prepareStmt {
// don't clear stmt since session will cache them
stmt, err := session.doPrepare(sqlStr)
stmt, err := session.doPrepare(db, sqlStr)
if err != nil {
return nil, err
}
@ -61,7 +69,7 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
return rows, nil
}
rows, err := session.DB().Query(sqlStr, args...)
rows, err := db.Query(sqlStr, args...)
if err != nil {
return nil, err
}
@ -171,7 +179,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
}
if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr)
stmt, err := session.doPrepare(session.DB(), sqlStr)
if err != nil {
return nil, err
}
@ -186,11 +194,34 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return session.DB().Exec(sqlStr, args...)
}
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {
switch sqlorArgs[0].(type) {
case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil
case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlorArgs[0].(builder.Builder)
return bd.ToSQL()
}
return "", nil, ErrUnSupportedType
}
// Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
func (session *Session) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
if session.isAutoClose {
defer session.Close()
}
if len(sqlorArgs) == 0 {
return nil, ErrUnSupportedType
}
sqlStr, args, err := convertSQLOrArgs(sqlorArgs...)
if err != nil {
return nil, err
}
return session.exec(sqlStr, args...)
}

View file

@ -6,9 +6,7 @@ package xorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"github.com/go-xorm/core"
@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error {
}
func (session *Session) createTable(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
}
func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error {
}
func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error {
}
func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}
tableName := session.engine.TableName(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr)
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr)
return err
}
return nil
@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close()
}
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
tableName := session.engine.TableName(beanOrTableName)
return session.isTableExist(tableName)
}
@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
// IsTableEmpty if table have any records
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean)
t := v.Type()
if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
return rows == 0, err
if session.isAutoClose {
defer session.Close()
}
return false, errors.New("bean should be a struct or struct's point")
return session.isTableEmpty(session.engine.TableName(bean))
}
func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
@ -255,6 +233,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var structTables []*core.Table
for _, bean := range beans {
@ -264,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err
}
structTables = append(structTables, table)
var tbName = session.tbNameNoSchema(table)
tbName := engine.TableName(bean)
tbNameWithSchema := engine.TableName(tbName, true)
var oriTable *core.Table
for _, tb := range tables {
@ -309,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
}
} else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
}
@ -348,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
} else {
session.statement.RefTable = table
session.statement.tableName = tbName
session.statement.tableName = tbNameWithSchema
err = session.addColumn(col.Name)
}
if err != nil {
@ -371,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
@ -387,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
@ -398,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addUnique(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addIndex(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return err
@ -428,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
}
}
}

View file

@ -24,6 +24,7 @@ func (session *Session) Rollback() error {
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.engine.dialect.RollBackStr())
session.isCommitedOrRollbacked = true
session.isAutoCommit = true
return session.tx.Rollback()
}
return nil
@ -34,6 +35,7 @@ func (session *Session) Commit() error {
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL("COMMIT")
session.isCommitedOrRollbacked = true
session.isAutoCommit = true
var err error
if err = session.tx.Commit(); err == nil {
// handle processors after tx committed

View file

@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
}
}
cacher := session.engine.getCacher2(table)
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}
@ -176,12 +176,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if session.statement.ColumnStr == "" {
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false,
false, false, session.statement.allUseBool, session.statement.useAllCols,
session.statement.mustColumnMap, session.statement.nullableMap,
session.statement.columnMap, true, session.statement.unscoped)
colNames, args = session.statement.buildUpdates(bean, false, false,
false, false, true)
} else {
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true)
colNames, args, err = session.genUpdateColumns(bean)
if err != nil {
return 0, err
}
@ -202,7 +200,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table := session.statement.RefTable
if session.statement.UseAutoTime && table != nil && table.Updated != "" {
if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok {
if !session.statement.columnMap.contain(table.Updated) &&
!session.statement.omitColumnMap.contain(table.Updated) {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn()
val, t := session.engine.nowTime(col)
@ -242,10 +241,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var autoCond builder.Cond
if !session.statement.noAutoCondition && len(condiBean) > 0 {
var err error
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
if c, ok := condiBean[0].(map[string]interface{}); ok {
autoCond = builder.Eq(c)
} else {
ct := reflect.TypeOf(condiBean[0])
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
var err error
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
} else {
return 0, ErrConditionType
}
}
}
@ -349,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
if table != nil {
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
// handle after update processors
@ -389,3 +400,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return res.RowsAffected()
}
func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated {
continue
}
if len(session.statement.columnMap) > 0 {
if !session.statement.columnMap.contain(col.Name) {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
}
return colNames, args, nil
}

View file

@ -5,7 +5,6 @@
package xorm
import (
"bytes"
"database/sql/driver"
"encoding/json"
"errors"
@ -18,21 +17,6 @@ import (
"github.com/go-xorm/core"
)
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
// Statement save all the sql info for executing SQL
type Statement struct {
RefTable *core.Table
@ -47,7 +31,6 @@ type Statement struct {
HavingStr string
ColumnStr string
selectStr string
columnMap map[string]bool
useAllCols bool
OmitStr string
AltTableName string
@ -67,6 +50,8 @@ type Statement struct {
allUseBool bool
checkVersion bool
unscoped bool
columnMap columnMap
omitColumnMap columnMap
mustColumnMap map[string]bool
nullableMap map[string]bool
incrColumns map[string]incrParam
@ -74,6 +59,7 @@ type Statement struct {
exprColumns map[string]exprParam
cond builder.Cond
bufferSize int
context ContextCache
}
// Init reset all the statement's fields
@ -89,7 +75,8 @@ func (statement *Statement) Init() {
statement.HavingStr = ""
statement.ColumnStr = ""
statement.OmitStr = ""
statement.columnMap = make(map[string]bool)
statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{}
statement.AltTableName = ""
statement.tableName = ""
statement.idParam = nil
@ -113,6 +100,7 @@ func (statement *Statement) Init() {
statement.exprColumns = make(map[string]exprParam)
statement.cond = builder.NewCond()
statement.bufferSize = 0
statement.context = nil
}
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function
@ -160,6 +148,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.And(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.And(cond)
@ -181,6 +172,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{}))
statement.cond = statement.cond.Or(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.Or(cond)
@ -215,34 +209,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.tbName(v)
statement.tableName = statement.Engine.TableName(v, true)
return nil
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
func (statement *Statement) setRefBean(bean interface{}) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
if err != nil {
return err
}
return statement
statement.tableName = statement.Engine.TableName(bean, true)
return nil
}
// Auto generating update columnes and values according a struct
func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool,
mustColumnMap map[string]bool, nullableMap map[string]bool,
columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
func (statement *Statement) buildUpdates(bean interface{},
includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}) {
engine := statement.Engine
table := statement.RefTable
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.mustColumnMap
nullableMap := statement.nullableMap
columnMap := statement.columnMap
omitColumnMap := statement.omitColumnMap
unscoped := statement.unscoped
var colNames = make([]string, 0)
var args = make([]interface{}, 0)
@ -262,7 +255,14 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if col.IsDeleted && !unscoped {
continue
}
if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use {
if omitColumnMap.contain(col.Name) {
continue
}
if len(columnMap) > 0 && !columnMap.contain(col.Name) {
continue
}
if col.MapType == core.ONLYFROMDB {
continue
}
@ -598,17 +598,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
}
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, 0, len(statement.columnMap))
for col := range statement.columnMap {
fields := strings.Split(strings.TrimSpace(col), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
copy(newColumns, statement.columnMap)
for i := 0; i < len(statement.columnMap); i++ {
newColumns[i] = statement.Engine.Quote(newColumns[i])
}
return newColumns
}
@ -636,10 +629,11 @@ func (statement *Statement) Select(str string) *Statement {
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.columnMap[strings.ToLower(nc)] = true
statement.columnMap.add(nc)
}
newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement
@ -674,7 +668,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.columnMap[strings.ToLower(nc)] = false
statement.omitColumnMap = append(statement.omitColumnMap, nc)
}
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
}
@ -713,10 +707,9 @@ func (statement *Statement) OrderBy(order string) *Statement {
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf bytes.Buffer
fmt.Fprintf(&buf, statement.OrderStr)
var buf builder.StringBuilder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ")
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
@ -726,10 +719,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement {
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf bytes.Buffer
fmt.Fprintf(&buf, statement.OrderStr)
var buf builder.StringBuilder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ")
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
@ -737,48 +729,35 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer
var buf builder.StringBuilder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.String {
table = f.(string)
} else if t.Kind() == reflect.Struct {
table = statement.Engine.tbName(v)
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
default:
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
tbName := statement.Engine.TableName(tablename, true)
fmt.Fprintf(&buf, " ON %v", condition)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
@ -803,18 +782,20 @@ func (statement *Statement) Unscoped() *Statement {
}
func (statement *Statement) genColumnStr() string {
var buf bytes.Buffer
if statement.RefTable == nil {
return ""
}
var buf builder.StringBuilder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitStr != "" {
if _, ok := getFlagForColumn(statement.columnMap, col); ok {
continue
}
if statement.omitColumnMap.contain(col.Name) {
continue
}
if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
continue
}
if col.MapType == core.ONLYTODB {
@ -825,10 +806,6 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ")
}
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
buf.WriteString("id() AS ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
@ -853,11 +830,13 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
quote := statement.Engine.Quote
for idxName, index := range statement.RefTable.Indexes {
for _, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))
sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql)
}
}
@ -883,16 +862,18 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == core.UniqueType {
rIdxName = uniqueName(tbName, idxName)
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType {
rIdxName = indexName(tbName, idxName)
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
}
sqls = append(sqls, sql)
}
@ -901,8 +882,12 @@ func (statement *Statement) genDelIndexSQL() []string {
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := statement.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect))
if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
return sql, []interface{}{}
}
@ -939,7 +924,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.setRefValue(v)
statement.setRefBean(bean)
}
var columnStr = statement.ColumnStr
@ -950,7 +935,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
@ -958,7 +943,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
}
}
}
@ -972,13 +957,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
} else {
if err := statement.processIDParam(); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
@ -991,7 +980,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
statement.setRefBean(beans[0])
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
@ -1008,7 +997,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)"
}
}
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
if err != nil {
return "", nil, err
}
@ -1017,7 +1006,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
}
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean))
statement.setRefBean(bean)
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
@ -1033,7 +1022,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
if err != nil {
return "", nil, err
}
@ -1041,27 +1030,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
var distinct string
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var (
distinct string
dialect = statement.Engine.Dialect()
quote = statement.Engine.Quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
}
var dialect = statement.Engine.Dialect()
var quote = statement.Engine.Quote
var top string
var mssqlCondi string
if err := statement.processIDParam(); err != nil {
return "", err
}
var buf bytes.Buffer
if len(condSQL) > 0 {
fmt.Fprintf(&buf, " WHERE %v", condSQL)
whereStr = " WHERE " + condSQL
}
var whereStr = buf.String()
var fromStr = " FROM "
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
@ -1108,9 +1090,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
}
var orderStr string
if len(statement.OrderStr) > 0 {
if needOrderBy && len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}
var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
@ -1120,45 +1103,50 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
}
}
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
var buf builder.StringBuilder
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 {
if len(whereStr) > 0 {
a += " AND " + mssqlCondi
fmt.Fprint(&buf, " AND ", mssqlCondi)
} else {
a += " WHERE " + mssqlCondi
fmt.Fprint(&buf, " WHERE ", mssqlCondi)
}
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
}
if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr)
fmt.Fprint(&buf, " ", statement.HavingStr)
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
if needOrderBy && statement.OrderStr != "" {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
oldString := buf.String()
buf.Reset()
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
}
}
}
if statement.IsForUpdate {
a = dialect.ForUpdateSql(a)
return dialect.ForUpdateSql(buf.String()), nil
}
return
return buf.String(), nil
}
func (statement *Statement) processIDParam() error {
if statement.idParam == nil {
if statement.idParam == nil || statement.RefTable == nil {
return nil
}

26
vendor/github.com/go-xorm/xorm/transaction.go generated vendored Normal file
View file

@ -0,0 +1,26 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
// Transaction Execute sql wrapped in a transaction(abbr as tx), tx will automatic commit if no errors occurred
func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interface{}, error) {
session := engine.NewSession()
defer session.Close()
if err := session.Begin(); err != nil {
return nil, err
}
result, err := f(session)
if err != nil {
return nil, err
}
if err := session.Commit(); err != nil {
return nil, err
}
return result, nil
}

View file

@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package xorm
import (
@ -17,7 +19,7 @@ import (
const (
// Version show the xorm's version
Version string = "0.6.4.0910"
Version string = "0.7.0.0504"
)
func regDrvsNDialects() bool {
@ -31,7 +33,7 @@ func regDrvsNDialects() bool {
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
@ -90,6 +92,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TagIdentifier: "xorm",
TZLocation: time.Local,
tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
}
if uri.DbType == core.SQLITE {
@ -108,6 +111,13 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return engine, nil
}
// NewEngineWithParams new a db manager with params. The params will be passed to dialect.
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) {
engine, err := NewEngine(driverName, dataSourceName)
engine.dialect.SetParams(params)
return engine, err
}
// Clone clone an engine
func (engine *Engine) Clone() (*Engine, error) {
return NewEngine(engine.DriverName(), engine.DataSourceName())

Some files were not shown because too many files have changed in this diff Show more