diff --git a/go.mod b/go.mod index 5990fce7..dd82a0d5 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ module code.vikunja.io/api require ( 4d63.com/tz v1.2.0 - code.vikunja.io/web v0.0.0-20200809154828-8767618f181f + code.vikunja.io/web v0.0.0-20201223143420-588abb73703a dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 // indirect gitea.com/xorm/xorm-redis-cache v0.2.0 github.com/adlio/trello v1.8.0 @@ -41,6 +41,7 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/go-testfixtures/testfixtures/v3 v3.4.1 github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 + github.com/golang/snappy v0.0.2 // indirect github.com/gordonklaus/ineffassign v0.0.0-20201107091007-3b93a8888063 github.com/iancoleman/strcase v0.1.2 github.com/imdario/mergo v0.3.11 @@ -52,6 +53,7 @@ require ( github.com/lib/pq v1.9.0 github.com/magefile/mage v1.10.0 github.com/mailru/easyjson v0.7.6 // indirect + github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-sqlite3 v1.14.5 github.com/mitchellh/mapstructure v1.3.2 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect @@ -76,8 +78,10 @@ require ( golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad golang.org/x/image v0.0.0-20201208152932-35266b937fa6 golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 + golang.org/x/net v0.0.0-20201216054612-986b41b23924 // indirect golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a + golang.org/x/sys v0.0.0-20201223074533-0d417f636930 // indirect golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect @@ -91,7 +95,7 @@ require ( src.techknowlogick.com/xormigrate v1.4.0 xorm.io/builder v0.3.7 xorm.io/core v0.7.3 - xorm.io/xorm v1.0.2 + xorm.io/xorm v1.0.5 ) replace ( diff --git a/go.sum b/go.sum index 7457923a..50006f83 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,12 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -code.vikunja.io/web v0.0.0-20200809154828-8767618f181f h1:Zgtk9lbJkGbKjdTC78mg/c2uNkesxDJs1YUIL9zGvco= -code.vikunja.io/web v0.0.0-20200809154828-8767618f181f/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo= +code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7 h1:iS3TFA+y1If6DEbqzad5Ge7TI1NxZr9BevC/dU4ygEo= +code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo= +code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215 h1:O5zMWgcnVDVLaQUawgdsv/jX/4SUUAvSedvRR+5+x2o= +code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU= +code.vikunja.io/web v0.0.0-20201223143420-588abb73703a h1:LaWCucY5Pp30EIMgGOvdVFNss5OhIAwrAO8PuFVRUfw= +code.vikunja.io/web v0.0.0-20201223143420-588abb73703a/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU= dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 h1:o4lAkfETerCnr1kF9/qwkwjICnU+YLHNDCM8h2xj7as= dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363/go.mod h1:WG7q7swWsS2f9PYpt5DoEP/EBYWx8We5UoRltn9vJl8= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= @@ -152,6 +156,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o= github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204 h1:tI48fqaIkxxYuIylVv1tdDfBp6836GKSfmmzgSyP1CY= github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= @@ -293,6 +298,8 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= +github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -491,6 +498,7 @@ github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.8.0 h1:9xohqzkUwzR4Ga4ivdTcawVS89YSDVxXMa3xJX3cGzg= github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= @@ -516,6 +524,8 @@ github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+v github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= @@ -852,8 +862,6 @@ golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201217014255-9d1352758620 h1:3wPMTskHO3+O6jqTEXyFcsnuxMQOqYSaHsDxcbUXpqA= -golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -944,6 +952,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2l golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201216054612-986b41b23924 h1:QsnDpLLOKwHBBDa8nDws4DYNc/ryVW2vCpxCs09d4PY= +golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= @@ -1028,8 +1038,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuF golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs= golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0 h1:n+DPcgTwkgWzIFpLmoimYR2K2b0Ga5+Os4kayIN0vGo= +golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201223074533-0d417f636930 h1:vRgIt+nup/B/BwIS0g2oC0haq0iqbV3ZA+u6+0TlNCo= +golang.org/x/sys v0.0.0-20201223074533-0d417f636930/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1284,3 +1299,5 @@ xorm.io/xorm v1.0.1 h1:/lITxpJtkZauNpdzj+L9CN/3OQxZaABrbergMcJu+Cw= xorm.io/xorm v1.0.1/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= xorm.io/xorm v1.0.2 h1:kZlCh9rqd1AzGwWitcrEEqHE1h1eaZE/ujU5/2tWEtg= xorm.io/xorm v1.0.2/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= +xorm.io/xorm v1.0.5 h1:LRr5PfOUb4ODPR63YwbowkNDwcolT2LnkwP/TUaMaB0= +xorm.io/xorm v1.0.5/go.mod h1:uF9EtbhODq5kNWxMbnBEj8hRRZnlcNSz2t2N7HW/+A4= diff --git a/pkg/cmd/user.go b/pkg/cmd/user.go index 88a82771..9237fc94 100644 --- a/pkg/cmd/user.go +++ b/pkg/cmd/user.go @@ -24,6 +24,7 @@ import ( "strings" "time" + "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/initialize" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" @@ -31,6 +32,7 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/term" + "xorm.io/xorm" ) var ( @@ -91,13 +93,13 @@ func getPasswordFromFlagOrInput() (pw string) { return } -func getUserFromArg(arg string) *user.User { +func getUserFromArg(s *xorm.Session, arg string) *user.User { id, err := strconv.ParseInt(arg, 10, 64) if err != nil { log.Fatalf("Invalid user id: %s", err) } - u, err := user.GetUserByID(id) + u, err := user.GetUserByID(s, id) if err != nil { log.Fatalf("Could not get user: %s", err) } @@ -116,8 +118,16 @@ var userListCmd = &cobra.Command{ initialize.FullInit() }, Run: func(cmd *cobra.Command, args []string) { - users, err := user.ListUsers("") + s := db.NewSession() + defer s.Close() + + users, err := user.ListUsers(s, "") if err != nil { + _ = s.Rollback() + log.Fatalf("Error getting users: %s", err) + } + + if err := s.Commit(); err != nil { log.Fatalf("Error getting users: %s", err) } @@ -153,21 +163,30 @@ var userCreateCmd = &cobra.Command{ initialize.FullInit() }, Run: func(cmd *cobra.Command, args []string) { + s := db.NewSession() + defer s.Close() + u := &user.User{ Username: userFlagUsername, Email: userFlagEmail, Password: getPasswordFromFlagOrInput(), } - newUser, err := user.CreateUser(u) + newUser, err := user.CreateUser(s, u) if err != nil { + _ = s.Rollback() log.Fatalf("Error creating new user: %s", err) } - err = models.CreateNewNamespaceForUser(newUser) + err = models.CreateNewNamespaceForUser(s, newUser) if err != nil { + _ = s.Rollback() log.Fatalf("Error creating new namespace for user: %s", err) } + if err := s.Commit(); err != nil { + log.Fatalf("Error saving everything: %s", err) + } + fmt.Printf("\nUser was created successfully.\n") }, } @@ -180,7 +199,10 @@ var userUpdateCmd = &cobra.Command{ initialize.FullInit() }, Run: func(cmd *cobra.Command, args []string) { - u := getUserFromArg(args[0]) + s := db.NewSession() + defer s.Close() + + u := getUserFromArg(s, args[0]) if userFlagUsername != "" { u.Username = userFlagUsername @@ -192,11 +214,16 @@ var userUpdateCmd = &cobra.Command{ u.AvatarProvider = userFlagAvatar } - _, err := user.UpdateUser(u) + _, err := user.UpdateUser(s, u) if err != nil { + _ = s.Rollback() log.Fatalf("Error updating the user: %s", err) } + if err := s.Commit(); err != nil { + log.Fatalf("Error saving everything: %s", err) + } + fmt.Println("User updated successfully.") }, } @@ -209,22 +236,31 @@ var userResetPasswordCmd = &cobra.Command{ }, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - u := getUserFromArg(args[0]) + s := db.NewSession() + defer s.Close() + + u := getUserFromArg(s, args[0]) // By default we reset as usual, only with specific flag directly. if userFlagResetPasswordDirectly { - err := user.UpdateUserPassword(u, getPasswordFromFlagOrInput()) + err := user.UpdateUserPassword(s, u, getPasswordFromFlagOrInput()) if err != nil { + _ = s.Rollback() log.Fatalf("Could not update user password: %s", err) } fmt.Println("Password updated successfully.") } else { - err := user.RequestUserPasswordResetToken(u) + err := user.RequestUserPasswordResetToken(s, u) if err != nil { + _ = s.Rollback() log.Fatalf("Could not send password reset email: %s", err) } fmt.Println("Password reset email sent successfully.") } + + if err := s.Commit(); err != nil { + log.Fatalf("Could not send password reset email: %s", err) + } }, } @@ -236,7 +272,10 @@ var userChangeEnabledCmd = &cobra.Command{ }, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - u := getUserFromArg(args[0]) + s := db.NewSession() + defer s.Close() + + u := getUserFromArg(s, args[0]) if userFlagEnableUser { u.IsActive = true @@ -245,11 +284,16 @@ var userChangeEnabledCmd = &cobra.Command{ } else { u.IsActive = !u.IsActive } - _, err := user.UpdateUser(u) + _, err := user.UpdateUser(s, u) if err != nil { + _ = s.Rollback() log.Fatalf("Could not enable the user") } + if err := s.Commit(); err != nil { + log.Fatalf("Error saving everything: %s", err) + } + fmt.Printf("User status successfully changed, user is now active: %t.\n", u.IsActive) }, } diff --git a/pkg/db/db.go b/pkg/db/db.go index 542e0cb6..706456d2 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -31,6 +31,7 @@ import ( "xorm.io/core" "xorm.io/xorm" "xorm.io/xorm/caches" + "xorm.io/xorm/schemas" _ "github.com/go-sql-driver/mysql" // Because. _ "github.com/lib/pq" // Because. @@ -211,3 +212,13 @@ func WipeEverything() error { return nil } + +// NewSession creates a new xorm session +func NewSession() *xorm.Session { + return x.NewSession() +} + +// Type returns the db type of the currently configured db +func Type() schemas.DBType { + return x.Dialect().URI().DBType +} diff --git a/pkg/files/files.go b/pkg/files/files.go index d2c290fd..da2cca4a 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -22,6 +22,7 @@ import ( "time" "code.vikunja.io/api/pkg/config" + "code.vikunja.io/api/pkg/db" "code.vikunja.io/web" "github.com/c2h5oh/datasize" "github.com/spf13/afero" @@ -93,27 +94,44 @@ func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, m Mime: mime, } - _, err = x.Insert(file) + s := db.NewSession() + defer s.Close() + + _, err = s.Insert(file) if err != nil { + _ = s.Rollback() return } // Save the file to storage with its new ID as path err = file.Save(f) + if err != nil { + _ = s.Rollback() + return + } return } // Delete removes a file from the DB and the file system func (f *File) Delete() (err error) { - deleted, err := x.Where("id = ?", f.ID).Delete(f) + s := db.NewSession() + defer s.Close() + + deleted, err := s.Where("id = ?", f.ID).Delete(f) if err != nil { + _ = s.Rollback() return err } if deleted == 0 { + _ = s.Rollback() return ErrFileDoesNotExist{FileID: f.ID} } err = afs.Remove(f.getFileName()) + if err != nil { + _ = s.Rollback() + return err + } return } diff --git a/pkg/models/bulk_task.go b/pkg/models/bulk_task.go index d5541ff9..e1a42e0e 100644 --- a/pkg/models/bulk_task.go +++ b/pkg/models/bulk_task.go @@ -19,6 +19,7 @@ package models import ( "code.vikunja.io/web" "github.com/imdario/mergo" + "xorm.io/xorm" ) // BulkTask is the definition of a bulk update task @@ -29,9 +30,9 @@ type BulkTask struct { Task } -func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) { +func (bt *BulkTask) checkIfTasksAreOnTheSameList(s *xorm.Session) (err error) { // Get the tasks - err = bt.GetTasksByIDs() + err = bt.GetTasksByIDs(s) if err != nil { return err } @@ -52,16 +53,16 @@ func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) { } // CanUpdate checks if a user is allowed to update a task -func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) { +func (bt *BulkTask) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { - err := bt.checkIfTasksAreOnTheSameList() + err := bt.checkIfTasksAreOnTheSameList(s) if err != nil { return false, err } // A user can update an task if he has write acces to its list l := &List{ID: bt.Tasks[0].ListID} - return l.CanWrite(a) + return l.CanWrite(s, a) } // Update updates a bunch of tasks at once @@ -77,23 +78,14 @@ func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its list)" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/bulk [post] -func (bt *BulkTask) Update() (err error) { - - sess := x.NewSession() - defer sess.Close() - - err = sess.Begin() - if err != nil { - return - } - +func (bt *BulkTask) Update(s *xorm.Session) (err error) { for _, oldtask := range bt.Tasks { // When a repeating task is marked as done, we update all deadlines and reminders and set it as undone updateDone(oldtask, &bt.Task) // Update the assignees - if err := oldtask.updateTaskAssignees(sess, bt.Assignees); err != nil { + if err := oldtask.updateTaskAssignees(s, bt.Assignees); err != nil { return err } @@ -109,7 +101,7 @@ func (bt *BulkTask) Update() (err error) { oldtask.Done = false } - _, err = sess.ID(oldtask.ID). + _, err = s.ID(oldtask.ID). Cols("title", "description", "done", @@ -121,15 +113,9 @@ func (bt *BulkTask) Update() (err error) { "end_date"). Update(oldtask) if err != nil { - _ = sess.Rollback() return err } } - err = sess.Commit() - if err != nil { - return - } - return } diff --git a/pkg/models/bulk_task_test.go b/pkg/models/bulk_task_test.go index 1a02ed06..7ead77c5 100644 --- a/pkg/models/bulk_task_test.go +++ b/pkg/models/bulk_task_test.go @@ -57,18 +57,22 @@ func TestBulkTask_Update(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + bt := &BulkTask{ IDs: tt.fields.IDs, Tasks: tt.fields.Tasks, Task: tt.fields.Task, } - allowed, _ := bt.CanUpdate(tt.fields.User) + allowed, _ := bt.CanUpdate(s, tt.fields.User) if !allowed != tt.wantForbidden { t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden) } - if err := bt.Update(); (err != nil) != tt.wantErr { + if err := bt.Update(s); (err != nil) != tt.wantErr { t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr) } + + s.Close() }) } } diff --git a/pkg/models/kanban.go b/pkg/models/kanban.go index 47d05231..e24f9361 100644 --- a/pkg/models/kanban.go +++ b/pkg/models/kanban.go @@ -97,14 +97,14 @@ func getDefaultBucket(s *xorm.Session, listID int64) (bucket *Bucket, err error) // @Success 200 {array} models.Bucket "The buckets with their tasks" // @Failure 500 {object} models.Message "Internal server error" // @Router /lists/{id}/buckets [get] -func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (b *Bucket) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Note: I'm ignoring pagination for now since I've yet to figure out a way on how to make it work // I'll probably just don't do it and instead make individual tasks archivable. // Get all buckets for this list buckets := []*Bucket{} - err = x.Where("list_id = ?", b.ListID).Find(&buckets) + err = s.Where("list_id = ?", b.ListID).Find(&buckets) if err != nil { return } @@ -119,7 +119,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r // Get all users users := make(map[int64]*user.User) - err = x.In("id", userIDs).Find(&users) + err = s.In("id", userIDs).Find(&users) if err != nil { return } @@ -132,7 +132,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r b.TaskCollection.ListID = b.ListID b.TaskCollection.OrderBy = []string{string(orderAscending)} b.TaskCollection.SortBy = []string{taskPropertyPosition} - ts, _, _, err := b.TaskCollection.ReadAll(auth, "", -1, 0) + ts, _, _, err := b.TaskCollection.ReadAll(s, auth, "", -1, 0) if err != nil { return } @@ -168,10 +168,10 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r // @Failure 404 {object} web.HTTPError "The list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/buckets [put] -func (b *Bucket) Create(a web.Auth) (err error) { +func (b *Bucket) Create(s *xorm.Session, a web.Auth) (err error) { b.CreatedByID = a.GetID() - _, err = x.Insert(b) + _, err = s.Insert(b) return } @@ -190,8 +190,8 @@ func (b *Bucket) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "The bucket does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/buckets/{bucketID} [post] -func (b *Bucket) Update() (err error) { - _, err = x.Where("id = ?", b.ID).Update(b) +func (b *Bucket) Update(s *xorm.Session) (err error) { + _, err = s.Where("id = ?", b.ID).Update(b) return } @@ -208,14 +208,11 @@ func (b *Bucket) Update() (err error) { // @Failure 404 {object} web.HTTPError "The bucket does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/buckets/{bucketID} [delete] -func (b *Bucket) Delete() (err error) { - - s := x.NewSession() +func (b *Bucket) Delete(s *xorm.Session) (err error) { // Prevent removing the last bucket total, err := s.Where("list_id = ?", b.ListID).Count(&Bucket{}) if err != nil { - _ = s.Rollback() return } if total <= 1 { @@ -228,23 +225,19 @@ func (b *Bucket) Delete() (err error) { // Remove the bucket itself _, err = s.Where("id = ?", b.ID).Delete(&Bucket{}) if err != nil { - _ = s.Rollback() return } // Get the default bucket defaultBucket, err := getDefaultBucket(s, b.ListID) if err != nil { - _ = s.Rollback() return } // Remove all associations of tasks to that bucket - _, err = s.Where("bucket_id = ?", b.ID).Cols("bucket_id").Update(&Task{BucketID: defaultBucket.ID}) - if err != nil { - _ = s.Rollback() - return - } - - return s.Commit() + _, err = s. + Where("bucket_id = ?", b.ID). + Cols("bucket_id"). + Update(&Task{BucketID: defaultBucket.ID}) + return } diff --git a/pkg/models/kanban_rights.go b/pkg/models/kanban_rights.go index 83881f1e..acf10bb3 100644 --- a/pkg/models/kanban_rights.go +++ b/pkg/models/kanban_rights.go @@ -16,30 +16,33 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanCreate checks if a user can create a new bucket -func (b *Bucket) CanCreate(a web.Auth) (bool, error) { +func (b *Bucket) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { l := &List{ID: b.ListID} - return l.CanWrite(a) + return l.CanWrite(s, a) } // CanUpdate checks if a user can update an existing bucket -func (b *Bucket) CanUpdate(a web.Auth) (bool, error) { - return b.canDoBucket(a) +func (b *Bucket) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return b.canDoBucket(s, a) } // CanDelete checks if a user can delete an existing bucket -func (b *Bucket) CanDelete(a web.Auth) (bool, error) { - return b.canDoBucket(a) +func (b *Bucket) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return b.canDoBucket(s, a) } // canDoBucket checks if the bucket exists and if the user has the right to act on it -func (b *Bucket) canDoBucket(a web.Auth) (bool, error) { - bb, err := getBucketByID(x.NewSession(), b.ID) +func (b *Bucket) canDoBucket(s *xorm.Session, a web.Auth) (bool, error) { + bb, err := getBucketByID(s, b.ID) if err != nil { return false, err } l := &List{ID: bb.ListID} - return l.CanWrite(a) + return l.CanWrite(s, a) } diff --git a/pkg/models/kanban_test.go b/pkg/models/kanban_test.go index 58d5e61d..59fab2d0 100644 --- a/pkg/models/kanban_test.go +++ b/pkg/models/kanban_test.go @@ -27,10 +27,12 @@ import ( func TestBucket_ReadAll(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() testuser := &user.User{ID: 1} b := &Bucket{ListID: 1} - bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0) + bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0) assert.NoError(t, err) buckets, is := bucketsInterface.([]*Bucket) @@ -66,6 +68,8 @@ func TestBucket_ReadAll(t *testing.T) { }) t.Run("filtered", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() testuser := &user.User{ID: 1} b := &Bucket{ @@ -76,7 +80,7 @@ func TestBucket_ReadAll(t *testing.T) { FilterValue: []string{"done"}, }, } - bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0) + bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0) assert.NoError(t, err) buckets := bucketsInterface.([]*Bucket) @@ -88,16 +92,21 @@ func TestBucket_ReadAll(t *testing.T) { func TestBucket_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + b := &Bucket{ ID: 2, // The second bucket only has 3 tasks ListID: 1, } - err := b.Delete() + err := b.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) // Assert all tasks have been moved to bucket 1 as that one is the first tasks := []*Task{} - err = x.Where("bucket_id = ?", 1).Find(&tasks) + err = s.Where("bucket_id = ?", 1).Find(&tasks) assert.NoError(t, err) assert.Len(t, tasks, 15) db.AssertMissing(t, "buckets", map[string]interface{}{ @@ -107,13 +116,19 @@ func TestBucket_Delete(t *testing.T) { }) t.Run("last bucket in list", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + b := &Bucket{ ID: 34, ListID: 18, } - err := b.Delete() + err := b.Delete(s) assert.Error(t, err) assert.True(t, IsErrCannotRemoveLastBucket(err)) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "buckets", map[string]interface{}{ "id": 34, "list_id": 18, diff --git a/pkg/models/label.go b/pkg/models/label.go index 9786c044..d6697862 100644 --- a/pkg/models/label.go +++ b/pkg/models/label.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // Label represents a label @@ -64,7 +65,7 @@ func (Label) TableName() string { // @Failure 400 {object} web.HTTPError "Invalid label object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /labels [put] -func (l *Label) Create(a web.Auth) (err error) { +func (l *Label) Create(s *xorm.Session, a web.Auth) (err error) { u, err := user.GetFromAuth(a) if err != nil { return @@ -73,7 +74,7 @@ func (l *Label) Create(a web.Auth) (err error) { l.CreatedBy = u l.CreatedByID = u.ID - _, err = x.Insert(l) + _, err = s.Insert(l) return } @@ -92,8 +93,8 @@ func (l *Label) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "Label not found." // @Failure 500 {object} models.Message "Internal error" // @Router /labels/{id} [put] -func (l *Label) Update() (err error) { - _, err = x. +func (l *Label) Update(s *xorm.Session) (err error) { + _, err = s. ID(l.ID). Cols( "title", @@ -105,7 +106,7 @@ func (l *Label) Update() (err error) { return } - err = l.ReadOne() + err = l.ReadOne(s) return } @@ -122,8 +123,8 @@ func (l *Label) Update() (err error) { // @Failure 404 {object} web.HTTPError "Label not found." // @Failure 500 {object} models.Message "Internal error" // @Router /labels/{id} [delete] -func (l *Label) Delete() (err error) { - _, err = x.ID(l.ID).Delete(&Label{}) +func (l *Label) Delete(s *xorm.Session) (err error) { + _, err = s.ID(l.ID).Delete(&Label{}) return err } @@ -140,7 +141,7 @@ func (l *Label) Delete() (err error) { // @Success 200 {array} models.Label "The labels" // @Failure 500 {object} models.Message "Internal error" // @Router /labels [get] -func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) { +func (l *Label) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) { if _, is := a.(*LinkSharing); is { return nil, 0, 0, ErrGenericForbidden{} } @@ -148,12 +149,12 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in u := &user.User{ID: a.GetID()} // Get all tasks - taskIDs, err := getUserTaskIDs(u) + taskIDs, err := getUserTaskIDs(s, u) if err != nil { return nil, 0, 0, err } - return getLabelsByTaskIDs(&LabelByTaskIDsOptions{ + return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{ Search: search, User: u, TaskIDs: taskIDs, @@ -177,25 +178,25 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in // @Failure 404 {object} web.HTTPError "Label not found" // @Failure 500 {object} models.Message "Internal error" // @Router /labels/{id} [get] -func (l *Label) ReadOne() (err error) { - label, err := getLabelByIDSimple(l.ID) +func (l *Label) ReadOne(s *xorm.Session) (err error) { + label, err := getLabelByIDSimple(s, l.ID) if err != nil { return err } *l = *label - user, err := user.GetUserByID(l.CreatedByID) + u, err := user.GetUserByID(s, l.CreatedByID) if err != nil { return err } - l.CreatedBy = user + l.CreatedBy = u return } -func getLabelByIDSimple(labelID int64) (*Label, error) { +func getLabelByIDSimple(s *xorm.Session, labelID int64) (*Label, error) { label := Label{} - exists, err := x.ID(labelID).Get(&label) + exists, err := s.ID(labelID).Get(&label) if err != nil { return &label, err } @@ -207,18 +208,21 @@ func getLabelByIDSimple(labelID int64) (*Label, error) { } // Helper method to get all task ids a user has -func getUserTaskIDs(u *user.User) (taskIDs []int64, err error) { +func getUserTaskIDs(s *xorm.Session, u *user.User) (taskIDs []int64, err error) { // Get all lists - lists, _, _, err := getRawListsForUser(&listOptions{ - user: u, - page: -1, - }) + lists, _, _, err := getRawListsForUser( + s, + &listOptions{ + user: u, + page: -1, + }, + ) if err != nil { return nil, err } - tasks, _, _, err := getRawTasksForLists(lists, u, &taskOptions{ + tasks, _, _, err := getRawTasksForLists(s, lists, u, &taskOptions{ page: -1, perPage: 0, }) diff --git a/pkg/models/label_rights.go b/pkg/models/label_rights.go index 73eb1197..c6919301 100644 --- a/pkg/models/label_rights.go +++ b/pkg/models/label_rights.go @@ -20,26 +20,27 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "xorm.io/builder" + "xorm.io/xorm" ) // CanUpdate checks if a user can update a label -func (l *Label) CanUpdate(a web.Auth) (bool, error) { - return l.isLabelOwner(a) // Only owners should be allowed to update a label +func (l *Label) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return l.isLabelOwner(s, a) // Only owners should be allowed to update a label } // CanDelete checks if a user can delete a label -func (l *Label) CanDelete(a web.Auth) (bool, error) { - return l.isLabelOwner(a) // Only owners should be allowed to delete a label +func (l *Label) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return l.isLabelOwner(s, a) // Only owners should be allowed to delete a label } // CanRead checks if a user can read a label -func (l *Label) CanRead(a web.Auth) (bool, int, error) { - return l.hasAccessToLabel(a) +func (l *Label) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { + return l.hasAccessToLabel(s, a) } // CanCreate checks if the user can create a label // Currently a dummy. -func (l *Label) CanCreate(a web.Auth) (bool, error) { +func (l *Label) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { if _, is := a.(*LinkSharing); is { return false, nil } @@ -47,13 +48,13 @@ func (l *Label) CanCreate(a web.Auth) (bool, error) { return true, nil } -func (l *Label) isLabelOwner(a web.Auth) (bool, error) { +func (l *Label) isLabelOwner(s *xorm.Session, a web.Auth) (bool, error) { if _, is := a.(*LinkSharing); is { return false, nil } - lorig, err := getLabelByIDSimple(l.ID) + lorig, err := getLabelByIDSimple(s, l.ID) if err != nil { return false, err } @@ -61,19 +62,19 @@ func (l *Label) isLabelOwner(a web.Auth) (bool, error) { } // Helper method to check if a user can see a specific label -func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error) { +func (l *Label) hasAccessToLabel(s *xorm.Session, a web.Auth) (has bool, maxRight int, err error) { // TODO: add an extra check for link share handling // Get all tasks - taskIDs, err := getUserTaskIDs(&user.User{ID: a.GetID()}) + taskIDs, err := getUserTaskIDs(s, &user.User{ID: a.GetID()}) if err != nil { return false, 0, err } // Get all labels associated with these tasks ll := &LabelTask{} - has, err = x.Table("labels"). + has, err = s.Table("labels"). Select("label_task.*"). Join("LEFT", "label_task", "label_task.label_id = labels.id"). Where("label_task.label_id is not null OR labels.created_by_id = ?", a.GetID()). @@ -87,7 +88,7 @@ func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error) // Since the right depends on the task the label is associated with, we need to check that too. if ll.TaskID > 0 { t := &Task{ID: ll.TaskID} - _, maxRight, err = t.CanRead(a) + _, maxRight, err = t.CanRead(s, a) if err != nil { return } diff --git a/pkg/models/label_task.go b/pkg/models/label_task.go index b631ed83..9c0af12a 100644 --- a/pkg/models/label_task.go +++ b/pkg/models/label_task.go @@ -22,10 +22,10 @@ import ( "time" "code.vikunja.io/api/pkg/log" - "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "xorm.io/builder" + "xorm.io/xorm" ) // LabelTask represents a relation between a label and a task @@ -61,8 +61,8 @@ func (LabelTask) TableName() string { // @Failure 404 {object} web.HTTPError "Label not found." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{task}/labels/{label} [delete] -func (lt *LabelTask) Delete() (err error) { - _, err = x.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) +func (lt *LabelTask) Delete(s *xorm.Session) (err error) { + _, err = s.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) return err } @@ -81,9 +81,9 @@ func (lt *LabelTask) Delete() (err error) { // @Failure 404 {object} web.HTTPError "The label does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{task}/labels [put] -func (lt *LabelTask) Create(a web.Auth) (err error) { +func (lt *LabelTask) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the label is already added - exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) + exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) if err != nil { return err } @@ -92,12 +92,12 @@ func (lt *LabelTask) Create(a web.Auth) (err error) { } // Insert it - _, err = x.Insert(lt) + _, err = s.Insert(lt) if err != nil { return err } - err = updateListByTaskID(lt.TaskID) + err = updateListByTaskID(s, lt.TaskID) return } @@ -115,10 +115,10 @@ func (lt *LabelTask) Create(a web.Auth) (err error) { // @Success 200 {array} models.Label "The labels" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{task}/labels [get] -func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (lt *LabelTask) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Check if the user has the right to see the task task := Task{ID: lt.TaskID} - canRead, _, err := task.CanRead(a) + canRead, _, err := task.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -126,7 +126,7 @@ func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) ( return nil, 0, 0, ErrNoRightToSeeTask{lt.TaskID, a.GetID()} } - return getLabelsByTaskIDs(&LabelByTaskIDsOptions{ + return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{ User: &user.User{ID: a.GetID()}, Search: search, Page: page, @@ -153,7 +153,7 @@ type LabelByTaskIDsOptions struct { // Helper function to get all labels for a set of tasks // Used when getting all labels for one task as well when getting all lables -func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) { +func getLabelsByTaskIDs(s *xorm.Session, opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) { // We still need the task ID when we want to get all labels for a task, but because of this, we get the same label // multiple times when it is associated to more than one task. // Because of this whole thing, we need this extra switch here to only group by Task IDs if needed. @@ -194,7 +194,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res limit, start := getLimitFromPageIndex(opts.Page, opts.PerPage) - query := x.Table("labels"). + query := s.Table("labels"). Select(selectStmt). Join("LEFT", "label_task", "label_task.label_id = labels.id"). Where(cond). @@ -214,7 +214,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res userids = append(userids, l.CreatedByID) } users := make(map[int64]*user.User) - err = x.In("id", userids).Find(&users) + err = s.In("id", userids).Find(&users) if err != nil { return nil, 0, 0, err } @@ -230,7 +230,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res } // Get the total number of entries - totalEntries, err = x.Table("labels"). + totalEntries, err = s.Table("labels"). Select("count(DISTINCT labels.id)"). Join("LEFT", "label_task", "label_task.label_id = labels.id"). Where(cond). @@ -244,11 +244,11 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res } // Create or update a bunch of task labels -func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) { +func (t *Task) updateTaskLabels(s *xorm.Session, creator web.Auth, labels []*Label) (err error) { // If we don't have any new labels, delete everything right away. Saves us some hassle. if len(labels) == 0 && len(t.Labels) > 0 { - _, err = x.Where("task_id = ?", t.ID). + _, err = s.Where("task_id = ?", t.ID). Delete(LabelTask{}) return err } @@ -289,7 +289,7 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) { // Delete all labels not passed if len(labelsToDelete) > 0 { - _, err = x.In("label_id", labelsToDelete). + _, err = s.In("label_id", labelsToDelete). And("task_id = ?", t.ID). Delete(LabelTask{}) if err != nil { @@ -306,13 +306,13 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) { } // Add the new label - label, err := getLabelByIDSimple(l.ID) + label, err := getLabelByIDSimple(s, l.ID) if err != nil { return err } // Check if the user has the rights to see the label he is about to add - hasAccessToLabel, _, err := label.hasAccessToLabel(creator) + hasAccessToLabel, _, err := label.hasAccessToLabel(s, creator) if err != nil { return err } @@ -322,14 +322,14 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) { } // Insert it - _, err = x.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID}) + _, err = s.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID}) if err != nil { return err } t.Labels = append(t.Labels, label) } - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -356,12 +356,12 @@ type LabelTaskBulk struct { // @Failure 400 {object} web.HTTPError "Invalid label object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/labels/bulk [post] -func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) { - task, err := GetTaskByIDSimple(ltb.TaskID) +func (ltb *LabelTaskBulk) Create(s *xorm.Session, a web.Auth) (err error) { + task, err := GetTaskByIDSimple(s, ltb.TaskID) if err != nil { return } - labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{ + labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{ TaskIDs: []int64{ltb.TaskID}, }) if err != nil { @@ -370,5 +370,5 @@ func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) { for _, l := range labels { task.Labels = append(task.Labels, &l.Label) } - return task.updateTaskLabels(a, ltb.Labels) + return task.updateTaskLabels(s, a, ltb.Labels) } diff --git a/pkg/models/label_task_rights.go b/pkg/models/label_task_rights.go index f3d24c18..300969e0 100644 --- a/pkg/models/label_task_rights.go +++ b/pkg/models/label_task_rights.go @@ -18,21 +18,22 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if a user can add a label to a task -func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) { - label, err := getLabelByIDSimple(lt.LabelID) +func (lt *LabelTask) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + label, err := getLabelByIDSimple(s, lt.LabelID) if err != nil { return false, err } - hasAccessTolabel, _, err := label.hasAccessToLabel(a) + hasAccessTolabel, _, err := label.hasAccessToLabel(s, a) if err != nil || !hasAccessTolabel { // If the user doesn't have access to the label, we can error out here return false, err } - canDoLabelTask, err := canDoLabelTask(lt.TaskID, a) + canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a) if err != nil { return false, err } @@ -41,8 +42,8 @@ func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) { } // CanDelete checks if a user can delete a label from a task -func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) { - canDoLabelTask, err := canDoLabelTask(lt.TaskID, a) +func (lt *LabelTask) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a) if err != nil { return false, err } @@ -52,7 +53,7 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) { // We don't care here if the label exists or not. The only relevant thing here is if the relation already exists, // throw an error. - exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) + exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) if err != nil { return false, err } @@ -60,18 +61,18 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) { } // CanCreate determines if a user can update a labeltask -func (ltb *LabelTaskBulk) CanCreate(a web.Auth) (bool, error) { - return canDoLabelTask(ltb.TaskID, a) +func (ltb *LabelTaskBulk) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return canDoLabelTask(s, ltb.TaskID, a) } // Helper function to check if a user can write to a task // + is able to see the label // always the same check for either deleting or adding a label to a task -func canDoLabelTask(taskID int64, a web.Auth) (bool, error) { +func canDoLabelTask(s *xorm.Session, taskID int64, a web.Auth) (bool, error) { // A user can add a label to a task if he can write to the task - task, err := GetTaskByIDSimple(taskID) + task, err := GetTaskByIDSimple(s, taskID) if err != nil { return false, err } - return task.CanUpdate(a) + return task.CanUpdate(s, a) } diff --git a/pkg/models/label_task_test.go b/pkg/models/label_task_test.go index 66cc79c9..1a3adf7e 100644 --- a/pkg/models/label_task_test.go +++ b/pkg/models/label_task_test.go @@ -91,6 +91,7 @@ func TestLabelTask_ReadAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() l := &LabelTask{ ID: tt.fields.ID, @@ -100,7 +101,7 @@ func TestLabelTask_ReadAll(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - gotLabels, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0) + gotLabels, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0) if (err != nil) != tt.wantErr { t.Errorf("LabelTask.ReadAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -111,6 +112,8 @@ func TestLabelTask_ReadAll(t *testing.T) { if diff, equal := messagediff.PrettyDiff(gotLabels, tt.wantLabels); !equal { t.Errorf("LabelTask.ReadAll() = %v, want %v, diff: %v", l, tt.wantLabels, diff) } + + s.Close() }) } } @@ -186,6 +189,8 @@ func TestLabelTask_Create(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + l := &LabelTask{ ID: tt.fields.ID, TaskID: tt.fields.TaskID, @@ -194,11 +199,11 @@ func TestLabelTask_Create(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - allowed, _ := l.CanCreate(tt.args.a) + allowed, _ := l.CanCreate(s, tt.args.a) if !allowed && !tt.wantForbidden { t.Errorf("LabelTask.CanCreate() forbidden, want %v", tt.wantForbidden) } - err := l.Create(tt.args.a) + err := l.Create(s, tt.args.a) if (err != nil) != tt.wantErr { t.Errorf("LabelTask.Create() error = %v, wantErr %v", err, tt.wantErr) } @@ -212,6 +217,7 @@ func TestLabelTask_Create(t *testing.T) { "label_id": l.LabelID, }, false) } + s.Close() }) } } @@ -282,6 +288,8 @@ func TestLabelTask_Delete(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + l := &LabelTask{ ID: tt.fields.ID, TaskID: tt.fields.TaskID, @@ -290,11 +298,11 @@ func TestLabelTask_Delete(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - allowed, _ := l.CanDelete(tt.auth) + allowed, _ := l.CanDelete(s, tt.auth) if !allowed && !tt.wantForbidden { t.Errorf("LabelTask.CanDelete() forbidden, want %v", tt.wantForbidden) } - err := l.Delete() + err := l.Delete(s) if (err != nil) != tt.wantErr { t.Errorf("LabelTask.Delete() error = %v, wantErr %v", err, tt.wantErr) } @@ -307,6 +315,7 @@ func TestLabelTask_Delete(t *testing.T) { "task_id": l.TaskID, }) } + s.Close() }) } } diff --git a/pkg/models/label_test.go b/pkg/models/label_test.go index 3139fa64..64175f76 100644 --- a/pkg/models/label_test.go +++ b/pkg/models/label_test.go @@ -133,7 +133,8 @@ func TestLabel_ReadAll(t *testing.T) { Rights: tt.fields.Rights, } db.LoadAndAssertFixtures(t) - gotLs, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0) + s := db.NewSession() + gotLs, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0) if (err != nil) != tt.wantErr { t.Errorf("Label.ReadAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -141,6 +142,7 @@ func TestLabel_ReadAll(t *testing.T) { if diff, equal := messagediff.PrettyDiff(gotLs, tt.wantLs); !equal { t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", gotLs, tt.wantLs, diff) } + s.Close() }) } } @@ -249,11 +251,13 @@ func TestLabel_ReadOne(t *testing.T) { Rights: tt.fields.Rights, } - allowed, _, _ := l.CanRead(tt.auth) + s := db.NewSession() + + allowed, _, _ := l.CanRead(s, tt.auth) if !allowed && !tt.wantForbidden { t.Errorf("Label.CanRead() forbidden, want %v", tt.wantForbidden) } - err := l.ReadOne() + err := l.ReadOne(s) if (err != nil) != tt.wantErr { t.Errorf("Label.ReadOne() error = %v, wantErr %v", err, tt.wantErr) } @@ -263,6 +267,8 @@ func TestLabel_ReadOne(t *testing.T) { if diff, equal := messagediff.PrettyDiff(l, tt.want); !equal && !tt.wantErr && !tt.wantForbidden { t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", l, tt.want, diff) } + + s.Close() }) } } @@ -316,11 +322,12 @@ func TestLabel_Create(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - allowed, _ := l.CanCreate(tt.args.a) + s := db.NewSession() + allowed, _ := l.CanCreate(s, tt.args.a) if !allowed && !tt.wantForbidden { t.Errorf("Label.CanCreate() forbidden, want %v", tt.wantForbidden) } - if err := l.Create(tt.args.a); (err != nil) != tt.wantErr { + if err := l.Create(s, tt.args.a); (err != nil) != tt.wantErr { t.Errorf("Label.Create() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr { @@ -331,6 +338,7 @@ func TestLabel_Create(t *testing.T) { "hex_color": l.HexColor, }, false) } + _ = s.Close() }) } } @@ -406,11 +414,12 @@ func TestLabel_Update(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - allowed, _ := l.CanUpdate(tt.auth) + s := db.NewSession() + allowed, _ := l.CanUpdate(s, tt.auth) if !allowed && !tt.wantForbidden { t.Errorf("Label.CanUpdate() forbidden, want %v", tt.wantForbidden) } - if err := l.Update(); (err != nil) != tt.wantErr { + if err := l.Update(s); (err != nil) != tt.wantErr { t.Errorf("Label.Update() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && !tt.wantForbidden { @@ -419,6 +428,7 @@ func TestLabel_Update(t *testing.T) { "title": tt.fields.Title, }, false) } + _ = s.Close() }) } } @@ -490,11 +500,12 @@ func TestLabel_Delete(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - allowed, _ := l.CanDelete(tt.auth) + s := db.NewSession() + allowed, _ := l.CanDelete(s, tt.auth) if !allowed && !tt.wantForbidden { t.Errorf("Label.CanDelete() forbidden, want %v", tt.wantForbidden) } - if err := l.Delete(); (err != nil) != tt.wantErr { + if err := l.Delete(s); (err != nil) != tt.wantErr { t.Errorf("Label.Delete() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && !tt.wantForbidden { @@ -502,6 +513,7 @@ func TestLabel_Delete(t *testing.T) { "id": l.ID, }) } + _ = s.Close() }) } } diff --git a/pkg/models/link_sharing.go b/pkg/models/link_sharing.go index 608240fb..51f321cf 100644 --- a/pkg/models/link_sharing.go +++ b/pkg/models/link_sharing.go @@ -24,6 +24,7 @@ import ( "code.vikunja.io/api/pkg/utils" "code.vikunja.io/web" "github.com/dgrijalva/jwt-go" + "xorm.io/xorm" ) // SharingType holds the sharing type @@ -99,7 +100,7 @@ func GetLinkShareFromClaims(claims jwt.MapClaims) (share *LinkSharing, err error // @Failure 404 {object} web.HTTPError "The list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{list}/shares [put] -func (share *LinkSharing) Create(a web.Auth) (err error) { +func (share *LinkSharing) Create(s *xorm.Session, a web.Auth) (err error) { err = share.Right.isValid() if err != nil { @@ -108,7 +109,7 @@ func (share *LinkSharing) Create(a web.Auth) (err error) { share.SharedByID = a.GetID() share.Hash = utils.MakeRandomString(40) - _, err = x.Insert(share) + _, err = s.Insert(share) share.SharedBy, _ = user.GetFromAuth(a) return } @@ -127,8 +128,8 @@ func (share *LinkSharing) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "Share Link not found." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{list}/shares/{share} [get] -func (share *LinkSharing) ReadOne() (err error) { - exists, err := x.Where("id = ?", share.ID).Get(share) +func (share *LinkSharing) ReadOne(s *xorm.Session) (err error) { + exists, err := s.Where("id = ?", share.ID).Get(share) if err != nil { return err } @@ -152,9 +153,9 @@ func (share *LinkSharing) ReadOne() (err error) { // @Success 200 {array} models.LinkSharing "The share links" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{list}/shares [get] -func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { +func (share *LinkSharing) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { list := &List{ID: share.ListID} - can, _, err := list.CanRead(a) + can, _, err := list.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -165,7 +166,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i limit, start := getLimitFromPageIndex(page, perPage) var shares []*LinkSharing - query := x. + query := s. Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%") if limit > 0 { query = query.Limit(limit, start) @@ -182,7 +183,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i } users := make(map[int64]*user.User) - err = x.In("id", userIDs).Find(&users) + err = s.In("id", userIDs).Find(&users) if err != nil { return nil, 0, 0, err } @@ -192,7 +193,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i } // Total count - totalItems, err = x. + totalItems, err = s. Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%"). Count(&LinkSharing{}) if err != nil { @@ -216,15 +217,15 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i // @Failure 404 {object} web.HTTPError "Share Link not found." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{list}/shares/{share} [delete] -func (share *LinkSharing) Delete() (err error) { - _, err = x.Where("id = ?", share.ID).Delete(share) +func (share *LinkSharing) Delete(s *xorm.Session) (err error) { + _, err = s.Where("id = ?", share.ID).Delete(share) return } // GetLinkShareByHash returns a link share by hash -func GetLinkShareByHash(hash string) (share *LinkSharing, err error) { +func GetLinkShareByHash(s *xorm.Session, hash string) (share *LinkSharing, err error) { share = &LinkSharing{} - has, err := x.Where("hash = ?", hash).Get(share) + has, err := s.Where("hash = ?", hash).Get(share) if err != nil { return } @@ -235,13 +236,12 @@ func GetLinkShareByHash(hash string) (share *LinkSharing, err error) { } // GetListByShareHash returns a link share by its hash -func GetListByShareHash(hash string) (list *List, err error) { - share, err := GetLinkShareByHash(hash) +func GetListByShareHash(s *xorm.Session, hash string) (list *List, err error) { + share, err := GetLinkShareByHash(s, hash) if err != nil { return } - list = &List{ID: share.ListID} - err = list.GetSimpleByID() + list, err = GetListSimpleByID(s, share.ListID) return } diff --git a/pkg/models/link_sharing_rights.go b/pkg/models/link_sharing_rights.go index 5ffdcf6c..1d9f0d40 100644 --- a/pkg/models/link_sharing_rights.go +++ b/pkg/models/link_sharing_rights.go @@ -16,53 +16,55 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanRead implements the read right check for a link share -func (share *LinkSharing) CanRead(a web.Auth) (bool, int, error) { +func (share *LinkSharing) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { // Don't allow creating link shares if the user itself authenticated with a link share if _, is := a.(*LinkSharing); is { return false, 0, nil } - l, err := GetListByShareHash(share.Hash) + l, err := GetListByShareHash(s, share.Hash) if err != nil { return false, 0, err } - return l.CanRead(a) + return l.CanRead(s, a) } // CanDelete implements the delete right check for a link share -func (share *LinkSharing) CanDelete(a web.Auth) (bool, error) { - return share.canDoLinkShare(a) +func (share *LinkSharing) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return share.canDoLinkShare(s, a) } // CanUpdate implements the update right check for a link share -func (share *LinkSharing) CanUpdate(a web.Auth) (bool, error) { - return share.canDoLinkShare(a) +func (share *LinkSharing) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return share.canDoLinkShare(s, a) } // CanCreate implements the create right check for a link share -func (share *LinkSharing) CanCreate(a web.Auth) (bool, error) { - return share.canDoLinkShare(a) +func (share *LinkSharing) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return share.canDoLinkShare(s, a) } -func (share *LinkSharing) canDoLinkShare(a web.Auth) (bool, error) { +func (share *LinkSharing) canDoLinkShare(s *xorm.Session, a web.Auth) (bool, error) { // Don't allow creating link shares if the user itself authenticated with a link share if _, is := a.(*LinkSharing); is { return false, nil } - l := &List{ID: share.ListID} - err := l.GetSimpleByID() + l, err := GetListSimpleByID(s, share.ListID) if err != nil { return false, err } // Check if the user is admin when the link right is admin if share.Right == RightAdmin { - return l.IsAdmin(a) + return l.IsAdmin(s, a) } - return l.CanWrite(a) + return l.CanWrite(s, a) } diff --git a/pkg/models/list.go b/pkg/models/list.go index 21ffa9f0..a50a98a2 100644 --- a/pkg/models/list.go +++ b/pkg/models/list.go @@ -96,9 +96,9 @@ var FavoritesPseudoList = List{ } // GetListsByNamespaceID gets all lists in a namespace -func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error) { +func GetListsByNamespaceID(s *xorm.Session, nID int64, doer *user.User) (lists []*List, err error) { if nID == -1 { - err = x.Select("l.*"). + err = s.Select("l.*"). Table("list"). Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tl.team_id"). @@ -111,7 +111,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error GroupBy("l.id"). Find(&lists) } else { - err = x.Select("l.*"). + err = s.Select("l.*"). Alias("l"). Join("LEFT", []string{"namespaces", "n"}, "l.namespace_id = n.id"). Where("l.is_archived = false"). @@ -124,7 +124,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error } // get more list details - err = AddListDetails(lists) + err = addListDetails(s, lists) return lists, err } @@ -143,33 +143,34 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists [get] -func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { +func (l *List) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { // Check if we're dealing with a share auth shareAuth, ok := a.(*LinkSharing) if ok { - list := &List{ID: shareAuth.ListID} - err := list.GetSimpleByID() + list, err := GetListSimpleByID(s, shareAuth.ListID) if err != nil { return nil, 0, 0, err } lists := []*List{list} - err = AddListDetails(lists) + err = addListDetails(s, lists) return lists, 0, 0, err } - lists, resultCount, totalItems, err := getRawListsForUser(&listOptions{ - search: search, - user: &user.User{ID: a.GetID()}, - page: page, - perPage: perPage, - isArchived: l.IsArchived, - }) + lists, resultCount, totalItems, err := getRawListsForUser( + s, + &listOptions{ + search: search, + user: &user.User{ID: a.GetID()}, + page: page, + perPage: perPage, + isArchived: l.IsArchived, + }) if err != nil { return nil, 0, 0, err } // Add more list details - err = AddListDetails(lists) + err = addListDetails(s, lists) return lists, resultCount, totalItems, err } @@ -185,7 +186,7 @@ func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id} [get] -func (l *List) ReadOne() (err error) { +func (l *List) ReadOne(s *xorm.Session) (err error) { if l.ID == FavoritesPseudoList.ID { // Already "built" the list in CanRead @@ -194,7 +195,7 @@ func (l *List) ReadOne() (err error) { // Check for saved filters if getSavedFilterIDFromListID(l.ID) > 0 { - sf, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(l.ID)) + sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(l.ID)) if err != nil { return err } @@ -206,13 +207,13 @@ func (l *List) ReadOne() (err error) { } // Get list owner - l.Owner, err = user.GetUserByID(l.OwnerID) + l.Owner, err = user.GetUserByID(s, l.OwnerID) if err != nil { return err } // Check if the namespace is archived and set the namespace to archived if it is not already archived individually. if !l.IsArchived { - err = l.CheckIsArchived() + err = l.CheckIsArchived(s) if err != nil { if !IsErrNamespaceIsArchived(err) && !IsErrListIsArchived(err) { return @@ -224,7 +225,7 @@ func (l *List) ReadOne() (err error) { // Get any background information if there is one set if l.BackgroundFileID != 0 { // Unsplash image - l.BackgroundInformation, err = GetUnsplashPhotoByFileID(l.BackgroundFileID) + l.BackgroundInformation, err = GetUnsplashPhotoByFileID(s, l.BackgroundFileID) if err != nil && !files.IsErrFileIsNotUnsplashFile(err) { return } @@ -237,44 +238,33 @@ func (l *List) ReadOne() (err error) { return nil } -// GetSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist. -func (l *List) GetSimpleByID() (err error) { - s := x.NewSession() - err = l.getSimpleByID(s) - if err != nil { - _ = s.Rollback() - return err - } - return nil -} +// GetListSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist. +func GetListSimpleByID(s *xorm.Session, listID int64) (list *List, err error) { -func (l *List) getSimpleByID(s *xorm.Session) (err error) { - if l.ID < 1 { - return ErrListDoesNotExist{ID: l.ID} + list = &List{} + + if listID < 1 { + return nil, ErrListDoesNotExist{ID: listID} } - // We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object, - // leading to not finding anything if the id is good, but for example the title is different. - id := l.ID - *l = List{} - exists, err := s.Where("id = ?", id).Get(l) + exists, err := s.Where("id = ?", listID).Get(list) if err != nil { return } if !exists { - return ErrListDoesNotExist{ID: l.ID} + return nil, ErrListDoesNotExist{ID: listID} } return } // GetListSimplByTaskID gets a list by a task id -func GetListSimplByTaskID(taskID int64) (l *List, err error) { +func GetListSimplByTaskID(s *xorm.Session, taskID int64) (l *List, err error) { // We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object, // leading to not finding anything if the id is good, but for example the title is different. var list List - exists, err := x. + exists, err := s. Select("list.*"). Table(List{}). Join("INNER", "tasks", "list.id = tasks.list_id"). @@ -292,9 +282,9 @@ func GetListSimplByTaskID(taskID int64) (l *List, err error) { } // GetListsByIDs returns a map of lists from a slice with list ids -func GetListsByIDs(listIDs []int64) (lists map[int64]*List, err error) { +func GetListsByIDs(s *xorm.Session, listIDs []int64) (lists map[int64]*List, err error) { lists = make(map[int64]*List, len(listIDs)) - err = x.In("id", listIDs).Find(&lists) + err = s.In("id", listIDs).Find(&lists) return } @@ -307,8 +297,8 @@ type listOptions struct { } // Gets the lists only, without any tasks or so -func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) { - fullUser, err := user.GetUserByID(opts.user.ID) +func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) { + fullUser, err := user.GetUserByID(s, opts.user.ID) if err != nil { return nil, 0, 0, err } @@ -344,7 +334,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota // Gets all Lists where the user is either owner or in a team which has access to the list // Or in a team which has namespace read access - query := x.Select("l.*"). + query := s.Select("l.*"). Table("list"). Alias("l"). Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). @@ -372,7 +362,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota return nil, 0, 0, err } - totalItems, err = x. + totalItems, err = s. Table("list"). Alias("l"). Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). @@ -396,8 +386,8 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota return lists, len(lists), totalItems, err } -// AddListDetails adds owner user objects and list tasks to all lists in the slice -func AddListDetails(lists []*List) (err error) { +// addListDetails adds owner user objects and list tasks to all lists in the slice +func addListDetails(s *xorm.Session, lists []*List) (err error) { var ownerIDs []int64 for _, l := range lists { ownerIDs = append(ownerIDs, l.OwnerID) @@ -405,7 +395,7 @@ func AddListDetails(lists []*List) (err error) { // Get all list owners owners := map[int64]*user.User{} - err = x.In("id", ownerIDs).Find(&owners) + err = s.In("id", ownerIDs).Find(&owners) if err != nil { return } @@ -423,7 +413,7 @@ func AddListDetails(lists []*List) (err error) { // Unsplash background file info us := []*UnsplashPhoto{} - err = x.In("file_id", fileIDs).Find(&us) + err = s.In("file_id", fileIDs).Find(&us) if err != nil { return } @@ -450,15 +440,15 @@ type NamespaceList struct { } // CheckIsArchived returns an ErrListIsArchived or ErrNamespaceIsArchived if the list or its namespace is archived. -func (l *List) CheckIsArchived() (err error) { +func (l *List) CheckIsArchived(s *xorm.Session) (err error) { // When creating a new list, we check if the namespace is archived if l.ID == 0 { n := &Namespace{ID: l.NamespaceID} - return n.CheckIsArchived() + return n.CheckIsArchived(s) } nl := &NamespaceList{} - exists, err := x. + exists, err := s. Table("list"). Join("LEFT", "namespaces", "list.namespace_id = namespaces.id"). Where("list.id = ? AND (list.is_archived = true OR namespaces.is_archived = true)", l.ID). @@ -476,11 +466,11 @@ func (l *List) CheckIsArchived() (err error) { } // CreateOrUpdateList updates a list or creates it if it doesn't exist -func CreateOrUpdateList(list *List) (err error) { +func CreateOrUpdateList(s *xorm.Session, list *List) (err error) { // Check if the namespace exists if list.NamespaceID != 0 && list.NamespaceID != FavoritesPseudoNamespace.ID { - _, err = GetNamespaceByID(list.NamespaceID) + _, err = GetNamespaceByID(s, list.NamespaceID) if err != nil { return err } @@ -488,7 +478,7 @@ func CreateOrUpdateList(list *List) (err error) { // Check if the identifier is unique and not empty if list.Identifier != "" { - exists, err := x. + exists, err := s. Where("identifier = ?", list.Identifier). And("id != ?", list.ID). Exist(&List{}) @@ -501,7 +491,7 @@ func CreateOrUpdateList(list *List) (err error) { } if list.ID == 0 { - _, err = x.Insert(list) + _, err = s.Insert(list) metrics.UpdateCount(1, metrics.ListCountKey) } else { // We need to specify the cols we want to update here to be able to un-archive lists @@ -516,7 +506,7 @@ func CreateOrUpdateList(list *List) (err error) { colsToUpdate = append(colsToUpdate, "description") } - _, err = x. + _, err = s. ID(list.ID). Cols(colsToUpdate...). Update(list) @@ -526,12 +516,13 @@ func CreateOrUpdateList(list *List) (err error) { return } - err = list.GetSimpleByID() + l, err := GetListSimpleByID(s, list.ID) if err != nil { - return + return err } - err = list.ReadOne() + *list = *l + err = list.ReadOne(s) return } @@ -550,33 +541,23 @@ func CreateOrUpdateList(list *List) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id} [post] -func (l *List) Update() (err error) { - return CreateOrUpdateList(l) +func (l *List) Update(s *xorm.Session) (err error) { + return CreateOrUpdateList(s, l) } -func updateListLastUpdated(list *List) (err error) { - s := x.NewSession() - err = updateListLastUpdatedS(s, list) - if err != nil { - _ = s.Rollback() - return err - } - return nil -} - -func updateListLastUpdatedS(s *xorm.Session, list *List) error { +func updateListLastUpdated(s *xorm.Session, list *List) error { _, err := s.ID(list.ID).Cols("updated").Update(list) return err } -func updateListByTaskID(taskID int64) (err error) { +func updateListByTaskID(s *xorm.Session, taskID int64) (err error) { // need to get the task to update the list last updated timestamp - task, err := GetTaskByIDSimple(taskID) + task, err := GetTaskByIDSimple(s, taskID) if err != nil { return err } - return updateListLastUpdated(&List{ID: task.ListID}) + return updateListLastUpdated(s, &List{ID: task.ListID}) } // Create implements the create method of CRUDable @@ -593,8 +574,8 @@ func updateListByTaskID(taskID int64) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{namespaceID}/lists [put] -func (l *List) Create(a web.Auth) (err error) { - err = l.CheckIsArchived() +func (l *List) Create(s *xorm.Session, a web.Auth) (err error) { + err = l.CheckIsArchived(s) if err != nil { return err } @@ -608,7 +589,7 @@ func (l *List) Create(a web.Auth) (err error) { l.Owner = doer l.ID = 0 // Otherwise only the first time a new list would be created - err = CreateOrUpdateList(l) + err = CreateOrUpdateList(s, l) if err != nil { return } @@ -618,7 +599,7 @@ func (l *List) Create(a web.Auth) (err error) { ListID: l.ID, Title: "New Bucket", } - return b.Create(a) + return b.Create(s, a) } // Delete implements the delete method of CRUDable @@ -633,27 +614,27 @@ func (l *List) Create(a web.Auth) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id} [delete] -func (l *List) Delete() (err error) { +func (l *List) Delete(s *xorm.Session) (err error) { // Delete the list - _, err = x.ID(l.ID).Delete(&List{}) + _, err = s.ID(l.ID).Delete(&List{}) if err != nil { return } metrics.UpdateCount(-1, metrics.ListCountKey) // Delete all todotasks on that list - _, err = x.Where("list_id = ?", l.ID).Delete(&Task{}) + _, err = s.Where("list_id = ?", l.ID).Delete(&Task{}) return } // SetListBackground sets a background file as list background in the db -func SetListBackground(listID int64, background *files.File) (err error) { +func SetListBackground(s *xorm.Session, listID int64, background *files.File) (err error) { l := &List{ ID: listID, BackgroundFileID: background.ID, } - _, err = x. + _, err = s. Where("id = ?", l.ID). Cols("background_file_id"). Update(l) diff --git a/pkg/models/list_duplicate.go b/pkg/models/list_duplicate.go index e723f333..7de3007d 100644 --- a/pkg/models/list_duplicate.go +++ b/pkg/models/list_duplicate.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/utils" "code.vikunja.io/web" + "xorm.io/xorm" ) // ListDuplicate holds everything needed to duplicate a list @@ -38,17 +39,17 @@ type ListDuplicate struct { } // CanCreate checks if a user has the right to duplicate a list -func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) { +func (ld *ListDuplicate) CanCreate(s *xorm.Session, a web.Auth) (canCreate bool, err error) { // List Exists + user has read access to list ld.List = &List{ID: ld.ListID} - canRead, _, err := ld.List.CanRead(a) + canRead, _, err := ld.List.CanRead(s, a) if err != nil || !canRead { return canRead, err } // Namespace exists + user has write access to is (-> can create new lists) ld.List.NamespaceID = ld.NamespaceID - return ld.List.CanCreate(a) + return ld.List.CanCreate(s, a) } // Create duplicates a list @@ -66,7 +67,7 @@ func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) { // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/duplicate [put] //nolint:gocyclo -func (ld *ListDuplicate) Create(a web.Auth) (err error) { +func (ld *ListDuplicate) Create(s *xorm.Session, a web.Auth) (err error) { log.Debugf("Duplicating list %d", ld.ListID) @@ -74,7 +75,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { ld.List.Identifier = "" // Reset the identifier to trigger regenerating a new one // Set the owner to the current user ld.List.OwnerID = a.GetID() - if err := CreateOrUpdateList(ld.List); err != nil { + if err := CreateOrUpdateList(s, ld.List); err != nil { // If there is no available unique list identifier, just reset it. if IsErrListIdentifierIsNotUnique(err) { ld.List.Identifier = "" @@ -90,7 +91,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Used to map the newly created tasks to their new buckets bucketMap := make(map[int64]int64) buckets := []*Bucket{} - err = x.Where("list_id = ?", ld.ListID).Find(&buckets) + err = s.Where("list_id = ?", ld.ListID).Find(&buckets) if err != nil { return } @@ -98,7 +99,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { oldID := b.ID b.ID = 0 b.ListID = ld.List.ID - if err := b.Create(a); err != nil { + if err := b.Create(s, a); err != nil { return err } bucketMap[oldID] = b.ID @@ -107,7 +108,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { log.Debugf("Duplicated all buckets from list %d into %d", ld.ListID, ld.List.ID) // Get all tasks + all task details - tasks, _, _, err := getTasksForLists([]*List{{ID: ld.ListID}}, a, &taskOptions{}) + tasks, _, _, err := getTasksForLists(s, []*List{{ID: ld.ListID}}, a, &taskOptions{}) if err != nil { return err } @@ -123,10 +124,8 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { t.ListID = ld.List.ID t.BucketID = bucketMap[t.BucketID] t.UID = "" - s := x.NewSession() err := createTask(s, t, a, false) if err != nil { - _ = s.Rollback() return err } taskMap[oldID] = t.ID @@ -138,7 +137,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Save all attachments // We also duplicate all underlying files since they could be modified in one list which would result in // file changes in the other list which is not something we want. - attachments, err := getTaskAttachmentsByTaskIDs(oldTaskIDs) + attachments, err := getTaskAttachmentsByTaskIDs(s, oldTaskIDs) if err != nil { return err } @@ -164,7 +163,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { return err } - err := attachment.NewAttachment(attachment.File.File, attachment.File.Name, attachment.File.Size, a) + err := attachment.NewAttachment(s, attachment.File.File, attachment.File.Name, attachment.File.Size, a) if err != nil { return err } @@ -180,7 +179,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Copy label tasks (not the labels) labelTasks := []*LabelTask{} - err = x.In("task_id", oldTaskIDs).Find(&labelTasks) + err = s.In("task_id", oldTaskIDs).Find(&labelTasks) if err != nil { return } @@ -188,7 +187,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { for _, lt := range labelTasks { lt.ID = 0 lt.TaskID = taskMap[lt.TaskID] - if _, err := x.Insert(lt); err != nil { + if _, err := s.Insert(lt); err != nil { return err } } @@ -198,7 +197,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Assignees // Only copy those assignees who have access to the task assignees := []*TaskAssginee{} - err = x.In("task_id", oldTaskIDs).Find(&assignees) + err = s.In("task_id", oldTaskIDs).Find(&assignees) if err != nil { return } @@ -207,7 +206,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { ID: taskMap[a.TaskID], ListID: ld.List.ID, } - if err := t.addNewAssigneeByID(a.UserID, ld.List); err != nil { + if err := t.addNewAssigneeByID(s, a.UserID, ld.List); err != nil { if IsErrUserDoesNotHaveAccessToList(err) { continue } @@ -219,14 +218,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Comments comments := []*TaskComment{} - err = x.In("task_id", oldTaskIDs).Find(&comments) + err = s.In("task_id", oldTaskIDs).Find(&comments) if err != nil { return } for _, c := range comments { c.ID = 0 c.TaskID = taskMap[c.TaskID] - if _, err := x.Insert(c); err != nil { + if _, err := s.Insert(c); err != nil { return err } } @@ -237,7 +236,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Low-Effort: Only copy those relations which are between tasks in the same list // because we can do that without a lot of hassle relations := []*TaskRelation{} - err = x.In("task_id", oldTaskIDs).Find(&relations) + err = s.In("task_id", oldTaskIDs).Find(&relations) if err != nil { return } @@ -249,7 +248,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { r.ID = 0 r.OtherTaskID = otherTaskID r.TaskID = taskMap[r.TaskID] - if _, err := x.Insert(r); err != nil { + if _, err := s.Insert(r); err != nil { return err } } @@ -276,19 +275,19 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { } // Get unsplash info if applicable - up, err := GetUnsplashPhotoByFileID(ld.List.BackgroundFileID) + up, err := GetUnsplashPhotoByFileID(s, ld.List.BackgroundFileID) if err != nil && files.IsErrFileIsNotUnsplashFile(err) { return err } if up != nil { up.ID = 0 up.FileID = file.ID - if err := up.Save(); err != nil { + if err := up.Save(s); err != nil { return err } } - if err := SetListBackground(ld.List.ID, file); err != nil { + if err := SetListBackground(s, ld.List.ID, file); err != nil { return err } @@ -298,14 +297,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { // Rights / Shares // To keep it simple(r) we will only copy rights which are directly used with the list, no namespace changes. users := []*ListUser{} - err = x.Where("list_id = ?", ld.ListID).Find(&users) + err = s.Where("list_id = ?", ld.ListID).Find(&users) if err != nil { return } for _, u := range users { u.ID = 0 u.ListID = ld.List.ID - if _, err := x.Insert(u); err != nil { + if _, err := s.Insert(u); err != nil { return err } } @@ -313,21 +312,21 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { log.Debugf("Duplicated user shares from list %d into %d", ld.ListID, ld.List.ID) teams := []*TeamList{} - err = x.Where("list_id = ?", ld.ListID).Find(&teams) + err = s.Where("list_id = ?", ld.ListID).Find(&teams) if err != nil { return } for _, t := range teams { t.ID = 0 t.ListID = ld.List.ID - if _, err := x.Insert(t); err != nil { + if _, err := s.Insert(t); err != nil { return err } } // Generate new link shares if any are available linkShares := []*LinkSharing{} - err = x.Where("list_id = ?", ld.ListID).Find(&linkShares) + err = s.Where("list_id = ?", ld.ListID).Find(&linkShares) if err != nil { return } @@ -335,7 +334,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) { share.ID = 0 share.ListID = ld.List.ID share.Hash = utils.MakeRandomString(40) - if _, err := x.Insert(share); err != nil { + if _, err := s.Insert(share); err != nil { return err } } diff --git a/pkg/models/list_duplicate_test.go b/pkg/models/list_duplicate_test.go index 17385bf6..c1d5c3ae 100644 --- a/pkg/models/list_duplicate_test.go +++ b/pkg/models/list_duplicate_test.go @@ -29,6 +29,8 @@ func TestListDuplicate(t *testing.T) { db.LoadAndAssertFixtures(t) files.InitTestFileFixtures(t) + s := db.NewSession() + defer s.Close() u := &user.User{ ID: 1, @@ -38,10 +40,10 @@ func TestListDuplicate(t *testing.T) { ListID: 1, NamespaceID: 1, } - can, err := l.CanCreate(u) + can, err := l.CanCreate(s, u) assert.NoError(t, err) assert.True(t, can) - err = l.Create(u) + err = l.Create(s, u) assert.NoError(t, err) // To make this test 100% useful, it would need to assert a lot more stuff, but it is good enough for now. // Also, we're lacking utility functions to do all needed assertions. diff --git a/pkg/models/list_rights.go b/pkg/models/list_rights.go index fddcbf8d..aed3f0b7 100644 --- a/pkg/models/list_rights.go +++ b/pkg/models/list_rights.go @@ -20,10 +20,11 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "xorm.io/builder" + "xorm.io/xorm" ) // CanWrite return whether the user can write on that list or not -func (l *List) CanWrite(a web.Auth) (bool, error) { +func (l *List) CanWrite(s *xorm.Session, a web.Auth) (bool, error) { // The favorite list can't be edited if l.ID == FavoritesPseudoList.ID { @@ -31,15 +32,14 @@ func (l *List) CanWrite(a web.Auth) (bool, error) { } // Get the list and check the right - originalList := &List{ID: l.ID} - err := originalList.GetSimpleByID() + originalList, err := GetListSimpleByID(s, l.ID) if err != nil { return false, err } // We put the result of the is archived check in a separate variable to be able to return it later without // needing to recheck it again - errIsArchived := originalList.CheckIsArchived() + errIsArchived := originalList.CheckIsArchived(s) var canWrite bool @@ -59,7 +59,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) { return canWrite, errIsArchived } - canWrite, _, err = originalList.checkRight(a, RightWrite, RightAdmin) + canWrite, _, err = originalList.checkRight(s, a, RightWrite, RightAdmin) if err != nil { return false, err } @@ -67,7 +67,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) { } // CanRead checks if a user has read access to a list -func (l *List) CanRead(a web.Auth) (bool, int, error) { +func (l *List) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { // The favorite list needs a special treatment if l.ID == FavoritesPseudoList.ID { @@ -84,14 +84,18 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) { // Saved Filter Lists need a special case if getSavedFilterIDFromListID(l.ID) > 0 { sf := &SavedFilter{ID: getSavedFilterIDFromListID(l.ID)} - return sf.CanRead(a) + return sf.CanRead(s, a) } // Check if the user is either owner or can read - if err := l.GetSimpleByID(); err != nil { + var err error + originalList, err := GetListSimpleByID(s, l.ID) + if err != nil { return false, 0, err } + *l = *originalList + // Check if we're dealing with a share auth shareAuth, ok := a.(*LinkSharing) if ok { @@ -102,16 +106,16 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) { if l.isOwner(&user.User{ID: a.GetID()}) { return true, int(RightAdmin), nil } - return l.checkRight(a, RightRead, RightWrite, RightAdmin) + return l.checkRight(s, a, RightRead, RightWrite, RightAdmin) } // CanUpdate checks if the user can update a list -func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) { +func (l *List) CanUpdate(s *xorm.Session, a web.Auth) (canUpdate bool, err error) { // The favorite list can't be edited if l.ID == FavoritesPseudoList.ID { return false, nil } - canUpdate, err = l.CanWrite(a) + canUpdate, err = l.CanWrite(s, a) // If the list is archived and the user tries to un-archive it, let the request through if IsErrListIsArchived(err) && !l.IsArchived { err = nil @@ -120,26 +124,25 @@ func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) { } // CanDelete checks if the user can delete a list -func (l *List) CanDelete(a web.Auth) (bool, error) { - return l.IsAdmin(a) +func (l *List) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return l.IsAdmin(s, a) } // CanCreate checks if the user can create a list -func (l *List) CanCreate(a web.Auth) (bool, error) { +func (l *List) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { // A user can create a list if they have write access to the namespace n := &Namespace{ID: l.NamespaceID} - return n.CanWrite(a) + return n.CanWrite(s, a) } // IsAdmin returns whether the user has admin rights on the list or not -func (l *List) IsAdmin(a web.Auth) (bool, error) { +func (l *List) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) { // The favorite list can't be edited if l.ID == FavoritesPseudoList.ID { return false, nil } - originalList := &List{ID: l.ID} - err := originalList.GetSimpleByID() + originalList, err := GetListSimpleByID(s, l.ID) if err != nil { return false, err } @@ -156,7 +159,7 @@ func (l *List) IsAdmin(a web.Auth) (bool, error) { if originalList.isOwner(&user.User{ID: a.GetID()}) { return true, nil } - is, _, err := originalList.checkRight(a, RightAdmin) + is, _, err := originalList.checkRight(s, a, RightAdmin) return is, err } @@ -166,7 +169,7 @@ func (l *List) isOwner(u *user.User) bool { } // Checks n different rights for any given user -func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) { +func (l *List) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) { /* The following loop creates an sql condition like this one: @@ -218,7 +221,7 @@ func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) { r := &allListRights{} var maxRight = 0 - exists, err := x. + exists, err := s. Table("list"). Alias("l"). // User stuff diff --git a/pkg/models/list_team.go b/pkg/models/list_team.go index 371e267c..82619d2d 100644 --- a/pkg/models/list_team.go +++ b/pkg/models/list_team.go @@ -20,6 +20,7 @@ import ( "time" "code.vikunja.io/web" + "xorm.io/xorm" ) // TeamList defines the relation between a team and a list @@ -68,7 +69,7 @@ type TeamWithRight struct { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/teams [put] -func (tl *TeamList) Create(a web.Auth) (err error) { +func (tl *TeamList) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the rights are valid if err = tl.Right.isValid(); err != nil { @@ -76,19 +77,19 @@ func (tl *TeamList) Create(a web.Auth) (err error) { } // Check if the team exists - _, err = GetTeamByID(tl.TeamID) + _, err = GetTeamByID(s, tl.TeamID) if err != nil { return } // Check if the list exists - l := &List{ID: tl.ListID} - if err := l.GetSimpleByID(); err != nil { + l, err := GetListSimpleByID(s, tl.ListID) + if err != nil { return err } // Check if the team is already on the list - exists, err := x.Where("team_id = ?", tl.TeamID). + exists, err := s.Where("team_id = ?", tl.TeamID). And("list_id = ?", tl.ListID). Get(&TeamList{}) if err != nil { @@ -99,12 +100,12 @@ func (tl *TeamList) Create(a web.Auth) (err error) { } // Insert the new team - _, err = x.Insert(tl) + _, err = s.Insert(tl) if err != nil { return err } - err = updateListLastUpdated(l) + err = updateListLastUpdated(s, l) return } @@ -121,16 +122,17 @@ func (tl *TeamList) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "Team or list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/teams/{teamID} [delete] -func (tl *TeamList) Delete() (err error) { +func (tl *TeamList) Delete(s *xorm.Session) (err error) { // Check if the team exists - _, err = GetTeamByID(tl.TeamID) + _, err = GetTeamByID(s, tl.TeamID) if err != nil { return } // Check if the team has access to the list - has, err := x.Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID). + has, err := s. + Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID). Get(&TeamList{}) if err != nil { return @@ -140,14 +142,14 @@ func (tl *TeamList) Delete() (err error) { } // Delete the relation - _, err = x.Where("team_id = ?", tl.TeamID). + _, err = s.Where("team_id = ?", tl.TeamID). And("list_id = ?", tl.ListID). Delete(TeamList{}) if err != nil { return err } - err = updateListLastUpdated(&List{ID: tl.ListID}) + err = updateListLastUpdated(s, &List{ID: tl.ListID}) return } @@ -166,10 +168,10 @@ func (tl *TeamList) Delete() (err error) { // @Failure 403 {object} web.HTTPError "No right to see the list." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/teams [get] -func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { +func (tl *TeamList) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { // Check if the user can read the namespace l := &List{ID: tl.ListID} - canRead, _, err := l.CanRead(a) + canRead, _, err := l.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -181,7 +183,7 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r // Get the teams all := []*TeamWithRight{} - query := x. + query := s. Table("teams"). Join("INNER", "team_list", "team_id = teams.id"). Where("team_list.list_id = ?", tl.ListID). @@ -199,12 +201,12 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r teams = append(teams, &t.Team) } - err = addMoreInfoToTeams(teams) + err = addMoreInfoToTeams(s, teams) if err != nil { return } - totalItems, err = x. + totalItems, err = s. Table("teams"). Join("INNER", "team_list", "team_id = teams.id"). Where("team_list.list_id = ?", tl.ListID). @@ -232,14 +234,14 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r // @Failure 404 {object} web.HTTPError "Team or list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/teams/{teamID} [post] -func (tl *TeamList) Update() (err error) { +func (tl *TeamList) Update(s *xorm.Session) (err error) { // Check if the right is valid if err := tl.Right.isValid(); err != nil { return err } - _, err = x. + _, err = s. Where("list_id = ? AND team_id = ?", tl.ListID, tl.TeamID). Cols("right"). Update(tl) @@ -247,6 +249,6 @@ func (tl *TeamList) Update() (err error) { return err } - err = updateListLastUpdated(&List{ID: tl.ListID}) + err = updateListLastUpdated(s, &List{ID: tl.ListID}) return } diff --git a/pkg/models/list_team_rights.go b/pkg/models/list_team_rights.go index 923c8757..f90f9de5 100644 --- a/pkg/models/list_team_rights.go +++ b/pkg/models/list_team_rights.go @@ -18,29 +18,30 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if the user can create a team <-> list relation -func (tl *TeamList) CanCreate(a web.Auth) (bool, error) { - return tl.canDoTeamList(a) +func (tl *TeamList) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return tl.canDoTeamList(s, a) } // CanDelete checks if the user can delete a team <-> list relation -func (tl *TeamList) CanDelete(a web.Auth) (bool, error) { - return tl.canDoTeamList(a) +func (tl *TeamList) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return tl.canDoTeamList(s, a) } // CanUpdate checks if the user can update a team <-> list relation -func (tl *TeamList) CanUpdate(a web.Auth) (bool, error) { - return tl.canDoTeamList(a) +func (tl *TeamList) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return tl.canDoTeamList(s, a) } -func (tl *TeamList) canDoTeamList(a web.Auth) (bool, error) { +func (tl *TeamList) canDoTeamList(s *xorm.Session, a web.Auth) (bool, error) { // Link shares aren't allowed to do anything if _, is := a.(*LinkSharing); is { return false, nil } l := List{ID: tl.ListID} - return l.IsAdmin(a) + return l.IsAdmin(s, a) } diff --git a/pkg/models/list_team_test.go b/pkg/models/list_team_test.go index 6ff5c695..494f4414 100644 --- a/pkg/models/list_team_test.go +++ b/pkg/models/list_team_test.go @@ -37,20 +37,24 @@ func TestTeamList_ReadAll(t *testing.T) { ListID: 3, } db.LoadAndAssertFixtures(t) - teams, _, _, err := tl.ReadAll(u, "", 1, 50) + s := db.NewSession() + teams, _, _, err := tl.ReadAll(s, u, "", 1, 50) assert.NoError(t, err) assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice) - s := reflect.ValueOf(teams) - assert.Equal(t, s.Len(), 1) + ts := reflect.ValueOf(teams) + assert.Equal(t, ts.Len(), 1) + _ = s.Close() }) t.Run("nonexistant list", func(t *testing.T) { tl := TeamList{ ListID: 99999, } db.LoadAndAssertFixtures(t) - _, _, _, err := tl.ReadAll(u, "", 1, 50) + s := db.NewSession() + _, _, _, err := tl.ReadAll(s, u, "", 1, 50) assert.Error(t, err) assert.True(t, IsErrListDoesNotExist(err)) + _ = s.Close() }) t.Run("namespace owner", func(t *testing.T) { tl := TeamList{ @@ -59,8 +63,10 @@ func TestTeamList_ReadAll(t *testing.T) { Right: RightAdmin, } db.LoadAndAssertFixtures(t) - _, _, _, err := tl.ReadAll(u, "", 1, 50) + s := db.NewSession() + _, _, _, err := tl.ReadAll(s, u, "", 1, 50) assert.NoError(t, err) + _ = s.Close() }) t.Run("no access", func(t *testing.T) { tl := TeamList{ @@ -69,9 +75,11 @@ func TestTeamList_ReadAll(t *testing.T) { Right: RightAdmin, } db.LoadAndAssertFixtures(t) - _, _, _, err := tl.ReadAll(u, "", 1, 50) + s := db.NewSession() + _, _, _, err := tl.ReadAll(s, u, "", 1, 50) assert.Error(t, err) assert.True(t, IsErrNeedToHaveListReadAccess(err)) + _ = s.Close() }) } @@ -79,14 +87,17 @@ func TestTeamList_Create(t *testing.T) { u := &user.User{ID: 1} t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 1, Right: RightAdmin, } - allowed, _ := tl.CanCreate(u) + allowed, _ := tl.CanCreate(s, u) assert.True(t, allowed) - err := tl.Create(u) + err := tl.Create(s, u) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "team_list", map[string]interface{}{ "team_id": 1, @@ -96,56 +107,67 @@ func TestTeamList_Create(t *testing.T) { }) t.Run("team already has access", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 3, Right: RightAdmin, } - err := tl.Create(u) + err := tl.Create(s, u) assert.Error(t, err) assert.True(t, IsErrTeamAlreadyHasAccess(err)) + _ = s.Close() }) t.Run("wrong rights", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 1, Right: RightUnknown, } - err := tl.Create(u) + err := tl.Create(s, u) assert.Error(t, err) assert.True(t, IsErrInvalidRight(err)) + _ = s.Close() }) t.Run("nonexistant team", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 9999, ListID: 1, } - err := tl.Create(u) + err := tl.Create(s, u) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexistant list", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 9999, } - err := tl.Create(u) + err := tl.Create(s, u) assert.Error(t, err) assert.True(t, IsErrListDoesNotExist(err)) + _ = s.Close() }) } func TestTeamList_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 3, } - err := tl.Delete() + err := tl.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertMissing(t, "team_list", map[string]interface{}{ "team_id": 1, @@ -154,23 +176,27 @@ func TestTeamList_Delete(t *testing.T) { }) t.Run("nonexistant team", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 9999, ListID: 1, } - err := tl.Delete() + err := tl.Delete(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexistant list", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := TeamList{ TeamID: 1, ListID: 9999, } - err := tl.Delete() + err := tl.Delete(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotHaveAccessToList(err)) + _ = s.Close() }) } @@ -229,6 +255,7 @@ func TestTeamList_Update(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := &TeamList{ ID: tt.fields.ID, @@ -240,13 +267,15 @@ func TestTeamList_Update(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := tl.Update() + err := tl.Update(s) if (err != nil) != tt.wantErr { t.Errorf("TeamList.Update() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("TeamList.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + err = s.Commit() + assert.NoError(t, err) if !tt.wantErr { db.AssertExists(t, "team_list", map[string]interface{}{ "list_id": tt.fields.ListID, diff --git a/pkg/models/list_test.go b/pkg/models/list_test.go index ab763dd9..450d50d2 100644 --- a/pkg/models/list_test.go +++ b/pkg/models/list_test.go @@ -35,12 +35,15 @@ func TestList_CreateOrUpdate(t *testing.T) { t.Run("create", func(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ Title: "test", Description: "Lorem Ipsum", NamespaceID: 1, } - err := list.Create(usr) + err := list.Create(s, usr) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "list", map[string]interface{}{ "id": list.ID, @@ -51,49 +54,56 @@ func TestList_CreateOrUpdate(t *testing.T) { }) t.Run("nonexistant namespace", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ Title: "test", Description: "Lorem Ipsum", NamespaceID: 999999, } - - err := list.Create(usr) + err := list.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexistant owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() usr := &user.User{ID: 9482385} list := List{ Title: "test", Description: "Lorem Ipsum", NamespaceID: 1, } - err := list.Create(usr) + err := list.Create(s, usr) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) + _ = s.Close() }) t.Run("existing identifier", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ Title: "test", Description: "Lorem Ipsum", Identifier: "test1", NamespaceID: 1, } - - err := list.Create(usr) + err := list.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrListIdentifierIsNotUnique(err)) + _ = s.Close() }) t.Run("non ascii characters", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ Title: "приффки фсем", Description: "Lorem Ipsum", NamespaceID: 1, } - err := list.Create(usr) + err := list.Create(s, usr) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "list", map[string]interface{}{ "id": list.ID, @@ -107,6 +117,7 @@ func TestList_CreateOrUpdate(t *testing.T) { t.Run("update", func(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ ID: 1, Title: "test", @@ -114,7 +125,9 @@ func TestList_CreateOrUpdate(t *testing.T) { NamespaceID: 1, } list.Description = "Lorem Ipsum dolor sit amet." - err := list.Update() + err := list.Update(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "list", map[string]interface{}{ "id": list.ID, @@ -125,37 +138,43 @@ func TestList_CreateOrUpdate(t *testing.T) { }) t.Run("nonexistant", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ ID: 99999999, Title: "test", } - err := list.Update() + err := list.Update(s) assert.Error(t, err) assert.True(t, IsErrListDoesNotExist(err)) + _ = s.Close() }) t.Run("existing identifier", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ Title: "test", Description: "Lorem Ipsum", Identifier: "test1", NamespaceID: 1, } - - err := list.Create(usr) + err := list.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrListIdentifierIsNotUnique(err)) + _ = s.Close() }) }) } func TestList_Delete(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() list := List{ ID: 1, } - err := list.Delete() + err := list.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertMissing(t, "list", map[string]interface{}{ "id": 1, @@ -165,30 +184,34 @@ func TestList_Delete(t *testing.T) { func TestList_ReadAll(t *testing.T) { t.Run("all in namespace", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() // Get all lists for our namespace - lists, err := GetListsByNamespaceID(1, &user.User{}) + lists, err := GetListsByNamespaceID(s, 1, &user.User{}) assert.NoError(t, err) assert.Equal(t, len(lists), 2) + _ = s.Close() }) t.Run("all lists for user", func(t *testing.T) { db.LoadAndAssertFixtures(t) - + s := db.NewSession() u := &user.User{ID: 1} list := List{} - lists3, _, _, err := list.ReadAll(u, "", 1, 50) + lists3, _, _, err := list.ReadAll(s, u, "", 1, 50) assert.NoError(t, err) assert.Equal(t, reflect.TypeOf(lists3).Kind(), reflect.Slice) - s := reflect.ValueOf(lists3) - assert.Equal(t, 16, s.Len()) + ls := reflect.ValueOf(lists3) + assert.Equal(t, 16, ls.Len()) + _ = s.Close() }) t.Run("lists for nonexistant user", func(t *testing.T) { db.LoadAndAssertFixtures(t) - + s := db.NewSession() usr := &user.User{ID: 999999} list := List{} - _, _, _, err := list.ReadAll(usr, "", 1, 50) + _, _, _, err := list.ReadAll(s, usr, "", 1, 50) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) + _ = s.Close() }) } diff --git a/pkg/models/list_users.go b/pkg/models/list_users.go index b9ada17f..3c55a67d 100644 --- a/pkg/models/list_users.go +++ b/pkg/models/list_users.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // ListUser represents a list <-> user relation @@ -71,7 +72,7 @@ type UserWithRight struct { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/users [put] -func (lu *ListUser) Create(a web.Auth) (err error) { +func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the right is valid if err := lu.Right.isValid(); err != nil { @@ -79,17 +80,17 @@ func (lu *ListUser) Create(a web.Auth) (err error) { } // Check if the list exists - l := &List{ID: lu.ListID} - if err = l.GetSimpleByID(); err != nil { + l, err := GetListSimpleByID(s, lu.ListID) + if err != nil { return } // Check if the user exists - user, err := user.GetUserByUsername(lu.Username) + u, err := user.GetUserByUsername(s, lu.Username) if err != nil { return err } - lu.UserID = user.ID + lu.UserID = u.ID // Check if the user already has access or is owner of that list // We explicitly DONT check for teams here @@ -97,7 +98,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) { return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID} } - exist, err := x.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{}) + exist, err := s.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{}) if err != nil { return } @@ -106,12 +107,12 @@ func (lu *ListUser) Create(a web.Auth) (err error) { } // Insert user <-> list relation - _, err = x.Insert(lu) + _, err = s.Insert(lu) if err != nil { return err } - err = updateListLastUpdated(l) + err = updateListLastUpdated(s, l) return } @@ -128,17 +129,18 @@ func (lu *ListUser) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "user or list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/users/{userID} [delete] -func (lu *ListUser) Delete() (err error) { +func (lu *ListUser) Delete(s *xorm.Session) (err error) { // Check if the user exists - user, err := user.GetUserByUsername(lu.Username) + u, err := user.GetUserByUsername(s, lu.Username) if err != nil { return } - lu.UserID = user.ID + lu.UserID = u.ID // Check if the user has access to the list - has, err := x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). + has, err := s. + Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). Get(&ListUser{}) if err != nil { return @@ -147,13 +149,14 @@ func (lu *ListUser) Delete() (err error) { return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID} } - _, err = x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). + _, err = s. + Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). Delete(&ListUser{}) if err != nil { return err } - err = updateListLastUpdated(&List{ID: lu.ListID}) + err = updateListLastUpdated(s, &List{ID: lu.ListID}) return } @@ -172,10 +175,10 @@ func (lu *ListUser) Delete() (err error) { // @Failure 403 {object} web.HTTPError "No right to see the list." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/users [get] -func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (lu *ListUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Check if the user has access to the list l := &List{ID: lu.ListID} - canRead, _, err := l.CanRead(a) + canRead, _, err := l.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -187,7 +190,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r // Get all users all := []*UserWithRight{} - query := x. + query := s. Join("INNER", "users_list", "user_id = users.id"). Where("users_list.list_id = ?", lu.ListID). Where("users.username LIKE ?", "%"+search+"%") @@ -204,7 +207,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r u.Email = "" } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Join("INNER", "users_list", "user_id = users.id"). Where("users_list.list_id = ?", lu.ListID). Where("users.username LIKE ?", "%"+search+"%"). @@ -228,7 +231,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r // @Failure 404 {object} web.HTTPError "User or list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/users/{userID} [post] -func (lu *ListUser) Update() (err error) { +func (lu *ListUser) Update(s *xorm.Session) (err error) { // Check if the right is valid if err := lu.Right.isValid(); err != nil { @@ -236,13 +239,13 @@ func (lu *ListUser) Update() (err error) { } // Check if the user exists - u, err := user.GetUserByUsername(lu.Username) + u, err := user.GetUserByUsername(s, lu.Username) if err != nil { return err } lu.UserID = u.ID - _, err = x. + _, err = s. Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID). Cols("right"). Update(lu) @@ -250,6 +253,6 @@ func (lu *ListUser) Update() (err error) { return err } - err = updateListLastUpdated(&List{ID: lu.ListID}) + err = updateListLastUpdated(s, &List{ID: lu.ListID}) return } diff --git a/pkg/models/list_users_rights.go b/pkg/models/list_users_rights.go index b80fd1d9..057b70db 100644 --- a/pkg/models/list_users_rights.go +++ b/pkg/models/list_users_rights.go @@ -18,24 +18,25 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if the user can create a new user <-> list relation -func (lu *ListUser) CanCreate(a web.Auth) (bool, error) { - return lu.canDoListUser(a) +func (lu *ListUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return lu.canDoListUser(s, a) } // CanDelete checks if the user can delete a user <-> list relation -func (lu *ListUser) CanDelete(a web.Auth) (bool, error) { - return lu.canDoListUser(a) +func (lu *ListUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return lu.canDoListUser(s, a) } // CanUpdate checks if the user can update a user <-> list relation -func (lu *ListUser) CanUpdate(a web.Auth) (bool, error) { - return lu.canDoListUser(a) +func (lu *ListUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return lu.canDoListUser(s, a) } -func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) { +func (lu *ListUser) canDoListUser(s *xorm.Session, a web.Auth) (bool, error) { // Link shares aren't allowed to do anything if _, is := a.(*LinkSharing); is { return false, nil @@ -43,5 +44,5 @@ func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) { // Get the list and check if the user has write access on it l := List{ID: lu.ListID} - return l.IsAdmin(a) + return l.IsAdmin(s, a) } diff --git a/pkg/models/list_users_rights_test.go b/pkg/models/list_users_rights_test.go index 0ddb712d..359e38c4 100644 --- a/pkg/models/list_users_rights_test.go +++ b/pkg/models/list_users_rights_test.go @@ -80,6 +80,7 @@ func TestListUser_CanDoSomething(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() lu := &ListUser{ ID: tt.fields.ID, @@ -91,15 +92,16 @@ func TestListUser_CanDoSomething(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - if got, _ := lu.CanCreate(tt.args.a); got != tt.want["CanCreate"] { + if got, _ := lu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) } - if got, _ := lu.CanDelete(tt.args.a); got != tt.want["CanDelete"] { + if got, _ := lu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] { t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) } - if got, _ := lu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { + if got, _ := lu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] { t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) } + _ = s.Close() }) } } diff --git a/pkg/models/list_users_test.go b/pkg/models/list_users_test.go index 9ff75abe..27b6477c 100644 --- a/pkg/models/list_users_test.go +++ b/pkg/models/list_users_test.go @@ -24,9 +24,9 @@ import ( "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/user" - "gopkg.in/d4l3k/messagediff.v1" - "code.vikunja.io/web" + "github.com/stretchr/testify/assert" + "gopkg.in/d4l3k/messagediff.v1" ) func TestListUser_Create(t *testing.T) { @@ -108,6 +108,7 @@ func TestListUser_Create(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() ul := &ListUser{ ID: tt.fields.ID, @@ -120,13 +121,17 @@ func TestListUser_Create(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := ul.Create(tt.args.a) + err := ul.Create(s, tt.args.a) if (err != nil) != tt.wantErr { t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertExists(t, "users_list", map[string]interface{}{ "user_id": ul.UserID, @@ -212,6 +217,7 @@ func TestListUser_ReadAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() ul := &ListUser{ ID: tt.fields.ID, @@ -223,7 +229,7 @@ func TestListUser_ReadAll(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - got, _, _, err := ul.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) + got, _, _, err := ul.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50) if (err != nil) != tt.wantErr { t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr) } @@ -233,6 +239,7 @@ func TestListUser_ReadAll(t *testing.T) { if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal { t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff) } + _ = s.Close() }) } } @@ -292,6 +299,7 @@ func TestListUser_Update(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() lu := &ListUser{ ID: tt.fields.ID, @@ -303,13 +311,17 @@ func TestListUser_Update(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := lu.Update() + err := lu.Update(s) if (err != nil) != tt.wantErr { t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertExists(t, "users_list", map[string]interface{}{ "list_id": tt.fields.ListID, @@ -369,6 +381,7 @@ func TestListUser_Delete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() lu := &ListUser{ ID: tt.fields.ID, @@ -380,13 +393,17 @@ func TestListUser_Delete(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := lu.Delete() + err := lu.Delete(s) if (err != nil) != tt.wantErr { t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertMissing(t, "users_list", map[string]interface{}{ "user_id": tt.fields.UserID, diff --git a/pkg/models/namespace.go b/pkg/models/namespace.go index 4d4fb472..d0107e10 100644 --- a/pkg/models/namespace.go +++ b/pkg/models/namespace.go @@ -23,12 +23,11 @@ import ( "time" "code.vikunja.io/api/pkg/log" - "code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" - "github.com/imdario/mergo" "xorm.io/builder" + "xorm.io/xorm" ) // Namespace holds informations about a namespace @@ -95,55 +94,48 @@ func (Namespace) TableName() string { } // GetSimpleByID gets a namespace without things like the owner, it more or less only checks if it exists. -func (n *Namespace) GetSimpleByID() (err error) { - if n.ID == 0 { - return ErrNamespaceDoesNotExist{ID: n.ID} +func getNamespaceSimpleByID(s *xorm.Session, id int64) (namespace *Namespace, err error) { + if id == 0 { + return nil, ErrNamespaceDoesNotExist{ID: id} } // Get the namesapce with shared lists - if n.ID == -1 { - *n = SharedListsPseudoNamespace - return + if id == -1 { + return &SharedListsPseudoNamespace, nil } - if n.ID == FavoritesPseudoNamespace.ID { - *n = FavoritesPseudoNamespace - return + if id == FavoritesPseudoNamespace.ID { + return &FavoritesPseudoNamespace, nil } - namespaceFromDB := &Namespace{} - exists, err := x.Where("id = ?", n.ID).Get(namespaceFromDB) + namespace = &Namespace{} + + exists, err := s.Where("id = ?", id).Get(namespace) if err != nil { return } if !exists { - return ErrNamespaceDoesNotExist{ID: n.ID} + return nil, ErrNamespaceDoesNotExist{ID: id} } - // We don't want to override the provided user struct because this would break updating, so we have to merge it - if err := mergo.Merge(namespaceFromDB, n, mergo.WithOverride); err != nil { - return err - } - *n = *namespaceFromDB return } // GetNamespaceByID returns a namespace object by its ID -func GetNamespaceByID(id int64) (namespace Namespace, err error) { - namespace = Namespace{ID: id} - err = namespace.GetSimpleByID() +func GetNamespaceByID(s *xorm.Session, id int64) (namespace *Namespace, err error) { + namespace, err = getNamespaceSimpleByID(s, id) if err != nil { return } // Get the namespace Owner - namespace.Owner, err = user.GetUserByID(namespace.OwnerID) + namespace.Owner, err = user.GetUserByID(s, namespace.OwnerID) return } // CheckIsArchived returns an ErrNamespaceIsArchived if the namepace is archived. -func (n *Namespace) CheckIsArchived() error { - exists, err := x. +func (n *Namespace) CheckIsArchived(s *xorm.Session) error { + exists, err := s. Where("id = ? AND is_archived = true", n.ID). Exist(&Namespace{}) if err != nil { @@ -167,8 +159,12 @@ func (n *Namespace) CheckIsArchived() error { // @Failure 403 {object} web.HTTPError "The user does not have access to that namespace." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id} [get] -func (n *Namespace) ReadOne() (err error) { - *n, err = GetNamespaceByID(n.ID) +func (n *Namespace) ReadOne(s *xorm.Session) (err error) { + nn, err := GetNamespaceByID(s, n.ID) + if err != nil { + return err + } + *n = *nn return } @@ -207,7 +203,7 @@ func makeNamespaceSliceFromMap(namespaces map[int64]*NamespaceWithLists, userMap // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces [get] //nolint:gocyclo -func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (n *Namespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { if _, is := a.(*LinkSharing); is { return nil, 0, 0, ErrGenericForbidden{} } @@ -249,7 +245,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r } limit, start := getLimitFromPageIndex(page, perPage) - query := x.Select("namespaces.*"). + query := s.Select("namespaces.*"). Table("namespaces"). Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id"). Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id"). @@ -268,7 +264,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r return nil, 0, 0, err } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Table("namespaces"). Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id"). Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id"). @@ -294,7 +290,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r // Get all owners userMap := make(map[int64]*user.User) - err = x.In("id", userIDs).Find(&userMap) + err = s.In("id", userIDs).Find(&userMap) if err != nil { return nil, 0, 0, err } @@ -306,7 +302,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r // Get all lists lists := []*List{} - listQuery := x. + listQuery := s. In("namespace_id", namespaceids) if !n.IsArchived { @@ -330,7 +326,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r // Get all lists individually shared with our user (not via a namespace) individualLists := []*List{} - iListQuery := x.Select("l.*"). + iListQuery := s.Select("l.*"). Table("list"). Alias("l"). Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). @@ -360,7 +356,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r } // More details for the lists - err = AddListDetails(lists) + err = addListDetails(s, lists) if err != nil { return nil, 0, 0, err } @@ -386,7 +382,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r // Check if we have any favorites or favorited lists and remove the favorites namespace from the list if not var favoriteCount int64 - favoriteCount, err = x. + favoriteCount, err = s. Join("INNER", "list", "tasks.list_id = list.id"). Join("INNER", "namespaces", "list.namespace_id = namespaces.id"). Where(builder.And(builder.Eq{"tasks.is_favorite": true}, builder.In("namespaces.id", namespaceids))). @@ -413,7 +409,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r ///////////////// // Saved Filters - savedFilters, err := getSavedFiltersForUser(a) + savedFilters, err := getSavedFiltersForUser(s, a) if err != nil { return nil, 0, 0, err } @@ -457,7 +453,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces [put] -func (n *Namespace) Create(a web.Auth) (err error) { +func (n *Namespace) Create(s *xorm.Session, a web.Auth) (err error) { // Check if we have at least a name if n.Title == "" { return ErrNamespaceNameCannotBeEmpty{NamespaceID: 0, UserID: a.GetID()} @@ -465,14 +461,14 @@ func (n *Namespace) Create(a web.Auth) (err error) { n.ID = 0 // This would otherwise prevent the creation of new lists after one was created // Check if the User exists - n.Owner, err = user.GetUserByID(a.GetID()) + n.Owner, err = user.GetUserByID(s, a.GetID()) if err != nil { return } n.OwnerID = n.Owner.ID // Insert - if _, err = x.Insert(n); err != nil { + if _, err = s.Insert(n); err != nil { return err } @@ -482,12 +478,12 @@ func (n *Namespace) Create(a web.Auth) (err error) { // CreateNewNamespaceForUser creates a new namespace for a user. To prevent import cycles, we can't do that // directly in the user.Create function. -func CreateNewNamespaceForUser(user *user.User) (err error) { +func CreateNewNamespaceForUser(s *xorm.Session, user *user.User) (err error) { newN := &Namespace{ Title: user.Username, Description: user.Username + "'s namespace.", } - return newN.Create(user) + return newN.Create(s, user) } // Delete deletes a namespace @@ -502,22 +498,22 @@ func CreateNewNamespaceForUser(user *user.User) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id} [delete] -func (n *Namespace) Delete() (err error) { +func (n *Namespace) Delete(s *xorm.Session) (err error) { // Check if the namespace exists - _, err = GetNamespaceByID(n.ID) + _, err = GetNamespaceByID(s, n.ID) if err != nil { return } // Delete the namespace - _, err = x.ID(n.ID).Delete(&Namespace{}) + _, err = s.ID(n.ID).Delete(&Namespace{}) if err != nil { return } // Delete all lists with their tasks - lists, err := GetListsByNamespaceID(n.ID, &user.User{}) + lists, err := GetListsByNamespaceID(s, n.ID, &user.User{}) if err != nil { return } @@ -530,13 +526,13 @@ func (n *Namespace) Delete() (err error) { } // Delete tasks - _, err = x.In("list_id", listIDs).Delete(&Task{}) + _, err = s.In("list_id", listIDs).Delete(&Task{}) if err != nil { return } // Delete the lists - _, err = x.In("id", listIDs).Delete(&List{}) + _, err = s.In("id", listIDs).Delete(&List{}) if err != nil { return } @@ -560,14 +556,14 @@ func (n *Namespace) Delete() (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 500 {object} models.Message "Internal error" // @Router /namespace/{id} [post] -func (n *Namespace) Update() (err error) { +func (n *Namespace) Update(s *xorm.Session) (err error) { // Check if we have at least a name if n.Title == "" { return ErrNamespaceNameCannotBeEmpty{NamespaceID: n.ID} } // Check if the namespace exists - currentNamespace, err := GetNamespaceByID(n.ID) + currentNamespace, err := GetNamespaceByID(s, n.ID) if err != nil { return } @@ -581,7 +577,7 @@ func (n *Namespace) Update() (err error) { if n.Owner != nil { n.OwnerID = n.Owner.ID if currentNamespace.OwnerID != n.OwnerID { - n.Owner, err = user.GetUserByID(n.OwnerID) + n.Owner, err = user.GetUserByID(s, n.OwnerID) if err != nil { return } @@ -599,7 +595,7 @@ func (n *Namespace) Update() (err error) { } // Do the actual update - _, err = x. + _, err = s. ID(currentNamespace.ID). Cols(colsToUpdate...). Update(n) diff --git a/pkg/models/namespace_rights.go b/pkg/models/namespace_rights.go index 066fbcfc..539e790a 100644 --- a/pkg/models/namespace_rights.go +++ b/pkg/models/namespace_rights.go @@ -19,37 +19,38 @@ package models import ( "code.vikunja.io/web" "xorm.io/builder" + "xorm.io/xorm" ) // CanWrite checks if a user has write access to a namespace -func (n *Namespace) CanWrite(a web.Auth) (bool, error) { - can, _, err := n.checkRight(a, RightWrite, RightAdmin) +func (n *Namespace) CanWrite(s *xorm.Session, a web.Auth) (bool, error) { + can, _, err := n.checkRight(s, a, RightWrite, RightAdmin) return can, err } // IsAdmin returns true or false if the user is admin on that namespace or not -func (n *Namespace) IsAdmin(a web.Auth) (bool, error) { - is, _, err := n.checkRight(a, RightAdmin) +func (n *Namespace) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) { + is, _, err := n.checkRight(s, a, RightAdmin) return is, err } // CanRead checks if a user has read access to that namespace -func (n *Namespace) CanRead(a web.Auth) (bool, int, error) { - return n.checkRight(a, RightRead, RightWrite, RightAdmin) +func (n *Namespace) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { + return n.checkRight(s, a, RightRead, RightWrite, RightAdmin) } // CanUpdate checks if the user can update the namespace -func (n *Namespace) CanUpdate(a web.Auth) (bool, error) { - return n.IsAdmin(a) +func (n *Namespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return n.IsAdmin(s, a) } // CanDelete checks if the user can delete a namespace -func (n *Namespace) CanDelete(a web.Auth) (bool, error) { - return n.IsAdmin(a) +func (n *Namespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return n.IsAdmin(s, a) } // CanCreate checks if the user can create a new namespace -func (n *Namespace) CanCreate(a web.Auth) (bool, error) { +func (n *Namespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { if _, is := a.(*LinkSharing); is { return false, nil } @@ -58,7 +59,7 @@ func (n *Namespace) CanCreate(a web.Auth) (bool, error) { return true, nil } -func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) { +func (n *Namespace) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) { // If the auth is a link share, don't do anything if _, is := a.(*LinkSharing); is { @@ -66,13 +67,12 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) { } // Get the namespace and check the right - nn := &Namespace{ID: n.ID} - err := nn.GetSimpleByID() + nn, err := getNamespaceSimpleByID(s, n.ID) if err != nil { return false, 0, err } - if a.GetID() == n.OwnerID { + if a.GetID() == nn.OwnerID { return true, int(RightAdmin), nil } @@ -113,7 +113,8 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) { var maxRights = 0 r := &allRights{} - exists, err := x.Select("*"). + exists, err := s. + Select("*"). Table("namespaces"). // User stuff Join("LEFT", "users_namespace", "users_namespace.namespace_id = namespaces.id"). diff --git a/pkg/models/namespace_team.go b/pkg/models/namespace_team.go index f249e2e5..ff643f81 100644 --- a/pkg/models/namespace_team.go +++ b/pkg/models/namespace_team.go @@ -20,6 +20,7 @@ import ( "time" "code.vikunja.io/web" + "xorm.io/xorm" ) // TeamNamespace defines the relationship between a Team and a Namespace @@ -62,7 +63,7 @@ func (TeamNamespace) TableName() string { // @Failure 403 {object} web.HTTPError "The team does not have access to the namespace" // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id}/teams [put] -func (tn *TeamNamespace) Create(a web.Auth) (err error) { +func (tn *TeamNamespace) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the rights are valid if err = tn.Right.isValid(); err != nil { @@ -70,19 +71,20 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) { } // Check if the team exists - _, err = GetTeamByID(tn.TeamID) + _, err = GetTeamByID(s, tn.TeamID) if err != nil { return } // Check if the namespace exists - _, err = GetNamespaceByID(tn.NamespaceID) + _, err = GetNamespaceByID(s, tn.NamespaceID) if err != nil { return } // Check if the team already has access to the namespace - exists, err := x.Where("team_id = ?", tn.TeamID). + exists, err := s. + Where("team_id = ?", tn.TeamID). And("namespace_id = ?", tn.NamespaceID). Get(&TeamNamespace{}) if err != nil { @@ -93,7 +95,7 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) { } // Insert the new team - _, err = x.Insert(tn) + _, err = s.Insert(tn) return } @@ -110,16 +112,17 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "team or namespace does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{namespaceID}/teams/{teamID} [delete] -func (tn *TeamNamespace) Delete() (err error) { +func (tn *TeamNamespace) Delete(s *xorm.Session) (err error) { // Check if the team exists - _, err = GetTeamByID(tn.TeamID) + _, err = GetTeamByID(s, tn.TeamID) if err != nil { return } // Check if the team has access to the namespace - has, err := x.Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID). + has, err := s. + Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID). Get(&TeamNamespace{}) if err != nil { return @@ -129,7 +132,8 @@ func (tn *TeamNamespace) Delete() (err error) { } // Delete the relation - _, err = x.Where("team_id = ?", tn.TeamID). + _, err = s. + Where("team_id = ?", tn.TeamID). And("namespace_id = ?", tn.NamespaceID). Delete(TeamNamespace{}) @@ -151,10 +155,10 @@ func (tn *TeamNamespace) Delete() (err error) { // @Failure 403 {object} web.HTTPError "No right to see the namespace." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id}/teams [get] -func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (tn *TeamNamespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Check if the user can read the namespace n := Namespace{ID: tn.NamespaceID} - canRead, _, err := n.CanRead(a) + canRead, _, err := n.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -167,7 +171,8 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in limit, start := getLimitFromPageIndex(page, perPage) - query := x.Table("teams"). + query := s. + Table("teams"). Join("INNER", "team_namespaces", "team_id = teams.id"). Where("team_namespaces.namespace_id = ?", tn.NamespaceID). Where("teams.name LIKE ?", "%"+search+"%") @@ -184,12 +189,13 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in teams = append(teams, &t.Team) } - err = addMoreInfoToTeams(teams) + err = addMoreInfoToTeams(s, teams) if err != nil { return } - numberOfTotalItems, err = x.Table("teams"). + numberOfTotalItems, err = s. + Table("teams"). Join("INNER", "team_namespaces", "team_id = teams.id"). Where("team_namespaces.namespace_id = ?", tn.NamespaceID). Where("teams.name LIKE ?", "%"+search+"%"). @@ -213,14 +219,14 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in // @Failure 404 {object} web.HTTPError "Team or namespace does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{namespaceID}/teams/{teamID} [post] -func (tn *TeamNamespace) Update() (err error) { +func (tn *TeamNamespace) Update(s *xorm.Session) (err error) { // Check if the right is valid if err := tn.Right.isValid(); err != nil { return err } - _, err = x. + _, err = s. Where("namespace_id = ? AND team_id = ?", tn.NamespaceID, tn.TeamID). Cols("right"). Update(tn) diff --git a/pkg/models/namespace_team_rights.go b/pkg/models/namespace_team_rights.go index f62d5588..199cb8a7 100644 --- a/pkg/models/namespace_team_rights.go +++ b/pkg/models/namespace_team_rights.go @@ -18,22 +18,23 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if one can create a new team <-> namespace relation -func (tn *TeamNamespace) CanCreate(a web.Auth) (bool, error) { +func (tn *TeamNamespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { n := &Namespace{ID: tn.NamespaceID} - return n.IsAdmin(a) + return n.IsAdmin(s, a) } // CanDelete checks if a user can remove a team from a namespace. Only namespace admins can do that. -func (tn *TeamNamespace) CanDelete(a web.Auth) (bool, error) { +func (tn *TeamNamespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { n := &Namespace{ID: tn.NamespaceID} - return n.IsAdmin(a) + return n.IsAdmin(s, a) } // CanUpdate checks if a user can update a team from a Only namespace admins can do that. -func (tn *TeamNamespace) CanUpdate(a web.Auth) (bool, error) { +func (tn *TeamNamespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { n := &Namespace{ID: tn.NamespaceID} - return n.IsAdmin(a) + return n.IsAdmin(s, a) } diff --git a/pkg/models/namespace_team_rights_test.go b/pkg/models/namespace_team_rights_test.go index f50946e8..6619337b 100644 --- a/pkg/models/namespace_team_rights_test.go +++ b/pkg/models/namespace_team_rights_test.go @@ -80,6 +80,7 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tn := &TeamNamespace{ ID: tt.fields.ID, @@ -91,15 +92,16 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - if got, _ := tn.CanCreate(tt.args.a); got != tt.want["CanCreate"] { + if got, _ := tn.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { t.Errorf("TeamNamespace.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) } - if got, _ := tn.CanDelete(tt.args.a); got != tt.want["CanDelete"] { + if got, _ := tn.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] { t.Errorf("TeamNamespace.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) } - if got, _ := tn.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { + if got, _ := tn.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] { t.Errorf("TeamNamespace.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) } + _ = s.Close() }) } } diff --git a/pkg/models/namespace_team_test.go b/pkg/models/namespace_team_test.go index b0d59575..22fd64f4 100644 --- a/pkg/models/namespace_team_test.go +++ b/pkg/models/namespace_team_test.go @@ -36,29 +36,35 @@ func TestTeamNamespace_ReadAll(t *testing.T) { NamespaceID: 3, } db.LoadAndAssertFixtures(t) - teams, _, _, err := tn.ReadAll(u, "", 1, 50) + s := db.NewSession() + teams, _, _, err := tn.ReadAll(s, u, "", 1, 50) assert.NoError(t, err) assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice) - s := reflect.ValueOf(teams) - assert.Equal(t, s.Len(), 2) + ts := reflect.ValueOf(teams) + assert.Equal(t, ts.Len(), 2) + _ = s.Close() }) t.Run("nonexistant namespace", func(t *testing.T) { tn := TeamNamespace{ NamespaceID: 9999, } db.LoadAndAssertFixtures(t) - _, _, _, err := tn.ReadAll(u, "", 1, 50) + s := db.NewSession() + _, _, _, err := tn.ReadAll(s, u, "", 1, 50) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) t.Run("no right for namespace", func(t *testing.T) { tn := TeamNamespace{ NamespaceID: 17, } db.LoadAndAssertFixtures(t) - _, _, _, err := tn.ReadAll(u, "", 1, 50) + s := db.NewSession() + _, _, _, err := tn.ReadAll(s, u, "", 1, 50) assert.Error(t, err) assert.True(t, IsErrNeedToHaveNamespaceReadAccess(err)) + _ = s.Close() }) } @@ -72,10 +78,15 @@ func TestTeamNamespace_Create(t *testing.T) { Right: RightAdmin, } db.LoadAndAssertFixtures(t) - allowed, _ := tn.CanCreate(u) + s := db.NewSession() + allowed, _ := tn.CanCreate(s, u) assert.True(t, allowed) - err := tn.Create(u) + err := tn.Create(s, u) assert.NoError(t, err) + + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "team_namespaces", map[string]interface{}{ "team_id": 1, "namespace_id": 1, @@ -89,9 +100,11 @@ func TestTeamNamespace_Create(t *testing.T) { Right: RightRead, } db.LoadAndAssertFixtures(t) - err := tn.Create(u) + s := db.NewSession() + err := tn.Create(s, u) assert.Error(t, err) assert.True(t, IsErrTeamAlreadyHasAccess(err)) + _ = s.Close() }) t.Run("invalid team right", func(t *testing.T) { tn := TeamNamespace{ @@ -100,9 +113,11 @@ func TestTeamNamespace_Create(t *testing.T) { Right: RightUnknown, } db.LoadAndAssertFixtures(t) - err := tn.Create(u) + s := db.NewSession() + err := tn.Create(s, u) assert.Error(t, err) assert.True(t, IsErrInvalidRight(err)) + _ = s.Close() }) t.Run("nonexistant team", func(t *testing.T) { tn := TeamNamespace{ @@ -110,9 +125,11 @@ func TestTeamNamespace_Create(t *testing.T) { NamespaceID: 1, } db.LoadAndAssertFixtures(t) - err := tn.Create(u) + s := db.NewSession() + err := tn.Create(s, u) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexistant namespace", func(t *testing.T) { tn := TeamNamespace{ @@ -120,9 +137,11 @@ func TestTeamNamespace_Create(t *testing.T) { NamespaceID: 9999, } db.LoadAndAssertFixtures(t) - err := tn.Create(u) + s := db.NewSession() + err := tn.Create(s, u) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) } @@ -135,10 +154,14 @@ func TestTeamNamespace_Delete(t *testing.T) { NamespaceID: 9, } db.LoadAndAssertFixtures(t) - allowed, _ := tn.CanDelete(u) + s := db.NewSession() + allowed, _ := tn.CanDelete(s, u) assert.True(t, allowed) - err := tn.Delete() + err := tn.Delete(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertMissing(t, "team_namespaces", map[string]interface{}{ "team_id": 7, "namespace_id": 9, @@ -150,9 +173,11 @@ func TestTeamNamespace_Delete(t *testing.T) { NamespaceID: 3, } db.LoadAndAssertFixtures(t) - err := tn.Delete() + s := db.NewSession() + err := tn.Delete(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexistant namespace", func(t *testing.T) { tn := TeamNamespace{ @@ -160,9 +185,11 @@ func TestTeamNamespace_Delete(t *testing.T) { NamespaceID: 9999, } db.LoadAndAssertFixtures(t) - err := tn.Delete() + s := db.NewSession() + err := tn.Delete(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotHaveAccessToNamespace(err)) + _ = s.Close() }) } @@ -221,6 +248,7 @@ func TestTeamNamespace_Update(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() tl := &TeamNamespace{ ID: tt.fields.ID, @@ -232,13 +260,17 @@ func TestTeamNamespace_Update(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := tl.Update() + err := tl.Update(s) if (err != nil) != tt.wantErr { t.Errorf("TeamNamespace.Update() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("TeamNamespace.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertExists(t, "team_namespaces", map[string]interface{}{ "team_id": tt.fields.TeamID, diff --git a/pkg/models/namespace_test.go b/pkg/models/namespace_test.go index b7a93d83..5e0bc213 100644 --- a/pkg/models/namespace_test.go +++ b/pkg/models/namespace_test.go @@ -36,8 +36,12 @@ func TestNamespace_Create(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - err := dummynamespace.Create(user1) + s := db.NewSession() + err := dummynamespace.Create(s, user1) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "namespaces", map[string]interface{}{ "title": "Test", "description": "Lorem Ipsum", @@ -45,18 +49,22 @@ func TestNamespace_Create(t *testing.T) { }) t.Run("no title", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n2 := Namespace{} - err := n2.Create(user1) + err := n2.Create(s, user1) assert.Error(t, err) assert.True(t, IsErrNamespaceNameCannotBeEmpty(err)) + _ = s.Close() }) t.Run("nonexistant user", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() nUser := &user.User{ID: 9482385} dnsp2 := dummynamespace - err := dnsp2.Create(nUser) + err := dnsp2.Create(s, nUser) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) + _ = s.Close() }) } @@ -64,28 +72,36 @@ func TestNamespace_ReadOne(t *testing.T) { t.Run("normal", func(t *testing.T) { n := &Namespace{ID: 1} db.LoadAndAssertFixtures(t) - err := n.ReadOne() + s := db.NewSession() + err := n.ReadOne(s) assert.NoError(t, err) assert.Equal(t, n.Title, "testnamespace") + _ = s.Close() }) t.Run("nonexistant", func(t *testing.T) { n := &Namespace{ID: 99999} db.LoadAndAssertFixtures(t) - err := n.ReadOne() + s := db.NewSession() + err := n.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) } func TestNamespace_Update(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 1, Title: "Lorem Ipsum", } - err := n.Update() + err := n.Update(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "namespaces", map[string]interface{}{ "id": 1, "title": "Lorem Ipsum", @@ -93,56 +109,68 @@ func TestNamespace_Update(t *testing.T) { }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 99999, Title: "Lorem Ipsum", } - err := n.Update() + err := n.Update(s) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) t.Run("nonexisting owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 1, Title: "Lorem Ipsum", Owner: &user.User{ID: 99999}, } - err := n.Update() + err := n.Update(s) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) + _ = s.Close() }) t.Run("no title", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 1, } - err := n.Update() + err := n.Update(s) assert.Error(t, err) assert.True(t, IsErrNamespaceNameCannotBeEmpty(err)) + _ = s.Close() }) } func TestNamespace_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 1, } - err := n.Delete() + err := n.Delete(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertMissing(t, "namespaces", map[string]interface{}{ "id": 1, }) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() n := &Namespace{ ID: 9999, } - err := n.Delete() + err := n.Delete(s) assert.Error(t, err) assert.True(t, IsErrNamespaceDoesNotExist(err)) + _ = s.Close() }) } @@ -152,9 +180,12 @@ func TestNamespace_ReadAll(t *testing.T) { user11 := &user.User{ID: 11} user12 := &user.User{ID: 12} + s := db.NewSession() + defer s.Close() + t.Run("normal", func(t *testing.T) { n := &Namespace{} - nn, _, _, err := n.ReadAll(user1, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user1, "", 1, -1) assert.NoError(t, err) namespaces := nn.([]*NamespaceWithLists) assert.NotNil(t, namespaces) @@ -174,7 +205,7 @@ func TestNamespace_ReadAll(t *testing.T) { n := &Namespace{ NamespacesOnly: true, } - nn, _, _, err := n.ReadAll(user1, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user1, "", 1, -1) assert.NoError(t, err) namespaces := nn.([]*NamespaceWithLists) assert.NotNil(t, namespaces) @@ -188,7 +219,7 @@ func TestNamespace_ReadAll(t *testing.T) { n := &Namespace{ NamespacesOnly: true, } - nn, _, _, err := n.ReadAll(user7, "13,14", 1, -1) + nn, _, _, err := n.ReadAll(s, user7, "13,14", 1, -1) assert.NoError(t, err) namespaces := nn.([]*NamespaceWithLists) assert.NotNil(t, namespaces) @@ -200,7 +231,7 @@ func TestNamespace_ReadAll(t *testing.T) { n := &Namespace{ NamespacesOnly: true, } - nn, _, _, err := n.ReadAll(user1, "1,w", 1, -1) + nn, _, _, err := n.ReadAll(s, user1, "1,w", 1, -1) assert.NoError(t, err) namespaces := nn.([]*NamespaceWithLists) assert.NotNil(t, namespaces) @@ -211,7 +242,7 @@ func TestNamespace_ReadAll(t *testing.T) { n := &Namespace{ IsArchived: true, } - nn, _, _, err := n.ReadAll(user1, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user1, "", 1, -1) namespaces := nn.([]*NamespaceWithLists) assert.NoError(t, err) assert.NotNil(t, namespaces) @@ -222,7 +253,7 @@ func TestNamespace_ReadAll(t *testing.T) { }) t.Run("no favorites", func(t *testing.T) { n := &Namespace{} - nn, _, _, err := n.ReadAll(user11, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user11, "", 1, -1) namespaces := nn.([]*NamespaceWithLists) assert.NoError(t, err) // Assert the first namespace is not the favorites namespace @@ -230,7 +261,7 @@ func TestNamespace_ReadAll(t *testing.T) { }) t.Run("no favorite tasks but namespace", func(t *testing.T) { n := &Namespace{} - nn, _, _, err := n.ReadAll(user12, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user12, "", 1, -1) namespaces := nn.([]*NamespaceWithLists) assert.NoError(t, err) // Assert the first namespace is the favorites namespace and contains lists @@ -239,7 +270,7 @@ func TestNamespace_ReadAll(t *testing.T) { }) t.Run("no saved filters", func(t *testing.T) { n := &Namespace{} - nn, _, _, err := n.ReadAll(user11, "", 1, -1) + nn, _, _, err := n.ReadAll(s, user11, "", 1, -1) namespaces := nn.([]*NamespaceWithLists) assert.NoError(t, err) // Assert the first namespace is not the favorites namespace diff --git a/pkg/models/namespace_users.go b/pkg/models/namespace_users.go index 119d0f60..ad58fcf7 100644 --- a/pkg/models/namespace_users.go +++ b/pkg/models/namespace_users.go @@ -21,6 +21,7 @@ import ( user2 "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // NamespaceUser represents a namespace <-> user relation @@ -64,7 +65,7 @@ func (NamespaceUser) TableName() string { // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id}/users [put] -func (nu *NamespaceUser) Create(a web.Auth) (err error) { +func (nu *NamespaceUser) Create(s *xorm.Session, a web.Auth) (err error) { // Reset the id nu.ID = 0 @@ -74,13 +75,13 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) { } // Check if the namespace exists - l, err := GetNamespaceByID(nu.NamespaceID) + l, err := GetNamespaceByID(s, nu.NamespaceID) if err != nil { return } // Check if the user exists - user, err := user2.GetUserByUsername(nu.Username) + user, err := user2.GetUserByUsername(s, nu.Username) if err != nil { return err } @@ -92,7 +93,9 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) { return ErrUserAlreadyHasNamespaceAccess{UserID: nu.UserID, NamespaceID: nu.NamespaceID} } - exist, err := x.Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).Get(&NamespaceUser{}) + exist, err := s. + Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID). + Get(&NamespaceUser{}) if err != nil { return } @@ -101,7 +104,7 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) { } // Insert user <-> namespace relation - _, err = x.Insert(nu) + _, err = s.Insert(nu) return } @@ -119,17 +122,18 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "user or namespace does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{namespaceID}/users/{userID} [delete] -func (nu *NamespaceUser) Delete() (err error) { +func (nu *NamespaceUser) Delete(s *xorm.Session) (err error) { // Check if the user exists - user, err := user2.GetUserByUsername(nu.Username) + user, err := user2.GetUserByUsername(s, nu.Username) if err != nil { return } nu.UserID = user.ID // Check if the user has access to the namespace - has, err := x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). + has, err := s. + Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). Get(&NamespaceUser{}) if err != nil { return @@ -138,7 +142,8 @@ func (nu *NamespaceUser) Delete() (err error) { return ErrUserDoesNotHaveAccessToNamespace{NamespaceID: nu.NamespaceID, UserID: nu.UserID} } - _, err = x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). + _, err = s. + Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). Delete(&NamespaceUser{}) return } @@ -158,10 +163,10 @@ func (nu *NamespaceUser) Delete() (err error) { // @Failure 403 {object} web.HTTPError "No right to see the namespace." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id}/users [get] -func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (nu *NamespaceUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Check if the user has access to the namespace l := Namespace{ID: nu.NamespaceID} - canRead, _, err := l.CanRead(a) + canRead, _, err := l.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -174,7 +179,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in limit, start := getLimitFromPageIndex(page, perPage) - query := x. + query := s. Join("INNER", "users_namespace", "user_id = users.id"). Where("users_namespace.namespace_id = ?", nu.NamespaceID). Where("users.username LIKE ?", "%"+search+"%") @@ -191,7 +196,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in u.Email = "" } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Join("INNER", "users_namespace", "user_id = users.id"). Where("users_namespace.namespace_id = ?", nu.NamespaceID). Where("users.username LIKE ?", "%"+search+"%"). @@ -215,7 +220,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in // @Failure 404 {object} web.HTTPError "User or namespace does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{namespaceID}/users/{userID} [post] -func (nu *NamespaceUser) Update() (err error) { +func (nu *NamespaceUser) Update(s *xorm.Session) (err error) { // Check if the right is valid if err := nu.Right.isValid(); err != nil { @@ -223,13 +228,13 @@ func (nu *NamespaceUser) Update() (err error) { } // Check if the user exists - user, err := user2.GetUserByUsername(nu.Username) + user, err := user2.GetUserByUsername(s, nu.Username) if err != nil { return err } nu.UserID = user.ID - _, err = x. + _, err = s. Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID). Cols("right"). Update(nu) diff --git a/pkg/models/namespace_users_rights.go b/pkg/models/namespace_users_rights.go index d8a539c1..a0a7f262 100644 --- a/pkg/models/namespace_users_rights.go +++ b/pkg/models/namespace_users_rights.go @@ -18,24 +18,25 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if the user can create a new user <-> namespace relation -func (nu *NamespaceUser) CanCreate(a web.Auth) (bool, error) { - return nu.canDoNamespaceUser(a) +func (nu *NamespaceUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return nu.canDoNamespaceUser(s, a) } // CanDelete checks if the user can delete a user <-> namespace relation -func (nu *NamespaceUser) CanDelete(a web.Auth) (bool, error) { - return nu.canDoNamespaceUser(a) +func (nu *NamespaceUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return nu.canDoNamespaceUser(s, a) } // CanUpdate checks if the user can update a user <-> namespace relation -func (nu *NamespaceUser) CanUpdate(a web.Auth) (bool, error) { - return nu.canDoNamespaceUser(a) +func (nu *NamespaceUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return nu.canDoNamespaceUser(s, a) } -func (nu *NamespaceUser) canDoNamespaceUser(a web.Auth) (bool, error) { +func (nu *NamespaceUser) canDoNamespaceUser(s *xorm.Session, a web.Auth) (bool, error) { n := &Namespace{ID: nu.NamespaceID} - return n.IsAdmin(a) + return n.IsAdmin(s, a) } diff --git a/pkg/models/namespace_users_rights_test.go b/pkg/models/namespace_users_rights_test.go index 44c31a94..e8211cfa 100644 --- a/pkg/models/namespace_users_rights_test.go +++ b/pkg/models/namespace_users_rights_test.go @@ -80,6 +80,8 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() nu := &NamespaceUser{ ID: tt.fields.ID, @@ -91,13 +93,13 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - if got, _ := nu.CanCreate(tt.args.a); got != tt.want["CanCreate"] { + if got, _ := nu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { t.Errorf("NamespaceUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) } - if got, _ := nu.CanDelete(tt.args.a); got != tt.want["CanDelete"] { + if got, _ := nu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] { t.Errorf("NamespaceUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) } - if got, _ := nu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { + if got, _ := nu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] { t.Errorf("NamespaceUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) } }) diff --git a/pkg/models/namespace_users_test.go b/pkg/models/namespace_users_test.go index 7afcc41c..b49d1376 100644 --- a/pkg/models/namespace_users_test.go +++ b/pkg/models/namespace_users_test.go @@ -25,6 +25,7 @@ import ( "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "github.com/stretchr/testify/assert" "gopkg.in/d4l3k/messagediff.v1" ) @@ -108,6 +109,7 @@ func TestNamespaceUser_Create(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() un := &NamespaceUser{ ID: tt.fields.ID, @@ -119,13 +121,16 @@ func TestNamespaceUser_Create(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := un.Create(tt.args.a) + err := un.Create(s, tt.args.a) if (err != nil) != tt.wantErr { t.Errorf("NamespaceUser.Create() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("NamespaceUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertExists(t, "users_namespace", map[string]interface{}{ "user_id": tt.fields.UserID, @@ -211,6 +216,8 @@ func TestNamespaceUser_ReadAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() un := &NamespaceUser{ ID: tt.fields.ID, @@ -222,7 +229,7 @@ func TestNamespaceUser_ReadAll(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - got, _, _, err := un.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) + got, _, _, err := un.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50) if (err != nil) != tt.wantErr { t.Errorf("NamespaceUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -296,6 +303,7 @@ func TestNamespaceUser_Update(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() nu := &NamespaceUser{ ID: tt.fields.ID, @@ -307,13 +315,16 @@ func TestNamespaceUser_Update(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := nu.Update() + err := nu.Update(s) if (err != nil) != tt.wantErr { t.Errorf("NamespaceUser.Update() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("NamespaceUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertExists(t, "users_namespace", map[string]interface{}{ "user_id": tt.fields.UserID, @@ -373,6 +384,7 @@ func TestNamespaceUser_Delete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() nu := &NamespaceUser{ ID: tt.fields.ID, @@ -384,13 +396,16 @@ func TestNamespaceUser_Delete(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - err := nu.Delete() + err := nu.Delete(s) if (err != nil) != tt.wantErr { t.Errorf("NamespaceUser.Delete() error = %v, wantErr %v", err, tt.wantErr) } if (err != nil) && tt.wantErr && !tt.errType(err) { t.Errorf("NamespaceUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) } + err = s.Commit() + assert.NoError(t, err) + if !tt.wantErr { db.AssertMissing(t, "users_namespace", map[string]interface{}{ "user_id": tt.fields.UserID, diff --git a/pkg/models/saved_filters.go b/pkg/models/saved_filters.go index 97ab4f85..22ed0f1d 100644 --- a/pkg/models/saved_filters.go +++ b/pkg/models/saved_filters.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // SavedFilter represents a saved bunch of filters @@ -48,14 +49,14 @@ type SavedFilter struct { } // TableName returns a better table name for saved filters -func (s *SavedFilter) TableName() string { +func (sf *SavedFilter) TableName() string { return "saved_filters" } -func (s *SavedFilter) getTaskCollection() *TaskCollection { +func (sf *SavedFilter) getTaskCollection() *TaskCollection { // We're resetting the listID to return tasks from all lists - s.Filters.ListID = 0 - return s.Filters + sf.Filters.ListID = 0 + return sf.Filters } // Returns the saved filter ID from a list ID. Will not check if the filter actually exists. @@ -79,13 +80,13 @@ func getListIDFromSavedFilterID(filterID int64) (listID int64) { return } -func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) { +func getSavedFiltersForUser(s *xorm.Session, auth web.Auth) (filters []*SavedFilter, err error) { // Link shares can't view or modify saved filters, therefore we can error out right away if _, is := auth.(*LinkSharing); is { return nil, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID()} } - err = x.Where("owner_id = ?", auth.GetID()).Find(&filters) + err = s.Where("owner_id = ?", auth.GetID()).Find(&filters) return } @@ -100,17 +101,17 @@ func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter." // @Failure 500 {object} models.Message "Internal error" // @Router /filters [put] -func (s *SavedFilter) Create(auth web.Auth) error { - s.OwnerID = auth.GetID() - _, err := x.Insert(s) +func (sf *SavedFilter) Create(s *xorm.Session, auth web.Auth) error { + sf.OwnerID = auth.GetID() + _, err := s.Insert(sf) return err } -func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) { - s = &SavedFilter{} - exists, err := x. +func getSavedFilterSimpleByID(s *xorm.Session, id int64) (sf *SavedFilter, err error) { + sf = &SavedFilter{} + exists, err := s. Where("id = ?", id). - Get(s) + Get(sf) if err != nil { return nil, err } @@ -132,10 +133,10 @@ func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter." // @Failure 500 {object} models.Message "Internal error" // @Router /filters/{id} [get] -func (s *SavedFilter) ReadOne() error { +func (sf *SavedFilter) ReadOne(s *xorm.Session) error { // s already contains almost the full saved filter from the rights check, we only need to add the user - u, err := user.GetUserByID(s.OwnerID) - s.Owner = u + u, err := user.GetUserByID(s, sf.OwnerID) + sf.Owner = u return err } @@ -152,15 +153,15 @@ func (s *SavedFilter) ReadOne() error { // @Failure 404 {object} web.HTTPError "The saved filter does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /filters/{id} [post] -func (s *SavedFilter) Update() error { - _, err := x. - Where("id = ?", s.ID). +func (sf *SavedFilter) Update(s *xorm.Session) error { + _, err := s. + Where("id = ?", sf.ID). Cols( "title", "description", "filters", ). - Update(s) + Update(sf) return err } @@ -177,7 +178,9 @@ func (s *SavedFilter) Update() error { // @Failure 404 {object} web.HTTPError "The saved filter does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /filters/{id} [delete] -func (s *SavedFilter) Delete() error { - _, err := x.Where("id = ?", s.ID).Delete(s) +func (sf *SavedFilter) Delete(s *xorm.Session) error { + _, err := s. + Where("id = ?", sf.ID). + Delete(sf) return err } diff --git a/pkg/models/saved_filters_rights.go b/pkg/models/saved_filters_rights.go index 2e869fe8..7866958e 100644 --- a/pkg/models/saved_filters_rights.go +++ b/pkg/models/saved_filters_rights.go @@ -16,28 +16,31 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanRead checks if a user has the right to read a saved filter -func (s *SavedFilter) CanRead(auth web.Auth) (bool, int, error) { - can, err := s.canDoFilter(auth) +func (sf *SavedFilter) CanRead(s *xorm.Session, auth web.Auth) (bool, int, error) { + can, err := sf.canDoFilter(s, auth) return can, int(RightAdmin), err } // CanDelete checks if a user has the right to delete a saved filter -func (s *SavedFilter) CanDelete(auth web.Auth) (bool, error) { - return s.canDoFilter(auth) +func (sf *SavedFilter) CanDelete(s *xorm.Session, auth web.Auth) (bool, error) { + return sf.canDoFilter(s, auth) } // CanUpdate checks if a user has the right to update a saved filter -func (s *SavedFilter) CanUpdate(auth web.Auth) (bool, error) { +func (sf *SavedFilter) CanUpdate(s *xorm.Session, auth web.Auth) (bool, error) { // A normal check would replace the passed struct which in our case would override the values we want to update. - sf := &SavedFilter{ID: s.ID} - return sf.canDoFilter(auth) + sff := &SavedFilter{ID: sf.ID} + return sff.canDoFilter(s, auth) } // CanCreate checks if a user has the right to update a saved filter -func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) { +func (sf *SavedFilter) CanCreate(s *xorm.Session, auth web.Auth) (bool, error) { if _, is := auth.(*LinkSharing); is { return false, nil } @@ -46,23 +49,23 @@ func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) { } // Helper function to check saved filter rights sind they all have the same logic -func (s *SavedFilter) canDoFilter(auth web.Auth) (can bool, err error) { +func (sf *SavedFilter) canDoFilter(s *xorm.Session, auth web.Auth) (can bool, err error) { // Link shares can't view or modify saved filters, therefore we can error out right away if _, is := auth.(*LinkSharing); is { - return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: s.ID} + return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: sf.ID} } - sf, err := getSavedFilterSimpleByID(s.ID) + sff, err := getSavedFilterSimpleByID(s, sf.ID) if err != nil { return false, err } // Only owners are allowed to do something with a saved filter - if sf.OwnerID != auth.GetID() { + if sff.OwnerID != auth.GetID() { return false, nil } - *s = *sf + *sf = *sff return true, nil } diff --git a/pkg/models/saved_filters_test.go b/pkg/models/saved_filters_test.go index 7c6fb441..ea5afc7f 100644 --- a/pkg/models/saved_filters_test.go +++ b/pkg/models/saved_filters_test.go @@ -45,6 +45,9 @@ func TestSavedFilter_getFilterIDFromListID(t *testing.T) { func TestSavedFilter_Create(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ Title: "test", Description: "Lorem Ipsum dolor sit amet", @@ -52,9 +55,11 @@ func TestSavedFilter_Create(t *testing.T) { } u := &user.User{ID: 1} - err := sf.Create(u) + err := sf.Create(s, u) assert.NoError(t, err) assert.Equal(t, u.ID, sf.OwnerID) + err = s.Commit() + assert.NoError(t, err) vals := map[string]interface{}{ "title": "'test'", "description": "'Lorem Ipsum dolor sit amet'", @@ -62,7 +67,7 @@ func TestSavedFilter_Create(t *testing.T) { "owner_id": 1, } // Postgres can't compare json values directly, see https://dba.stackexchange.com/a/106290/210721 - if x.Dialect().URI().DBType == schemas.POSTGRES { + if db.Type() == schemas.POSTGRES { vals["filters::jsonb"] = vals["filters"].(string) + "::jsonb" delete(vals, "filters") } @@ -72,26 +77,34 @@ func TestSavedFilter_Create(t *testing.T) { func TestSavedFilter_ReadOne(t *testing.T) { user1 := &user.User{ID: 1} db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, } // canRead pre-populates the struct - _, _, err := sf.CanRead(user1) + _, _, err := sf.CanRead(s, user1) assert.NoError(t, err) - err = sf.ReadOne() + err = sf.ReadOne(s) assert.NoError(t, err) assert.NotNil(t, sf.Owner) } func TestSavedFilter_Update(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "NewTitle", Description: "", // Explicitly reset the description Filters: &TaskCollection{}, } - err := sf.Update() + err := sf.Update(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "saved_filters", map[string]interface{}{ "id": 1, @@ -102,10 +115,15 @@ func TestSavedFilter_Update(t *testing.T) { func TestSavedFilter_Delete(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, } - err := sf.Delete() + err := sf.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertMissing(t, "saved_filters", map[string]interface{}{ "id": 1, @@ -120,50 +138,65 @@ func TestSavedFilter_Rights(t *testing.T) { t.Run("create", func(t *testing.T) { // Should always be true db.LoadAndAssertFixtures(t) - can, err := (&SavedFilter{}).CanCreate(user1) + s := db.NewSession() + defer s.Close() + + can, err := (&SavedFilter{}).CanCreate(s, user1) assert.NoError(t, err) assert.True(t, can) }) t.Run("read", func(t *testing.T) { t.Run("owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, max, err := sf.CanRead(user1) + can, max, err := sf.CanRead(s, user1) assert.NoError(t, err) assert.Equal(t, int(RightAdmin), max) assert.True(t, can) }) t.Run("not owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, _, err := sf.CanRead(user2) + can, _, err := sf.CanRead(s, user2) assert.NoError(t, err) assert.False(t, can) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 9999, Title: "Lorem", } - can, _, err := sf.CanRead(user1) + can, _, err := sf.CanRead(s, user1) assert.Error(t, err) assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.False(t, can) }) t.Run("link share", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, _, err := sf.CanRead(ls) + can, _, err := sf.CanRead(s, ls) assert.Error(t, err) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.False(t, can) @@ -172,42 +205,54 @@ func TestSavedFilter_Rights(t *testing.T) { t.Run("update", func(t *testing.T) { t.Run("owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, err := sf.CanUpdate(user1) + can, err := sf.CanUpdate(s, user1) assert.NoError(t, err) assert.True(t, can) }) t.Run("not owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, err := sf.CanUpdate(user2) + can, err := sf.CanUpdate(s, user2) assert.NoError(t, err) assert.False(t, can) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 9999, Title: "Lorem", } - can, err := sf.CanUpdate(user1) + can, err := sf.CanUpdate(s, user1) assert.Error(t, err) assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.False(t, can) }) t.Run("link share", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, err := sf.CanUpdate(ls) + can, err := sf.CanUpdate(s, ls) assert.Error(t, err) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.False(t, can) @@ -216,40 +261,52 @@ func TestSavedFilter_Rights(t *testing.T) { t.Run("delete", func(t *testing.T) { t.Run("owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, } - can, err := sf.CanDelete(user1) + can, err := sf.CanDelete(s, user1) assert.NoError(t, err) assert.True(t, can) }) t.Run("not owner", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, } - can, err := sf.CanDelete(user2) + can, err := sf.CanDelete(s, user2) assert.NoError(t, err) assert.False(t, can) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 9999, Title: "Lorem", } - can, err := sf.CanDelete(user1) + can, err := sf.CanDelete(s, user1) assert.Error(t, err) assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.False(t, can) }) t.Run("link share", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + sf := &SavedFilter{ ID: 1, Title: "Lorem", } - can, err := sf.CanDelete(ls) + can, err := sf.CanDelete(s, ls) assert.Error(t, err) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.False(t, can) diff --git a/pkg/models/task_assignees.go b/pkg/models/task_assignees.go index cbf0f7fe..ce35b26f 100644 --- a/pkg/models/task_assignees.go +++ b/pkg/models/task_assignees.go @@ -46,9 +46,9 @@ type TaskAssigneeWithUser struct { user.User `xorm:"extends"` } -func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) { +func getRawTaskAssigneesForTasks(s *xorm.Session, taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) { taskAssignees = []*TaskAssigneeWithUser{} - err = x.Table("task_assignees"). + err = s.Table("task_assignees"). Select("task_id, users.*"). In("task_id", taskIDs). Join("INNER", "users", "task_assignees.user_id = users.id"). @@ -60,7 +60,7 @@ func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssignee func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err error) { // Load the current assignees - currentAssignees, err := getRawTaskAssigneesForTasks([]int64{t.ID}) + currentAssignees, err := getRawTaskAssigneesForTasks(s, []int64{t.ID}) if err != nil { return err } @@ -118,8 +118,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err } // Get the list to perform later checks - list := List{ID: t.ListID} - err = list.GetSimpleByID() + list, err := GetListSimpleByID(s, t.ListID) if err != nil { return } @@ -133,7 +132,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err } // Add the new assignee - err = t.addNewAssigneeByID(u.ID, &list) + err = t.addNewAssigneeByID(s, u.ID, list) if err != nil { return err } @@ -141,7 +140,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err t.setTaskAssignees(assignees) - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -167,13 +166,13 @@ func (t *Task) setTaskAssignees(assignees []*user.User) { // @Failure 403 {object} web.HTTPError "Not allowed to delete the assignee." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/assignees/{userID} [delete] -func (la *TaskAssginee) Delete() (err error) { - _, err = x.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID}) +func (la *TaskAssginee) Delete(s *xorm.Session) (err error) { + _, err = s.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID}) if err != nil { return err } - err = updateListByTaskID(la.TaskID) + err = updateListByTaskID(s, la.TaskID) return } @@ -190,25 +189,25 @@ func (la *TaskAssginee) Delete() (err error) { // @Failure 400 {object} web.HTTPError "Invalid assignee object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/assignees [put] -func (la *TaskAssginee) Create(a web.Auth) (err error) { +func (la *TaskAssginee) Create(s *xorm.Session, a web.Auth) (err error) { // Get the list to perform later checks - list, err := GetListSimplByTaskID(la.TaskID) + list, err := GetListSimplByTaskID(s, la.TaskID) if err != nil { return } task := &Task{ID: la.TaskID} - return task.addNewAssigneeByID(la.UserID, list) + return task.addNewAssigneeByID(s, la.UserID, list) } -func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) { +func (t *Task) addNewAssigneeByID(s *xorm.Session, newAssigneeID int64, list *List) (err error) { // Check if the user exists and has access to the list - newAssignee, err := user.GetUserByID(newAssigneeID) + newAssignee, err := user.GetUserByID(s, newAssigneeID) if err != nil { return err } - canRead, _, err := list.CanRead(newAssignee) + canRead, _, err := list.CanRead(s, newAssignee) if err != nil { return err } @@ -216,7 +215,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) { return ErrUserDoesNotHaveAccessToList{list.ID, newAssigneeID} } - _, err = x.Insert(TaskAssginee{ + _, err = s.Insert(TaskAssginee{ TaskID: t.ID, UserID: newAssigneeID, }) @@ -224,7 +223,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) { return err } - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -242,13 +241,13 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) { // @Success 200 {array} user.User "The assignees" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/assignees [get] -func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { - task, err := GetListSimplByTaskID(la.TaskID) +func (la *TaskAssginee) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { + task, err := GetListSimplByTaskID(s, la.TaskID) if err != nil { return nil, 0, 0, err } - can, _, err := task.CanRead(a) + can, _, err := task.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -258,7 +257,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int limit, start := getLimitFromPageIndex(page, perPage) var taskAssignees []*user.User - query := x.Table("task_assignees"). + query := s.Table("task_assignees"). Select("users.*"). Join("INNER", "users", "task_assignees.user_id = users.id"). Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%") @@ -270,7 +269,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int return nil, 0, 0, err } - numberOfTotalItems, err = x.Table("task_assignees"). + numberOfTotalItems, err = s.Table("task_assignees"). Select("users.*"). Join("INNER", "users", "task_assignees.user_id = users.id"). Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%"). @@ -301,14 +300,12 @@ type BulkAssignees struct { // @Failure 400 {object} web.HTTPError "Invalid assignee object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/assignees/bulk [post] -func (ba *BulkAssignees) Create(a web.Auth) (err error) { - s := x.NewSession() - - task, err := GetTaskByIDSimple(ba.TaskID) +func (ba *BulkAssignees) Create(s *xorm.Session, a web.Auth) (err error) { + task, err := GetTaskByIDSimple(s, ba.TaskID) if err != nil { return } - assignees, err := getRawTaskAssigneesForTasks([]int64{task.ID}) + assignees, err := getRawTaskAssigneesForTasks(s, []int64{task.ID}) if err != nil { return err } @@ -317,10 +314,5 @@ func (ba *BulkAssignees) Create(a web.Auth) (err error) { } err = task.updateTaskAssignees(s, ba.Assignees) - if err != nil { - _ = s.Rollback() - return err - } - - return s.Commit() + return } diff --git a/pkg/models/task_assignees_rights.go b/pkg/models/task_assignees_rights.go index ea8cf13f..1385ea08 100644 --- a/pkg/models/task_assignees_rights.go +++ b/pkg/models/task_assignees_rights.go @@ -18,28 +18,29 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if a user can add a new assignee -func (la *TaskAssginee) CanCreate(a web.Auth) (bool, error) { - return canDoTaskAssingee(la.TaskID, a) +func (la *TaskAssginee) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return canDoTaskAssingee(s, la.TaskID, a) } // CanCreate checks if a user can add a new assignee -func (ba *BulkAssignees) CanCreate(a web.Auth) (bool, error) { - return canDoTaskAssingee(ba.TaskID, a) +func (ba *BulkAssignees) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return canDoTaskAssingee(s, ba.TaskID, a) } // CanDelete checks if a user can delete an assignee -func (la *TaskAssginee) CanDelete(a web.Auth) (bool, error) { - return canDoTaskAssingee(la.TaskID, a) +func (la *TaskAssginee) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return canDoTaskAssingee(s, la.TaskID, a) } -func canDoTaskAssingee(taskID int64, a web.Auth) (bool, error) { +func canDoTaskAssingee(s *xorm.Session, taskID int64, a web.Auth) (bool, error) { // Check if the current user can edit the list - list, err := GetListSimplByTaskID(taskID) + list, err := GetListSimplByTaskID(s, taskID) if err != nil { return false, err } - return list.CanUpdate(a) + return list.CanUpdate(s, a) } diff --git a/pkg/models/task_attachment.go b/pkg/models/task_attachment.go index a55fd611..df438ece 100644 --- a/pkg/models/task_attachment.go +++ b/pkg/models/task_attachment.go @@ -23,6 +23,7 @@ import ( "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // TaskAttachment is the definition of a task attachment @@ -49,7 +50,7 @@ func (TaskAttachment) TableName() string { // NewAttachment creates a new task attachment // Note: I'm not sure if only accepting an io.ReadCloser and not an afero.File or os.File instead is a good way of doing things. -func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsize uint64, a web.Auth) error { +func (ta *TaskAttachment) NewAttachment(s *xorm.Session, f io.ReadCloser, realname string, realsize uint64, a web.Auth) error { // Store the file file, err := files.Create(f, realname, realsize, a) @@ -64,7 +65,7 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi // Add an entry to the db ta.FileID = file.ID ta.CreatedByID = a.GetID() - _, err = x.Insert(ta) + _, err = s.Insert(ta) if err != nil { // remove the uploaded file if adding it to the db fails if err2 := file.Delete(); err2 != nil { @@ -77,8 +78,8 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi } // ReadOne returns a task attachment -func (ta *TaskAttachment) ReadOne() (err error) { - exists, err := x.Where("id = ?", ta.ID).Get(ta) +func (ta *TaskAttachment) ReadOne(s *xorm.Session) (err error) { + exists, err := s.Where("id = ?", ta.ID).Get(ta) if err != nil { return } @@ -110,12 +111,12 @@ func (ta *TaskAttachment) ReadOne() (err error) { // @Failure 404 {object} models.Message "The task does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id}/attachments [get] -func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (ta *TaskAttachment) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { attachments := []*TaskAttachment{} limit, start := getLimitFromPageIndex(page, perPage) - query := x. + query := s. Where("task_id = ?", ta.TaskID) if limit > 0 { query = query.Limit(limit, start) @@ -133,13 +134,13 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i } fs := make(map[int64]*files.File) - err = x.In("id", fileIDs).Find(&fs) + err = s.In("id", fileIDs).Find(&fs) if err != nil { return nil, 0, 0, err } us := make(map[int64]*user.User) - err = x.In("id", userIDs).Find(&us) + err = s.In("id", userIDs).Find(&us) if err != nil { return nil, 0, 0, err } @@ -153,7 +154,7 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i r.CreatedBy = us[r.CreatedByID] } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Where("task_id = ?", ta.TaskID). Count(&TaskAttachment{}) return attachments, len(attachments), numberOfTotalItems, err @@ -173,15 +174,17 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i // @Failure 404 {object} models.Message "The task does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id}/attachments/{attachmentID} [delete] -func (ta *TaskAttachment) Delete() error { +func (ta *TaskAttachment) Delete(s *xorm.Session) error { // Load the attachment - err := ta.ReadOne() + err := ta.ReadOne(s) if err != nil && !files.IsErrFileDoesNotExist(err) { return err } // Delete it - _, err = x.Where("task_id = ? AND id = ?", ta.TaskID, ta.ID).Delete(ta) + _, err = s. + Where("task_id = ? AND id = ?", ta.TaskID, ta.ID). + Delete(ta) if err != nil { return err } @@ -195,9 +198,9 @@ func (ta *TaskAttachment) Delete() error { return err } -func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment, err error) { +func getTaskAttachmentsByTaskIDs(s *xorm.Session, taskIDs []int64) (attachments []*TaskAttachment, err error) { attachments = []*TaskAttachment{} - err = x. + err = s. In("task_id", taskIDs). Find(&attachments) if err != nil { @@ -213,13 +216,13 @@ func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment // Get all files fs := make(map[int64]*files.File) - err = x.In("id", fileIDs).Find(&fs) + err = s.In("id", fileIDs).Find(&fs) if err != nil { return } users := make(map[int64]*user.User) - err = x.In("id", userIDs).Find(&users) + err = s.In("id", userIDs).Find(&users) if err != nil { return } diff --git a/pkg/models/task_attachment_rights.go b/pkg/models/task_attachment_rights.go index 5e9bbb6f..0bc4248f 100644 --- a/pkg/models/task_attachment_rights.go +++ b/pkg/models/task_attachment_rights.go @@ -16,25 +16,28 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanRead checks if the user can see an attachment -func (ta *TaskAttachment) CanRead(a web.Auth) (bool, int, error) { +func (ta *TaskAttachment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { t := &Task{ID: ta.TaskID} - return t.CanRead(a) + return t.CanRead(s, a) } // CanDelete checks if the user can delete an attachment -func (ta *TaskAttachment) CanDelete(a web.Auth) (bool, error) { +func (ta *TaskAttachment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { t := &Task{ID: ta.TaskID} - return t.CanWrite(a) + return t.CanWrite(s, a) } // CanCreate checks if the user can create an attachment -func (ta *TaskAttachment) CanCreate(a web.Auth) (bool, error) { - t, err := GetTaskByIDSimple(ta.TaskID) +func (ta *TaskAttachment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + t, err := GetTaskByIDSimple(s, ta.TaskID) if err != nil { return false, err } - return t.CanCreate(a) + return t.CanCreate(s, a) } diff --git a/pkg/models/task_attachment_test.go b/pkg/models/task_attachment_test.go index 8f198535..286bad13 100644 --- a/pkg/models/task_attachment_test.go +++ b/pkg/models/task_attachment_test.go @@ -33,11 +33,14 @@ import ( func TestTaskAttachment_ReadOne(t *testing.T) { t.Run("Normal File", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) ta := &TaskAttachment{ ID: 1, } - err := ta.ReadOne() + err := ta.ReadOne(s) assert.NoError(t, err) assert.NotNil(t, ta.File) assert.True(t, ta.File.ID == ta.FileID && ta.FileID != 0) @@ -54,21 +57,27 @@ func TestTaskAttachment_ReadOne(t *testing.T) { }) t.Run("Nonexisting Attachment", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) ta := &TaskAttachment{ ID: 9999, } - err := ta.ReadOne() + err := ta.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrTaskAttachmentDoesNotExist(err)) }) t.Run("Existing Attachment, Nonexisting File", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) ta := &TaskAttachment{ ID: 2, } - err := ta.ReadOne() + err := ta.ReadOne(s) assert.Error(t, err) assert.EqualError(t, err, "file 9999 does not exist") }) @@ -94,6 +103,9 @@ func (t *testfile) Close() error { func TestTaskAttachment_NewAttachment(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) // Assert the file is being stored correctly ta := TaskAttachment{ @@ -104,7 +116,7 @@ func TestTaskAttachment_NewAttachment(t *testing.T) { } testuser := &user.User{ID: 1} - err := ta.NewAttachment(tf, "testfile", 100, testuser) + err := ta.NewAttachment(s, tf, "testfile", 100, testuser) assert.NoError(t, err) assert.NotEqual(t, 0, ta.FileID) _, err = files.FileStat("files/" + strconv.FormatInt(ta.FileID, 10)) @@ -125,9 +137,12 @@ func TestTaskAttachment_NewAttachment(t *testing.T) { func TestTaskAttachment_ReadAll(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) ta := &TaskAttachment{TaskID: 1} - as, _, _, err := ta.ReadAll(&user.User{ID: 1}, "", 0, 50) + as, _, _, err := ta.ReadAll(s, &user.User{ID: 1}, "", 0, 50) attachments, _ := as.([]*TaskAttachment) assert.NoError(t, err) assert.Len(t, attachments, 2) @@ -136,10 +151,13 @@ func TestTaskAttachment_ReadAll(t *testing.T) { func TestTaskAttachment_Delete(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + files.InitTestFileFixtures(t) t.Run("Normal", func(t *testing.T) { ta := &TaskAttachment{ID: 1} - err := ta.Delete() + err := ta.Delete(s) assert.NoError(t, err) // Check if the file itself was deleted _, err = files.FileStat("/1") // The new file has the id 2 since it's the second attachment @@ -148,14 +166,14 @@ func TestTaskAttachment_Delete(t *testing.T) { t.Run("Nonexisting", func(t *testing.T) { files.InitTestFileFixtures(t) ta := &TaskAttachment{ID: 9999} - err := ta.Delete() + err := ta.Delete(s) assert.Error(t, err) assert.True(t, IsErrTaskAttachmentDoesNotExist(err)) }) t.Run("Existing attachment, nonexisting file", func(t *testing.T) { files.InitTestFileFixtures(t) ta := &TaskAttachment{ID: 2} - err := ta.Delete() + err := ta.Delete(s) assert.NoError(t, err) }) } @@ -165,15 +183,21 @@ func TestTaskAttachment_Rights(t *testing.T) { t.Run("Can Read", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 1} - can, _, err := ta.CanRead(u) + can, _, err := ta.CanRead(s, u) assert.NoError(t, err) assert.True(t, can) }) t.Run("Forbidden", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 14} - can, _, err := ta.CanRead(u) + can, _, err := ta.CanRead(s, u) assert.NoError(t, err) assert.False(t, can) }) @@ -181,22 +205,31 @@ func TestTaskAttachment_Rights(t *testing.T) { t.Run("Can Delete", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 1} - can, err := ta.CanDelete(u) + can, err := ta.CanDelete(s, u) assert.NoError(t, err) assert.True(t, can) }) t.Run("Forbidden, no access", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 14} - can, err := ta.CanDelete(u) + can, err := ta.CanDelete(s, u) assert.NoError(t, err) assert.False(t, can) }) t.Run("Forbidden, shared read only", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 15} - can, err := ta.CanDelete(u) + can, err := ta.CanDelete(s, u) assert.NoError(t, err) assert.False(t, can) }) @@ -204,22 +237,31 @@ func TestTaskAttachment_Rights(t *testing.T) { t.Run("Can Create", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 1} - can, err := ta.CanCreate(u) + can, err := ta.CanCreate(s, u) assert.NoError(t, err) assert.True(t, can) }) t.Run("Forbidden, no access", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 14} - can, err := ta.CanCreate(u) + can, err := ta.CanCreate(s, u) assert.NoError(t, err) assert.False(t, can) }) t.Run("Forbidden, shared read only", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + ta := &TaskAttachment{TaskID: 15} - can, err := ta.CanCreate(u) + can, err := ta.CanCreate(s, u) assert.NoError(t, err) assert.False(t, can) }) diff --git a/pkg/models/task_collection.go b/pkg/models/task_collection.go index a0c56a15..fe777724 100644 --- a/pkg/models/task_collection.go +++ b/pkg/models/task_collection.go @@ -20,6 +20,7 @@ package models import ( "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // TaskCollection is a struct used to hold filter details and not clutter the Task struct with information not related to actual tasks. @@ -100,17 +101,17 @@ func validateTaskField(fieldName string) error { // @Success 200 {array} models.Task "The tasks" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{listID}/tasks [get] -func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { +func (tf *TaskCollection) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { // If the list id is < -1 this means we're dealing with a saved filter - in that case we get and populate the filter // -1 is the favorites list which works as intended if tf.ListID < -1 { - s, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(tf.ListID)) + sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(tf.ListID)) if err != nil { return nil, 0, 0, err } - return s.getTaskCollection().ReadAll(a, search, page, perPage) + return sf.getTaskCollection().ReadAll(s, a, search, page, perPage) } if len(tf.SortByArr) > 0 { @@ -156,28 +157,30 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i shareAuth, is := a.(*LinkSharing) if is { - list := &List{ID: shareAuth.ListID} - err := list.GetSimpleByID() + list, err := GetListSimpleByID(s, shareAuth.ListID) if err != nil { return nil, 0, 0, err } - return getTasksForLists([]*List{list}, a, taskopts) + return getTasksForLists(s, []*List{list}, a, taskopts) } // If the list ID is not set, we get all tasks for the user. // This allows to use this function in Task.ReadAll with a possibility to deprecate the latter at some point. if tf.ListID == 0 { - tf.Lists, _, _, err = getRawListsForUser(&listOptions{ - user: &user.User{ID: a.GetID()}, - page: -1, - }) + tf.Lists, _, _, err = getRawListsForUser( + s, + &listOptions{ + user: &user.User{ID: a.GetID()}, + page: -1, + }, + ) if err != nil { return nil, 0, 0, err } } else { // Check the list exists and the user has acess on it list := &List{ID: tf.ListID} - canRead, _, err := list.CanRead(a) + canRead, _, err := list.CanRead(s, a) if err != nil { return nil, 0, 0, err } @@ -187,5 +190,5 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i tf.Lists = []*List{{ID: tf.ListID}} } - return getTasksForLists(tf.Lists, a, taskopts) + return getTasksForLists(s, tf.Lists, a, taskopts) } diff --git a/pkg/models/task_collection_test.go b/pkg/models/task_collection_test.go index b5925ab9..29c8fc7c 100644 --- a/pkg/models/task_collection_test.go +++ b/pkg/models/task_collection_test.go @@ -986,6 +986,8 @@ func TestTaskCollection_ReadAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() lt := &TaskCollection{ ListID: tt.fields.ListID, @@ -1000,7 +1002,7 @@ func TestTaskCollection_ReadAll(t *testing.T) { CRUDable: tt.fields.CRUDable, Rights: tt.fields.Rights, } - got, _, _, err := lt.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) + got, _, _, err := lt.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50) if (err != nil) != tt.wantErr { t.Errorf("Test %s, Task.ReadAll() error = %v, wantErr %v", tt.name, err, tt.wantErr) return diff --git a/pkg/models/task_comment_rights.go b/pkg/models/task_comment_rights.go index b81f72d2..9ff80871 100644 --- a/pkg/models/task_comment_rights.go +++ b/pkg/models/task_comment_rights.go @@ -17,28 +17,31 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanRead checks if a user can read a comment -func (tc *TaskComment) CanRead(a web.Auth) (bool, int, error) { +func (tc *TaskComment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { t := Task{ID: tc.TaskID} - return t.CanRead(a) + return t.CanRead(s, a) } // CanDelete checks if a user can delete a comment -func (tc *TaskComment) CanDelete(a web.Auth) (bool, error) { +func (tc *TaskComment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { t := Task{ID: tc.TaskID} - return t.CanWrite(a) + return t.CanWrite(s, a) } // CanUpdate checks if a user can update a comment -func (tc *TaskComment) CanUpdate(a web.Auth) (bool, error) { +func (tc *TaskComment) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { t := Task{ID: tc.TaskID} - return t.CanWrite(a) + return t.CanWrite(s, a) } // CanCreate checks if a user can create a new comment -func (tc *TaskComment) CanCreate(a web.Auth) (bool, error) { +func (tc *TaskComment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { t := Task{ID: tc.TaskID} - return t.CanWrite(a) + return t.CanWrite(s, a) } diff --git a/pkg/models/task_comments.go b/pkg/models/task_comments.go index a0c7ae3a..adf619c2 100644 --- a/pkg/models/task_comments.go +++ b/pkg/models/task_comments.go @@ -20,6 +20,8 @@ package models import ( "time" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" ) @@ -57,19 +59,19 @@ func (tc *TaskComment) TableName() string { // @Failure 400 {object} web.HTTPError "Invalid task comment object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/comments [put] -func (tc *TaskComment) Create(a web.Auth) (err error) { +func (tc *TaskComment) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the task exists - _, err = GetTaskSimple(&Task{ID: tc.TaskID}) + _, err = GetTaskSimple(s, &Task{ID: tc.TaskID}) if err != nil { return err } tc.AuthorID = a.GetID() - _, err = x.Insert(tc) + _, err = s.Insert(tc) if err != nil { return } - tc.Author, err = user.GetUserByID(a.GetID()) + tc.Author, err = user.GetUserByID(s, a.GetID()) return } @@ -87,8 +89,11 @@ func (tc *TaskComment) Create(a web.Auth) (err error) { // @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/comments/{commentID} [delete] -func (tc *TaskComment) Delete() error { - deleted, err := x.ID(tc.ID).NoAutoCondition().Delete(tc) +func (tc *TaskComment) Delete(s *xorm.Session) error { + deleted, err := s. + ID(tc.ID). + NoAutoCondition(). + Delete(tc) if deleted == 0 { return ErrTaskCommentDoesNotExist{ID: tc.ID} } @@ -109,8 +114,11 @@ func (tc *TaskComment) Delete() error { // @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/comments/{commentID} [post] -func (tc *TaskComment) Update() error { - updated, err := x.ID(tc.ID).Cols("comment").Update(tc) +func (tc *TaskComment) Update(s *xorm.Session) error { + updated, err := s. + ID(tc.ID). + Cols("comment"). + Update(tc) if updated == 0 { return ErrTaskCommentDoesNotExist{ID: tc.ID} } @@ -131,8 +139,8 @@ func (tc *TaskComment) Update() error { // @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/comments/{commentID} [get] -func (tc *TaskComment) ReadOne() (err error) { - exists, err := x.Get(tc) +func (tc *TaskComment) ReadOne(s *xorm.Session) (err error) { + exists, err := s.Get(tc) if err != nil { return } @@ -145,7 +153,7 @@ func (tc *TaskComment) ReadOne() (err error) { // Get the author author := &user.User{} - _, err = x. + _, err = s. Where("id = ?", tc.AuthorID). Get(author) tc.Author = author @@ -163,10 +171,10 @@ func (tc *TaskComment) ReadOne() (err error) { // @Success 200 {array} models.TaskComment "The array with all task comments" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/comments [get] -func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (tc *TaskComment) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { // Check if the user has access to the task - canRead, _, err := tc.CanRead(auth) + canRead, _, err := tc.CanRead(s, auth) if err != nil { return nil, 0, 0, err } @@ -184,7 +192,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i limit, start := getLimitFromPageIndex(page, perPage) comments := []*TaskComment{} - query := x. + query := s. Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). Join("LEFT", "users", "users.id = task_comments.author_id") if limit > 0 { @@ -197,7 +205,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i // Get all authors authors := make(map[int64]*user.User) - err = x. + err = s. Select("users.*"). Table("task_comments"). Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). @@ -211,7 +219,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i comment.Author = authors[comment.AuthorID] } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). Count(&TaskCommentWithAuthor{}) return comments, len(comments), numberOfTotalItems, err diff --git a/pkg/models/task_comments_test.go b/pkg/models/task_comments_test.go index 41839dee..0d8bc33f 100644 --- a/pkg/models/task_comments_test.go +++ b/pkg/models/task_comments_test.go @@ -28,14 +28,20 @@ func TestTaskComment_Create(t *testing.T) { u := &user.User{ID: 1} t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ Comment: "test", TaskID: 1, } - err := tc.Create(u) + err := tc.Create(s, u) assert.NoError(t, err) assert.Equal(t, "test", tc.Comment) assert.Equal(t, int64(1), tc.Author.ID) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "task_comments", map[string]interface{}{ "id": tc.ID, "author_id": u.ID, @@ -45,11 +51,14 @@ func TestTaskComment_Create(t *testing.T) { }) t.Run("nonexisting task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ Comment: "test", TaskID: 99999, } - err := tc.Create(u) + err := tc.Create(s, u) assert.Error(t, err) assert.True(t, IsErrTaskDoesNotExist(err)) }) @@ -58,17 +67,26 @@ func TestTaskComment_Create(t *testing.T) { func TestTaskComment_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ID: 1} - err := tc.Delete() + err := tc.Delete(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertMissing(t, "task_comments", map[string]interface{}{ "id": 1, }) }) t.Run("nonexisting comment", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ID: 9999} - err := tc.Delete() + err := tc.Delete(s) assert.Error(t, err) assert.True(t, IsErrTaskCommentDoesNotExist(err)) }) @@ -77,12 +95,18 @@ func TestTaskComment_Delete(t *testing.T) { func TestTaskComment_Update(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ ID: 1, Comment: "testing", } - err := tc.Update() + err := tc.Update(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "task_comments", map[string]interface{}{ "id": 1, "comment": "testing", @@ -90,10 +114,13 @@ func TestTaskComment_Update(t *testing.T) { }) t.Run("nonexisting comment", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ ID: 9999, } - err := tc.Update() + err := tc.Update(s) assert.Error(t, err) assert.True(t, IsErrTaskCommentDoesNotExist(err)) }) @@ -102,16 +129,22 @@ func TestTaskComment_Update(t *testing.T) { func TestTaskComment_ReadOne(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ID: 1} - err := tc.ReadOne() + err := tc.ReadOne(s) assert.NoError(t, err) assert.Equal(t, "Lorem Ipsum Dolor Sit Amet", tc.Comment) assert.NotEmpty(t, tc.Author.ID) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{ID: 9999} - err := tc.ReadOne() + err := tc.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrTaskCommentDoesNotExist(err)) }) @@ -120,9 +153,12 @@ func TestTaskComment_ReadOne(t *testing.T) { func TestTaskComment_ReadAll(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{TaskID: 1} u := &user.User{ID: 1} - result, resultCount, total, err := tc.ReadAll(u, "", 0, -1) + result, resultCount, total, err := tc.ReadAll(s, u, "", 0, -1) resultComment := result.([]*TaskComment) assert.NoError(t, err) assert.Equal(t, 1, resultCount) @@ -133,9 +169,12 @@ func TestTaskComment_ReadAll(t *testing.T) { }) t.Run("no access to task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tc := &TaskComment{TaskID: 14} u := &user.User{ID: 1} - _, _, _, err := tc.ReadAll(u, "", 0, -1) + _, _, _, err := tc.ReadAll(s, u, "", 0, -1) assert.Error(t, err) assert.True(t, IsErrGenericForbidden(err)) }) diff --git a/pkg/models/task_relation.go b/pkg/models/task_relation.go index 6ab162ff..0b4d826c 100644 --- a/pkg/models/task_relation.go +++ b/pkg/models/task_relation.go @@ -20,6 +20,8 @@ package models import ( "time" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" ) @@ -117,7 +119,7 @@ type RelatedTaskMap map[RelationKind][]*Task // @Failure 400 {object} web.HTTPError "Invalid task relation object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/relations [put] -func (rel *TaskRelation) Create(a web.Auth) error { +func (rel *TaskRelation) Create(s *xorm.Session, a web.Auth) error { // Check if both tasks are the same if rel.TaskID == rel.OtherTaskID { @@ -128,7 +130,7 @@ func (rel *TaskRelation) Create(a web.Auth) error { } // Check if the relation already exists, in one form or the other. - exists, err := x. + exists, err := s. Where("(task_id = ? AND other_task_id = ? AND relation_kind = ?) OR (task_id = ? AND other_task_id = ? AND relation_kind = ?)", rel.TaskID, rel.OtherTaskID, rel.RelationKind, rel.TaskID, rel.OtherTaskID, rel.RelationKind). Exist(rel) @@ -180,7 +182,7 @@ func (rel *TaskRelation) Create(a web.Auth) error { } // Finally insert everything - _, err = x.Insert(&[]*TaskRelation{ + _, err = s.Insert(&[]*TaskRelation{ rel, otherRelation, }) @@ -200,9 +202,9 @@ func (rel *TaskRelation) Create(a web.Auth) error { // @Failure 404 {object} web.HTTPError "The task relation was not found." // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{taskID}/relations [delete] -func (rel *TaskRelation) Delete() error { +func (rel *TaskRelation) Delete(s *xorm.Session) error { // Check if the relation exists - exists, err := x. + exists, err := s. Cols("task_id", "other_task_id", "relation_kind"). Get(rel) if err != nil { @@ -216,6 +218,6 @@ func (rel *TaskRelation) Delete() error { } } - _, err = x.Delete(rel) + _, err = s.Delete(rel) return err } diff --git a/pkg/models/task_relation_rights.go b/pkg/models/task_relation_rights.go index ca4e00b7..e68fa1a0 100644 --- a/pkg/models/task_relation_rights.go +++ b/pkg/models/task_relation_rights.go @@ -17,17 +17,20 @@ package models -import "code.vikunja.io/web" +import ( + "code.vikunja.io/web" + "xorm.io/xorm" +) // CanDelete checks if a user can delete a task relation -func (rel *TaskRelation) CanDelete(a web.Auth) (bool, error) { +func (rel *TaskRelation) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { // A user can delete a relation if it can update the base task baseTask := &Task{ID: rel.TaskID} - return baseTask.CanUpdate(a) + return baseTask.CanUpdate(s, a) } // CanCreate checks if a user can create a new relation between two relations -func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) { +func (rel *TaskRelation) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { // Check if the relation kind is valid if !rel.RelationKind.isValid() { return false, ErrInvalidRelationKind{Kind: rel.RelationKind} @@ -35,14 +38,14 @@ func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) { // Needs have write access to the base task and at least read access to the other task baseTask := &Task{ID: rel.TaskID} - has, err := baseTask.CanUpdate(a) + has, err := baseTask.CanUpdate(s, a) if err != nil || !has { return false, err } // We explicitly don't check if the two tasks are on the same list. otherTask := &Task{ID: rel.OtherTaskID} - has, _, err = otherTask.CanRead(a) + has, _, err = otherTask.CanRead(s, a) if err != nil { return false, err } diff --git a/pkg/models/task_relation_test.go b/pkg/models/task_relation_test.go index 6987c5cb..eb053e7e 100644 --- a/pkg/models/task_relation_test.go +++ b/pkg/models/task_relation_test.go @@ -28,13 +28,17 @@ import ( func TestTaskRelation_Create(t *testing.T) { t.Run("Normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 2, RelationKind: RelationKindSubtask, } - err := rel.Create(&user.User{ID: 1}) + err := rel.Create(s, &user.User{ID: 1}) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "task_relations", map[string]interface{}{ "task_id": 1, @@ -45,13 +49,17 @@ func TestTaskRelation_Create(t *testing.T) { }) t.Run("Two Tasks In Different Lists", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 13, RelationKind: RelationKindSubtask, } - err := rel.Create(&user.User{ID: 1}) + err := rel.Create(s, &user.User{ID: 1}) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "task_relations", map[string]interface{}{ "task_id": 1, @@ -62,24 +70,28 @@ func TestTaskRelation_Create(t *testing.T) { }) t.Run("Already Existing", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 29, RelationKind: RelationKindSubtask, } - err := rel.Create(&user.User{ID: 1}) + err := rel.Create(s, &user.User{ID: 1}) assert.Error(t, err) assert.True(t, IsErrRelationAlreadyExists(err)) }) t.Run("Same Task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 1, } - err := rel.Create(&user.User{ID: 1}) + err := rel.Create(s, &user.User{ID: 1}) assert.Error(t, err) assert.True(t, IsErrRelationTasksCannotBeTheSame(err)) }) @@ -88,13 +100,17 @@ func TestTaskRelation_Create(t *testing.T) { func TestTaskRelation_Delete(t *testing.T) { t.Run("Normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 29, RelationKind: RelationKindSubtask, } - err := rel.Delete() + err := rel.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertMissing(t, "task_relations", map[string]interface{}{ "task_id": 1, @@ -104,13 +120,15 @@ func TestTaskRelation_Delete(t *testing.T) { }) t.Run("Not existing", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 9999, OtherTaskID: 3, RelationKind: RelationKindSubtask, } - err := rel.Delete() + err := rel.Delete(s) assert.Error(t, err) assert.True(t, IsErrRelationDoesNotExist(err)) }) @@ -119,86 +137,100 @@ func TestTaskRelation_Delete(t *testing.T) { func TestTaskRelation_CanCreate(t *testing.T) { t.Run("Normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 2, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.NoError(t, err) assert.True(t, can) }) t.Run("Two tasks on different lists", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 13, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.NoError(t, err) assert.True(t, can) }) t.Run("No update rights on base task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 14, OtherTaskID: 1, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.NoError(t, err) assert.False(t, can) }) t.Run("No update rights on base task, but read rights", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 15, OtherTaskID: 1, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.NoError(t, err) assert.False(t, can) }) t.Run("No read rights on other task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 14, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.NoError(t, err) assert.False(t, can) }) t.Run("Nonexisting base task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 999999, OtherTaskID: 1, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.Error(t, err) assert.True(t, IsErrTaskDoesNotExist(err)) assert.False(t, can) }) t.Run("Nonexisting other task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() rel := TaskRelation{ TaskID: 1, OtherTaskID: 999999, RelationKind: RelationKindSubtask, } - can, err := rel.CanCreate(&user.User{ID: 1}) + can, err := rel.CanCreate(s, &user.User{ID: 1}) assert.Error(t, err) assert.True(t, IsErrTaskDoesNotExist(err)) assert.False(t, can) diff --git a/pkg/models/task_reminder.go b/pkg/models/task_reminder.go index 876d53d9..15a9d6c9 100644 --- a/pkg/models/task_reminder.go +++ b/pkg/models/task_reminder.go @@ -19,6 +19,9 @@ package models import ( "time" + "code.vikunja.io/api/pkg/db" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/cron" "code.vikunja.io/api/pkg/log" @@ -44,10 +47,10 @@ type taskUser struct { User *user.User `xorm:"extends"` } -func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) { +func getTaskUsersForTasks(s *xorm.Session, taskIDs []int64) (taskUsers []*taskUser, err error) { // Get all creators of tasks creators := make(map[int64]*user.User, len(taskIDs)) - err = x. + err = s. Select("users.id, users.username, users.email, users.name"). Join("LEFT", "tasks", "tasks.created_by_id = users.id"). In("tasks.id", taskIDs). @@ -58,13 +61,13 @@ func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) { return } - assignees, err := getRawTaskAssigneesForTasks(taskIDs) + assignees, err := getRawTaskAssigneesForTasks(s, taskIDs) if err != nil { return } taskMap := make(map[int64]*Task, len(taskIDs)) - err = x.In("id", taskIDs).Find(&taskMap) + err = s.In("id", taskIDs).Find(&taskMap) if err != nil { return } @@ -106,6 +109,8 @@ func RegisterReminderCron() { log.Debugf("[Task Reminder Cron] Timezone is %s", tz) + s := db.NewSession() + err := cron.Schedule("* * * * *", func() { // By default, time.Now() includes nanoseconds which we don't save. That results in getting the wrong dates, // so we make sure the time we use to get the reminders don't contain nanoseconds. @@ -116,7 +121,7 @@ func RegisterReminderCron() { log.Debugf("[Task Reminder Cron] Looking for reminders between %s and %s to send...", now, nextMinute) reminders := []*TaskReminder{} - err := x. + err := s. Where("reminder >= ? and reminder < ?", now.Format(dbFormat), nextMinute.Format(dbFormat)). Find(&reminders) if err != nil { @@ -136,7 +141,7 @@ func RegisterReminderCron() { taskIDs = append(taskIDs, r.TaskID) } - users, err := getTaskUsersForTasks(taskIDs) + users, err := getTaskUsersForTasks(s, taskIDs) if err != nil { log.Errorf("[Task Reminder Cron] Could not get task users to send them reminders: %s", err) return diff --git a/pkg/models/tasks.go b/pkg/models/tasks.go index dab48ac8..fb49dc6d 100644 --- a/pkg/models/tasks.go +++ b/pkg/models/tasks.go @@ -22,6 +22,8 @@ import ( "strconv" "time" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/user" @@ -153,7 +155,7 @@ type taskOptions struct { // @Success 200 {array} models.Task "The tasks" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/all [get] -func (t *Task) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { +func (t *Task) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { return nil, 0, 0, nil } @@ -209,7 +211,7 @@ func getFilterCondForSeparateTable(table string, concat taskFilterConcatinator, } //nolint:gocyclo -func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { +func getRawTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { // If the user does not have any lists, don't try to get any tasks if len(lists) == 0 { @@ -253,7 +255,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] // Postgres sorts by default entries with null values after ones with values. // To make that consistent with the sort order we have and other dbms, we're adding a separate clause here. - if x.Dialect().URI().DBType == schemas.POSTGRES { + if db.Type() == schemas.POSTGRES { if param.orderBy == orderAscending { orderby += " NULLS FIRST" } @@ -324,9 +326,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] } // Then return all tasks for that lists - query := x.NewSession(). - OrderBy(orderby) - queryCount := x.NewSession() + var where builder.Cond if len(opts.search) > 0 { // Postgres' is case sensitive by default. @@ -335,11 +335,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] // See https://stackoverflow.com/q/7005302/10924593 // Seems okay to use that now, we may need to find a better solution overall in the future. if config.DatabaseType.GetString() == "postgres" { - query = query.Where("title ILIKE ?", "%"+opts.search+"%") - queryCount = queryCount.Where("title ILIKE ?", "%"+opts.search+"%") + where = builder.Expr("title ILIKE ?", "%"+opts.search+"%") } else { - query = query.Where("title LIKE ?", "%"+opts.search+"%") - queryCount = queryCount.Where("title LIKE ?", "%"+opts.search+"%") + where = &builder.Like{"title", "%" + opts.search + "%"} } } @@ -352,10 +350,13 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] if hasFavoriteLists { // Make sure users can only see their favorites - userLists, _, _, err := getRawListsForUser(&listOptions{ - user: &user.User{ID: a.GetID()}, - page: -1, - }) + userLists, _, _, err := getRawListsForUser( + s, + &listOptions{ + user: &user.User{ID: a.GetID()}, + page: -1, + }, + ) if err != nil { return nil, 0, 0, err } @@ -399,32 +400,31 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] filters = append(filters, cond) } - query = query.Where(listCond) - queryCount = queryCount.Where(listCond) - + var filterCond builder.Cond if len(filters) > 0 { if opts.filterConcat == filterConcatOr { - query = query.Where(builder.Or(filters...)) - queryCount = queryCount.Where(builder.Or(filters...)) + filterCond = builder.Or(filters...) } if opts.filterConcat == filterConcatAnd { - query = query.Where(builder.And(filters...)) - queryCount = queryCount.Where(builder.And(filters...)) + filterCond = builder.And(filters...) } } limit, start := getLimitFromPageIndex(opts.page, opts.perPage) + cond := builder.And(listCond, where, filterCond) + query := s.Where(cond) if limit > 0 { query = query.Limit(limit, start) } tasks = []*Task{} - err = query.Find(&tasks) + err = query.OrderBy(orderby).Find(&tasks) if err != nil { return nil, 0, 0, err } + queryCount := s.Where(cond) totalItems, err = queryCount. Count(&Task{}) if err != nil { @@ -434,9 +434,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks [] return tasks, len(tasks), totalItems, nil } -func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { +func getTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { - tasks, resultCount, totalItems, err = getRawTasksForLists(lists, a, opts) + tasks, resultCount, totalItems, err = getRawTasksForLists(s, lists, a, opts) if err != nil { return nil, 0, 0, err } @@ -446,7 +446,7 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta taskMap[t.ID] = t } - err = addMoreInfoToTasks(taskMap) + err = addMoreInfoToTasks(s, taskMap) if err != nil { return nil, 0, 0, err } @@ -455,18 +455,18 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta } // GetTaskByIDSimple returns a raw task without extra data by the task ID -func GetTaskByIDSimple(taskID int64) (task Task, err error) { +func GetTaskByIDSimple(s *xorm.Session, taskID int64) (task Task, err error) { if taskID < 1 { return Task{}, ErrTaskDoesNotExist{taskID} } - return GetTaskSimple(&Task{ID: taskID}) + return GetTaskSimple(s, &Task{ID: taskID}) } // GetTaskSimple returns a raw task without extra data -func GetTaskSimple(t *Task) (task Task, err error) { +func GetTaskSimple(s *xorm.Session, t *Task) (task Task, err error) { task = *t - exists, err := x.Get(&task) + exists, err := s.Get(&task) if err != nil { return Task{}, err } @@ -478,14 +478,14 @@ func GetTaskSimple(t *Task) (task Task, err error) { } // GetTasksByIDs returns all tasks for a list of ids -func (bt *BulkTask) GetTasksByIDs() (err error) { +func (bt *BulkTask) GetTasksByIDs(s *xorm.Session) (err error) { for _, id := range bt.IDs { if id < 1 { return ErrTaskDoesNotExist{id} } } - err = x.In("id", bt.IDs).Find(&bt.Tasks) + err = s.In("id", bt.IDs).Find(&bt.Tasks) if err != nil { return } @@ -494,9 +494,9 @@ func (bt *BulkTask) GetTasksByIDs() (err error) { } // GetTasksByUIDs gets all tasks from a bunch of uids -func GetTasksByUIDs(uids []string) (tasks []*Task, err error) { +func GetTasksByUIDs(s *xorm.Session, uids []string) (tasks []*Task, err error) { tasks = []*Task{} - err = x.In("uid", uids).Find(&tasks) + err = s.In("uid", uids).Find(&tasks) if err != nil { return } @@ -506,13 +506,13 @@ func GetTasksByUIDs(uids []string) (tasks []*Task, err error) { taskMap[t.ID] = t } - err = addMoreInfoToTasks(taskMap) + err = addMoreInfoToTasks(s, taskMap) return } -func getRemindersForTasks(taskIDs []int64) (reminders []*TaskReminder, err error) { +func getRemindersForTasks(s *xorm.Session, taskIDs []int64) (reminders []*TaskReminder, err error) { reminders = []*TaskReminder{} - err = x.In("task_id", taskIDs).Find(&reminders) + err = s.In("task_id", taskIDs).Find(&reminders) return } @@ -521,8 +521,8 @@ func (t *Task) setIdentifier(list *List) { } // Get all assignees -func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { - taskAssignees, err := getRawTaskAssigneesForTasks(taskIDs) +func addAssigneesToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) { + taskAssignees, err := getRawTaskAssigneesForTasks(s, taskIDs) if err != nil { return } @@ -538,8 +538,8 @@ func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { } // Get all labels for all the tasks -func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { - labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{ +func addLabelsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) { + labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{ TaskIDs: taskIDs, Page: -1, }) @@ -556,8 +556,8 @@ func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { } // Get task attachments -func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { - attachments, err := getTaskAttachmentsByTaskIDs(taskIDs) +func addAttachmentsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) { + attachments, err := getTaskAttachmentsByTaskIDs(s, taskIDs) if err != nil { return } @@ -568,11 +568,11 @@ func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) return } -func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, err error) { +func getTaskReminderMap(s *xorm.Session, taskIDs []int64) (taskReminders map[int64][]time.Time, err error) { taskReminders = make(map[int64][]time.Time) // Get all reminders and put them in a map to have it easier later - reminders, err := getRemindersForTasks(taskIDs) + reminders, err := getRemindersForTasks(s, taskIDs) if err != nil { return } @@ -584,9 +584,9 @@ func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, e return } -func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { +func addRelatedTasksToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) { relatedTasks := []*TaskRelation{} - err = x.In("task_id", taskIDs).Find(&relatedTasks) + err = s.In("task_id", taskIDs).Find(&relatedTasks) if err != nil { return } @@ -597,7 +597,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error relatedTaskIDs = append(relatedTaskIDs, rt.OtherTaskID) } fullRelatedTasks := make(map[int64]*Task) - err = x.In("id", relatedTaskIDs).Find(&fullRelatedTasks) + err = s.In("id", relatedTaskIDs).Find(&fullRelatedTasks) if err != nil { return } @@ -614,7 +614,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error // This function takes a map with pointers and returns a slice with pointers to tasks // It adds more stuff like assignees/labels/etc to a bunch of tasks -func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) { +func addMoreInfoToTasks(s *xorm.Session, taskMap map[int64]*Task) (err error) { // No need to iterate over users and stuff if the list doesn't have tasks if len(taskMap) == 0 { @@ -631,33 +631,33 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) { listIDs = append(listIDs, i.ListID) } - err = addAssigneesToTasks(taskIDs, taskMap) + err = addAssigneesToTasks(s, taskIDs, taskMap) if err != nil { return } - err = addLabelsToTasks(taskIDs, taskMap) + err = addLabelsToTasks(s, taskIDs, taskMap) if err != nil { return } - err = addAttachmentsToTasks(taskIDs, taskMap) + err = addAttachmentsToTasks(s, taskIDs, taskMap) if err != nil { return } - users, err := user.GetUsersByIDs(userIDs) + users, err := user.GetUsersByIDs(s, userIDs) if err != nil { return } - taskReminders, err := getTaskReminderMap(taskIDs) + taskReminders, err := getTaskReminderMap(s, taskIDs) if err != nil { return err } // Get all identifiers - lists, err := GetListsByIDs(listIDs) + lists, err := GetListsByIDs(s, listIDs) if err != nil { return err } @@ -679,7 +679,7 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) { } // Get all related tasks - err = addRelatedTasksToTasks(taskIDs, taskMap) + err = addRelatedTasksToTasks(s, taskIDs, taskMap) return } @@ -739,14 +739,8 @@ func checkBucketLimit(s *xorm.Session, t *Task, bucket *Bucket) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id} [put] -func (t *Task) Create(a web.Auth) (err error) { - s := x.NewSession() - err = createTask(s, t, a, true) - if err != nil { - _ = s.Rollback() - return err - } - return s.Commit() +func (t *Task) Create(s *xorm.Session, a web.Auth) (err error) { + return createTask(s, t, a, true) } func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err error) { @@ -759,16 +753,16 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err } // Check if the list exists - l := &List{ID: t.ListID} - if err = l.getSimpleByID(s); err != nil { - return + l, err := GetListSimpleByID(s, t.ListID) + if err != nil { + return err } if _, is := a.(*LinkSharing); is { // A negative user id indicates user share links t.CreatedByID = a.GetID() * -1 } else { - u, err := user.GetUserByID(a.GetID()) + u, err := user.GetUserByID(s, a.GetID()) if err != nil { return err } @@ -834,7 +828,7 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err t.setIdentifier(l) - err = updateListLastUpdatedS(s, &List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -853,21 +847,17 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id} [post] //nolint:gocyclo -func (t *Task) Update() (err error) { - - s := x.NewSession() +func (t *Task) Update(s *xorm.Session) (err error) { // Check if the task exists and get the old values - ot, err := GetTaskByIDSimple(t.ID) + ot, err := GetTaskByIDSimple(s, t.ID) if err != nil { - _ = s.Rollback() return } // Get the reminders - reminders, err := getRemindersForTasks([]int64{t.ID}) + reminders, err := getRemindersForTasks(s, []int64{t.ID}) if err != nil { - _ = s.Rollback() return } @@ -881,20 +871,17 @@ func (t *Task) Update() (err error) { // Update the assignees if err := ot.updateTaskAssignees(s, t.Assignees); err != nil { - _ = s.Rollback() return err } // Update the reminders if err := ot.updateReminders(s, t.Reminders); err != nil { - _ = s.Rollback() return err } // If there is a bucket set, make sure they belong to the same list as the task err = checkBucketAndTaskBelongToSameList(s, &ot, t.BucketID) if err != nil { - _ = s.Rollback() return } @@ -923,7 +910,6 @@ func (t *Task) Update() (err error) { if t.BucketID == 0 || (t.ListID != 0 && ot.ListID != t.ListID) { bucket, err = getDefaultBucket(s, t.ListID) if err != nil { - _ = s.Rollback() return err } t.BucketID = bucket.ID @@ -934,7 +920,6 @@ func (t *Task) Update() (err error) { latestTask := &Task{} _, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) if err != nil { - _ = s.Rollback() return err } @@ -946,7 +931,6 @@ func (t *Task) Update() (err error) { // Only check the bucket limit if the task is being moved between buckets, allow reordering the task within a bucket if t.BucketID != ot.BucketID { if err := checkBucketLimit(s, t, bucket); err != nil { - _ = s.Rollback() return err } } @@ -972,7 +956,6 @@ func (t *Task) Update() (err error) { // Which is why we merge the actual task struct with the one we got from the db // The user struct overrides values in the actual one. if err := mergo.Merge(&ot, t, mergo.WithOverride); err != nil { - _ = s.Rollback() return err } @@ -1034,7 +1017,6 @@ func (t *Task) Update() (err error) { Update(ot) *t = ot if err != nil { - _ = s.Rollback() return err } // Get the task updated timestamp in a new struct - if we'd just try to put it into t which we already have, it @@ -1042,17 +1024,11 @@ func (t *Task) Update() (err error) { nt := &Task{} _, err = s.ID(t.ID).Get(nt) if err != nil { - _ = s.Rollback() return err } t.Updated = nt.Updated - err = updateListLastUpdatedS(s, &List{ID: t.ListID}) - if err != nil { - _ = s.Rollback() - return err - } - return s.Commit() + return updateListLastUpdated(s, &List{ID: t.ListID}) } // This helper function updates the reminders, doneAt, start and end dates of the *old* task @@ -1174,7 +1150,7 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro t.Reminders = nil } - err = updateListLastUpdatedS(s, &List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -1190,20 +1166,20 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{id} [delete] -func (t *Task) Delete() (err error) { +func (t *Task) Delete(s *xorm.Session) (err error) { - if _, err = x.ID(t.ID).Delete(Task{}); err != nil { + if _, err = s.ID(t.ID).Delete(Task{}); err != nil { return err } // Delete assignees - if _, err = x.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil { + if _, err = s.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil { return err } metrics.UpdateCount(-1, metrics.TaskCountKey) - err = updateListLastUpdated(&List{ID: t.ListID}) + err = updateListLastUpdated(s, &List{ID: t.ListID}) return } @@ -1219,16 +1195,16 @@ func (t *Task) Delete() (err error) { // @Failure 404 {object} models.Message "Task not found" // @Failure 500 {object} models.Message "Internal error" // @Router /tasks/{ID} [get] -func (t *Task) ReadOne() (err error) { +func (t *Task) ReadOne(s *xorm.Session) (err error) { taskMap := make(map[int64]*Task, 1) taskMap[t.ID] = &Task{} - *taskMap[t.ID], err = GetTaskByIDSimple(t.ID) + *taskMap[t.ID], err = GetTaskByIDSimple(s, t.ID) if err != nil { return } - err = addMoreInfoToTasks(taskMap) + err = addMoreInfoToTasks(s, taskMap) if err != nil { return } diff --git a/pkg/models/tasks_rights.go b/pkg/models/tasks_rights.go index efc0f4b8..da044ac7 100644 --- a/pkg/models/tasks_rights.go +++ b/pkg/models/tasks_rights.go @@ -18,47 +18,48 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanDelete checks if the user can delete an task -func (t *Task) CanDelete(a web.Auth) (bool, error) { - return t.canDoTask(a) +func (t *Task) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return t.canDoTask(s, a) } // CanUpdate determines if a user has the right to update a list task -func (t *Task) CanUpdate(a web.Auth) (bool, error) { - return t.canDoTask(a) +func (t *Task) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return t.canDoTask(s, a) } // CanCreate determines if a user has the right to create a list task -func (t *Task) CanCreate(a web.Auth) (bool, error) { +func (t *Task) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { // A user can do a task if he has write acces to its list l := &List{ID: t.ListID} - return l.CanWrite(a) + return l.CanWrite(s, a) } // CanRead determines if a user can read a task -func (t *Task) CanRead(a web.Auth) (canRead bool, maxRight int, err error) { +func (t *Task) CanRead(s *xorm.Session, a web.Auth) (canRead bool, maxRight int, err error) { // Get the task, error out if it doesn't exist - *t, err = GetTaskByIDSimple(t.ID) + *t, err = GetTaskByIDSimple(s, t.ID) if err != nil { return } // A user can read a task if it has access to the list l := &List{ID: t.ListID} - return l.CanRead(a) + return l.CanRead(s, a) } // CanWrite checks if a user has write access to a task -func (t *Task) CanWrite(a web.Auth) (canWrite bool, err error) { - return t.canDoTask(a) +func (t *Task) CanWrite(s *xorm.Session, a web.Auth) (canWrite bool, err error) { + return t.canDoTask(s, a) } // Helper function to check if a user can do stuff on a list task -func (t *Task) canDoTask(a web.Auth) (bool, error) { +func (t *Task) canDoTask(s *xorm.Session, a web.Auth) (bool, error) { // Get the task - ot, err := GetTaskByIDSimple(t.ID) + ot, err := GetTaskByIDSimple(s, t.ID) if err != nil { return false, err } @@ -66,7 +67,7 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) { // Check if we're moving the task into a different list to check if the user has sufficient rights for that on the new list if t.ListID != 0 && t.ListID != ot.ListID { newList := &List{ID: t.ListID} - can, err := newList.CanWrite(a) + can, err := newList.CanWrite(s, a) if err != nil { return false, err } @@ -77,5 +78,5 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) { // A user can do a task if it has write acces to its list l := &List{ID: ot.ListID} - return l.CanWrite(a) + return l.CanWrite(s, a) } diff --git a/pkg/models/tasks_test.go b/pkg/models/tasks_test.go index b350f73f..e3ab6c6a 100644 --- a/pkg/models/tasks_test.go +++ b/pkg/models/tasks_test.go @@ -36,12 +36,15 @@ func TestTask_Create(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ Title: "Lorem", Description: "Lorem Ipsum Dolor", ListID: 1, } - err := task.Create(usr) + err := task.Create(s, usr) assert.NoError(t, err) // Assert getting a uid assert.NotEmpty(t, task.UID) @@ -50,6 +53,9 @@ func TestTask_Create(t *testing.T) { assert.Equal(t, int64(18), task.Index) // Assert moving it into the default bucket assert.Equal(t, int64(1), task.BucketID) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "tasks", map[string]interface{}{ "id": task.ID, "title": "Lorem", @@ -62,47 +68,59 @@ func TestTask_Create(t *testing.T) { }) t.Run("empty title", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ Title: "", Description: "Lorem Ipsum Dolor", ListID: 1, } - err := task.Create(usr) + err := task.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrTaskCannotBeEmpty(err)) }) t.Run("nonexistant list", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ Title: "Test", Description: "Lorem Ipsum Dolor", ListID: 9999999, } - err := task.Create(usr) + err := task.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrListDoesNotExist(err)) }) t.Run("noneixtant user", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + nUser := &user.User{ID: 99999999} task := &Task{ Title: "Test", Description: "Lorem Ipsum Dolor", ListID: 1, } - err := task.Create(nUser) + err := task.Create(s, nUser) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) }) t.Run("full bucket", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ Title: "Lorem", Description: "Lorem Ipsum Dolor", ListID: 1, BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 } - err := task.Create(usr) + err := task.Create(s, usr) assert.Error(t, err) assert.True(t, IsErrBucketLimitExceeded(err)) }) @@ -111,14 +129,20 @@ func TestTask_Create(t *testing.T) { func TestTask_Update(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ ID: 1, Title: "test10000", Description: "Lorem Ipsum Dolor", ListID: 1, } - err := task.Update() + err := task.Update(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "tasks", map[string]interface{}{ "id": 1, "title": "test10000", @@ -128,18 +152,24 @@ func TestTask_Update(t *testing.T) { }) t.Run("nonexistant task", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ ID: 9999999, Title: "test10000", Description: "Lorem Ipsum Dolor", ListID: 1, } - err := task.Update() + err := task.Update(s) assert.Error(t, err) assert.True(t, IsErrTaskDoesNotExist(err)) }) t.Run("full bucket", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ ID: 1, Title: "test10000", @@ -147,12 +177,15 @@ func TestTask_Update(t *testing.T) { ListID: 1, BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 } - err := task.Update() + err := task.Update(s) assert.Error(t, err) assert.True(t, IsErrBucketLimitExceeded(err)) }) t.Run("full bucket but not changing the bucket", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ ID: 4, Title: "test10000", @@ -161,7 +194,7 @@ func TestTask_Update(t *testing.T) { ListID: 1, BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 } - err := task.Update() + err := task.Update(s) assert.NoError(t, err) }) } @@ -169,11 +202,17 @@ func TestTask_Update(t *testing.T) { func TestTask_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ ID: 1, } - err := task.Delete() + err := task.Delete(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertMissing(t, "tasks", map[string]interface{}{ "id": 1, }) @@ -183,6 +222,9 @@ func TestTask_Delete(t *testing.T) { func TestUpdateDone(t *testing.T) { t.Run("marking a task as done", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + oldTask := &Task{Done: false} newTask := &Task{Done: true} updateDone(oldTask, newTask) @@ -190,6 +232,9 @@ func TestUpdateDone(t *testing.T) { }) t.Run("unmarking a task as done", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + oldTask := &Task{Done: true} newTask := &Task{Done: false} updateDone(oldTask, newTask) @@ -397,15 +442,21 @@ func TestUpdateDone(t *testing.T) { func TestTask_ReadOne(t *testing.T) { t.Run("default", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ID: 1} - err := task.ReadOne() + err := task.ReadOne(s) assert.NoError(t, err) assert.Equal(t, "task #1", task.Title) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + task := &Task{ID: 99999} - err := task.ReadOne() + err := task.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrTaskDoesNotExist(err)) }) diff --git a/pkg/models/team_members.go b/pkg/models/team_members.go index 0739a05c..3e930731 100644 --- a/pkg/models/team_members.go +++ b/pkg/models/team_members.go @@ -19,6 +19,7 @@ package models import ( user2 "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" + "xorm.io/xorm" ) // Create implements the create method to assign a user to a team @@ -35,23 +36,24 @@ import ( // @Failure 403 {object} web.HTTPError "The user does not have access to the team" // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id}/members [put] -func (tm *TeamMember) Create(a web.Auth) (err error) { +func (tm *TeamMember) Create(s *xorm.Session, a web.Auth) (err error) { // Check if the team extst - _, err = GetTeamByID(tm.TeamID) + _, err = GetTeamByID(s, tm.TeamID) if err != nil { return } // Check if the user exists - user, err := user2.GetUserByUsername(tm.Username) + user, err := user2.GetUserByUsername(s, tm.Username) if err != nil { return } tm.UserID = user.ID // Check if that user is already part of the team - exists, err := x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). + exists, err := s. + Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). Get(&TeamMember{}) if err != nil { return @@ -61,7 +63,7 @@ func (tm *TeamMember) Create(a web.Auth) (err error) { } // Insert the user - _, err = x.Insert(tm) + _, err = s.Insert(tm) return } @@ -76,9 +78,9 @@ func (tm *TeamMember) Create(a web.Auth) (err error) { // @Success 200 {object} models.Message "The user was successfully removed from the team." // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id}/members/{userID} [delete] -func (tm *TeamMember) Delete() (err error) { +func (tm *TeamMember) Delete(s *xorm.Session) (err error) { - total, err := x.Where("team_id = ?", tm.TeamID).Count(&TeamMember{}) + total, err := s.Where("team_id = ?", tm.TeamID).Count(&TeamMember{}) if err != nil { return } @@ -87,13 +89,13 @@ func (tm *TeamMember) Delete() (err error) { } // Find the numeric user id - user, err := user2.GetUserByUsername(tm.Username) + user, err := user2.GetUserByUsername(s, tm.Username) if err != nil { return } tm.UserID = user.ID - _, err = x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{}) + _, err = s.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{}) return } @@ -108,9 +110,9 @@ func (tm *TeamMember) Delete() (err error) { // @Success 200 {object} models.Message "The member right was successfully changed." // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id}/members/{userID}/admin [post] -func (tm *TeamMember) Update() (err error) { +func (tm *TeamMember) Update(s *xorm.Session) (err error) { // Find the numeric user id - user, err := user2.GetUserByUsername(tm.Username) + user, err := user2.GetUserByUsername(s, tm.Username) if err != nil { return } @@ -118,7 +120,7 @@ func (tm *TeamMember) Update() (err error) { // Get the full member object and change the admin right ttm := &TeamMember{} - _, err = x. + _, err = s. Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). Get(ttm) if err != nil { @@ -127,7 +129,7 @@ func (tm *TeamMember) Update() (err error) { ttm.Admin = !ttm.Admin // Do the update - _, err = x. + _, err = s. Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). Cols("admin"). Update(ttm) diff --git a/pkg/models/team_members_rights.go b/pkg/models/team_members_rights.go index 872c0f70..3037c83d 100644 --- a/pkg/models/team_members_rights.go +++ b/pkg/models/team_members_rights.go @@ -18,32 +18,34 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if the user can add a new tem member -func (tm *TeamMember) CanCreate(a web.Auth) (bool, error) { - return tm.IsAdmin(a) +func (tm *TeamMember) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { + return tm.IsAdmin(s, a) } // CanDelete checks if the user can delete a new team member -func (tm *TeamMember) CanDelete(a web.Auth) (bool, error) { - return tm.IsAdmin(a) +func (tm *TeamMember) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return tm.IsAdmin(s, a) } // CanUpdate checks if the user can modify a team member's right -func (tm *TeamMember) CanUpdate(a web.Auth) (bool, error) { - return tm.IsAdmin(a) +func (tm *TeamMember) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return tm.IsAdmin(s, a) } // IsAdmin checks if the user is team admin -func (tm *TeamMember) IsAdmin(a web.Auth) (bool, error) { +func (tm *TeamMember) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) { // Don't allow anything if we're dealing with a list share here if _, is := a.(*LinkSharing); is { return false, nil } // A user can add a member to a team if he is admin of that team - exists, err := x.Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true). + exists, err := s. + Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true). Get(&TeamMember{}) return exists, err } diff --git a/pkg/models/team_members_test.go b/pkg/models/team_members_test.go index 51b2bc83..b682577d 100644 --- a/pkg/models/team_members_test.go +++ b/pkg/models/team_members_test.go @@ -32,12 +32,18 @@ func TestTeamMember_Create(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "user3", } - err := tm.Create(doer) + err := tm.Create(s, doer) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "team_members", map[string]interface{}{ "id": tm.ID, "team_id": 1, @@ -46,31 +52,40 @@ func TestTeamMember_Create(t *testing.T) { }) t.Run("already existing", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "user1", } - err := tm.Create(doer) + err := tm.Create(s, doer) assert.Error(t, err) assert.True(t, IsErrUserIsMemberOfTeam(err)) }) t.Run("nonexisting user", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "nonexistinguser", } - err := tm.Create(doer) + err := tm.Create(s, doer) assert.Error(t, err) assert.True(t, user.IsErrUserDoesNotExist(err)) }) t.Run("nonexisting team", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 9999999, Username: "user1", } - err := tm.Create(doer) + err := tm.Create(s, doer) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) }) @@ -79,12 +94,18 @@ func TestTeamMember_Create(t *testing.T) { func TestTeamMember_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "user1", } - err := tm.Delete() + err := tm.Delete(s) assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertMissing(t, "team_members", map[string]interface{}{ "team_id": 1, "user_id": 1, @@ -95,14 +116,20 @@ func TestTeamMember_Delete(t *testing.T) { func TestTeamMember_Update(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "user1", Admin: true, } - err := tm.Update() + err := tm.Update(s) assert.NoError(t, err) assert.False(t, tm.Admin) // Since this endpoint toggles the right, we should get a false for admin back. + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "team_members", map[string]interface{}{ "team_id": 1, "user_id": 1, @@ -113,14 +140,20 @@ func TestTeamMember_Update(t *testing.T) { // should ignore what was passed. t.Run("explicitly false in payload", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + tm := &TeamMember{ TeamID: 1, Username: "user1", Admin: true, } - err := tm.Update() + err := tm.Update(s) assert.NoError(t, err) assert.False(t, tm.Admin) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "team_members", map[string]interface{}{ "team_id": 1, "user_id": 1, diff --git a/pkg/models/teams.go b/pkg/models/teams.go index b7c40c9e..27611a4d 100644 --- a/pkg/models/teams.go +++ b/pkg/models/teams.go @@ -19,6 +19,8 @@ package models import ( "time" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" @@ -54,10 +56,6 @@ func (Team) TableName() string { return "teams" } -// AfterLoad gets the created by user object -func (t *Team) AfterLoad() { -} - // TeamMember defines the relationship between a user and a team type TeamMember struct { // The unique, numeric id of this team member relation. @@ -92,14 +90,14 @@ type TeamUser struct { } // GetTeamByID gets a team by its ID -func GetTeamByID(id int64) (team *Team, err error) { +func GetTeamByID(s *xorm.Session, id int64) (team *Team, err error) { if id < 1 { return team, ErrTeamDoesNotExist{id} } t := Team{} - exists, err := x. + exists, err := s. Where("id = ?", id). Get(&t) if err != nil { @@ -110,7 +108,7 @@ func GetTeamByID(id int64) (team *Team, err error) { } teamSlice := []*Team{&t} - err = addMoreInfoToTeams(teamSlice) + err = addMoreInfoToTeams(s, teamSlice) if err != nil { return } @@ -120,7 +118,7 @@ func GetTeamByID(id int64) (team *Team, err error) { return } -func addMoreInfoToTeams(teams []*Team) (err error) { +func addMoreInfoToTeams(s *xorm.Session, teams []*Team) (err error) { // Put the teams in a map to make assigning more info to it more efficient teamMap := make(map[int64]*Team, len(teams)) var teamIDs []int64 @@ -133,7 +131,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) { // Get all owners and team members users := make(map[int64]*TeamUser) - err = x.Select("*"). + err = s. + Select("*"). Table("users"). Join("LEFT", "team_members", "team_members.user_id = users.id"). Join("LEFT", "teams", "team_members.team_id = teams.id"). @@ -178,8 +177,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) { // @Failure 403 {object} web.HTTPError "The user does not have access to the team" // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id} [get] -func (t *Team) ReadOne() (err error) { - team, err := GetTeamByID(t.ID) +func (t *Team) ReadOne(s *xorm.Session) (err error) { + team, err := GetTeamByID(s, t.ID) if team != nil { *t = *team } @@ -199,7 +198,7 @@ func (t *Team) ReadOne() (err error) { // @Success 200 {array} models.Team "The teams." // @Failure 500 {object} models.Message "Internal error" // @Router /teams [get] -func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { +func (t *Team) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { if _, is := a.(*LinkSharing); is { return nil, 0, 0, ErrGenericForbidden{} } @@ -207,7 +206,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result limit, start := getLimitFromPageIndex(page, perPage) all := []*Team{} - query := x.Select("teams.*"). + query := s.Select("teams.*"). Table("teams"). Join("INNER", "team_members", "team_members.team_id = teams.id"). Where("team_members.user_id = ?", a.GetID()). @@ -220,12 +219,12 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result return nil, 0, 0, err } - err = addMoreInfoToTeams(all) + err = addMoreInfoToTeams(s, all) if err != nil { return nil, 0, 0, err } - numberOfTotalItems, err = x. + numberOfTotalItems, err = s. Table("teams"). Join("INNER", "team_members", "team_members.team_id = teams.id"). Where("team_members.user_id = ?", a.GetID()). @@ -246,7 +245,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result // @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /teams [put] -func (t *Team) Create(a web.Auth) (err error) { +func (t *Team) Create(s *xorm.Session, a web.Auth) (err error) { doer, err := user.GetFromAuth(a) if err != nil { return err @@ -260,14 +259,14 @@ func (t *Team) Create(a web.Auth) (err error) { t.CreatedByID = doer.ID t.CreatedBy = doer - _, err = x.Insert(t) + _, err = s.Insert(t) if err != nil { return } // Insert the current user as member and admin tm := TeamMember{TeamID: t.ID, Username: doer.Username, Admin: true} - if err = tm.Create(doer); err != nil { + if err = tm.Create(s, doer); err != nil { return err } @@ -286,28 +285,28 @@ func (t *Team) Create(a web.Auth) (err error) { // @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id} [delete] -func (t *Team) Delete() (err error) { +func (t *Team) Delete(s *xorm.Session) (err error) { // Delete the team - _, err = x.ID(t.ID).Delete(&Team{}) + _, err = s.ID(t.ID).Delete(&Team{}) if err != nil { return } // Delete team members - _, err = x.Where("team_id = ?", t.ID).Delete(&TeamMember{}) + _, err = s.Where("team_id = ?", t.ID).Delete(&TeamMember{}) if err != nil { return } // Delete team <-> namespace relations - _, err = x.Where("team_id = ?", t.ID).Delete(&TeamNamespace{}) + _, err = s.Where("team_id = ?", t.ID).Delete(&TeamNamespace{}) if err != nil { return } // Delete team <-> lists relations - _, err = x.Where("team_id = ?", t.ID).Delete(&TeamList{}) + _, err = s.Where("team_id = ?", t.ID).Delete(&TeamList{}) if err != nil { return } @@ -329,25 +328,25 @@ func (t *Team) Delete() (err error) { // @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 500 {object} models.Message "Internal error" // @Router /teams/{id} [post] -func (t *Team) Update() (err error) { +func (t *Team) Update(s *xorm.Session) (err error) { // Check if we have a name if t.Name == "" { return ErrTeamNameCannotBeEmpty{} } // Check if the team exists - _, err = GetTeamByID(t.ID) + _, err = GetTeamByID(s, t.ID) if err != nil { return } - _, err = x.ID(t.ID).Update(t) + _, err = s.ID(t.ID).Update(t) if err != nil { return } // Get the newly updated team - team, err := GetTeamByID(t.ID) + team, err := GetTeamByID(s, t.ID) if team != nil { *t = *team } diff --git a/pkg/models/teams_rights.go b/pkg/models/teams_rights.go index 0a926fa1..0285209d 100644 --- a/pkg/models/teams_rights.go +++ b/pkg/models/teams_rights.go @@ -18,10 +18,11 @@ package models import ( "code.vikunja.io/web" + "xorm.io/xorm" ) // CanCreate checks if the user can create a new team -func (t *Team) CanCreate(a web.Auth) (bool, error) { +func (t *Team) CanCreate(s *xorm.Session, a web.Auth) (bool, error) { if _, is := a.(*LinkSharing); is { return false, nil } @@ -31,39 +32,40 @@ func (t *Team) CanCreate(a web.Auth) (bool, error) { } // CanUpdate checks if the user can update a team -func (t *Team) CanUpdate(a web.Auth) (bool, error) { - return t.IsAdmin(a) +func (t *Team) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) { + return t.IsAdmin(s, a) } // CanDelete checks if a user can delete a team -func (t *Team) CanDelete(a web.Auth) (bool, error) { - return t.IsAdmin(a) +func (t *Team) CanDelete(s *xorm.Session, a web.Auth) (bool, error) { + return t.IsAdmin(s, a) } // IsAdmin returns true when the user is admin of a team -func (t *Team) IsAdmin(a web.Auth) (bool, error) { +func (t *Team) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) { // Don't do anything if we're deadling with a link share auth here if _, is := a.(*LinkSharing); is { return false, nil } // Check if the team exists to be able to return a proper error message if not - _, err := GetTeamByID(t.ID) + _, err := GetTeamByID(s, t.ID) if err != nil { return false, err } - return x.Where("team_id = ?", t.ID). + return s.Where("team_id = ?", t.ID). And("user_id = ?", a.GetID()). And("admin = ?", true). Get(&TeamMember{}) } // CanRead returns true if the user has read access to the team -func (t *Team) CanRead(a web.Auth) (bool, int, error) { +func (t *Team) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) { // Check if the user is in the team tm := &TeamMember{} - can, err := x.Where("team_id = ?", t.ID). + can, err := s. + Where("team_id = ?", t.ID). And("user_id = ?", a.GetID()). Get(tm) diff --git a/pkg/models/teams_rights_test.go b/pkg/models/teams_rights_test.go index 72bb4886..59d5ec02 100644 --- a/pkg/models/teams_rights_test.go +++ b/pkg/models/teams_rights_test.go @@ -82,6 +82,8 @@ func TestTeam_CanDoSomething(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() tm := &Team{ ID: tt.fields.ID, @@ -96,19 +98,19 @@ func TestTeam_CanDoSomething(t *testing.T) { Rights: tt.fields.Rights, } - if got, _ := tm.CanCreate(tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true + if got, _ := tm.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true t.Errorf("Team.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) } - if got, _ := tm.CanDelete(tt.args.a); got != tt.want["CanDelete"] { + if got, _ := tm.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] { t.Errorf("Team.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) } - if got, _ := tm.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { + if got, _ := tm.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] { t.Errorf("Team.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) } - if got, _, _ := tm.CanRead(tt.args.a); got != tt.want["CanRead"] { + if got, _, _ := tm.CanRead(s, tt.args.a); got != tt.want["CanRead"] { t.Errorf("Team.CanRead() = %v, want %v", got, tt.want["CanRead"]) } - if got, _ := tm.IsAdmin(tt.args.a); got != tt.want["IsAdmin"] { + if got, _ := tm.IsAdmin(s, tt.args.a); got != tt.want["IsAdmin"] { t.Errorf("Team.IsAdmin() = %v, want %v", got, tt.want["IsAdmin"]) } }) diff --git a/pkg/models/teams_test.go b/pkg/models/teams_test.go index 4d40575b..00fa6d0d 100644 --- a/pkg/models/teams_test.go +++ b/pkg/models/teams_test.go @@ -32,11 +32,16 @@ func TestTeam_Create(t *testing.T) { } t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ Name: "Testteam293", Description: "Lorem Ispum", } - err := team.Create(doer) + err := team.Create(s, doer) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "teams", map[string]interface{}{ "id": team.ID, @@ -46,8 +51,11 @@ func TestTeam_Create(t *testing.T) { }) t.Run("empty name", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{} - err := team.Create(doer) + err := team.Create(s, doer) assert.Error(t, err) assert.True(t, IsErrTeamNameCannotBeEmpty(err)) }) @@ -56,8 +64,11 @@ func TestTeam_Create(t *testing.T) { func TestTeam_ReadOne(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ID: 1} - err := team.ReadOne() + err := team.ReadOne(s) assert.NoError(t, err) assert.Equal(t, "testteam1", team.Name) assert.Equal(t, "Lorem Ipsum", team.Description) @@ -66,15 +77,21 @@ func TestTeam_ReadOne(t *testing.T) { }) t.Run("invalid id", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ID: -1} - err := team.ReadOne() + err := team.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ID: 99999} - err := team.ReadOne() + err := team.ReadOne(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) }) @@ -83,23 +100,31 @@ func TestTeam_ReadOne(t *testing.T) { func TestTeam_ReadAll(t *testing.T) { doer := &user.User{ID: 1} t.Run("normal", func(t *testing.T) { + s := db.NewSession() + defer s.Close() + team := &Team{} - ts, _, _, err := team.ReadAll(doer, "", 1, 50) + teams, _, _, err := team.ReadAll(s, doer, "", 1, 50) assert.NoError(t, err) - assert.Equal(t, reflect.TypeOf(ts).Kind(), reflect.Slice) - s := reflect.ValueOf(ts) - assert.Equal(t, 8, s.Len()) + assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice) + ts := reflect.ValueOf(teams) + assert.Equal(t, 8, ts.Len()) }) } func TestTeam_Update(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ ID: 1, Name: "SomethingNew", } - err := team.Update() + err := team.Update(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertExists(t, "teams", map[string]interface{}{ "id": team.ID, @@ -108,21 +133,27 @@ func TestTeam_Update(t *testing.T) { }) t.Run("empty name", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ ID: 1, Name: "", } - err := team.Update() + err := team.Update(s) assert.Error(t, err) assert.True(t, IsErrTeamNameCannotBeEmpty(err)) }) t.Run("nonexisting", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ ID: 9999, Name: "SomethingNew", } - err := team.Update() + err := team.Update(s) assert.Error(t, err) assert.True(t, IsErrTeamDoesNotExist(err)) }) @@ -131,10 +162,15 @@ func TestTeam_Update(t *testing.T) { func TestTeam_Delete(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + team := &Team{ ID: 1, } - err := team.Delete() + err := team.Delete(s) + assert.NoError(t, err) + err = s.Commit() assert.NoError(t, err) db.AssertMissing(t, "teams", map[string]interface{}{ "id": 1, diff --git a/pkg/models/unsplash.go b/pkg/models/unsplash.go index 8b084748..1820dccf 100644 --- a/pkg/models/unsplash.go +++ b/pkg/models/unsplash.go @@ -16,7 +16,10 @@ package models -import "code.vikunja.io/api/pkg/files" +import ( + "code.vikunja.io/api/pkg/files" + "xorm.io/xorm" +) // Unsplash requires us to do pingbacks to their site and also name the image author. // To do this properly, we need to save these information somewhere. @@ -36,15 +39,15 @@ func (u *UnsplashPhoto) TableName() string { } // Save persists an unsplash photo to the db -func (u *UnsplashPhoto) Save() error { - _, err := x.Insert(u) +func (u *UnsplashPhoto) Save(s *xorm.Session) error { + _, err := s.Insert(u) return err } // GetUnsplashPhotoByFileID returns an unsplash photo by its saved file id -func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) { +func GetUnsplashPhotoByFileID(s *xorm.Session, fileID int64) (u *UnsplashPhoto, err error) { u = &UnsplashPhoto{} - exists, err := x.Where("file_id = ?", fileID).Get(u) + exists, err := s.Where("file_id = ?", fileID).Get(u) if err != nil { return } @@ -55,10 +58,10 @@ func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) { } // RemoveUnsplashPhoto removes an unsplash photo from the db -func RemoveUnsplashPhoto(fileID int64) (err error) { +func RemoveUnsplashPhoto(s *xorm.Session, fileID int64) (err error) { // This is intentionally "fire and forget" which is why we don't check if we have an // unsplash entry for that file at all. If there is one, it will be deleted. // We do this to keep the function simple. - _, err = x.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{}) + _, err = s.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{}) return } diff --git a/pkg/models/user_list.go b/pkg/models/user_list.go index cc150998..50552414 100644 --- a/pkg/models/user_list.go +++ b/pkg/models/user_list.go @@ -20,6 +20,7 @@ package models import ( "code.vikunja.io/api/pkg/user" "xorm.io/builder" + "xorm.io/xorm" ) // ListUIDs hold all kinds of user IDs from accounts who have somehow access to a list @@ -33,11 +34,11 @@ type ListUIDs struct { } // ListUsersFromList returns a list with all users who have access to a list, regardless of the method which gave them access -func ListUsersFromList(l *List, search string) (users []*user.User, err error) { +func ListUsersFromList(s *xorm.Session, l *List, search string) (users []*user.User, err error) { userids := []*ListUIDs{} - err = x. + err = s. Select(`l.owner_id as listOwner, un.user_id as unID, ul.user_id as ulID, @@ -97,7 +98,7 @@ func ListUsersFromList(l *List, search string) (users []*user.User, err error) { } // Get all users - err = x. + err = s. Table("users"). Select("*"). In("id", uids). diff --git a/pkg/models/users_list_test.go b/pkg/models/users_list_test.go index 16201aa0..0e513a9c 100644 --- a/pkg/models/users_list_test.go +++ b/pkg/models/users_list_test.go @@ -201,8 +201,10 @@ func TestListUsersFromList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() - gotUsers, err := ListUsersFromList(tt.args.l, tt.args.search) + gotUsers, err := ListUsersFromList(s, tt.args.l, tt.args.search) if (err != nil) != tt.wantErr { t.Errorf("ListUsersFromList() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/modules/auth/openid/openid.go b/pkg/modules/auth/openid/openid.go index 15b68fe1..ae407006 100644 --- a/pkg/modules/auth/openid/openid.go +++ b/pkg/modules/auth/openid/openid.go @@ -23,6 +23,9 @@ import ( "net/http" "time" + "code.vikunja.io/api/pkg/db" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth" @@ -130,8 +133,17 @@ func HandleCallback(c echo.Context) error { return err } + s := db.NewSession() + defer s.Close() + // Check if we have seen this user before - u, err := getOrCreateUser(cl, idToken.Issuer, idToken.Subject) + u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject) + if err != nil { + _ = s.Rollback() + return err + } + + err = s.Commit() if err != nil { return err } @@ -140,9 +152,9 @@ func HandleCallback(c echo.Context) error { return auth.NewUserAuthTokenResponse(u, c) } -func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) { +func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) { // Check if the user exists for that issuer and subject - u, err = user.GetUserWithEmail(&user.User{ + u, err = user.GetUserWithEmail(s, &user.User{ Issuer: issuer, Subject: subject, }) @@ -165,7 +177,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro uu.Username = petname.Generate(3, "-") } - u, err = user.CreateUser(uu) + u, err = user.CreateUser(s, uu) if err != nil && !user.IsErrUsernameExists(err) { return nil, err } @@ -173,14 +185,14 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro // If their preferred username is already taken, create some random one from the email and subject if user.IsErrUsernameExists(err) { uu.Username = petname.Generate(3, "-") - u, err = user.CreateUser(uu) + u, err = user.CreateUser(s, uu) if err != nil { return nil, err } } // And create its namespace - err = models.CreateNewNamespaceForUser(u) + err = models.CreateNewNamespaceForUser(s, u) if err != nil { return nil, err } @@ -196,7 +208,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro if cl.Name != u.Name { u.Name = cl.Name } - u, err = user.UpdateUser(&user.User{ + u, err = user.UpdateUser(s, &user.User{ ID: u.ID, Email: u.Email, Name: u.Name, diff --git a/pkg/modules/auth/openid/openid_test.go b/pkg/modules/auth/openid/openid_test.go index 4e29727b..ba1d8b07 100644 --- a/pkg/modules/auth/openid/openid_test.go +++ b/pkg/modules/auth/openid/openid_test.go @@ -26,12 +26,18 @@ import ( func TestGetOrCreateUser(t *testing.T) { t.Run("new user", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + cl := &claims{ Email: "test@example.com", PreferredUsername: "someUserWhoDoesNotExistYet", } - u, err := getOrCreateUser(cl, "https://some.issuer", "12345") + u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "users", map[string]interface{}{ "id": u.ID, "email": cl.Email, @@ -40,13 +46,19 @@ func TestGetOrCreateUser(t *testing.T) { }) t.Run("new user, no username provided", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + cl := &claims{ Email: "test@example.com", PreferredUsername: "", } - u, err := getOrCreateUser(cl, "https://some.issuer", "12345") + u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") assert.NoError(t, err) assert.NotEmpty(t, u.Username) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "users", map[string]interface{}{ "id": u.ID, "email": cl.Email, @@ -54,19 +66,28 @@ func TestGetOrCreateUser(t *testing.T) { }) t.Run("new user, no email address", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + cl := &claims{ Email: "", } - _, err := getOrCreateUser(cl, "https://some.issuer", "12345") + _, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") assert.Error(t, err) }) t.Run("existing user, different email address", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + cl := &claims{ Email: "other-email-address@some.service.com", } - u, err := getOrCreateUser(cl, "https://some.service.com", "12345") + u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345") assert.NoError(t, err) + err = s.Commit() + assert.NoError(t, err) + db.AssertExists(t, "users", map[string]interface{}{ "id": u.ID, "email": cl.Email, diff --git a/pkg/modules/background/background.go b/pkg/modules/background/background.go index bb24b0fc..f60b6da0 100644 --- a/pkg/modules/background/background.go +++ b/pkg/modules/background/background.go @@ -19,6 +19,7 @@ package background import ( "code.vikunja.io/api/pkg/models" "code.vikunja.io/web" + "xorm.io/xorm" ) // Image represents an image which can be used as a list background @@ -33,7 +34,7 @@ type Image struct { // Provider represents something that is able to get a list of images and set one of them as background type Provider interface { // Search is used to either return a pre-defined list of Image or let the user search for an image - Search(search string, page int64) (result []*Image, err error) + Search(s *xorm.Session, search string, page int64) (result []*Image, err error) // Set sets an image which was most likely previously obtained by Search as list background - Set(image *Image, list *models.List, auth web.Auth) (err error) + Set(s *xorm.Session, image *Image, list *models.List, auth web.Auth) (err error) } diff --git a/pkg/modules/background/handler/background.go b/pkg/modules/background/handler/background.go index 2c1921c4..b8439107 100644 --- a/pkg/modules/background/handler/background.go +++ b/pkg/modules/background/handler/background.go @@ -22,6 +22,9 @@ import ( "strconv" "strings" + "code.vikunja.io/api/pkg/db" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" @@ -59,8 +62,17 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error { } } - result, err := p.Search(search, page) + s := db.NewSession() + defer s.Close() + + result, err := p.Search(s, search, page) if err != nil { + _ = s.Rollback() + return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error()) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error()) } @@ -68,7 +80,7 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error { } // This function does all kinds of preparations for setting and uploading a background -func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *models.List, auth web.Auth, err error) { +func (bp *BackgroundProvider) setBackgroundPreparations(s *xorm.Session, c echo.Context) (list *models.List, auth web.Auth, err error) { auth, err = auth2.GetAuthFromClaims(c) if err != nil { return nil, nil, echo.NewHTTPError(http.StatusBadRequest, "Invalid auth token: "+err.Error()) @@ -81,7 +93,7 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m // Check if the user has the right to change the list background list = &models.List{ID: listID} - can, err := list.CanUpdate(auth) + can, err := list.CanUpdate(s, auth) if err != nil { return } @@ -90,14 +102,18 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m return list, auth, models.ErrGenericForbidden{} } // Load the list - err = list.GetSimpleByID() + list, err = models.GetListSimpleByID(s, list.ID) return } // SetBackground sets an Image as list background func (bp *BackgroundProvider) SetBackground(c echo.Context) error { - list, auth, err := bp.setBackgroundPreparations(c) + s := db.NewSession() + defer s.Close() + + list, auth, err := bp.setBackgroundPreparations(s, c) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -106,11 +122,13 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error { image := &background.Image{} err = c.Bind(image) if err != nil { + _ = s.Rollback() return echo.NewHTTPError(http.StatusBadRequest, "No or invalid model provided: "+err.Error()) } - err = p.Set(image, list, auth) + err = p.Set(s, image, list, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } return c.JSON(http.StatusOK, list) @@ -118,8 +136,12 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error { // UploadBackground uploads a background and passes the id of the uploaded file as an Image to the Set function of the BackgroundProvider. func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { - list, auth, err := bp.setBackgroundPreparations(c) + s := db.NewSession() + defer s.Close() + + list, auth, err := bp.setBackgroundPreparations(s, c) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -128,10 +150,12 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { // Get + upload the image file, err := c.FormFile("background") if err != nil { + _ = s.Rollback() return err } src, err := file.Open() if err != nil { + _ = s.Rollback() return err } defer src.Close() @@ -139,9 +163,11 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { // Validate we're dealing with an image mime, err := mimetype.DetectReader(src) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !strings.HasPrefix(mime.String(), "image") { + _ = s.Rollback() return c.JSON(http.StatusBadRequest, models.Message{Message: "Uploaded file is no image."}) } _, _ = src.Seek(0, io.SeekStart) @@ -149,6 +175,7 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { // Save the file f, err := files.CreateWithMime(src, file.Filename, uint64(file.Size), auth, mime.String()) if err != nil { + _ = s.Rollback() if files.IsErrFileIsTooLarge(err) { return echo.ErrBadRequest } @@ -158,10 +185,17 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { image := &background.Image{ID: strconv.FormatInt(f.ID, 10)} - err = p.Set(image, list, auth) + err = p.Set(s, image, list, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } + + if err := s.Commit(); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + return c.JSON(http.StatusOK, list) } @@ -190,17 +224,23 @@ func GetListBackground(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "Invalid list ID: "+err.Error()) } + s := db.NewSession() + defer s.Close() + // Check if a background for this list exists + Rights list := &models.List{ID: listID} - can, _, err := list.CanRead(auth) + can, _, err := list.CanRead(s, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !can { + _ = s.Rollback() log.Infof("Tried to get list background of list %d while not having the rights for it (User: %v)", listID, auth) return echo.NewHTTPError(http.StatusForbidden) } if list.BackgroundFileID == 0 { + _ = s.Rollback() return echo.NotFoundHandler(c) } @@ -209,13 +249,19 @@ func GetListBackground(c echo.Context) error { ID: list.BackgroundFileID, } if err := bgFile.LoadFileByID(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } // Unsplash requires pingbacks as per their api usage guidelines. // To do this in a privacy-preserving manner, we do the ping from inside of Vikunja to not expose any user details. // FIXME: This should use an event once we have events - unsplash.Pingback(bgFile) + unsplash.Pingback(s, bgFile) + + if err := s.Commit(); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } // Serve the file return c.Stream(http.StatusOK, "image/jpg", bgFile.File) diff --git a/pkg/modules/background/unsplash/unsplash.go b/pkg/modules/background/unsplash/unsplash.go index d7952e1b..963dcb9c 100644 --- a/pkg/modules/background/unsplash/unsplash.go +++ b/pkg/modules/background/unsplash/unsplash.go @@ -26,6 +26,8 @@ import ( "strings" "time" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/log" @@ -150,7 +152,7 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) { // @Success 200 {array} background.Image "An array with photos" // @Failure 500 {object} models.Message "Internal error" // @Router /backgrounds/unsplash/search [get] -func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) { +func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) { // If we don't have a search query, return results from the unsplash featured collection if search == "" { @@ -243,7 +245,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image // @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/backgrounds/unsplash [post] -func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) { +func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) { // Find the photo photo, err := getUnsplashPhotoInfoByID(image.ID) @@ -292,7 +294,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth return err } - if err := models.RemoveUnsplashPhoto(list.BackgroundFileID); err != nil { + if err := models.RemoveUnsplashPhoto(s, list.BackgroundFileID); err != nil { return err } } @@ -304,7 +306,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth Author: photo.User.Username, AuthorName: photo.User.Name, } - err = unsplashPhoto.Save() + err = unsplashPhoto.Save(s) if err != nil { return } @@ -315,13 +317,13 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth list.BackgroundInformation = unsplashPhoto // Set it as the list background - return models.SetListBackground(list.ID, file) + return models.SetListBackground(s, list.ID, file) } // Pingback pings the unsplash api if an unsplash photo has been accessed. -func Pingback(f *files.File) { +func Pingback(s *xorm.Session, f *files.File) { // Check if the file is actually downloaded from unsplash - unsplashPhoto, err := models.GetUnsplashPhotoByFileID(f.ID) + unsplashPhoto, err := models.GetUnsplashPhotoByFileID(s, f.ID) if err != nil { if files.IsErrFileIsNotUnsplashFile(err) { return diff --git a/pkg/modules/background/upload/upload.go b/pkg/modules/background/upload/upload.go index 9f39bee3..3e5d280c 100644 --- a/pkg/modules/background/upload/upload.go +++ b/pkg/modules/background/upload/upload.go @@ -19,6 +19,8 @@ package upload import ( "strconv" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/background" @@ -30,7 +32,7 @@ type Provider struct { } // Search is only used to implement the interface -func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) { +func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) { return } @@ -50,7 +52,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image // @Failure 404 {object} models.Message "The list does not exist." // @Failure 500 {object} models.Message "Internal error" // @Router /lists/{id}/backgrounds/upload [put] -func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) { +func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) { // Remove the old background if one exists if list.BackgroundFileID != 0 { file := files.File{ID: list.BackgroundFileID} @@ -67,5 +69,5 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth list.BackgroundInformation = &models.ListBackgroundType{Type: models.ListBackgroundUpload} - return models.SetListBackground(list.ID, file) + return models.SetListBackground(s, list.ID, file) } diff --git a/pkg/modules/migration/create_from_structure.go b/pkg/modules/migration/create_from_structure.go index b6f69503..28211925 100644 --- a/pkg/modules/migration/create_from_structure.go +++ b/pkg/modules/migration/create_from_structure.go @@ -20,6 +20,8 @@ import ( "bytes" "io/ioutil" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" @@ -34,10 +36,14 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err labels := make(map[string]*models.Label) + s := db.NewSession() + defer s.Close() + // Create all namespaces for _, n := range str { - err = n.Create(user) + err = n.Create(s, user) if err != nil { + _ = s.Rollback() return } @@ -54,8 +60,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err needsDefaultBucket := false l.NamespaceID = n.ID - err = l.Create(user) + err = l.Create(s, user) if err != nil { + _ = s.Rollback() return } @@ -67,11 +74,13 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err file, err := files.Create(backgroundFile, "", uint64(backgroundFile.Len()), user) if err != nil { + _ = s.Rollback() return err } - err = models.SetListBackground(l.ID, file) + err = models.SetListBackground(s, l.ID, file) if err != nil { + _ = s.Rollback() return err } @@ -87,8 +96,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err oldID := bucket.ID bucket.ID = 0 // We want a new id bucket.ListID = l.ID - err = bucket.Create(user) + err = bucket.Create(s, user) if err != nil { + _ = s.Rollback() return } buckets[oldID] = bucket @@ -111,8 +121,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err } t.ListID = l.ID - err = t.Create(user) + err = t.Create(s, user) if err != nil { + _ = s.Rollback() return } @@ -132,8 +143,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err // First create the related tasks if they do not exist if rt.ID == 0 { rt.ListID = t.ListID - err = rt.Create(user) + err = rt.Create(s, user) if err != nil { + _ = s.Rollback() return } log.Debugf("[creating structure] Created related task %d", rt.ID) @@ -145,8 +157,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err OtherTaskID: rt.ID, RelationKind: kind, } - err = taskRel.Create(user) + err = taskRel.Create(s, user) if err != nil { + _ = s.Rollback() return } @@ -164,8 +177,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err if len(a.File.FileContent) > 0 { a.TaskID = t.ID fr := ioutil.NopCloser(bytes.NewReader(a.File.FileContent)) - err = a.NewAttachment(fr, a.File.Name, a.File.Size, user) + err = a.NewAttachment(s, fr, a.File.Name, a.File.Size, user) if err != nil { + _ = s.Rollback() return } log.Debugf("[creating structure] Created new attachment %d", a.ID) @@ -180,8 +194,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err var exists bool lb, exists = labels[label.Title+label.HexColor] if !exists { - err = label.Create(user) + err = label.Create(s, user) if err != nil { + _ = s.Rollback() return err } log.Debugf("[creating structure] Created new label %d", label.ID) @@ -193,8 +208,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err LabelID: lb.ID, TaskID: t.ID, } - err = lt.Create(user) + err = lt.Create(s, user) if err != nil { + _ = s.Rollback() return err } log.Debugf("[creating structure] Associated task %d with label %d", t.ID, lb.ID) @@ -204,13 +220,15 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err // All tasks brought their own bucket with them, therefore the newly created default bucket is just extra space if !needsDefaultBucket { b := &models.Bucket{ListID: l.ID} - bucketsIn, _, _, err := b.ReadAll(user, "", 1, 1) + bucketsIn, _, _, err := b.ReadAll(s, user, "", 1, 1) if err != nil { + _ = s.Rollback() return err } buckets := bucketsIn.([]*models.Bucket) - err = buckets[0].Delete() + err = buckets[0].Delete(s) if err != nil { + _ = s.Rollback() return err } } @@ -222,5 +240,5 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err log.Debugf("[creating structure] Done inserting new task structure") - return nil + return s.Commit() } diff --git a/pkg/modules/migration/db.go b/pkg/modules/migration/db.go index 3e0168d8..778f62bd 100644 --- a/pkg/modules/migration/db.go +++ b/pkg/modules/migration/db.go @@ -19,20 +19,10 @@ package migration import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/db" - "code.vikunja.io/api/pkg/log" - "xorm.io/xorm" ) -var x *xorm.Engine - // InitDB sets up the database connection to use in this module func InitDB() (err error) { - x, err = db.CreateDBEngine() - if err != nil { - log.Criticalf("Could not connect to db: %v", err.Error()) - return - } - // Cache if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" { db.RegisterTableStructsForCache(GetTables()) diff --git a/pkg/modules/migration/migration_status.go b/pkg/modules/migration/migration_status.go index be893db3..08d082eb 100644 --- a/pkg/modules/migration/migration_status.go +++ b/pkg/modules/migration/migration_status.go @@ -19,6 +19,7 @@ package migration import ( "time" + "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/user" ) @@ -37,17 +38,26 @@ func (s *Status) TableName() string { // SetMigrationStatus sets the migration status for a user func SetMigrationStatus(m Migrator, u *user.User) (err error) { + s := db.NewSession() + defer s.Close() + status := &Status{ UserID: u.ID, MigratorName: m.Name(), } - _, err = x.Insert(status) + _, err = s.Insert(status) return } // GetMigrationStatus returns the migration status for a migration and a user func GetMigrationStatus(m Migrator, u *user.User) (status *Status, err error) { + s := db.NewSession() + defer s.Close() + status = &Status{} - _, err = x.Where("user_id = ? and migrator_name = ?", u.ID, m.Name()).Desc("id").Get(status) + _, err = s. + Where("user_id = ? and migrator_name = ?", u.ID, m.Name()). + Desc("id"). + Get(status) return } diff --git a/pkg/routes/api/v1/avatar.go b/pkg/routes/api/v1/avatar.go index 8e8d4a1e..7eec8f78 100644 --- a/pkg/routes/api/v1/avatar.go +++ b/pkg/routes/api/v1/avatar.go @@ -17,6 +17,7 @@ package v1 import ( + "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" @@ -56,8 +57,11 @@ func GetAvatar(c echo.Context) error { // Get the username username := c.Param("username") + s := db.NewSession() + defer s.Close() + // Get the user - u, err := user.GetUserWithEmail(&user.User{Username: username}) + u, err := user.GetUserWithEmail(s, &user.User{Username: username}) if err != nil { log.Errorf("Error getting user for avatar: %v", err) return handler.HandleHTTPError(err, c) @@ -113,22 +117,28 @@ func GetAvatar(c echo.Context) error { // @Router /user/settings/avatar/upload [put] func UploadAvatar(c echo.Context) (err error) { + s := db.NewSession() + defer s.Close() + uc, err := user.GetCurrentUser(c) if err != nil { return handler.HandleHTTPError(err, c) } - u, err := user.GetUserByID(uc.ID) + u, err := user.GetUserByID(s, uc.ID) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } // Get + upload the image file, err := c.FormFile("avatar") if err != nil { + _ = s.Rollback() return err } src, err := file.Open() if err != nil { + _ = s.Rollback() return err } defer src.Close() @@ -136,6 +146,7 @@ func UploadAvatar(c echo.Context) (err error) { // Validate we're dealing with an image mime, err := mimetype.DetectReader(src) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !strings.HasPrefix(mime.String(), "image") { @@ -148,6 +159,7 @@ func UploadAvatar(c echo.Context) (err error) { f := &files.File{ID: u.AvatarFileID} if err := f.Delete(); err != nil { if !files.IsErrFileDoesNotExist(err) { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } } @@ -157,11 +169,13 @@ func UploadAvatar(c echo.Context) (err error) { // Resize the new file to a max height of 1024 img, _, err := image.Decode(src) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } resizedImg := imaging.Resize(img, 0, 1024, imaging.Lanczos) buf := &bytes.Buffer{} if err := png.Encode(buf, resizedImg); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -170,6 +184,7 @@ func UploadAvatar(c echo.Context) (err error) { // Save the file f, err := files.CreateWithMime(buf, file.Filename, uint64(file.Size), u, "image/png") if err != nil { + _ = s.Rollback() if files.IsErrFileIsTooLarge(err) { return echo.ErrBadRequest } @@ -180,7 +195,13 @@ func UploadAvatar(c echo.Context) (err error) { u.AvatarFileID = f.ID u.AvatarProvider = "upload" - if _, err := user.UpdateUser(u); err != nil { + if _, err := user.UpdateUser(s, u); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/link_sharing_auth.go b/pkg/routes/api/v1/link_sharing_auth.go index 0040cfec..a457fe25 100644 --- a/pkg/routes/api/v1/link_sharing_auth.go +++ b/pkg/routes/api/v1/link_sharing_auth.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/web/handler" @@ -45,8 +47,18 @@ type LinkShareToken struct { // @Router /shares/{share}/auth [post] func AuthenticateLinkShare(c echo.Context) error { hash := c.Param("share") - share, err := models.GetLinkShareByHash(hash) + + s := db.NewSession() + defer s.Close() + + share, err := models.GetLinkShareByHash(s, hash) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/list_by_namespace.go b/pkg/routes/api/v1/list_by_namespace.go index 1d6f021b..2b668da0 100644 --- a/pkg/routes/api/v1/list_by_namespace.go +++ b/pkg/routes/api/v1/list_by_namespace.go @@ -20,6 +20,9 @@ import ( "net/http" "strconv" + "code.vikunja.io/api/pkg/db" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" @@ -41,8 +44,11 @@ import ( // @Failure 500 {object} models.Message "Internal error" // @Router /namespaces/{id}/lists [get] func GetListsByNamespaceID(c echo.Context) error { + s := db.NewSession() + defer s.Close() + // Get our namespace - namespace, err := getNamespace(c) + namespace, err := getNamespace(s, c) if err != nil { return handler.HandleHTTPError(err, c) } @@ -53,14 +59,14 @@ func GetListsByNamespaceID(c echo.Context) error { return handler.HandleHTTPError(err, c) } - lists, err := models.GetListsByNamespaceID(namespace.ID, doer) + lists, err := models.GetListsByNamespaceID(s, namespace.ID, doer) if err != nil { return handler.HandleHTTPError(err, c) } return c.JSON(http.StatusOK, lists) } -func getNamespace(c echo.Context) (namespace *models.Namespace, err error) { +func getNamespace(s *xorm.Session, c echo.Context) (namespace *models.Namespace, err error) { // Check if we have our ID id := c.Param("namespace") // Make int @@ -75,12 +81,12 @@ func getNamespace(c echo.Context) (namespace *models.Namespace, err error) { } // Check if the user has acces to that namespace - user, err := user.GetCurrentUser(c) + u, err := user.GetCurrentUser(c) if err != nil { return } namespace = &models.Namespace{ID: namespaceID} - canRead, _, err := namespace.CanRead(user) + canRead, _, err := namespace.CanRead(s, u) if err != nil { return namespace, err } diff --git a/pkg/routes/api/v1/login.go b/pkg/routes/api/v1/login.go index 56dcc084..b7ac1cfe 100644 --- a/pkg/routes/api/v1/login.go +++ b/pkg/routes/api/v1/login.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth" user2 "code.vikunja.io/api/pkg/user" @@ -45,27 +47,38 @@ func Login(c echo.Context) error { return c.JSON(http.StatusBadRequest, models.Message{Message: "Please provide a username and password."}) } + s := db.NewSession() + defer s.Close() + // Check user - user, err := user2.CheckUserCredentials(&u) + user, err := user2.CheckUserCredentials(s, &u) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } - totpEnabled, err := user2.TOTPEnabledForUser(user) + totpEnabled, err := user2.TOTPEnabledForUser(s, user) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if totpEnabled { - _, err = user2.ValidateTOTPPasscode(&user2.TOTPPasscode{ + _, err = user2.ValidateTOTPPasscode(s, &user2.TOTPPasscode{ User: user, Passcode: u.TOTPPasscode, }) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } } + if err := s.Commit(); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + // Create token return auth.NewUserAuthTokenResponse(user, c) } @@ -82,18 +95,23 @@ func Login(c echo.Context) error { // @Router /user/token [post] func RenewToken(c echo.Context) (err error) { + s := db.NewSession() + defer s.Close() + jwtinf := c.Get("user").(*jwt.Token) claims := jwtinf.Claims.(jwt.MapClaims) typ := int(claims["type"].(float64)) if typ == auth.AuthTypeLinkShare { share := &models.LinkSharing{} share.ID = int64(claims["id"].(float64)) - err := share.ReadOne() + err := share.ReadOne(s) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } t, err := auth.NewLinkShareJWTAuthtoken(share) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } return c.JSON(http.StatusOK, auth.Token{Token: t}) @@ -101,11 +119,18 @@ func RenewToken(c echo.Context) (err error) { u, err := user2.GetUserFromClaims(claims) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } - user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) + user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID}) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/task_attachment.go b/pkg/routes/api/v1/task_attachment.go index 1d3cc61f..bd4e99ed 100644 --- a/pkg/routes/api/v1/task_attachment.go +++ b/pkg/routes/api/v1/task_attachment.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" auth2 "code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/web/handler" @@ -52,8 +54,12 @@ func UploadTaskAttachment(c echo.Context) error { return handler.HandleHTTPError(err, c) } - can, err := taskAttachment.CanCreate(auth) + s := db.NewSession() + defer s.Close() + + can, err := taskAttachment.CanCreate(s, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !can { @@ -63,6 +69,7 @@ func UploadTaskAttachment(c echo.Context) error { // Multipart form form, err := c.MultipartForm() if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -85,7 +92,7 @@ func UploadTaskAttachment(c echo.Context) error { } defer f.Close() - err = ta.NewAttachment(f, file.Filename, uint64(file.Size), auth) + err = ta.NewAttachment(s, f, file.Filename, uint64(file.Size), auth) if err != nil { r.Errors = append(r.Errors, handler.HandleHTTPError(err, c)) continue @@ -93,6 +100,11 @@ func UploadTaskAttachment(c echo.Context) error { r.Success = append(r.Success, ta) } + if err := s.Commit(); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + return c.JSON(http.StatusOK, r) } @@ -121,8 +133,13 @@ func GetTaskAttachment(c echo.Context) error { if err != nil { return handler.HandleHTTPError(err, c) } - can, _, err := taskAttachment.CanRead(auth) + + s := db.NewSession() + defer s.Close() + + can, _, err := taskAttachment.CanRead(s, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !can { @@ -130,14 +147,21 @@ func GetTaskAttachment(c echo.Context) error { } // Get the attachment incl file - err = taskAttachment.ReadOne() + err = taskAttachment.ReadOne(s) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } // Open an send the file to the client err = taskAttachment.File.LoadFileByID() if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_confirm_email.go b/pkg/routes/api/v1/user_confirm_email.go index cde7db7b..c99bb1c9 100644 --- a/pkg/routes/api/v1/user_confirm_email.go +++ b/pkg/routes/api/v1/user_confirm_email.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" @@ -43,8 +45,17 @@ func UserConfirmEmail(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "No token provided.") } - err := user.ConfirmEmail(&emailConfirm) + s := db.NewSession() + defer s.Close() + + err := user.ConfirmEmail(s, &emailConfirm) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_list.go b/pkg/routes/api/v1/user_list.go index ca925516..2893125b 100644 --- a/pkg/routes/api/v1/user_list.go +++ b/pkg/routes/api/v1/user_list.go @@ -20,6 +20,8 @@ import ( "net/http" "strconv" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" auth2 "code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/user" @@ -40,9 +42,19 @@ import ( // @Failure 500 {object} models.Message "Internal server error." // @Router /users [get] func UserList(c echo.Context) error { - s := c.QueryParam("s") - users, err := user.ListUsers(s) + search := c.QueryParam("s") + + s := db.NewSession() + defer s.Close() + + users, err := user.ListUsers(s, search) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -80,17 +92,27 @@ func ListUsersForList(c echo.Context) error { return handler.HandleHTTPError(err, c) } - canRead, _, err := list.CanRead(auth) + s := db.NewSession() + defer s.Close() + + canRead, _, err := list.CanRead(s, auth) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } if !canRead { return echo.ErrForbidden } - s := c.QueryParam("s") - users, err := models.ListUsersFromList(&list, s) + search := c.QueryParam("s") + users, err := models.ListUsersFromList(s, &list, search) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_password_reset.go b/pkg/routes/api/v1/user_password_reset.go index 383332af..d24262b2 100644 --- a/pkg/routes/api/v1/user_password_reset.go +++ b/pkg/routes/api/v1/user_password_reset.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" @@ -43,8 +45,17 @@ func UserResetPassword(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "No password provided.") } - err := user.ResetPassword(&pwReset) + s := db.NewSession() + defer s.Close() + + err := user.ResetPassword(s, &pwReset) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -73,8 +84,17 @@ func UserRequestResetPasswordToken(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err) } - err := user.RequestUserPasswordResetTokenByEmail(&pwTokenReset) + s := db.NewSession() + defer s.Close() + + err := user.RequestUserPasswordResetTokenByEmail(s, &pwTokenReset) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_register.go b/pkg/routes/api/v1/user_register.go index 4c19b205..1030149e 100644 --- a/pkg/routes/api/v1/user_register.go +++ b/pkg/routes/api/v1/user_register.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" @@ -50,15 +52,25 @@ func RegisterUser(c echo.Context) error { return c.JSON(http.StatusBadRequest, models.Message{Message: "No or invalid user model provided."}) } + s := db.NewSession() + defer s.Close() + // Insert the user - newUser, err := user.CreateUser(datUser.APIFormat()) + newUser, err := user.CreateUser(s, datUser.APIFormat()) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } // Add its namespace - err = models.CreateNewNamespaceForUser(newUser) + err = models.CreateNewNamespaceForUser(s, newUser) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_settings.go b/pkg/routes/api/v1/user_settings.go index 853c2612..59c44610 100644 --- a/pkg/routes/api/v1/user_settings.go +++ b/pkg/routes/api/v1/user_settings.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" user2 "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" @@ -57,8 +59,17 @@ func GetUserAvatarProvider(c echo.Context) error { return handler.HandleHTTPError(err, c) } - user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) + s := db.NewSession() + defer s.Close() + + user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID}) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -91,15 +102,25 @@ func ChangeUserAvatarProvider(c echo.Context) error { return handler.HandleHTTPError(err, c) } - user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) + s := db.NewSession() + defer s.Close() + + user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID}) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } user.AvatarProvider = uap.AvatarProvider - _, err = user2.UpdateUser(user) + _, err = user2.UpdateUser(s, user) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -129,16 +150,26 @@ func UpdateGeneralUserSettings(c echo.Context) error { return handler.HandleHTTPError(err, c) } - user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) + s := db.NewSession() + defer s.Close() + + user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID}) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } user.Name = us.Name user.EmailRemindersEnabled = us.EmailRemindersEnabled - _, err = user2.UpdateUser(user) + _, err = user2.UpdateUser(s, user) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_show.go b/pkg/routes/api/v1/user_show.go index 7118e97e..d5953e1e 100644 --- a/pkg/routes/api/v1/user_show.go +++ b/pkg/routes/api/v1/user_show.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + user2 "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" "github.com/labstack/echo/v4" @@ -41,8 +43,17 @@ func UserShow(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "Error getting current user.") } - user, err := user2.GetUserByID(userInfos.ID) + s := db.NewSession() + defer s.Close() + + user, err := user2.GetUserByID(s, userInfos.ID) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_totp.go b/pkg/routes/api/v1/user_totp.go index c61a069f..097a5fc1 100644 --- a/pkg/routes/api/v1/user_totp.go +++ b/pkg/routes/api/v1/user_totp.go @@ -22,6 +22,8 @@ import ( "image/jpeg" "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" @@ -47,8 +49,17 @@ func UserTOTPEnroll(c echo.Context) error { return handler.HandleHTTPError(err, c) } - t, err := user.EnrollTOTP(u) + s := db.NewSession() + defer s.Close() + + t, err := user.EnrollTOTP(s, u) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -86,8 +97,17 @@ func UserTOTPEnable(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.") } - err = user.EnableTOTP(passcode) + s := db.NewSession() + defer s.Close() + + err = user.EnableTOTP(s, passcode) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -122,18 +142,29 @@ func UserTOTPDisable(c echo.Context) error { return handler.HandleHTTPError(err, c) } - u, err = user.GetUserByID(u.ID) + s := db.NewSession() + defer s.Close() + + u, err = user.GetUserByID(s, u.ID) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } err = user.CheckUserPassword(u, login.Password) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } - err = user.DisableTOTP(u) + err = user.DisableTOTP(s, u) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -156,14 +187,24 @@ func UserTOTPQrCode(c echo.Context) error { return handler.HandleHTTPError(err, c) } - qrcode, err := user.GetTOTPQrCodeForUser(u) + s := db.NewSession() + defer s.Close() + + qrcode, err := user.GetTOTPQrCodeForUser(s, u) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } buff := &bytes.Buffer{} err = jpeg.Encode(buff, qrcode, nil) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } @@ -186,8 +227,17 @@ func UserTOTP(c echo.Context) error { return handler.HandleHTTPError(err, c) } - t, err := user.GetTOTPForUser(u) + s := db.NewSession() + defer s.Close() + + t, err := user.GetTOTPForUser(s, u) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_update_email.go b/pkg/routes/api/v1/user_update_email.go index 3aca2d02..cba4a1c0 100644 --- a/pkg/routes/api/v1/user_update_email.go +++ b/pkg/routes/api/v1/user_update_email.go @@ -20,6 +20,8 @@ import ( "fmt" "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" @@ -56,16 +58,26 @@ func UpdateUserEmail(c echo.Context) (err error) { return handler.HandleHTTPError(err, c) } - emailUpdate.User, err = user.CheckUserCredentials(&user.Login{ + s := db.NewSession() + defer s.Close() + + emailUpdate.User, err = user.CheckUserCredentials(s, &user.Login{ Username: emailUpdate.User.Username, Password: emailUpdate.Password, }) if err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } - err = user.UpdateEmail(emailUpdate) + err = user.UpdateEmail(s, emailUpdate) if err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/api/v1/user_update_password.go b/pkg/routes/api/v1/user_update_password.go index 8a65fae7..fa87a57d 100644 --- a/pkg/routes/api/v1/user_update_password.go +++ b/pkg/routes/api/v1/user_update_password.go @@ -19,6 +19,8 @@ package v1 import ( "net/http" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web/handler" @@ -61,13 +63,23 @@ func UserChangePassword(c echo.Context) error { return handler.HandleHTTPError(user.ErrEmptyOldPassword{}, c) } + s := db.NewSession() + defer s.Close() + // Check the current password - if _, err = user.CheckUserCredentials(&user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil { + if _, err = user.CheckUserCredentials(s, &user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } // Update the password - if err = user.UpdateUserPassword(doer, newPW.NewPassword); err != nil { + if err = user.UpdateUserPassword(s, doer, newPW.NewPassword); err != nil { + _ = s.Rollback() + return handler.HandleHTTPError(err, c) + } + + if err := s.Commit(); err != nil { + _ = s.Rollback() return handler.HandleHTTPError(err, c) } diff --git a/pkg/routes/caldav/listStorageProvider.go b/pkg/routes/caldav/listStorageProvider.go index 0892bb0b..9e36bfca 100644 --- a/pkg/routes/caldav/listStorageProvider.go +++ b/pkg/routes/caldav/listStorageProvider.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "code.vikunja.io/api/pkg/db" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" user2 "code.vikunja.io/api/pkg/user" @@ -90,9 +92,16 @@ func (vcls *VikunjaCaldavListStorage) GetResources(rpath string, withChildren bo return []data.Resource{r}, nil } + s := db.NewSession() + defer s.Close() + // Otherwise get all lists - thelists, _, _, err := vcls.list.ReadAll(vcls.user, "", -1, 50) + thelists, _, _, err := vcls.list.ReadAll(s, vcls.user, "", -1, 50) if err != nil { + _ = s.Rollback() + return nil, err + } + if err := s.Commit(); err != nil { return nil, err } lists := thelists.([]*models.List) @@ -125,10 +134,17 @@ func (vcls *VikunjaCaldavListStorage) GetResourcesByList(rpaths []string) ([]dat uids = append(uids, string(uid[:endlen])) } + s := db.NewSession() + defer s.Close() + // GetTasksByUIDs... // Parse these into ressources... - tasks, err := models.GetTasksByUIDs(uids) + tasks, err := models.GetTasksByUIDs(s, uids) if err != nil { + _ = s.Rollback() + return nil, err + } + if err := s.Commit(); err != nil { return nil, err } @@ -187,15 +203,22 @@ func (vcls *VikunjaCaldavListStorage) GetResource(rpath string) (*data.Resource, // If the task is not nil, we need to get the task and not the list if vcls.task != nil { + s := db.NewSession() + defer s.Close() + // save and override the updated unix date to not break any later etag checks updated := vcls.task.Updated - task, err := models.GetTaskSimple(&models.Task{ID: vcls.task.ID, UID: vcls.task.UID}) + task, err := models.GetTaskSimple(s, &models.Task{ID: vcls.task.ID, UID: vcls.task.UID}) if err != nil { + _ = s.Rollback() if models.IsErrTaskDoesNotExist(err) { return nil, false, errs.ResourceNotFoundError } return nil, false, err } + if err := s.Commit(); err != nil { + return nil, false, err + } vcls.task = &task if updated.Unix() > 0 { @@ -230,6 +253,9 @@ func (vcls *VikunjaCaldavListStorage) GetShallowResource(rpath string) (*data.Re // CreateResource creates a new resource func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*data.Resource, error) { + s := db.NewSession() + defer s.Close() + vTask, err := parseTaskFromVTODO(content) if err != nil { return nil, err @@ -238,7 +264,7 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da vTask.ListID = vcls.list.ID // Check the rights - canCreate, err := vTask.CanCreate(vcls.user) + canCreate, err := vTask.CanCreate(s, vcls.user) if err != nil { return nil, err } @@ -247,8 +273,13 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da } // Create the task - err = vTask.Create(vcls.user) + err = vTask.Create(s, vcls.user) if err != nil { + _ = s.Rollback() + return nil, err + } + + if err := s.Commit(); err != nil { return nil, err } @@ -272,18 +303,28 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da // At this point, we already have the right task in vcls.task, so we can use that ID directly vTask.ID = vcls.task.ID + s := db.NewSession() + defer s.Close() + // Check the rights - canUpdate, err := vTask.CanUpdate(vcls.user) + canUpdate, err := vTask.CanUpdate(s, vcls.user) if err != nil { + _ = s.Rollback() return nil, err } if !canUpdate { + _ = s.Rollback() return nil, errs.ForbiddenError } // Update the task - err = vTask.Update() + err = vTask.Update(s) if err != nil { + _ = s.Rollback() + return nil, err + } + + if err := s.Commit(); err != nil { return nil, err } @@ -299,9 +340,13 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da // DeleteResource deletes a resource func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error { if vcls.task != nil { + s := db.NewSession() + defer s.Close() + // Check the rights - canDelete, err := vcls.task.CanDelete(vcls.user) + canDelete, err := vcls.task.CanDelete(s, vcls.user) if err != nil { + _ = s.Rollback() return err } if !canDelete { @@ -309,7 +354,13 @@ func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error { } // Delete it - return vcls.task.Delete() + err = vcls.task.Delete(s) + if err != nil { + _ = s.Rollback() + return err + } + + return s.Commit() } return nil @@ -385,16 +436,22 @@ func (vlra *VikunjaListResourceAdapter) GetModTime() time.Time { } func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr VikunjaListResourceAdapter, err error) { - can, _, err := vcls.list.CanRead(vcls.user) + s := db.NewSession() + defer s.Close() + + can, _, err := vcls.list.CanRead(s, vcls.user) if err != nil { + _ = s.Rollback() return } if !can { + _ = s.Rollback() log.Errorf("User %v tried to access a caldav resource (List %v) which they are not allowed to access", vcls.user.Username, vcls.list.ID) return rr, models.ErrUserDoesNotHaveAccessToList{ListID: vcls.list.ID} } - err = vcls.list.ReadOne() + err = vcls.list.ReadOne(s) if err != nil { + _ = s.Rollback() return } @@ -403,8 +460,9 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi tk := models.TaskCollection{ ListID: vcls.list.ID, } - iface, _, _, err := tk.ReadAll(vcls.user, "", 1, 1000) + iface, _, _, err := tk.ReadAll(s, vcls.user, "", 1, 1000) if err != nil { + _ = s.Rollback() return rr, err } tasks, ok := iface.([]*models.Task) @@ -416,6 +474,10 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi vcls.list.Tasks = tasks } + if err := s.Commit(); err != nil { + return rr, err + } + rr = VikunjaListResourceAdapter{ list: vcls.list, listTasks: listTasks, diff --git a/pkg/routes/routes.go b/pkg/routes/routes.go index c46ad316..e99157e8 100644 --- a/pkg/routes/routes.go +++ b/pkg/routes/routes.go @@ -50,11 +50,8 @@ import ( "strings" "time" - microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo" - - "code.vikunja.io/api/pkg/modules/migration/trello" - "code.vikunja.io/api/pkg/config" + "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth" @@ -65,7 +62,9 @@ import ( "code.vikunja.io/api/pkg/modules/background/upload" "code.vikunja.io/api/pkg/modules/migration" migrationHandler "code.vikunja.io/api/pkg/modules/migration/handler" + microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo" "code.vikunja.io/api/pkg/modules/migration/todoist" + "code.vikunja.io/api/pkg/modules/migration/trello" "code.vikunja.io/api/pkg/modules/migration/wunderlist" apiv1 "code.vikunja.io/api/pkg/routes/api/v1" "code.vikunja.io/api/pkg/routes/caldav" @@ -175,6 +174,7 @@ func NewEcho() *echo.Echo { }) handler.SetLoggingProvider(log.GetLogger()) handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt()) + handler.SetSessionFactory(db.NewSession) return e } @@ -601,11 +601,19 @@ func caldavBasicAuth(username, password string, c echo.Context) (bool, error) { Username: username, Password: password, } - u, err := user.CheckUserCredentials(creds) + s := db.NewSession() + defer s.Close() + u, err := user.CheckUserCredentials(s, creds) if err != nil { + _ = s.Rollback() log.Errorf("Error during basic auth for caldav: %v", err) return false, nil } + + if err := s.Commit(); err != nil { + return false, err + } + // Save the user in echo context for later use c.Set("userBasicAuth", u) return true, nil diff --git a/pkg/user/db.go b/pkg/user/db.go index 54d312a4..86c5200e 100644 --- a/pkg/user/db.go +++ b/pkg/user/db.go @@ -20,20 +20,10 @@ package user import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/db" - "code.vikunja.io/api/pkg/log" - "xorm.io/xorm" ) -var x *xorm.Engine - // InitDB sets up the database connection to use in this module func InitDB() (err error) { - x, err = db.CreateDBEngine() - if err != nil { - log.Criticalf("Could not connect to db: %v", err.Error()) - return - } - // Cache if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" { db.RegisterTableStructsForCache(GetTables()) diff --git a/pkg/user/test.go b/pkg/user/test.go index f2a61cd2..1fcd3452 100644 --- a/pkg/user/test.go +++ b/pkg/user/test.go @@ -24,8 +24,7 @@ import ( // InitTests handles the actual bootstrapping of the test env func InitTests() { - var err error - x, err = db.CreateTestEngine() + x, err := db.CreateTestEngine() if err != nil { log.Fatal(err) } diff --git a/pkg/user/totp.go b/pkg/user/totp.go index c8b6a163..b1f836a6 100644 --- a/pkg/user/totp.go +++ b/pkg/user/totp.go @@ -19,6 +19,8 @@ package user import ( "image" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/config" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" @@ -47,19 +49,19 @@ type TOTPPasscode struct { } // TOTPEnabledForUser checks if totp is enabled for a user - not if it is activated, use GetTOTPForUser to check that. -func TOTPEnabledForUser(user *User) (bool, error) { +func TOTPEnabledForUser(s *xorm.Session, user *User) (bool, error) { if !config.ServiceEnableTotp.GetBool() { return false, nil } t := &TOTP{} - _, err := x.Where("user_id = ?", user.ID).Get(t) + _, err := s.Where("user_id = ?", user.ID).Get(t) return t.Enabled, err } // GetTOTPForUser returns the current state of totp settings for the user. -func GetTOTPForUser(user *User) (t *TOTP, err error) { +func GetTOTPForUser(s *xorm.Session, user *User) (t *TOTP, err error) { t = &TOTP{} - exists, err := x.Where("user_id = ?", user.ID).Get(t) + exists, err := s.Where("user_id = ?", user.ID).Get(t) if err != nil { return } @@ -71,8 +73,8 @@ func GetTOTPForUser(user *User) (t *TOTP, err error) { } // EnrollTOTP creates a new TOTP entry for the user - it does not enable it yet. -func EnrollTOTP(user *User) (t *TOTP, err error) { - isEnrolled, err := x.Where("user_id = ?", user.ID).Exist(&TOTP{}) +func EnrollTOTP(s *xorm.Session, user *User) (t *TOTP, err error) { + isEnrolled, err := s.Where("user_id = ?", user.ID).Exist(&TOTP{}) if err != nil { return } @@ -94,18 +96,18 @@ func EnrollTOTP(user *User) (t *TOTP, err error) { Enabled: false, URL: key.URL(), } - _, err = x.Insert(t) + _, err = s.Insert(t) return } // EnableTOTP enables totp for a user. The provided passcode is used to verify the user has a working totp setup. -func EnableTOTP(passcode *TOTPPasscode) (err error) { - t, err := ValidateTOTPPasscode(passcode) +func EnableTOTP(s *xorm.Session, passcode *TOTPPasscode) (err error) { + t, err := ValidateTOTPPasscode(s, passcode) if err != nil { return } - _, err = x. + _, err = s. Where("id = ?", t.ID). Cols("enabled"). Update(&TOTP{Enabled: true}) @@ -113,14 +115,16 @@ func EnableTOTP(passcode *TOTPPasscode) (err error) { } // DisableTOTP removes all totp settings for a user. -func DisableTOTP(user *User) (err error) { - _, err = x.Where("user_id = ?", user.ID).Delete(&TOTP{}) +func DisableTOTP(s *xorm.Session, user *User) (err error) { + _, err = s. + Where("user_id = ?", user.ID). + Delete(&TOTP{}) return } // ValidateTOTPPasscode validated totp codes of users. -func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) { - t, err = GetTOTPForUser(passcode.User) +func ValidateTOTPPasscode(s *xorm.Session, passcode *TOTPPasscode) (t *TOTP, err error) { + t, err = GetTOTPForUser(s, passcode.User) if err != nil { return } @@ -133,8 +137,8 @@ func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) { } // GetTOTPQrCodeForUser returns a qrcode for a user's totp setting -func GetTOTPQrCodeForUser(user *User) (qrcode image.Image, err error) { - t, err := GetTOTPForUser(user) +func GetTOTPQrCodeForUser(s *xorm.Session, user *User) (qrcode image.Image, err error) { + t, err := GetTOTPForUser(s, user) if err != nil { return } diff --git a/pkg/user/update_email.go b/pkg/user/update_email.go index 9af5c8c6..2b464676 100644 --- a/pkg/user/update_email.go +++ b/pkg/user/update_email.go @@ -20,6 +20,7 @@ import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/mail" "code.vikunja.io/api/pkg/utils" + "xorm.io/xorm" ) // EmailUpdate is the data structure to update a user's email address @@ -32,11 +33,11 @@ type EmailUpdate struct { } // UpdateEmail lets a user update their email address -func UpdateEmail(update *EmailUpdate) (err error) { +func UpdateEmail(s *xorm.Session, update *EmailUpdate) (err error) { // Check the email is not already used user := &User{} - has, err := x.Where("email = ?", update.NewEmail).Get(user) + has, err := s.Where("email = ?", update.NewEmail).Get(user) if err != nil { return } @@ -46,7 +47,7 @@ func UpdateEmail(update *EmailUpdate) (err error) { } // Set the user as unconfirmed and the new email address - update.User, err = GetUserWithEmail(&User{ID: update.User.ID}) + update.User, err = GetUserWithEmail(s, &User{ID: update.User.ID}) if err != nil { return } @@ -54,7 +55,7 @@ func UpdateEmail(update *EmailUpdate) (err error) { update.User.IsActive = false update.User.Email = update.NewEmail update.User.EmailConfirmToken = utils.MakeRandomString(64) - _, err = x. + _, err = s. Where("id = ?", update.User.ID). Cols("email", "is_active", "email_confirm_token"). Update(update.User) diff --git a/pkg/user/user.go b/pkg/user/user.go index d7450191..f658b316 100644 --- a/pkg/user/user.go +++ b/pkg/user/user.go @@ -23,6 +23,8 @@ import ( "reflect" "time" + "xorm.io/xorm" + "code.vikunja.io/web" "github.com/dgrijalva/jwt-go" "github.com/labstack/echo/v4" @@ -116,38 +118,33 @@ func (apiUser *APIUserPassword) APIFormat() *User { } // GetUserByID gets informations about a user by its ID -func GetUserByID(id int64) (user *User, err error) { +func GetUserByID(s *xorm.Session, id int64) (user *User, err error) { // Apparently xorm does otherwise look for all users but return only one, which leads to returing one even if the ID is 0 if id < 1 { return &User{}, ErrUserDoesNotExist{} } - return GetUser(&User{ID: id}) + return getUser(s, &User{ID: id}, false) } // GetUserByUsername gets a user from its user name. This is an extra function to be able to add an extra error check. -func GetUserByUsername(username string) (user *User, err error) { +func GetUserByUsername(s *xorm.Session, username string) (user *User, err error) { if username == "" { return &User{}, ErrUserDoesNotExist{} } - return GetUser(&User{Username: username}) -} - -// GetUser gets a user object -func GetUser(user *User) (userOut *User, err error) { - return getUser(user, false) + return getUser(s, &User{Username: username}, false) } // GetUserWithEmail returns a user object with email -func GetUserWithEmail(user *User) (userOut *User, err error) { - return getUser(user, true) +func GetUserWithEmail(s *xorm.Session, user *User) (userOut *User, err error) { + return getUser(s, user, true) } // GetUsersByIDs returns a map of users from a slice of user ids -func GetUsersByIDs(userIDs []int64) (users map[int64]*User, err error) { +func GetUsersByIDs(s *xorm.Session, userIDs []int64) (users map[int64]*User, err error) { users = make(map[int64]*User) - err = x.In("id", userIDs).Find(&users) + err = s.In("id", userIDs).Find(&users) if err != nil { return } @@ -161,10 +158,10 @@ func GetUsersByIDs(userIDs []int64) (users map[int64]*User, err error) { } // getUser is a small helper function to avoid having duplicated code for almost the same use case -func getUser(user *User, withEmail bool) (userOut *User, err error) { +func getUser(s *xorm.Session, user *User, withEmail bool) (userOut *User, err error) { userOut = &User{} // To prevent a panic if user is nil *userOut = *user - exists, err := x.Get(userOut) + exists, err := s.Get(userOut) if err != nil { return nil, err } @@ -179,9 +176,9 @@ func getUser(user *User, withEmail bool) (userOut *User, err error) { return userOut, err } -func getUserByUsernameOrEmail(usernameOrEmail string) (u *User, err error) { +func getUserByUsernameOrEmail(s *xorm.Session, usernameOrEmail string) (u *User, err error) { u = &User{} - exists, err := x. + exists, err := s. Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail). Get(u) if err != nil { @@ -196,14 +193,14 @@ func getUserByUsernameOrEmail(usernameOrEmail string) (u *User, err error) { } // CheckUserCredentials checks user credentials -func CheckUserCredentials(u *Login) (*User, error) { +func CheckUserCredentials(s *xorm.Session, u *Login) (*User, error) { // Check if we have any credentials if u.Password == "" || u.Username == "" { return nil, ErrNoUsernamePassword{} } // Check if the user exists - user, err := getUserByUsernameOrEmail(u.Username) + user, err := getUserByUsernameOrEmail(s, u.Username) if err != nil { // hashing the password takes a long time, so we hash something to not make it clear if the username was wrong _, _ = bcrypt.GenerateFromPassword([]byte(u.Username), 14) @@ -261,10 +258,10 @@ func GetUserFromClaims(claims jwt.MapClaims) (user *User, err error) { } // UpdateUser updates a user -func UpdateUser(user *User) (updatedUser *User, err error) { +func UpdateUser(s *xorm.Session, user *User) (updatedUser *User, err error) { // Check if it exists - theUser, err := GetUserWithEmail(&User{ID: user.ID}) + theUser, err := GetUserWithEmail(s, &User{ID: user.ID}) if err != nil { return &User{}, err } @@ -274,7 +271,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) { user.Username = theUser.Username // Dont change the username if we dont have one } else { // Check if the new username already exists - uu, err := GetUserByUsername(user.Username) + uu, err := GetUserByUsername(s, user.Username) if err != nil && !IsErrUserDoesNotExist(err) { return nil, err } @@ -292,7 +289,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) { if user.Email == "" { user.Email = theUser.Email } else { - uu, err := getUser(&User{ + uu, err := getUser(s, &User{ Email: user.Email, Issuer: user.Issuer, Subject: user.Subject, @@ -316,7 +313,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) { } // Update it - _, err = x. + _, err = s. ID(user.ID). Cols( "username", @@ -333,7 +330,7 @@ func UpdateUser(user *User) (updatedUser *User, err error) { } // Get the newly updated user - updatedUser, err = GetUserByID(user.ID) + updatedUser, err = GetUserByID(s, user.ID) if err != nil { return &User{}, err } @@ -342,14 +339,14 @@ func UpdateUser(user *User) (updatedUser *User, err error) { } // UpdateUserPassword updates the password of a user -func UpdateUserPassword(user *User, newPassword string) (err error) { +func UpdateUserPassword(s *xorm.Session, user *User, newPassword string) (err error) { if newPassword == "" { return ErrEmptyNewPassword{} } // Get all user details - theUser, err := GetUserByID(user.ID) + theUser, err := GetUserByID(s, user.ID) if err != nil { return err } @@ -362,7 +359,7 @@ func UpdateUserPassword(user *User, newPassword string) (err error) { theUser.Password = hashed // Update it - _, err = x.ID(user.ID).Update(theUser) + _, err = s.ID(user.ID).Update(theUser) if err != nil { return err } diff --git a/pkg/user/user_create.go b/pkg/user/user_create.go index 74354c31..d592e80e 100644 --- a/pkg/user/user_create.go +++ b/pkg/user/user_create.go @@ -22,12 +22,13 @@ import ( "code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/utils" "golang.org/x/crypto/bcrypt" + "xorm.io/xorm" ) const issuerLocal = `local` // CreateUser creates a new user and inserts it into the database -func CreateUser(user *User) (newUser *User, err error) { +func CreateUser(s *xorm.Session, user *User) (newUser *User, err error) { if user.Issuer == "" { user.Issuer = issuerLocal @@ -40,7 +41,7 @@ func CreateUser(user *User) (newUser *User, err error) { } // Check if the user already exists with that username - err = checkIfUserExists(user) + err = checkIfUserExists(s, user) if err != nil { return nil, err } @@ -64,7 +65,7 @@ func CreateUser(user *User) (newUser *User, err error) { user.AvatarProvider = "initials" // Insert it - _, err = x.Insert(user) + _, err = s.Insert(user) if err != nil { return nil, err } @@ -73,7 +74,7 @@ func CreateUser(user *User) (newUser *User, err error) { metrics.UpdateCount(1, metrics.ActiveUsersKey) // Get the full new User - newUserOut, err := GetUserByID(user.ID) + newUserOut, err := GetUserByID(s, user.ID) if err != nil { return nil, err } @@ -100,9 +101,9 @@ func checkIfUserIsValid(user *User) error { return nil } -func checkIfUserExists(user *User) (err error) { +func checkIfUserExists(s *xorm.Session, user *User) (err error) { exists := true - _, err = GetUserByUsername(user.Username) + _, err = GetUserByUsername(s, user.Username) if err != nil { if IsErrUserDoesNotExist(err) { exists = false @@ -126,7 +127,7 @@ func checkIfUserExists(user *User) (err error) { userToCheck.Email = "" } - _, err = GetUser(userToCheck) + _, err = getUser(s, userToCheck, false) if err != nil { if IsErrUserDoesNotExist(err) { exists = false diff --git a/pkg/user/user_email_confirm.go b/pkg/user/user_email_confirm.go index 68378c2b..7bcb5c2d 100644 --- a/pkg/user/user_email_confirm.go +++ b/pkg/user/user_email_confirm.go @@ -17,6 +17,8 @@ package user +import "xorm.io/xorm" + // EmailConfirm holds the token to confirm a mail address type EmailConfirm struct { // The email confirm token sent via email. @@ -24,7 +26,7 @@ type EmailConfirm struct { } // ConfirmEmail handles the confirmation of an email address -func ConfirmEmail(c *EmailConfirm) (err error) { +func ConfirmEmail(s *xorm.Session, c *EmailConfirm) (err error) { // Check if we have an email confirm token if c.Token == "" { @@ -33,7 +35,9 @@ func ConfirmEmail(c *EmailConfirm) (err error) { // Check if the token is valid user := User{} - has, err := x.Where("email_confirm_token = ?", c.Token).Get(&user) + has, err := s. + Where("email_confirm_token = ?", c.Token). + Get(&user) if err != nil { return } @@ -44,6 +48,9 @@ func ConfirmEmail(c *EmailConfirm) (err error) { user.IsActive = true user.EmailConfirmToken = "" - _, err = x.Where("id = ?", user.ID).Cols("is_active", "email_confirm_token").Update(&user) + _, err = s. + Where("id = ?", user.ID). + Cols("is_active", "email_confirm_token"). + Update(&user) return } diff --git a/pkg/user/user_email_confirm_test.go b/pkg/user/user_email_confirm_test.go index 4964c3ca..ccebeda2 100644 --- a/pkg/user/user_email_confirm_test.go +++ b/pkg/user/user_email_confirm_test.go @@ -65,7 +65,10 @@ func TestUserEmailConfirm(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db.LoadAndAssertFixtures(t) - if err := ConfirmEmail(tt.args.c); (err != nil) != tt.wantErr { + s := db.NewSession() + defer s.Close() + + if err := ConfirmEmail(s, tt.args.c); (err != nil) != tt.wantErr { t.Errorf("ConfirmEmail() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/pkg/user/user_password_reset.go b/pkg/user/user_password_reset.go index 4ec892e0..3d395893 100644 --- a/pkg/user/user_password_reset.go +++ b/pkg/user/user_password_reset.go @@ -21,6 +21,7 @@ import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/mail" "code.vikunja.io/api/pkg/utils" + "xorm.io/xorm" ) // PasswordReset holds the data to reset a password @@ -32,7 +33,7 @@ type PasswordReset struct { } // ResetPassword resets a users password -func ResetPassword(reset *PasswordReset) (err error) { +func ResetPassword(s *xorm.Session, reset *PasswordReset) (err error) { // Check if the password is not empty if reset.NewPassword == "" { @@ -41,7 +42,9 @@ func ResetPassword(reset *PasswordReset) (err error) { // Check if we have a token var user User - exists, err := x.Where("password_reset_token = ?", reset.Token).Get(&user) + exists, err := s. + Where("password_reset_token = ?", reset.Token). + Get(&user) if err != nil { return } @@ -57,7 +60,9 @@ func ResetPassword(reset *PasswordReset) (err error) { } // Save it - _, err = x.Where("id = ?", user.ID).Update(&user) + _, err = s. + Where("id = ?", user.ID). + Update(&user) if err != nil { return } @@ -83,27 +88,29 @@ type PasswordTokenRequest struct { } // RequestUserPasswordResetTokenByEmail inserts a random token to reset a users password into the databsse -func RequestUserPasswordResetTokenByEmail(tr *PasswordTokenRequest) (err error) { +func RequestUserPasswordResetTokenByEmail(s *xorm.Session, tr *PasswordTokenRequest) (err error) { if tr.Email == "" { return ErrNoUsernamePassword{} } // Check if the user exists - user, err := GetUserWithEmail(&User{Email: tr.Email}) + user, err := GetUserWithEmail(s, &User{Email: tr.Email}) if err != nil { return } - return RequestUserPasswordResetToken(user) + return RequestUserPasswordResetToken(s, user) } // RequestUserPasswordResetToken sends a user a password reset email. -func RequestUserPasswordResetToken(user *User) (err error) { +func RequestUserPasswordResetToken(s *xorm.Session, user *User) (err error) { // Generate a token and save it user.PasswordResetToken = utils.MakeRandomString(400) // Save it - _, err = x.Where("id = ?", user.ID).Update(user) + _, err = s. + Where("id = ?", user.ID). + Update(user) if err != nil { return } diff --git a/pkg/user/user_test.go b/pkg/user/user_test.go index 472bff40..cda18374 100644 --- a/pkg/user/user_test.go +++ b/pkg/user/user_test.go @@ -34,13 +34,19 @@ func TestCreateUser(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - createdUser, err := CreateUser(dummyuser) + s := db.NewSession() + defer s.Close() + + createdUser, err := CreateUser(s, dummyuser) assert.NoError(t, err) assert.NotZero(t, createdUser.Created) }) t.Run("already existing", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "user1", Password: "12345", Email: "email@example.com", @@ -50,7 +56,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("same email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "testuser", Password: "12345", Email: "user1@example.com", @@ -60,7 +69,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("no username", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "", Password: "12345", Email: "user1@example.com", @@ -70,7 +82,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("no password", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "testuser", Password: "", Email: "user1@example.com", @@ -80,7 +95,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("no email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "testuser", Password: "12345", Email: "", @@ -90,7 +108,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("same email but different issuer", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "somenewuser", Email: "user1@example.com", Issuer: "https://some.site", @@ -100,7 +121,10 @@ func TestCreateUser(t *testing.T) { }) t.Run("same subject but different issuer", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CreateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := CreateUser(s, &User{ Username: "somenewuser", Email: "somenewuser@example.com", Issuer: "https://some.site", @@ -113,25 +137,41 @@ func TestCreateUser(t *testing.T) { func TestGetUser(t *testing.T) { t.Run("by name", func(t *testing.T) { db.LoadAndAssertFixtures(t) - theuser, err := GetUser(&User{ - Username: "user1", - }) + s := db.NewSession() + defer s.Close() + + theuser, err := getUser( + s, + &User{ + Username: "user1", + }, + false, + ) assert.NoError(t, err) assert.Equal(t, theuser.ID, int64(1)) assert.Empty(t, theuser.Email) }) t.Run("by email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - theuser, err := GetUser(&User{ - Email: "user1@example.com", - }) + s := db.NewSession() + defer s.Close() + + theuser, err := getUser( + s, + &User{ + Email: "user1@example.com", + }, + false) assert.NoError(t, err) assert.Equal(t, theuser.ID, int64(1)) assert.Empty(t, theuser.Email) }) t.Run("by id", func(t *testing.T) { db.LoadAndAssertFixtures(t) - theuser, err := GetUserByID(1) + s := db.NewSession() + defer s.Close() + + theuser, err := GetUserByID(s, 1) assert.NoError(t, err) assert.Equal(t, theuser.ID, int64(1)) assert.Equal(t, theuser.Username, "user1") @@ -139,25 +179,37 @@ func TestGetUser(t *testing.T) { }) t.Run("invalid id", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := GetUserByID(99999) + s := db.NewSession() + defer s.Close() + + _, err := GetUserByID(s, 99999) assert.Error(t, err) assert.True(t, IsErrUserDoesNotExist(err)) }) t.Run("nonexistant", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := GetUserByID(0) + s := db.NewSession() + defer s.Close() + + _, err := GetUserByID(s, 0) assert.Error(t, err) assert.True(t, IsErrUserDoesNotExist(err)) }) t.Run("empty name", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := GetUserByUsername("") + s := db.NewSession() + defer s.Close() + + _, err := GetUserByUsername(s, "") assert.Error(t, err) assert.True(t, IsErrUserDoesNotExist(err)) }) t.Run("with email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - theuser, err := GetUserWithEmail(&User{ID: 1}) + s := db.NewSession() + defer s.Close() + + theuser, err := GetUserWithEmail(s, &User{ID: 1}) assert.NoError(t, err) assert.Equal(t, theuser.ID, int64(1)) assert.Equal(t, theuser.Username, "user1") @@ -168,42 +220,63 @@ func TestGetUser(t *testing.T) { func TestCheckUserCredentials(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "user1", Password: "1234"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "user1", Password: "1234"}) assert.NoError(t, err) }) t.Run("unverified email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "user5", Password: "1234"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "user5", Password: "1234"}) assert.Error(t, err) assert.True(t, IsErrEmailNotConfirmed(err)) }) t.Run("wrong password", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "user1", Password: "12345"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "user1", Password: "12345"}) assert.Error(t, err) assert.True(t, IsErrWrongUsernameOrPassword(err)) }) t.Run("nonexistant user", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "dfstestuu", Password: "1234"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "dfstestuu", Password: "1234"}) assert.Error(t, err) assert.True(t, IsErrWrongUsernameOrPassword(err)) }) t.Run("empty password", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "user1"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "user1"}) assert.Error(t, err) assert.True(t, IsErrNoUsernamePassword(err)) }) t.Run("empty username", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Password: "1234"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Password: "1234"}) assert.Error(t, err) assert.True(t, IsErrNoUsernamePassword(err)) }) t.Run("email", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := CheckUserCredentials(&Login{Username: "user1@example.com", Password: "1234"}) + s := db.NewSession() + defer s.Close() + + _, err := CheckUserCredentials(s, &Login{Username: "user1@example.com", Password: "1234"}) assert.NoError(t, err) }) } @@ -211,7 +284,10 @@ func TestCheckUserCredentials(t *testing.T) { func TestUpdateUser(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - uuser, err := UpdateUser(&User{ + s := db.NewSession() + defer s.Close() + + uuser, err := UpdateUser(s, &User{ ID: 1, Password: "LoremIpsum", Email: "testing@example.com", @@ -222,7 +298,10 @@ func TestUpdateUser(t *testing.T) { }) t.Run("change username", func(t *testing.T) { db.LoadAndAssertFixtures(t) - uuser, err := UpdateUser(&User{ + s := db.NewSession() + defer s.Close() + + uuser, err := UpdateUser(s, &User{ ID: 1, Username: "changedname", }) @@ -232,7 +311,10 @@ func TestUpdateUser(t *testing.T) { }) t.Run("nonexistant", func(t *testing.T) { db.LoadAndAssertFixtures(t) - _, err := UpdateUser(&User{ + s := db.NewSession() + defer s.Close() + + _, err := UpdateUser(s, &User{ ID: 99999, }) assert.Error(t, err) @@ -244,15 +326,20 @@ func TestUpdateUserPassword(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - err := UpdateUserPassword(&User{ + s := db.NewSession() + defer s.Close() + + err := UpdateUserPassword(s, &User{ ID: 1, - }, "12345", - ) + }, "12345") assert.NoError(t, err) }) t.Run("nonexistant user", func(t *testing.T) { db.LoadAndAssertFixtures(t) - err := UpdateUserPassword(&User{ + s := db.NewSession() + defer s.Close() + + err := UpdateUserPassword(s, &User{ ID: 9999, }, "12345") assert.Error(t, err) @@ -260,10 +347,12 @@ func TestUpdateUserPassword(t *testing.T) { }) t.Run("empty password", func(t *testing.T) { db.LoadAndAssertFixtures(t) - err := UpdateUserPassword(&User{ + s := db.NewSession() + defer s.Close() + + err := UpdateUserPassword(s, &User{ ID: 1, - }, "", - ) + }, "") assert.Error(t, err) assert.True(t, IsErrEmptyNewPassword(err)) }) @@ -272,14 +361,20 @@ func TestUpdateUserPassword(t *testing.T) { func TestListUsers(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) - all, err := ListUsers("user1") + s := db.NewSession() + defer s.Close() + + all, err := ListUsers(s, "user1") assert.NoError(t, err) assert.True(t, len(all) > 0) assert.Equal(t, all[0].Username, "user1") }) t.Run("all users", func(t *testing.T) { db.LoadAndAssertFixtures(t) - all, err := ListUsers("") + s := db.NewSession() + defer s.Close() + + all, err := ListUsers(s, "") assert.NoError(t, err) assert.Len(t, all, 14) }) @@ -288,39 +383,51 @@ func TestListUsers(t *testing.T) { func TestUserPasswordReset(t *testing.T) { t.Run("normal", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + reset := &PasswordReset{ Token: "passwordresettesttoken", NewPassword: "12345", } - err := ResetPassword(reset) + err := ResetPassword(s, reset) assert.NoError(t, err) }) t.Run("without password", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + reset := &PasswordReset{ Token: "passwordresettesttoken", } - err := ResetPassword(reset) + err := ResetPassword(s, reset) assert.Error(t, err) assert.True(t, IsErrNoUsernamePassword(err)) }) t.Run("empty token", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + reset := &PasswordReset{ Token: "somethingsomething", NewPassword: "12345", } - err := ResetPassword(reset) + err := ResetPassword(s, reset) assert.Error(t, err) assert.True(t, IsErrInvalidPasswordResetToken(err)) }) t.Run("wrong token", func(t *testing.T) { db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + reset := &PasswordReset{ Token: "somethingsomething", NewPassword: "12345", } - err := ResetPassword(reset) + err := ResetPassword(s, reset) assert.Error(t, err) assert.True(t, IsErrInvalidPasswordResetToken(err)) }) diff --git a/pkg/user/users_list.go b/pkg/user/users_list.go index 750d17e1..0f7e12b5 100644 --- a/pkg/user/users_list.go +++ b/pkg/user/users_list.go @@ -21,11 +21,13 @@ import ( "strconv" "strings" + "xorm.io/xorm" + "code.vikunja.io/api/pkg/log" ) // ListUsers returns a list with all users, filtered by an optional searchstring -func ListUsers(searchterm string) (users []*User, err error) { +func ListUsers(s *xorm.Session, searchterm string) (users []*User, err error) { vals := strings.Split(searchterm, ",") ids := []int64{} @@ -39,18 +41,18 @@ func ListUsers(searchterm string) (users []*User, err error) { } if len(ids) > 0 { - err = x. + err = s. In("id", ids). Find(&users) return } if searchterm == "" { - err = x.Find(&users) + err = s.Find(&users) return } - err = x. + err = s. Where("username LIKE ?", "%"+searchterm+"%"). Find(&users) return