Use db sessions everywere (#750)

Fix lint

Fix lint

Fix loading tasks with search

Fix loading lists

Fix loading task

Fix loading lists and namespaces

Fix tests

Fix user commands

Fix upload

Fix migration handlers

Fix all manual root handlers

Fix session in avatar

Fix session in list duplication & routes

Use sessions in migration code

Make sure the openid stuff uses a session

Add alias for db type in db package

Use sessions for file

Use a session for everything in users

Use a session for everything in users

Make sure to use a session everywhere in models

Create new session from db

Add session handling for user list

Add session handling for unsplash

Add session handling for teams and related

Add session handling for tasks and related entities

Add session handling for task reminders

Add session handling for task relations

Add session handling for task comments

Add session handling for task collections

Add session handling for task attachments

Add session handling for task assignees

Add session handling for saved filters

Add session handling for namespace and related types

Add session handling for namespace and related types

Add session handling for list users

Add session handling for list tests

Add session handling to list teams and related entities

Add session handling for link shares and related entities

Add session handling for labels and related entities

Add session handling for kanban and related entities

Add session handling for bulk task and related entities

Add session handling for lists and related entities

Add session configuration for web handler

Update web handler

Co-authored-by: kolaente <k@knt.li>
Reviewed-on: https://kolaente.dev/vikunja/api/pulls/750
Co-Authored-By: konrad <konrad@kola-entertainments.de>
Co-Committed-By: konrad <konrad@kola-entertainments.de>
This commit is contained in:
konrad 2020-12-23 15:32:28 +00:00
parent fa68e89c04
commit 8d1a09b5a2
107 changed files with 2428 additions and 1279 deletions

8
go.mod
View file

@ -18,7 +18,7 @@ module code.vikunja.io/api
require ( require (
4d63.com/tz v1.2.0 4d63.com/tz v1.2.0
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f code.vikunja.io/web v0.0.0-20201223143420-588abb73703a
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 // indirect dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 // indirect
gitea.com/xorm/xorm-redis-cache v0.2.0 gitea.com/xorm/xorm-redis-cache v0.2.0
github.com/adlio/trello v1.8.0 github.com/adlio/trello v1.8.0
@ -41,6 +41,7 @@ require (
github.com/go-sql-driver/mysql v1.5.0 github.com/go-sql-driver/mysql v1.5.0
github.com/go-testfixtures/testfixtures/v3 v3.4.1 github.com/go-testfixtures/testfixtures/v3 v3.4.1
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
github.com/golang/snappy v0.0.2 // indirect
github.com/gordonklaus/ineffassign v0.0.0-20201107091007-3b93a8888063 github.com/gordonklaus/ineffassign v0.0.0-20201107091007-3b93a8888063
github.com/iancoleman/strcase v0.1.2 github.com/iancoleman/strcase v0.1.2
github.com/imdario/mergo v0.3.11 github.com/imdario/mergo v0.3.11
@ -52,6 +53,7 @@ require (
github.com/lib/pq v1.9.0 github.com/lib/pq v1.9.0
github.com/magefile/mage v1.10.0 github.com/magefile/mage v1.10.0
github.com/mailru/easyjson v0.7.6 // indirect github.com/mailru/easyjson v0.7.6 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-sqlite3 v1.14.5 github.com/mattn/go-sqlite3 v1.14.5
github.com/mitchellh/mapstructure v1.3.2 // indirect github.com/mitchellh/mapstructure v1.3.2 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
@ -76,8 +78,10 @@ require (
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
golang.org/x/image v0.0.0-20201208152932-35266b937fa6 golang.org/x/image v0.0.0-20201208152932-35266b937fa6
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5
golang.org/x/net v0.0.0-20201216054612-986b41b23924 // indirect
golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
golang.org/x/sys v0.0.0-20201223074533-0d417f636930 // indirect
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
@ -91,7 +95,7 @@ require (
src.techknowlogick.com/xormigrate v1.4.0 src.techknowlogick.com/xormigrate v1.4.0
xorm.io/builder v0.3.7 xorm.io/builder v0.3.7
xorm.io/core v0.7.3 xorm.io/core v0.7.3
xorm.io/xorm v1.0.2 xorm.io/xorm v1.0.5
) )
replace ( replace (

25
go.sum
View file

@ -38,8 +38,12 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo
cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f h1:Zgtk9lbJkGbKjdTC78mg/c2uNkesxDJs1YUIL9zGvco= code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7 h1:iS3TFA+y1If6DEbqzad5Ge7TI1NxZr9BevC/dU4ygEo=
code.vikunja.io/web v0.0.0-20200809154828-8767618f181f/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo= code.vikunja.io/web v0.0.0-20201218134444-505d0e77fac7/go.mod h1:vDWiCtftF6LNCCrem7mjstPWMgzLUvMW/L4YwIQ1Voo=
code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215 h1:O5zMWgcnVDVLaQUawgdsv/jX/4SUUAvSedvRR+5+x2o=
code.vikunja.io/web v0.0.0-20201222144643-6fa2fb587215/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU=
code.vikunja.io/web v0.0.0-20201223143420-588abb73703a h1:LaWCucY5Pp30EIMgGOvdVFNss5OhIAwrAO8PuFVRUfw=
code.vikunja.io/web v0.0.0-20201223143420-588abb73703a/go.mod h1:OgFO06HN1KpA4S7Dw/QAIeygiUPSeGJJn1ykz/sjZdU=
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 h1:o4lAkfETerCnr1kF9/qwkwjICnU+YLHNDCM8h2xj7as= dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363 h1:o4lAkfETerCnr1kF9/qwkwjICnU+YLHNDCM8h2xj7as=
dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363/go.mod h1:WG7q7swWsS2f9PYpt5DoEP/EBYWx8We5UoRltn9vJl8= dmitri.shuralyov.com/go/generated v0.0.0-20170818220700-b1254a446363/go.mod h1:WG7q7swWsS2f9PYpt5DoEP/EBYWx8We5UoRltn9vJl8=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
@ -152,6 +156,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o= github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o=
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM=
github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204 h1:tI48fqaIkxxYuIylVv1tdDfBp6836GKSfmmzgSyP1CY= github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204 h1:tI48fqaIkxxYuIylVv1tdDfBp6836GKSfmmzgSyP1CY=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
@ -293,6 +298,8 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw=
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
@ -491,6 +498,7 @@ github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU=
github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.8.0 h1:9xohqzkUwzR4Ga4ivdTcawVS89YSDVxXMa3xJX3cGzg= github.com/lib/pq v1.8.0 h1:9xohqzkUwzR4Ga4ivdTcawVS89YSDVxXMa3xJX3cGzg=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8=
@ -516,6 +524,8 @@ github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+v
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw=
github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs= github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
@ -852,8 +862,6 @@ golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201217014255-9d1352758620 h1:3wPMTskHO3+O6jqTEXyFcsnuxMQOqYSaHsDxcbUXpqA=
golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -944,6 +952,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2l
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201216054612-986b41b23924 h1:QsnDpLLOKwHBBDa8nDws4DYNc/ryVW2vCpxCs09d4PY=
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
@ -1028,8 +1038,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuF
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs= golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs=
golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0 h1:n+DPcgTwkgWzIFpLmoimYR2K2b0Ga5+Os4kayIN0vGo=
golang.org/x/sys v0.0.0-20201221093633-bc327ba9c2f0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201223074533-0d417f636930 h1:vRgIt+nup/B/BwIS0g2oC0haq0iqbV3ZA+u6+0TlNCo=
golang.org/x/sys v0.0.0-20201223074533-0d417f636930/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1284,3 +1299,5 @@ xorm.io/xorm v1.0.1 h1:/lITxpJtkZauNpdzj+L9CN/3OQxZaABrbergMcJu+Cw=
xorm.io/xorm v1.0.1/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= xorm.io/xorm v1.0.1/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY=
xorm.io/xorm v1.0.2 h1:kZlCh9rqd1AzGwWitcrEEqHE1h1eaZE/ujU5/2tWEtg= xorm.io/xorm v1.0.2 h1:kZlCh9rqd1AzGwWitcrEEqHE1h1eaZE/ujU5/2tWEtg=
xorm.io/xorm v1.0.2/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= xorm.io/xorm v1.0.2/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY=
xorm.io/xorm v1.0.5 h1:LRr5PfOUb4ODPR63YwbowkNDwcolT2LnkwP/TUaMaB0=
xorm.io/xorm v1.0.5/go.mod h1:uF9EtbhODq5kNWxMbnBEj8hRRZnlcNSz2t2N7HW/+A4=

View file

@ -24,6 +24,7 @@ import (
"strings" "strings"
"time" "time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/initialize" "code.vikunja.io/api/pkg/initialize"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
@ -31,6 +32,7 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
"xorm.io/xorm"
) )
var ( var (
@ -91,13 +93,13 @@ func getPasswordFromFlagOrInput() (pw string) {
return return
} }
func getUserFromArg(arg string) *user.User { func getUserFromArg(s *xorm.Session, arg string) *user.User {
id, err := strconv.ParseInt(arg, 10, 64) id, err := strconv.ParseInt(arg, 10, 64)
if err != nil { if err != nil {
log.Fatalf("Invalid user id: %s", err) log.Fatalf("Invalid user id: %s", err)
} }
u, err := user.GetUserByID(id) u, err := user.GetUserByID(s, id)
if err != nil { if err != nil {
log.Fatalf("Could not get user: %s", err) log.Fatalf("Could not get user: %s", err)
} }
@ -116,8 +118,16 @@ var userListCmd = &cobra.Command{
initialize.FullInit() initialize.FullInit()
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
users, err := user.ListUsers("") s := db.NewSession()
defer s.Close()
users, err := user.ListUsers(s, "")
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Error getting users: %s", err)
}
if err := s.Commit(); err != nil {
log.Fatalf("Error getting users: %s", err) log.Fatalf("Error getting users: %s", err)
} }
@ -153,21 +163,30 @@ var userCreateCmd = &cobra.Command{
initialize.FullInit() initialize.FullInit()
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
s := db.NewSession()
defer s.Close()
u := &user.User{ u := &user.User{
Username: userFlagUsername, Username: userFlagUsername,
Email: userFlagEmail, Email: userFlagEmail,
Password: getPasswordFromFlagOrInput(), Password: getPasswordFromFlagOrInput(),
} }
newUser, err := user.CreateUser(u) newUser, err := user.CreateUser(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Error creating new user: %s", err) log.Fatalf("Error creating new user: %s", err)
} }
err = models.CreateNewNamespaceForUser(newUser) err = models.CreateNewNamespaceForUser(s, newUser)
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Error creating new namespace for user: %s", err) log.Fatalf("Error creating new namespace for user: %s", err)
} }
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Printf("\nUser was created successfully.\n") fmt.Printf("\nUser was created successfully.\n")
}, },
} }
@ -180,7 +199,10 @@ var userUpdateCmd = &cobra.Command{
initialize.FullInit() initialize.FullInit()
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0]) s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
if userFlagUsername != "" { if userFlagUsername != "" {
u.Username = userFlagUsername u.Username = userFlagUsername
@ -192,11 +214,16 @@ var userUpdateCmd = &cobra.Command{
u.AvatarProvider = userFlagAvatar u.AvatarProvider = userFlagAvatar
} }
_, err := user.UpdateUser(u) _, err := user.UpdateUser(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Error updating the user: %s", err) log.Fatalf("Error updating the user: %s", err)
} }
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Println("User updated successfully.") fmt.Println("User updated successfully.")
}, },
} }
@ -209,22 +236,31 @@ var userResetPasswordCmd = &cobra.Command{
}, },
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0]) s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
// By default we reset as usual, only with specific flag directly. // By default we reset as usual, only with specific flag directly.
if userFlagResetPasswordDirectly { if userFlagResetPasswordDirectly {
err := user.UpdateUserPassword(u, getPasswordFromFlagOrInput()) err := user.UpdateUserPassword(s, u, getPasswordFromFlagOrInput())
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Could not update user password: %s", err) log.Fatalf("Could not update user password: %s", err)
} }
fmt.Println("Password updated successfully.") fmt.Println("Password updated successfully.")
} else { } else {
err := user.RequestUserPasswordResetToken(u) err := user.RequestUserPasswordResetToken(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Could not send password reset email: %s", err) log.Fatalf("Could not send password reset email: %s", err)
} }
fmt.Println("Password reset email sent successfully.") fmt.Println("Password reset email sent successfully.")
} }
if err := s.Commit(); err != nil {
log.Fatalf("Could not send password reset email: %s", err)
}
}, },
} }
@ -236,7 +272,10 @@ var userChangeEnabledCmd = &cobra.Command{
}, },
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
u := getUserFromArg(args[0]) s := db.NewSession()
defer s.Close()
u := getUserFromArg(s, args[0])
if userFlagEnableUser { if userFlagEnableUser {
u.IsActive = true u.IsActive = true
@ -245,11 +284,16 @@ var userChangeEnabledCmd = &cobra.Command{
} else { } else {
u.IsActive = !u.IsActive u.IsActive = !u.IsActive
} }
_, err := user.UpdateUser(u) _, err := user.UpdateUser(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
log.Fatalf("Could not enable the user") log.Fatalf("Could not enable the user")
} }
if err := s.Commit(); err != nil {
log.Fatalf("Error saving everything: %s", err)
}
fmt.Printf("User status successfully changed, user is now active: %t.\n", u.IsActive) fmt.Printf("User status successfully changed, user is now active: %t.\n", u.IsActive)
}, },
} }

View file

@ -31,6 +31,7 @@ import (
"xorm.io/core" "xorm.io/core"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
_ "github.com/go-sql-driver/mysql" // Because. _ "github.com/go-sql-driver/mysql" // Because.
_ "github.com/lib/pq" // Because. _ "github.com/lib/pq" // Because.
@ -211,3 +212,13 @@ func WipeEverything() error {
return nil return nil
} }
// NewSession creates a new xorm session
func NewSession() *xorm.Session {
return x.NewSession()
}
// Type returns the db type of the currently configured db
func Type() schemas.DBType {
return x.Dialect().URI().DBType
}

View file

@ -22,6 +22,7 @@ import (
"time" "time"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/c2h5oh/datasize" "github.com/c2h5oh/datasize"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -93,27 +94,44 @@ func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, m
Mime: mime, Mime: mime,
} }
_, err = x.Insert(file) s := db.NewSession()
defer s.Close()
_, err = s.Insert(file)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
// Save the file to storage with its new ID as path // Save the file to storage with its new ID as path
err = file.Save(f) err = file.Save(f)
if err != nil {
_ = s.Rollback()
return
}
return return
} }
// Delete removes a file from the DB and the file system // Delete removes a file from the DB and the file system
func (f *File) Delete() (err error) { func (f *File) Delete() (err error) {
deleted, err := x.Where("id = ?", f.ID).Delete(f) s := db.NewSession()
defer s.Close()
deleted, err := s.Where("id = ?", f.ID).Delete(f)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
if deleted == 0 { if deleted == 0 {
_ = s.Rollback()
return ErrFileDoesNotExist{FileID: f.ID} return ErrFileDoesNotExist{FileID: f.ID}
} }
err = afs.Remove(f.getFileName()) err = afs.Remove(f.getFileName())
if err != nil {
_ = s.Rollback()
return err
}
return return
} }

View file

@ -19,6 +19,7 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/imdario/mergo" "github.com/imdario/mergo"
"xorm.io/xorm"
) )
// BulkTask is the definition of a bulk update task // BulkTask is the definition of a bulk update task
@ -29,9 +30,9 @@ type BulkTask struct {
Task Task
} }
func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) { func (bt *BulkTask) checkIfTasksAreOnTheSameList(s *xorm.Session) (err error) {
// Get the tasks // Get the tasks
err = bt.GetTasksByIDs() err = bt.GetTasksByIDs(s)
if err != nil { if err != nil {
return err return err
} }
@ -52,16 +53,16 @@ func (bt *BulkTask) checkIfTasksAreOnTheSameList() (err error) {
} }
// CanUpdate checks if a user is allowed to update a task // CanUpdate checks if a user is allowed to update a task
func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) { func (bt *BulkTask) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
err := bt.checkIfTasksAreOnTheSameList() err := bt.checkIfTasksAreOnTheSameList(s)
if err != nil { if err != nil {
return false, err return false, err
} }
// A user can update an task if he has write acces to its list // A user can update an task if he has write acces to its list
l := &List{ID: bt.Tasks[0].ListID} l := &List{ID: bt.Tasks[0].ListID}
return l.CanWrite(a) return l.CanWrite(s, a)
} }
// Update updates a bunch of tasks at once // Update updates a bunch of tasks at once
@ -77,23 +78,14 @@ func (bt *BulkTask) CanUpdate(a web.Auth) (bool, error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its list)" // @Failure 403 {object} web.HTTPError "The user does not have access to the task (aka its list)"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/bulk [post] // @Router /tasks/bulk [post]
func (bt *BulkTask) Update() (err error) { func (bt *BulkTask) Update(s *xorm.Session) (err error) {
sess := x.NewSession()
defer sess.Close()
err = sess.Begin()
if err != nil {
return
}
for _, oldtask := range bt.Tasks { for _, oldtask := range bt.Tasks {
// When a repeating task is marked as done, we update all deadlines and reminders and set it as undone // When a repeating task is marked as done, we update all deadlines and reminders and set it as undone
updateDone(oldtask, &bt.Task) updateDone(oldtask, &bt.Task)
// Update the assignees // Update the assignees
if err := oldtask.updateTaskAssignees(sess, bt.Assignees); err != nil { if err := oldtask.updateTaskAssignees(s, bt.Assignees); err != nil {
return err return err
} }
@ -109,7 +101,7 @@ func (bt *BulkTask) Update() (err error) {
oldtask.Done = false oldtask.Done = false
} }
_, err = sess.ID(oldtask.ID). _, err = s.ID(oldtask.ID).
Cols("title", Cols("title",
"description", "description",
"done", "done",
@ -121,15 +113,9 @@ func (bt *BulkTask) Update() (err error) {
"end_date"). "end_date").
Update(oldtask) Update(oldtask)
if err != nil { if err != nil {
_ = sess.Rollback()
return err return err
} }
} }
err = sess.Commit()
if err != nil {
return
}
return return
} }

View file

@ -57,18 +57,22 @@ func TestBulkTask_Update(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
bt := &BulkTask{ bt := &BulkTask{
IDs: tt.fields.IDs, IDs: tt.fields.IDs,
Tasks: tt.fields.Tasks, Tasks: tt.fields.Tasks,
Task: tt.fields.Task, Task: tt.fields.Task,
} }
allowed, _ := bt.CanUpdate(tt.fields.User) allowed, _ := bt.CanUpdate(s, tt.fields.User)
if !allowed != tt.wantForbidden { if !allowed != tt.wantForbidden {
t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden) t.Errorf("BulkTask.Update() want forbidden, got %v, want %v", allowed, tt.wantForbidden)
} }
if err := bt.Update(); (err != nil) != tt.wantErr { if err := bt.Update(s); (err != nil) != tt.wantErr {
t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("BulkTask.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
s.Close()
}) })
} }
} }

View file

@ -97,14 +97,14 @@ func getDefaultBucket(s *xorm.Session, listID int64) (bucket *Bucket, err error)
// @Success 200 {array} models.Bucket "The buckets with their tasks" // @Success 200 {array} models.Bucket "The buckets with their tasks"
// @Failure 500 {object} models.Message "Internal server error" // @Failure 500 {object} models.Message "Internal server error"
// @Router /lists/{id}/buckets [get] // @Router /lists/{id}/buckets [get]
func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (b *Bucket) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Note: I'm ignoring pagination for now since I've yet to figure out a way on how to make it work // Note: I'm ignoring pagination for now since I've yet to figure out a way on how to make it work
// I'll probably just don't do it and instead make individual tasks archivable. // I'll probably just don't do it and instead make individual tasks archivable.
// Get all buckets for this list // Get all buckets for this list
buckets := []*Bucket{} buckets := []*Bucket{}
err = x.Where("list_id = ?", b.ListID).Find(&buckets) err = s.Where("list_id = ?", b.ListID).Find(&buckets)
if err != nil { if err != nil {
return return
} }
@ -119,7 +119,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
// Get all users // Get all users
users := make(map[int64]*user.User) users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users) err = s.In("id", userIDs).Find(&users)
if err != nil { if err != nil {
return return
} }
@ -132,7 +132,7 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
b.TaskCollection.ListID = b.ListID b.TaskCollection.ListID = b.ListID
b.TaskCollection.OrderBy = []string{string(orderAscending)} b.TaskCollection.OrderBy = []string{string(orderAscending)}
b.TaskCollection.SortBy = []string{taskPropertyPosition} b.TaskCollection.SortBy = []string{taskPropertyPosition}
ts, _, _, err := b.TaskCollection.ReadAll(auth, "", -1, 0) ts, _, _, err := b.TaskCollection.ReadAll(s, auth, "", -1, 0)
if err != nil { if err != nil {
return return
} }
@ -168,10 +168,10 @@ func (b *Bucket) ReadAll(auth web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "The list does not exist." // @Failure 404 {object} web.HTTPError "The list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/buckets [put] // @Router /lists/{id}/buckets [put]
func (b *Bucket) Create(a web.Auth) (err error) { func (b *Bucket) Create(s *xorm.Session, a web.Auth) (err error) {
b.CreatedByID = a.GetID() b.CreatedByID = a.GetID()
_, err = x.Insert(b) _, err = s.Insert(b)
return return
} }
@ -190,8 +190,8 @@ func (b *Bucket) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "The bucket does not exist." // @Failure 404 {object} web.HTTPError "The bucket does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/buckets/{bucketID} [post] // @Router /lists/{listID}/buckets/{bucketID} [post]
func (b *Bucket) Update() (err error) { func (b *Bucket) Update(s *xorm.Session) (err error) {
_, err = x.Where("id = ?", b.ID).Update(b) _, err = s.Where("id = ?", b.ID).Update(b)
return return
} }
@ -208,14 +208,11 @@ func (b *Bucket) Update() (err error) {
// @Failure 404 {object} web.HTTPError "The bucket does not exist." // @Failure 404 {object} web.HTTPError "The bucket does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/buckets/{bucketID} [delete] // @Router /lists/{listID}/buckets/{bucketID} [delete]
func (b *Bucket) Delete() (err error) { func (b *Bucket) Delete(s *xorm.Session) (err error) {
s := x.NewSession()
// Prevent removing the last bucket // Prevent removing the last bucket
total, err := s.Where("list_id = ?", b.ListID).Count(&Bucket{}) total, err := s.Where("list_id = ?", b.ListID).Count(&Bucket{})
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
if total <= 1 { if total <= 1 {
@ -228,23 +225,19 @@ func (b *Bucket) Delete() (err error) {
// Remove the bucket itself // Remove the bucket itself
_, err = s.Where("id = ?", b.ID).Delete(&Bucket{}) _, err = s.Where("id = ?", b.ID).Delete(&Bucket{})
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
// Get the default bucket // Get the default bucket
defaultBucket, err := getDefaultBucket(s, b.ListID) defaultBucket, err := getDefaultBucket(s, b.ListID)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
// Remove all associations of tasks to that bucket // Remove all associations of tasks to that bucket
_, err = s.Where("bucket_id = ?", b.ID).Cols("bucket_id").Update(&Task{BucketID: defaultBucket.ID}) _, err = s.
if err != nil { Where("bucket_id = ?", b.ID).
_ = s.Rollback() Cols("bucket_id").
Update(&Task{BucketID: defaultBucket.ID})
return return
}
return s.Commit()
} }

View file

@ -16,30 +16,33 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanCreate checks if a user can create a new bucket // CanCreate checks if a user can create a new bucket
func (b *Bucket) CanCreate(a web.Auth) (bool, error) { func (b *Bucket) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
l := &List{ID: b.ListID} l := &List{ID: b.ListID}
return l.CanWrite(a) return l.CanWrite(s, a)
} }
// CanUpdate checks if a user can update an existing bucket // CanUpdate checks if a user can update an existing bucket
func (b *Bucket) CanUpdate(a web.Auth) (bool, error) { func (b *Bucket) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return b.canDoBucket(a) return b.canDoBucket(s, a)
} }
// CanDelete checks if a user can delete an existing bucket // CanDelete checks if a user can delete an existing bucket
func (b *Bucket) CanDelete(a web.Auth) (bool, error) { func (b *Bucket) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return b.canDoBucket(a) return b.canDoBucket(s, a)
} }
// canDoBucket checks if the bucket exists and if the user has the right to act on it // canDoBucket checks if the bucket exists and if the user has the right to act on it
func (b *Bucket) canDoBucket(a web.Auth) (bool, error) { func (b *Bucket) canDoBucket(s *xorm.Session, a web.Auth) (bool, error) {
bb, err := getBucketByID(x.NewSession(), b.ID) bb, err := getBucketByID(s, b.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
l := &List{ID: bb.ListID} l := &List{ID: bb.ListID}
return l.CanWrite(a) return l.CanWrite(s, a)
} }

View file

@ -27,10 +27,12 @@ import (
func TestBucket_ReadAll(t *testing.T) { func TestBucket_ReadAll(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
testuser := &user.User{ID: 1} testuser := &user.User{ID: 1}
b := &Bucket{ListID: 1} b := &Bucket{ListID: 1}
bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0) bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0)
assert.NoError(t, err) assert.NoError(t, err)
buckets, is := bucketsInterface.([]*Bucket) buckets, is := bucketsInterface.([]*Bucket)
@ -66,6 +68,8 @@ func TestBucket_ReadAll(t *testing.T) {
}) })
t.Run("filtered", func(t *testing.T) { t.Run("filtered", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
testuser := &user.User{ID: 1} testuser := &user.User{ID: 1}
b := &Bucket{ b := &Bucket{
@ -76,7 +80,7 @@ func TestBucket_ReadAll(t *testing.T) {
FilterValue: []string{"done"}, FilterValue: []string{"done"},
}, },
} }
bucketsInterface, _, _, err := b.ReadAll(testuser, "", 0, 0) bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", 0, 0)
assert.NoError(t, err) assert.NoError(t, err)
buckets := bucketsInterface.([]*Bucket) buckets := bucketsInterface.([]*Bucket)
@ -88,16 +92,21 @@ func TestBucket_ReadAll(t *testing.T) {
func TestBucket_Delete(t *testing.T) { func TestBucket_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
b := &Bucket{ b := &Bucket{
ID: 2, // The second bucket only has 3 tasks ID: 2, // The second bucket only has 3 tasks
ListID: 1, ListID: 1,
} }
err := b.Delete() err := b.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
// Assert all tasks have been moved to bucket 1 as that one is the first // Assert all tasks have been moved to bucket 1 as that one is the first
tasks := []*Task{} tasks := []*Task{}
err = x.Where("bucket_id = ?", 1).Find(&tasks) err = s.Where("bucket_id = ?", 1).Find(&tasks)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, tasks, 15) assert.Len(t, tasks, 15)
db.AssertMissing(t, "buckets", map[string]interface{}{ db.AssertMissing(t, "buckets", map[string]interface{}{
@ -107,13 +116,19 @@ func TestBucket_Delete(t *testing.T) {
}) })
t.Run("last bucket in list", func(t *testing.T) { t.Run("last bucket in list", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
b := &Bucket{ b := &Bucket{
ID: 34, ID: 34,
ListID: 18, ListID: 18,
} }
err := b.Delete() err := b.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrCannotRemoveLastBucket(err)) assert.True(t, IsErrCannotRemoveLastBucket(err))
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "buckets", map[string]interface{}{ db.AssertExists(t, "buckets", map[string]interface{}{
"id": 34, "id": 34,
"list_id": 18, "list_id": 18,

View file

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// Label represents a label // Label represents a label
@ -64,7 +65,7 @@ func (Label) TableName() string {
// @Failure 400 {object} web.HTTPError "Invalid label object provided." // @Failure 400 {object} web.HTTPError "Invalid label object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /labels [put] // @Router /labels [put]
func (l *Label) Create(a web.Auth) (err error) { func (l *Label) Create(s *xorm.Session, a web.Auth) (err error) {
u, err := user.GetFromAuth(a) u, err := user.GetFromAuth(a)
if err != nil { if err != nil {
return return
@ -73,7 +74,7 @@ func (l *Label) Create(a web.Auth) (err error) {
l.CreatedBy = u l.CreatedBy = u
l.CreatedByID = u.ID l.CreatedByID = u.ID
_, err = x.Insert(l) _, err = s.Insert(l)
return return
} }
@ -92,8 +93,8 @@ func (l *Label) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Label not found." // @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [put] // @Router /labels/{id} [put]
func (l *Label) Update() (err error) { func (l *Label) Update(s *xorm.Session) (err error) {
_, err = x. _, err = s.
ID(l.ID). ID(l.ID).
Cols( Cols(
"title", "title",
@ -105,7 +106,7 @@ func (l *Label) Update() (err error) {
return return
} }
err = l.ReadOne() err = l.ReadOne(s)
return return
} }
@ -122,8 +123,8 @@ func (l *Label) Update() (err error) {
// @Failure 404 {object} web.HTTPError "Label not found." // @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [delete] // @Router /labels/{id} [delete]
func (l *Label) Delete() (err error) { func (l *Label) Delete(s *xorm.Session) (err error) {
_, err = x.ID(l.ID).Delete(&Label{}) _, err = s.ID(l.ID).Delete(&Label{})
return err return err
} }
@ -140,7 +141,7 @@ func (l *Label) Delete() (err error) {
// @Success 200 {array} models.Label "The labels" // @Success 200 {array} models.Label "The labels"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /labels [get] // @Router /labels [get]
func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) { func (l *Label) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (ls interface{}, resultCount int, numberOfEntries int64, err error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{} return nil, 0, 0, ErrGenericForbidden{}
} }
@ -148,12 +149,12 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in
u := &user.User{ID: a.GetID()} u := &user.User{ID: a.GetID()}
// Get all tasks // Get all tasks
taskIDs, err := getUserTaskIDs(u) taskIDs, err := getUserTaskIDs(s, u)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return getLabelsByTaskIDs(&LabelByTaskIDsOptions{ return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
Search: search, Search: search,
User: u, User: u,
TaskIDs: taskIDs, TaskIDs: taskIDs,
@ -177,25 +178,25 @@ func (l *Label) ReadAll(a web.Auth, search string, page int, perPage int) (ls in
// @Failure 404 {object} web.HTTPError "Label not found" // @Failure 404 {object} web.HTTPError "Label not found"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /labels/{id} [get] // @Router /labels/{id} [get]
func (l *Label) ReadOne() (err error) { func (l *Label) ReadOne(s *xorm.Session) (err error) {
label, err := getLabelByIDSimple(l.ID) label, err := getLabelByIDSimple(s, l.ID)
if err != nil { if err != nil {
return err return err
} }
*l = *label *l = *label
user, err := user.GetUserByID(l.CreatedByID) u, err := user.GetUserByID(s, l.CreatedByID)
if err != nil { if err != nil {
return err return err
} }
l.CreatedBy = user l.CreatedBy = u
return return
} }
func getLabelByIDSimple(labelID int64) (*Label, error) { func getLabelByIDSimple(s *xorm.Session, labelID int64) (*Label, error) {
label := Label{} label := Label{}
exists, err := x.ID(labelID).Get(&label) exists, err := s.ID(labelID).Get(&label)
if err != nil { if err != nil {
return &label, err return &label, err
} }
@ -207,18 +208,21 @@ func getLabelByIDSimple(labelID int64) (*Label, error) {
} }
// Helper method to get all task ids a user has // Helper method to get all task ids a user has
func getUserTaskIDs(u *user.User) (taskIDs []int64, err error) { func getUserTaskIDs(s *xorm.Session, u *user.User) (taskIDs []int64, err error) {
// Get all lists // Get all lists
lists, _, _, err := getRawListsForUser(&listOptions{ lists, _, _, err := getRawListsForUser(
s,
&listOptions{
user: u, user: u,
page: -1, page: -1,
}) },
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tasks, _, _, err := getRawTasksForLists(lists, u, &taskOptions{ tasks, _, _, err := getRawTasksForLists(s, lists, u, &taskOptions{
page: -1, page: -1,
perPage: 0, perPage: 0,
}) })

View file

@ -20,26 +20,27 @@ import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// CanUpdate checks if a user can update a label // CanUpdate checks if a user can update a label
func (l *Label) CanUpdate(a web.Auth) (bool, error) { func (l *Label) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return l.isLabelOwner(a) // Only owners should be allowed to update a label return l.isLabelOwner(s, a) // Only owners should be allowed to update a label
} }
// CanDelete checks if a user can delete a label // CanDelete checks if a user can delete a label
func (l *Label) CanDelete(a web.Auth) (bool, error) { func (l *Label) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return l.isLabelOwner(a) // Only owners should be allowed to delete a label return l.isLabelOwner(s, a) // Only owners should be allowed to delete a label
} }
// CanRead checks if a user can read a label // CanRead checks if a user can read a label
func (l *Label) CanRead(a web.Auth) (bool, int, error) { func (l *Label) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
return l.hasAccessToLabel(a) return l.hasAccessToLabel(s, a)
} }
// CanCreate checks if the user can create a label // CanCreate checks if the user can create a label
// Currently a dummy. // Currently a dummy.
func (l *Label) CanCreate(a web.Auth) (bool, error) { func (l *Label) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
@ -47,13 +48,13 @@ func (l *Label) CanCreate(a web.Auth) (bool, error) {
return true, nil return true, nil
} }
func (l *Label) isLabelOwner(a web.Auth) (bool, error) { func (l *Label) isLabelOwner(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
lorig, err := getLabelByIDSimple(l.ID) lorig, err := getLabelByIDSimple(s, l.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -61,19 +62,19 @@ func (l *Label) isLabelOwner(a web.Auth) (bool, error) {
} }
// Helper method to check if a user can see a specific label // Helper method to check if a user can see a specific label
func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error) { func (l *Label) hasAccessToLabel(s *xorm.Session, a web.Auth) (has bool, maxRight int, err error) {
// TODO: add an extra check for link share handling // TODO: add an extra check for link share handling
// Get all tasks // Get all tasks
taskIDs, err := getUserTaskIDs(&user.User{ID: a.GetID()}) taskIDs, err := getUserTaskIDs(s, &user.User{ID: a.GetID()})
if err != nil { if err != nil {
return false, 0, err return false, 0, err
} }
// Get all labels associated with these tasks // Get all labels associated with these tasks
ll := &LabelTask{} ll := &LabelTask{}
has, err = x.Table("labels"). has, err = s.Table("labels").
Select("label_task.*"). Select("label_task.*").
Join("LEFT", "label_task", "label_task.label_id = labels.id"). Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where("label_task.label_id is not null OR labels.created_by_id = ?", a.GetID()). Where("label_task.label_id is not null OR labels.created_by_id = ?", a.GetID()).
@ -87,7 +88,7 @@ func (l *Label) hasAccessToLabel(a web.Auth) (has bool, maxRight int, err error)
// Since the right depends on the task the label is associated with, we need to check that too. // Since the right depends on the task the label is associated with, we need to check that too.
if ll.TaskID > 0 { if ll.TaskID > 0 {
t := &Task{ID: ll.TaskID} t := &Task{ID: ll.TaskID}
_, maxRight, err = t.CanRead(a) _, maxRight, err = t.CanRead(s, a)
if err != nil { if err != nil {
return return
} }

View file

@ -22,10 +22,10 @@ import (
"time" "time"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// LabelTask represents a relation between a label and a task // LabelTask represents a relation between a label and a task
@ -61,8 +61,8 @@ func (LabelTask) TableName() string {
// @Failure 404 {object} web.HTTPError "Label not found." // @Failure 404 {object} web.HTTPError "Label not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels/{label} [delete] // @Router /tasks/{task}/labels/{label} [delete]
func (lt *LabelTask) Delete() (err error) { func (lt *LabelTask) Delete(s *xorm.Session) (err error) {
_, err = x.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) _, err = s.Delete(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
return err return err
} }
@ -81,9 +81,9 @@ func (lt *LabelTask) Delete() (err error) {
// @Failure 404 {object} web.HTTPError "The label does not exist." // @Failure 404 {object} web.HTTPError "The label does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels [put] // @Router /tasks/{task}/labels [put]
func (lt *LabelTask) Create(a web.Auth) (err error) { func (lt *LabelTask) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the label is already added // Check if the label is already added
exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
if err != nil { if err != nil {
return err return err
} }
@ -92,12 +92,12 @@ func (lt *LabelTask) Create(a web.Auth) (err error) {
} }
// Insert it // Insert it
_, err = x.Insert(lt) _, err = s.Insert(lt)
if err != nil { if err != nil {
return err return err
} }
err = updateListByTaskID(lt.TaskID) err = updateListByTaskID(s, lt.TaskID)
return return
} }
@ -115,10 +115,10 @@ func (lt *LabelTask) Create(a web.Auth) (err error) {
// @Success 200 {array} models.Label "The labels" // @Success 200 {array} models.Label "The labels"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{task}/labels [get] // @Router /tasks/{task}/labels [get]
func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (lt *LabelTask) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has the right to see the task // Check if the user has the right to see the task
task := Task{ID: lt.TaskID} task := Task{ID: lt.TaskID}
canRead, _, err := task.CanRead(a) canRead, _, err := task.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -126,7 +126,7 @@ func (lt *LabelTask) ReadAll(a web.Auth, search string, page int, perPage int) (
return nil, 0, 0, ErrNoRightToSeeTask{lt.TaskID, a.GetID()} return nil, 0, 0, ErrNoRightToSeeTask{lt.TaskID, a.GetID()}
} }
return getLabelsByTaskIDs(&LabelByTaskIDsOptions{ return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
User: &user.User{ID: a.GetID()}, User: &user.User{ID: a.GetID()},
Search: search, Search: search,
Page: page, Page: page,
@ -153,7 +153,7 @@ type LabelByTaskIDsOptions struct {
// Helper function to get all labels for a set of tasks // Helper function to get all labels for a set of tasks
// Used when getting all labels for one task as well when getting all lables // Used when getting all labels for one task as well when getting all lables
func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) { func getLabelsByTaskIDs(s *xorm.Session, opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, resultCount int, totalEntries int64, err error) {
// We still need the task ID when we want to get all labels for a task, but because of this, we get the same label // We still need the task ID when we want to get all labels for a task, but because of this, we get the same label
// multiple times when it is associated to more than one task. // multiple times when it is associated to more than one task.
// Because of this whole thing, we need this extra switch here to only group by Task IDs if needed. // Because of this whole thing, we need this extra switch here to only group by Task IDs if needed.
@ -194,7 +194,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
limit, start := getLimitFromPageIndex(opts.Page, opts.PerPage) limit, start := getLimitFromPageIndex(opts.Page, opts.PerPage)
query := x.Table("labels"). query := s.Table("labels").
Select(selectStmt). Select(selectStmt).
Join("LEFT", "label_task", "label_task.label_id = labels.id"). Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where(cond). Where(cond).
@ -214,7 +214,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
userids = append(userids, l.CreatedByID) userids = append(userids, l.CreatedByID)
} }
users := make(map[int64]*user.User) users := make(map[int64]*user.User)
err = x.In("id", userids).Find(&users) err = s.In("id", userids).Find(&users)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -230,7 +230,7 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
} }
// Get the total number of entries // Get the total number of entries
totalEntries, err = x.Table("labels"). totalEntries, err = s.Table("labels").
Select("count(DISTINCT labels.id)"). Select("count(DISTINCT labels.id)").
Join("LEFT", "label_task", "label_task.label_id = labels.id"). Join("LEFT", "label_task", "label_task.label_id = labels.id").
Where(cond). Where(cond).
@ -244,11 +244,11 @@ func getLabelsByTaskIDs(opts *LabelByTaskIDsOptions) (ls []*labelWithTaskID, res
} }
// Create or update a bunch of task labels // Create or update a bunch of task labels
func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) { func (t *Task) updateTaskLabels(s *xorm.Session, creator web.Auth, labels []*Label) (err error) {
// If we don't have any new labels, delete everything right away. Saves us some hassle. // If we don't have any new labels, delete everything right away. Saves us some hassle.
if len(labels) == 0 && len(t.Labels) > 0 { if len(labels) == 0 && len(t.Labels) > 0 {
_, err = x.Where("task_id = ?", t.ID). _, err = s.Where("task_id = ?", t.ID).
Delete(LabelTask{}) Delete(LabelTask{})
return err return err
} }
@ -289,7 +289,7 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
// Delete all labels not passed // Delete all labels not passed
if len(labelsToDelete) > 0 { if len(labelsToDelete) > 0 {
_, err = x.In("label_id", labelsToDelete). _, err = s.In("label_id", labelsToDelete).
And("task_id = ?", t.ID). And("task_id = ?", t.ID).
Delete(LabelTask{}) Delete(LabelTask{})
if err != nil { if err != nil {
@ -306,13 +306,13 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
} }
// Add the new label // Add the new label
label, err := getLabelByIDSimple(l.ID) label, err := getLabelByIDSimple(s, l.ID)
if err != nil { if err != nil {
return err return err
} }
// Check if the user has the rights to see the label he is about to add // Check if the user has the rights to see the label he is about to add
hasAccessToLabel, _, err := label.hasAccessToLabel(creator) hasAccessToLabel, _, err := label.hasAccessToLabel(s, creator)
if err != nil { if err != nil {
return err return err
} }
@ -322,14 +322,14 @@ func (t *Task) updateTaskLabels(creator web.Auth, labels []*Label) (err error) {
} }
// Insert it // Insert it
_, err = x.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID}) _, err = s.Insert(&LabelTask{LabelID: l.ID, TaskID: t.ID})
if err != nil { if err != nil {
return err return err
} }
t.Labels = append(t.Labels, label) t.Labels = append(t.Labels, label)
} }
err = updateListLastUpdated(&List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -356,12 +356,12 @@ type LabelTaskBulk struct {
// @Failure 400 {object} web.HTTPError "Invalid label object provided." // @Failure 400 {object} web.HTTPError "Invalid label object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/labels/bulk [post] // @Router /tasks/{taskID}/labels/bulk [post]
func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) { func (ltb *LabelTaskBulk) Create(s *xorm.Session, a web.Auth) (err error) {
task, err := GetTaskByIDSimple(ltb.TaskID) task, err := GetTaskByIDSimple(s, ltb.TaskID)
if err != nil { if err != nil {
return return
} }
labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{ labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
TaskIDs: []int64{ltb.TaskID}, TaskIDs: []int64{ltb.TaskID},
}) })
if err != nil { if err != nil {
@ -370,5 +370,5 @@ func (ltb *LabelTaskBulk) Create(a web.Auth) (err error) {
for _, l := range labels { for _, l := range labels {
task.Labels = append(task.Labels, &l.Label) task.Labels = append(task.Labels, &l.Label)
} }
return task.updateTaskLabels(a, ltb.Labels) return task.updateTaskLabels(s, a, ltb.Labels)
} }

View file

@ -18,21 +18,22 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if a user can add a label to a task // CanCreate checks if a user can add a label to a task
func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) { func (lt *LabelTask) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
label, err := getLabelByIDSimple(lt.LabelID) label, err := getLabelByIDSimple(s, lt.LabelID)
if err != nil { if err != nil {
return false, err return false, err
} }
hasAccessTolabel, _, err := label.hasAccessToLabel(a) hasAccessTolabel, _, err := label.hasAccessToLabel(s, a)
if err != nil || !hasAccessTolabel { // If the user doesn't have access to the label, we can error out here if err != nil || !hasAccessTolabel { // If the user doesn't have access to the label, we can error out here
return false, err return false, err
} }
canDoLabelTask, err := canDoLabelTask(lt.TaskID, a) canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -41,8 +42,8 @@ func (lt *LabelTask) CanCreate(a web.Auth) (bool, error) {
} }
// CanDelete checks if a user can delete a label from a task // CanDelete checks if a user can delete a label from a task
func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) { func (lt *LabelTask) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
canDoLabelTask, err := canDoLabelTask(lt.TaskID, a) canDoLabelTask, err := canDoLabelTask(s, lt.TaskID, a)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -52,7 +53,7 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) {
// We don't care here if the label exists or not. The only relevant thing here is if the relation already exists, // We don't care here if the label exists or not. The only relevant thing here is if the relation already exists,
// throw an error. // throw an error.
exists, err := x.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID}) exists, err := s.Exist(&LabelTask{LabelID: lt.LabelID, TaskID: lt.TaskID})
if err != nil { if err != nil {
return false, err return false, err
} }
@ -60,18 +61,18 @@ func (lt *LabelTask) CanDelete(a web.Auth) (bool, error) {
} }
// CanCreate determines if a user can update a labeltask // CanCreate determines if a user can update a labeltask
func (ltb *LabelTaskBulk) CanCreate(a web.Auth) (bool, error) { func (ltb *LabelTaskBulk) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoLabelTask(ltb.TaskID, a) return canDoLabelTask(s, ltb.TaskID, a)
} }
// Helper function to check if a user can write to a task // Helper function to check if a user can write to a task
// + is able to see the label // + is able to see the label
// always the same check for either deleting or adding a label to a task // always the same check for either deleting or adding a label to a task
func canDoLabelTask(taskID int64, a web.Auth) (bool, error) { func canDoLabelTask(s *xorm.Session, taskID int64, a web.Auth) (bool, error) {
// A user can add a label to a task if he can write to the task // A user can add a label to a task if he can write to the task
task, err := GetTaskByIDSimple(taskID) task, err := GetTaskByIDSimple(s, taskID)
if err != nil { if err != nil {
return false, err return false, err
} }
return task.CanUpdate(a) return task.CanUpdate(s, a)
} }

View file

@ -91,6 +91,7 @@ func TestLabelTask_ReadAll(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{ l := &LabelTask{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -100,7 +101,7 @@ func TestLabelTask_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
gotLabels, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0) gotLabels, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.ReadAll() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("LabelTask.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -111,6 +112,8 @@ func TestLabelTask_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(gotLabels, tt.wantLabels); !equal { if diff, equal := messagediff.PrettyDiff(gotLabels, tt.wantLabels); !equal {
t.Errorf("LabelTask.ReadAll() = %v, want %v, diff: %v", l, tt.wantLabels, diff) t.Errorf("LabelTask.ReadAll() = %v, want %v, diff: %v", l, tt.wantLabels, diff)
} }
s.Close()
}) })
} }
} }
@ -186,6 +189,8 @@ func TestLabelTask_Create(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{ l := &LabelTask{
ID: tt.fields.ID, ID: tt.fields.ID,
TaskID: tt.fields.TaskID, TaskID: tt.fields.TaskID,
@ -194,11 +199,11 @@ func TestLabelTask_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _ := l.CanCreate(tt.args.a) allowed, _ := l.CanCreate(s, tt.args.a)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("LabelTask.CanCreate() forbidden, want %v", tt.wantForbidden) t.Errorf("LabelTask.CanCreate() forbidden, want %v", tt.wantForbidden)
} }
err := l.Create(tt.args.a) err := l.Create(s, tt.args.a)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.Create() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("LabelTask.Create() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -212,6 +217,7 @@ func TestLabelTask_Create(t *testing.T) {
"label_id": l.LabelID, "label_id": l.LabelID,
}, false) }, false)
} }
s.Close()
}) })
} }
} }
@ -282,6 +288,8 @@ func TestLabelTask_Delete(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
l := &LabelTask{ l := &LabelTask{
ID: tt.fields.ID, ID: tt.fields.ID,
TaskID: tt.fields.TaskID, TaskID: tt.fields.TaskID,
@ -290,11 +298,11 @@ func TestLabelTask_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _ := l.CanDelete(tt.auth) allowed, _ := l.CanDelete(s, tt.auth)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("LabelTask.CanDelete() forbidden, want %v", tt.wantForbidden) t.Errorf("LabelTask.CanDelete() forbidden, want %v", tt.wantForbidden)
} }
err := l.Delete() err := l.Delete(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("LabelTask.Delete() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("LabelTask.Delete() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -307,6 +315,7 @@ func TestLabelTask_Delete(t *testing.T) {
"task_id": l.TaskID, "task_id": l.TaskID,
}) })
} }
s.Close()
}) })
} }
} }

View file

@ -133,7 +133,8 @@ func TestLabel_ReadAll(t *testing.T) {
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
gotLs, _, _, err := l.ReadAll(tt.args.a, tt.args.search, tt.args.page, 0) s := db.NewSession()
gotLs, _, _, err := l.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 0)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Label.ReadAll() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Label.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -141,6 +142,7 @@ func TestLabel_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(gotLs, tt.wantLs); !equal { if diff, equal := messagediff.PrettyDiff(gotLs, tt.wantLs); !equal {
t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", gotLs, tt.wantLs, diff) t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", gotLs, tt.wantLs, diff)
} }
s.Close()
}) })
} }
} }
@ -249,11 +251,13 @@ func TestLabel_ReadOne(t *testing.T) {
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _, _ := l.CanRead(tt.auth) s := db.NewSession()
allowed, _, _ := l.CanRead(s, tt.auth)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanRead() forbidden, want %v", tt.wantForbidden) t.Errorf("Label.CanRead() forbidden, want %v", tt.wantForbidden)
} }
err := l.ReadOne() err := l.ReadOne(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Label.ReadOne() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Label.ReadOne() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -263,6 +267,8 @@ func TestLabel_ReadOne(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(l, tt.want); !equal && !tt.wantErr && !tt.wantForbidden { if diff, equal := messagediff.PrettyDiff(l, tt.want); !equal && !tt.wantErr && !tt.wantForbidden {
t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", l, tt.want, diff) t.Errorf("Label.ReadAll() = %v, want %v, diff: %v", l, tt.want, diff)
} }
s.Close()
}) })
} }
} }
@ -316,11 +322,12 @@ func TestLabel_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _ := l.CanCreate(tt.args.a) s := db.NewSession()
allowed, _ := l.CanCreate(s, tt.args.a)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanCreate() forbidden, want %v", tt.wantForbidden) t.Errorf("Label.CanCreate() forbidden, want %v", tt.wantForbidden)
} }
if err := l.Create(tt.args.a); (err != nil) != tt.wantErr { if err := l.Create(s, tt.args.a); (err != nil) != tt.wantErr {
t.Errorf("Label.Create() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Label.Create() error = %v, wantErr %v", err, tt.wantErr)
} }
if !tt.wantErr { if !tt.wantErr {
@ -331,6 +338,7 @@ func TestLabel_Create(t *testing.T) {
"hex_color": l.HexColor, "hex_color": l.HexColor,
}, false) }, false)
} }
_ = s.Close()
}) })
} }
} }
@ -406,11 +414,12 @@ func TestLabel_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _ := l.CanUpdate(tt.auth) s := db.NewSession()
allowed, _ := l.CanUpdate(s, tt.auth)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanUpdate() forbidden, want %v", tt.wantForbidden) t.Errorf("Label.CanUpdate() forbidden, want %v", tt.wantForbidden)
} }
if err := l.Update(); (err != nil) != tt.wantErr { if err := l.Update(s); (err != nil) != tt.wantErr {
t.Errorf("Label.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Label.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
if !tt.wantErr && !tt.wantForbidden { if !tt.wantErr && !tt.wantForbidden {
@ -419,6 +428,7 @@ func TestLabel_Update(t *testing.T) {
"title": tt.fields.Title, "title": tt.fields.Title,
}, false) }, false)
} }
_ = s.Close()
}) })
} }
} }
@ -490,11 +500,12 @@ func TestLabel_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
allowed, _ := l.CanDelete(tt.auth) s := db.NewSession()
allowed, _ := l.CanDelete(s, tt.auth)
if !allowed && !tt.wantForbidden { if !allowed && !tt.wantForbidden {
t.Errorf("Label.CanDelete() forbidden, want %v", tt.wantForbidden) t.Errorf("Label.CanDelete() forbidden, want %v", tt.wantForbidden)
} }
if err := l.Delete(); (err != nil) != tt.wantErr { if err := l.Delete(s); (err != nil) != tt.wantErr {
t.Errorf("Label.Delete() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Label.Delete() error = %v, wantErr %v", err, tt.wantErr)
} }
if !tt.wantErr && !tt.wantForbidden { if !tt.wantErr && !tt.wantForbidden {
@ -502,6 +513,7 @@ func TestLabel_Delete(t *testing.T) {
"id": l.ID, "id": l.ID,
}) })
} }
_ = s.Close()
}) })
} }
} }

View file

@ -24,6 +24,7 @@ import (
"code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/utils"
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"xorm.io/xorm"
) )
// SharingType holds the sharing type // SharingType holds the sharing type
@ -99,7 +100,7 @@ func GetLinkShareFromClaims(claims jwt.MapClaims) (share *LinkSharing, err error
// @Failure 404 {object} web.HTTPError "The list does not exist." // @Failure 404 {object} web.HTTPError "The list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares [put] // @Router /lists/{list}/shares [put]
func (share *LinkSharing) Create(a web.Auth) (err error) { func (share *LinkSharing) Create(s *xorm.Session, a web.Auth) (err error) {
err = share.Right.isValid() err = share.Right.isValid()
if err != nil { if err != nil {
@ -108,7 +109,7 @@ func (share *LinkSharing) Create(a web.Auth) (err error) {
share.SharedByID = a.GetID() share.SharedByID = a.GetID()
share.Hash = utils.MakeRandomString(40) share.Hash = utils.MakeRandomString(40)
_, err = x.Insert(share) _, err = s.Insert(share)
share.SharedBy, _ = user.GetFromAuth(a) share.SharedBy, _ = user.GetFromAuth(a)
return return
} }
@ -127,8 +128,8 @@ func (share *LinkSharing) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Share Link not found." // @Failure 404 {object} web.HTTPError "Share Link not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares/{share} [get] // @Router /lists/{list}/shares/{share} [get]
func (share *LinkSharing) ReadOne() (err error) { func (share *LinkSharing) ReadOne(s *xorm.Session) (err error) {
exists, err := x.Where("id = ?", share.ID).Get(share) exists, err := s.Where("id = ?", share.ID).Get(share)
if err != nil { if err != nil {
return err return err
} }
@ -152,9 +153,9 @@ func (share *LinkSharing) ReadOne() (err error) {
// @Success 200 {array} models.LinkSharing "The share links" // @Success 200 {array} models.LinkSharing "The share links"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares [get] // @Router /lists/{list}/shares [get]
func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { func (share *LinkSharing) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
list := &List{ID: share.ListID} list := &List{ID: share.ListID}
can, _, err := list.CanRead(a) can, _, err := list.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -165,7 +166,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
var shares []*LinkSharing var shares []*LinkSharing
query := x. query := s.
Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%") Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%")
if limit > 0 { if limit > 0 {
query = query.Limit(limit, start) query = query.Limit(limit, start)
@ -182,7 +183,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
} }
users := make(map[int64]*user.User) users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users) err = s.In("id", userIDs).Find(&users)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -192,7 +193,7 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
} }
// Total count // Total count
totalItems, err = x. totalItems, err = s.
Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%"). Where("list_id = ? AND hash LIKE ?", share.ListID, "%"+search+"%").
Count(&LinkSharing{}) Count(&LinkSharing{})
if err != nil { if err != nil {
@ -216,15 +217,15 @@ func (share *LinkSharing) ReadAll(a web.Auth, search string, page int, perPage i
// @Failure 404 {object} web.HTTPError "Share Link not found." // @Failure 404 {object} web.HTTPError "Share Link not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{list}/shares/{share} [delete] // @Router /lists/{list}/shares/{share} [delete]
func (share *LinkSharing) Delete() (err error) { func (share *LinkSharing) Delete(s *xorm.Session) (err error) {
_, err = x.Where("id = ?", share.ID).Delete(share) _, err = s.Where("id = ?", share.ID).Delete(share)
return return
} }
// GetLinkShareByHash returns a link share by hash // GetLinkShareByHash returns a link share by hash
func GetLinkShareByHash(hash string) (share *LinkSharing, err error) { func GetLinkShareByHash(s *xorm.Session, hash string) (share *LinkSharing, err error) {
share = &LinkSharing{} share = &LinkSharing{}
has, err := x.Where("hash = ?", hash).Get(share) has, err := s.Where("hash = ?", hash).Get(share)
if err != nil { if err != nil {
return return
} }
@ -235,13 +236,12 @@ func GetLinkShareByHash(hash string) (share *LinkSharing, err error) {
} }
// GetListByShareHash returns a link share by its hash // GetListByShareHash returns a link share by its hash
func GetListByShareHash(hash string) (list *List, err error) { func GetListByShareHash(s *xorm.Session, hash string) (list *List, err error) {
share, err := GetLinkShareByHash(hash) share, err := GetLinkShareByHash(s, hash)
if err != nil { if err != nil {
return return
} }
list = &List{ID: share.ListID} list, err = GetListSimpleByID(s, share.ListID)
err = list.GetSimpleByID()
return return
} }

View file

@ -16,53 +16,55 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead implements the read right check for a link share // CanRead implements the read right check for a link share
func (share *LinkSharing) CanRead(a web.Auth) (bool, int, error) { func (share *LinkSharing) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// Don't allow creating link shares if the user itself authenticated with a link share // Don't allow creating link shares if the user itself authenticated with a link share
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, 0, nil return false, 0, nil
} }
l, err := GetListByShareHash(share.Hash) l, err := GetListByShareHash(s, share.Hash)
if err != nil { if err != nil {
return false, 0, err return false, 0, err
} }
return l.CanRead(a) return l.CanRead(s, a)
} }
// CanDelete implements the delete right check for a link share // CanDelete implements the delete right check for a link share
func (share *LinkSharing) CanDelete(a web.Auth) (bool, error) { func (share *LinkSharing) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(a) return share.canDoLinkShare(s, a)
} }
// CanUpdate implements the update right check for a link share // CanUpdate implements the update right check for a link share
func (share *LinkSharing) CanUpdate(a web.Auth) (bool, error) { func (share *LinkSharing) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(a) return share.canDoLinkShare(s, a)
} }
// CanCreate implements the create right check for a link share // CanCreate implements the create right check for a link share
func (share *LinkSharing) CanCreate(a web.Auth) (bool, error) { func (share *LinkSharing) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return share.canDoLinkShare(a) return share.canDoLinkShare(s, a)
} }
func (share *LinkSharing) canDoLinkShare(a web.Auth) (bool, error) { func (share *LinkSharing) canDoLinkShare(s *xorm.Session, a web.Auth) (bool, error) {
// Don't allow creating link shares if the user itself authenticated with a link share // Don't allow creating link shares if the user itself authenticated with a link share
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
l := &List{ID: share.ListID} l, err := GetListSimpleByID(s, share.ListID)
err := l.GetSimpleByID()
if err != nil { if err != nil {
return false, err return false, err
} }
// Check if the user is admin when the link right is admin // Check if the user is admin when the link right is admin
if share.Right == RightAdmin { if share.Right == RightAdmin {
return l.IsAdmin(a) return l.IsAdmin(s, a)
} }
return l.CanWrite(a) return l.CanWrite(s, a)
} }

View file

@ -96,9 +96,9 @@ var FavoritesPseudoList = List{
} }
// GetListsByNamespaceID gets all lists in a namespace // GetListsByNamespaceID gets all lists in a namespace
func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error) { func GetListsByNamespaceID(s *xorm.Session, nID int64, doer *user.User) (lists []*List, err error) {
if nID == -1 { if nID == -1 {
err = x.Select("l.*"). err = s.Select("l.*").
Table("list"). Table("list").
Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id").
Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tl.team_id"). Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tl.team_id").
@ -111,7 +111,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
GroupBy("l.id"). GroupBy("l.id").
Find(&lists) Find(&lists)
} else { } else {
err = x.Select("l.*"). err = s.Select("l.*").
Alias("l"). Alias("l").
Join("LEFT", []string{"namespaces", "n"}, "l.namespace_id = n.id"). Join("LEFT", []string{"namespaces", "n"}, "l.namespace_id = n.id").
Where("l.is_archived = false"). Where("l.is_archived = false").
@ -124,7 +124,7 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
} }
// get more list details // get more list details
err = AddListDetails(lists) err = addListDetails(s, lists)
return lists, err return lists, err
} }
@ -143,21 +143,22 @@ func GetListsByNamespaceID(nID int64, doer *user.User) (lists []*List, err error
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists [get] // @Router /lists [get]
func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { func (l *List) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// Check if we're dealing with a share auth // Check if we're dealing with a share auth
shareAuth, ok := a.(*LinkSharing) shareAuth, ok := a.(*LinkSharing)
if ok { if ok {
list := &List{ID: shareAuth.ListID} list, err := GetListSimpleByID(s, shareAuth.ListID)
err := list.GetSimpleByID()
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
lists := []*List{list} lists := []*List{list}
err = AddListDetails(lists) err = addListDetails(s, lists)
return lists, 0, 0, err return lists, 0, 0, err
} }
lists, resultCount, totalItems, err := getRawListsForUser(&listOptions{ lists, resultCount, totalItems, err := getRawListsForUser(
s,
&listOptions{
search: search, search: search,
user: &user.User{ID: a.GetID()}, user: &user.User{ID: a.GetID()},
page: page, page: page,
@ -169,7 +170,7 @@ func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result
} }
// Add more list details // Add more list details
err = AddListDetails(lists) err = addListDetails(s, lists)
return lists, resultCount, totalItems, err return lists, resultCount, totalItems, err
} }
@ -185,7 +186,7 @@ func (l *List) ReadAll(a web.Auth, search string, page int, perPage int) (result
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [get] // @Router /lists/{id} [get]
func (l *List) ReadOne() (err error) { func (l *List) ReadOne(s *xorm.Session) (err error) {
if l.ID == FavoritesPseudoList.ID { if l.ID == FavoritesPseudoList.ID {
// Already "built" the list in CanRead // Already "built" the list in CanRead
@ -194,7 +195,7 @@ func (l *List) ReadOne() (err error) {
// Check for saved filters // Check for saved filters
if getSavedFilterIDFromListID(l.ID) > 0 { if getSavedFilterIDFromListID(l.ID) > 0 {
sf, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(l.ID)) sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(l.ID))
if err != nil { if err != nil {
return err return err
} }
@ -206,13 +207,13 @@ func (l *List) ReadOne() (err error) {
} }
// Get list owner // Get list owner
l.Owner, err = user.GetUserByID(l.OwnerID) l.Owner, err = user.GetUserByID(s, l.OwnerID)
if err != nil { if err != nil {
return err return err
} }
// Check if the namespace is archived and set the namespace to archived if it is not already archived individually. // Check if the namespace is archived and set the namespace to archived if it is not already archived individually.
if !l.IsArchived { if !l.IsArchived {
err = l.CheckIsArchived() err = l.CheckIsArchived(s)
if err != nil { if err != nil {
if !IsErrNamespaceIsArchived(err) && !IsErrListIsArchived(err) { if !IsErrNamespaceIsArchived(err) && !IsErrListIsArchived(err) {
return return
@ -224,7 +225,7 @@ func (l *List) ReadOne() (err error) {
// Get any background information if there is one set // Get any background information if there is one set
if l.BackgroundFileID != 0 { if l.BackgroundFileID != 0 {
// Unsplash image // Unsplash image
l.BackgroundInformation, err = GetUnsplashPhotoByFileID(l.BackgroundFileID) l.BackgroundInformation, err = GetUnsplashPhotoByFileID(s, l.BackgroundFileID)
if err != nil && !files.IsErrFileIsNotUnsplashFile(err) { if err != nil && !files.IsErrFileIsNotUnsplashFile(err) {
return return
} }
@ -237,44 +238,33 @@ func (l *List) ReadOne() (err error) {
return nil return nil
} }
// GetSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist. // GetListSimpleByID gets a list with only the basic items, aka no tasks or user objects. Returns an error if the list does not exist.
func (l *List) GetSimpleByID() (err error) { func GetListSimpleByID(s *xorm.Session, listID int64) (list *List, err error) {
s := x.NewSession()
err = l.getSimpleByID(s)
if err != nil {
_ = s.Rollback()
return err
}
return nil
}
func (l *List) getSimpleByID(s *xorm.Session) (err error) { list = &List{}
if l.ID < 1 {
return ErrListDoesNotExist{ID: l.ID} if listID < 1 {
return nil, ErrListDoesNotExist{ID: listID}
} }
// We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object, exists, err := s.Where("id = ?", listID).Get(list)
// leading to not finding anything if the id is good, but for example the title is different.
id := l.ID
*l = List{}
exists, err := s.Where("id = ?", id).Get(l)
if err != nil { if err != nil {
return return
} }
if !exists { if !exists {
return ErrListDoesNotExist{ID: l.ID} return nil, ErrListDoesNotExist{ID: listID}
} }
return return
} }
// GetListSimplByTaskID gets a list by a task id // GetListSimplByTaskID gets a list by a task id
func GetListSimplByTaskID(taskID int64) (l *List, err error) { func GetListSimplByTaskID(s *xorm.Session, taskID int64) (l *List, err error) {
// We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object, // We need to re-init our list object, because otherwise xorm creates a "where for every item in that list object,
// leading to not finding anything if the id is good, but for example the title is different. // leading to not finding anything if the id is good, but for example the title is different.
var list List var list List
exists, err := x. exists, err := s.
Select("list.*"). Select("list.*").
Table(List{}). Table(List{}).
Join("INNER", "tasks", "list.id = tasks.list_id"). Join("INNER", "tasks", "list.id = tasks.list_id").
@ -292,9 +282,9 @@ func GetListSimplByTaskID(taskID int64) (l *List, err error) {
} }
// GetListsByIDs returns a map of lists from a slice with list ids // GetListsByIDs returns a map of lists from a slice with list ids
func GetListsByIDs(listIDs []int64) (lists map[int64]*List, err error) { func GetListsByIDs(s *xorm.Session, listIDs []int64) (lists map[int64]*List, err error) {
lists = make(map[int64]*List, len(listIDs)) lists = make(map[int64]*List, len(listIDs))
err = x.In("id", listIDs).Find(&lists) err = s.In("id", listIDs).Find(&lists)
return return
} }
@ -307,8 +297,8 @@ type listOptions struct {
} }
// Gets the lists only, without any tasks or so // Gets the lists only, without any tasks or so
func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) { func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) {
fullUser, err := user.GetUserByID(opts.user.ID) fullUser, err := user.GetUserByID(s, opts.user.ID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -344,7 +334,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
// Gets all Lists where the user is either owner or in a team which has access to the list // Gets all Lists where the user is either owner or in a team which has access to the list
// Or in a team which has namespace read access // Or in a team which has namespace read access
query := x.Select("l.*"). query := s.Select("l.*").
Table("list"). Table("list").
Alias("l"). Alias("l").
Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id").
@ -372,7 +362,7 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
return nil, 0, 0, err return nil, 0, 0, err
} }
totalItems, err = x. totalItems, err = s.
Table("list"). Table("list").
Alias("l"). Alias("l").
Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id").
@ -396,8 +386,8 @@ func getRawListsForUser(opts *listOptions) (lists []*List, resultCount int, tota
return lists, len(lists), totalItems, err return lists, len(lists), totalItems, err
} }
// AddListDetails adds owner user objects and list tasks to all lists in the slice // addListDetails adds owner user objects and list tasks to all lists in the slice
func AddListDetails(lists []*List) (err error) { func addListDetails(s *xorm.Session, lists []*List) (err error) {
var ownerIDs []int64 var ownerIDs []int64
for _, l := range lists { for _, l := range lists {
ownerIDs = append(ownerIDs, l.OwnerID) ownerIDs = append(ownerIDs, l.OwnerID)
@ -405,7 +395,7 @@ func AddListDetails(lists []*List) (err error) {
// Get all list owners // Get all list owners
owners := map[int64]*user.User{} owners := map[int64]*user.User{}
err = x.In("id", ownerIDs).Find(&owners) err = s.In("id", ownerIDs).Find(&owners)
if err != nil { if err != nil {
return return
} }
@ -423,7 +413,7 @@ func AddListDetails(lists []*List) (err error) {
// Unsplash background file info // Unsplash background file info
us := []*UnsplashPhoto{} us := []*UnsplashPhoto{}
err = x.In("file_id", fileIDs).Find(&us) err = s.In("file_id", fileIDs).Find(&us)
if err != nil { if err != nil {
return return
} }
@ -450,15 +440,15 @@ type NamespaceList struct {
} }
// CheckIsArchived returns an ErrListIsArchived or ErrNamespaceIsArchived if the list or its namespace is archived. // CheckIsArchived returns an ErrListIsArchived or ErrNamespaceIsArchived if the list or its namespace is archived.
func (l *List) CheckIsArchived() (err error) { func (l *List) CheckIsArchived(s *xorm.Session) (err error) {
// When creating a new list, we check if the namespace is archived // When creating a new list, we check if the namespace is archived
if l.ID == 0 { if l.ID == 0 {
n := &Namespace{ID: l.NamespaceID} n := &Namespace{ID: l.NamespaceID}
return n.CheckIsArchived() return n.CheckIsArchived(s)
} }
nl := &NamespaceList{} nl := &NamespaceList{}
exists, err := x. exists, err := s.
Table("list"). Table("list").
Join("LEFT", "namespaces", "list.namespace_id = namespaces.id"). Join("LEFT", "namespaces", "list.namespace_id = namespaces.id").
Where("list.id = ? AND (list.is_archived = true OR namespaces.is_archived = true)", l.ID). Where("list.id = ? AND (list.is_archived = true OR namespaces.is_archived = true)", l.ID).
@ -476,11 +466,11 @@ func (l *List) CheckIsArchived() (err error) {
} }
// CreateOrUpdateList updates a list or creates it if it doesn't exist // CreateOrUpdateList updates a list or creates it if it doesn't exist
func CreateOrUpdateList(list *List) (err error) { func CreateOrUpdateList(s *xorm.Session, list *List) (err error) {
// Check if the namespace exists // Check if the namespace exists
if list.NamespaceID != 0 && list.NamespaceID != FavoritesPseudoNamespace.ID { if list.NamespaceID != 0 && list.NamespaceID != FavoritesPseudoNamespace.ID {
_, err = GetNamespaceByID(list.NamespaceID) _, err = GetNamespaceByID(s, list.NamespaceID)
if err != nil { if err != nil {
return err return err
} }
@ -488,7 +478,7 @@ func CreateOrUpdateList(list *List) (err error) {
// Check if the identifier is unique and not empty // Check if the identifier is unique and not empty
if list.Identifier != "" { if list.Identifier != "" {
exists, err := x. exists, err := s.
Where("identifier = ?", list.Identifier). Where("identifier = ?", list.Identifier).
And("id != ?", list.ID). And("id != ?", list.ID).
Exist(&List{}) Exist(&List{})
@ -501,7 +491,7 @@ func CreateOrUpdateList(list *List) (err error) {
} }
if list.ID == 0 { if list.ID == 0 {
_, err = x.Insert(list) _, err = s.Insert(list)
metrics.UpdateCount(1, metrics.ListCountKey) metrics.UpdateCount(1, metrics.ListCountKey)
} else { } else {
// We need to specify the cols we want to update here to be able to un-archive lists // We need to specify the cols we want to update here to be able to un-archive lists
@ -516,7 +506,7 @@ func CreateOrUpdateList(list *List) (err error) {
colsToUpdate = append(colsToUpdate, "description") colsToUpdate = append(colsToUpdate, "description")
} }
_, err = x. _, err = s.
ID(list.ID). ID(list.ID).
Cols(colsToUpdate...). Cols(colsToUpdate...).
Update(list) Update(list)
@ -526,12 +516,13 @@ func CreateOrUpdateList(list *List) (err error) {
return return
} }
err = list.GetSimpleByID() l, err := GetListSimpleByID(s, list.ID)
if err != nil { if err != nil {
return return err
} }
err = list.ReadOne() *list = *l
err = list.ReadOne(s)
return return
} }
@ -550,33 +541,23 @@ func CreateOrUpdateList(list *List) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [post] // @Router /lists/{id} [post]
func (l *List) Update() (err error) { func (l *List) Update(s *xorm.Session) (err error) {
return CreateOrUpdateList(l) return CreateOrUpdateList(s, l)
} }
func updateListLastUpdated(list *List) (err error) { func updateListLastUpdated(s *xorm.Session, list *List) error {
s := x.NewSession()
err = updateListLastUpdatedS(s, list)
if err != nil {
_ = s.Rollback()
return err
}
return nil
}
func updateListLastUpdatedS(s *xorm.Session, list *List) error {
_, err := s.ID(list.ID).Cols("updated").Update(list) _, err := s.ID(list.ID).Cols("updated").Update(list)
return err return err
} }
func updateListByTaskID(taskID int64) (err error) { func updateListByTaskID(s *xorm.Session, taskID int64) (err error) {
// need to get the task to update the list last updated timestamp // need to get the task to update the list last updated timestamp
task, err := GetTaskByIDSimple(taskID) task, err := GetTaskByIDSimple(s, taskID)
if err != nil { if err != nil {
return err return err
} }
return updateListLastUpdated(&List{ID: task.ListID}) return updateListLastUpdated(s, &List{ID: task.ListID})
} }
// Create implements the create method of CRUDable // Create implements the create method of CRUDable
@ -593,8 +574,8 @@ func updateListByTaskID(taskID int64) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/lists [put] // @Router /namespaces/{namespaceID}/lists [put]
func (l *List) Create(a web.Auth) (err error) { func (l *List) Create(s *xorm.Session, a web.Auth) (err error) {
err = l.CheckIsArchived() err = l.CheckIsArchived(s)
if err != nil { if err != nil {
return err return err
} }
@ -608,7 +589,7 @@ func (l *List) Create(a web.Auth) (err error) {
l.Owner = doer l.Owner = doer
l.ID = 0 // Otherwise only the first time a new list would be created l.ID = 0 // Otherwise only the first time a new list would be created
err = CreateOrUpdateList(l) err = CreateOrUpdateList(s, l)
if err != nil { if err != nil {
return return
} }
@ -618,7 +599,7 @@ func (l *List) Create(a web.Auth) (err error) {
ListID: l.ID, ListID: l.ID,
Title: "New Bucket", Title: "New Bucket",
} }
return b.Create(a) return b.Create(s, a)
} }
// Delete implements the delete method of CRUDable // Delete implements the delete method of CRUDable
@ -633,27 +614,27 @@ func (l *List) Create(a web.Auth) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [delete] // @Router /lists/{id} [delete]
func (l *List) Delete() (err error) { func (l *List) Delete(s *xorm.Session) (err error) {
// Delete the list // Delete the list
_, err = x.ID(l.ID).Delete(&List{}) _, err = s.ID(l.ID).Delete(&List{})
if err != nil { if err != nil {
return return
} }
metrics.UpdateCount(-1, metrics.ListCountKey) metrics.UpdateCount(-1, metrics.ListCountKey)
// Delete all todotasks on that list // Delete all todotasks on that list
_, err = x.Where("list_id = ?", l.ID).Delete(&Task{}) _, err = s.Where("list_id = ?", l.ID).Delete(&Task{})
return return
} }
// SetListBackground sets a background file as list background in the db // SetListBackground sets a background file as list background in the db
func SetListBackground(listID int64, background *files.File) (err error) { func SetListBackground(s *xorm.Session, listID int64, background *files.File) (err error) {
l := &List{ l := &List{
ID: listID, ID: listID,
BackgroundFileID: background.ID, BackgroundFileID: background.ID,
} }
_, err = x. _, err = s.
Where("id = ?", l.ID). Where("id = ?", l.ID).
Cols("background_file_id"). Cols("background_file_id").
Update(l) Update(l)

View file

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/utils"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// ListDuplicate holds everything needed to duplicate a list // ListDuplicate holds everything needed to duplicate a list
@ -38,17 +39,17 @@ type ListDuplicate struct {
} }
// CanCreate checks if a user has the right to duplicate a list // CanCreate checks if a user has the right to duplicate a list
func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) { func (ld *ListDuplicate) CanCreate(s *xorm.Session, a web.Auth) (canCreate bool, err error) {
// List Exists + user has read access to list // List Exists + user has read access to list
ld.List = &List{ID: ld.ListID} ld.List = &List{ID: ld.ListID}
canRead, _, err := ld.List.CanRead(a) canRead, _, err := ld.List.CanRead(s, a)
if err != nil || !canRead { if err != nil || !canRead {
return canRead, err return canRead, err
} }
// Namespace exists + user has write access to is (-> can create new lists) // Namespace exists + user has write access to is (-> can create new lists)
ld.List.NamespaceID = ld.NamespaceID ld.List.NamespaceID = ld.NamespaceID
return ld.List.CanCreate(a) return ld.List.CanCreate(s, a)
} }
// Create duplicates a list // Create duplicates a list
@ -66,7 +67,7 @@ func (ld *ListDuplicate) CanCreate(a web.Auth) (canCreate bool, err error) {
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/duplicate [put] // @Router /lists/{listID}/duplicate [put]
//nolint:gocyclo //nolint:gocyclo
func (ld *ListDuplicate) Create(a web.Auth) (err error) { func (ld *ListDuplicate) Create(s *xorm.Session, a web.Auth) (err error) {
log.Debugf("Duplicating list %d", ld.ListID) log.Debugf("Duplicating list %d", ld.ListID)
@ -74,7 +75,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
ld.List.Identifier = "" // Reset the identifier to trigger regenerating a new one ld.List.Identifier = "" // Reset the identifier to trigger regenerating a new one
// Set the owner to the current user // Set the owner to the current user
ld.List.OwnerID = a.GetID() ld.List.OwnerID = a.GetID()
if err := CreateOrUpdateList(ld.List); err != nil { if err := CreateOrUpdateList(s, ld.List); err != nil {
// If there is no available unique list identifier, just reset it. // If there is no available unique list identifier, just reset it.
if IsErrListIdentifierIsNotUnique(err) { if IsErrListIdentifierIsNotUnique(err) {
ld.List.Identifier = "" ld.List.Identifier = ""
@ -90,7 +91,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Used to map the newly created tasks to their new buckets // Used to map the newly created tasks to their new buckets
bucketMap := make(map[int64]int64) bucketMap := make(map[int64]int64)
buckets := []*Bucket{} buckets := []*Bucket{}
err = x.Where("list_id = ?", ld.ListID).Find(&buckets) err = s.Where("list_id = ?", ld.ListID).Find(&buckets)
if err != nil { if err != nil {
return return
} }
@ -98,7 +99,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
oldID := b.ID oldID := b.ID
b.ID = 0 b.ID = 0
b.ListID = ld.List.ID b.ListID = ld.List.ID
if err := b.Create(a); err != nil { if err := b.Create(s, a); err != nil {
return err return err
} }
bucketMap[oldID] = b.ID bucketMap[oldID] = b.ID
@ -107,7 +108,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
log.Debugf("Duplicated all buckets from list %d into %d", ld.ListID, ld.List.ID) log.Debugf("Duplicated all buckets from list %d into %d", ld.ListID, ld.List.ID)
// Get all tasks + all task details // Get all tasks + all task details
tasks, _, _, err := getTasksForLists([]*List{{ID: ld.ListID}}, a, &taskOptions{}) tasks, _, _, err := getTasksForLists(s, []*List{{ID: ld.ListID}}, a, &taskOptions{})
if err != nil { if err != nil {
return err return err
} }
@ -123,10 +124,8 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
t.ListID = ld.List.ID t.ListID = ld.List.ID
t.BucketID = bucketMap[t.BucketID] t.BucketID = bucketMap[t.BucketID]
t.UID = "" t.UID = ""
s := x.NewSession()
err := createTask(s, t, a, false) err := createTask(s, t, a, false)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
taskMap[oldID] = t.ID taskMap[oldID] = t.ID
@ -138,7 +137,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Save all attachments // Save all attachments
// We also duplicate all underlying files since they could be modified in one list which would result in // We also duplicate all underlying files since they could be modified in one list which would result in
// file changes in the other list which is not something we want. // file changes in the other list which is not something we want.
attachments, err := getTaskAttachmentsByTaskIDs(oldTaskIDs) attachments, err := getTaskAttachmentsByTaskIDs(s, oldTaskIDs)
if err != nil { if err != nil {
return err return err
} }
@ -164,7 +163,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
return err return err
} }
err := attachment.NewAttachment(attachment.File.File, attachment.File.Name, attachment.File.Size, a) err := attachment.NewAttachment(s, attachment.File.File, attachment.File.Name, attachment.File.Size, a)
if err != nil { if err != nil {
return err return err
} }
@ -180,7 +179,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Copy label tasks (not the labels) // Copy label tasks (not the labels)
labelTasks := []*LabelTask{} labelTasks := []*LabelTask{}
err = x.In("task_id", oldTaskIDs).Find(&labelTasks) err = s.In("task_id", oldTaskIDs).Find(&labelTasks)
if err != nil { if err != nil {
return return
} }
@ -188,7 +187,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
for _, lt := range labelTasks { for _, lt := range labelTasks {
lt.ID = 0 lt.ID = 0
lt.TaskID = taskMap[lt.TaskID] lt.TaskID = taskMap[lt.TaskID]
if _, err := x.Insert(lt); err != nil { if _, err := s.Insert(lt); err != nil {
return err return err
} }
} }
@ -198,7 +197,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Assignees // Assignees
// Only copy those assignees who have access to the task // Only copy those assignees who have access to the task
assignees := []*TaskAssginee{} assignees := []*TaskAssginee{}
err = x.In("task_id", oldTaskIDs).Find(&assignees) err = s.In("task_id", oldTaskIDs).Find(&assignees)
if err != nil { if err != nil {
return return
} }
@ -207,7 +206,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
ID: taskMap[a.TaskID], ID: taskMap[a.TaskID],
ListID: ld.List.ID, ListID: ld.List.ID,
} }
if err := t.addNewAssigneeByID(a.UserID, ld.List); err != nil { if err := t.addNewAssigneeByID(s, a.UserID, ld.List); err != nil {
if IsErrUserDoesNotHaveAccessToList(err) { if IsErrUserDoesNotHaveAccessToList(err) {
continue continue
} }
@ -219,14 +218,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Comments // Comments
comments := []*TaskComment{} comments := []*TaskComment{}
err = x.In("task_id", oldTaskIDs).Find(&comments) err = s.In("task_id", oldTaskIDs).Find(&comments)
if err != nil { if err != nil {
return return
} }
for _, c := range comments { for _, c := range comments {
c.ID = 0 c.ID = 0
c.TaskID = taskMap[c.TaskID] c.TaskID = taskMap[c.TaskID]
if _, err := x.Insert(c); err != nil { if _, err := s.Insert(c); err != nil {
return err return err
} }
} }
@ -237,7 +236,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Low-Effort: Only copy those relations which are between tasks in the same list // Low-Effort: Only copy those relations which are between tasks in the same list
// because we can do that without a lot of hassle // because we can do that without a lot of hassle
relations := []*TaskRelation{} relations := []*TaskRelation{}
err = x.In("task_id", oldTaskIDs).Find(&relations) err = s.In("task_id", oldTaskIDs).Find(&relations)
if err != nil { if err != nil {
return return
} }
@ -249,7 +248,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
r.ID = 0 r.ID = 0
r.OtherTaskID = otherTaskID r.OtherTaskID = otherTaskID
r.TaskID = taskMap[r.TaskID] r.TaskID = taskMap[r.TaskID]
if _, err := x.Insert(r); err != nil { if _, err := s.Insert(r); err != nil {
return err return err
} }
} }
@ -276,19 +275,19 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
} }
// Get unsplash info if applicable // Get unsplash info if applicable
up, err := GetUnsplashPhotoByFileID(ld.List.BackgroundFileID) up, err := GetUnsplashPhotoByFileID(s, ld.List.BackgroundFileID)
if err != nil && files.IsErrFileIsNotUnsplashFile(err) { if err != nil && files.IsErrFileIsNotUnsplashFile(err) {
return err return err
} }
if up != nil { if up != nil {
up.ID = 0 up.ID = 0
up.FileID = file.ID up.FileID = file.ID
if err := up.Save(); err != nil { if err := up.Save(s); err != nil {
return err return err
} }
} }
if err := SetListBackground(ld.List.ID, file); err != nil { if err := SetListBackground(s, ld.List.ID, file); err != nil {
return err return err
} }
@ -298,14 +297,14 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
// Rights / Shares // Rights / Shares
// To keep it simple(r) we will only copy rights which are directly used with the list, no namespace changes. // To keep it simple(r) we will only copy rights which are directly used with the list, no namespace changes.
users := []*ListUser{} users := []*ListUser{}
err = x.Where("list_id = ?", ld.ListID).Find(&users) err = s.Where("list_id = ?", ld.ListID).Find(&users)
if err != nil { if err != nil {
return return
} }
for _, u := range users { for _, u := range users {
u.ID = 0 u.ID = 0
u.ListID = ld.List.ID u.ListID = ld.List.ID
if _, err := x.Insert(u); err != nil { if _, err := s.Insert(u); err != nil {
return err return err
} }
} }
@ -313,21 +312,21 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
log.Debugf("Duplicated user shares from list %d into %d", ld.ListID, ld.List.ID) log.Debugf("Duplicated user shares from list %d into %d", ld.ListID, ld.List.ID)
teams := []*TeamList{} teams := []*TeamList{}
err = x.Where("list_id = ?", ld.ListID).Find(&teams) err = s.Where("list_id = ?", ld.ListID).Find(&teams)
if err != nil { if err != nil {
return return
} }
for _, t := range teams { for _, t := range teams {
t.ID = 0 t.ID = 0
t.ListID = ld.List.ID t.ListID = ld.List.ID
if _, err := x.Insert(t); err != nil { if _, err := s.Insert(t); err != nil {
return err return err
} }
} }
// Generate new link shares if any are available // Generate new link shares if any are available
linkShares := []*LinkSharing{} linkShares := []*LinkSharing{}
err = x.Where("list_id = ?", ld.ListID).Find(&linkShares) err = s.Where("list_id = ?", ld.ListID).Find(&linkShares)
if err != nil { if err != nil {
return return
} }
@ -335,7 +334,7 @@ func (ld *ListDuplicate) Create(a web.Auth) (err error) {
share.ID = 0 share.ID = 0
share.ListID = ld.List.ID share.ListID = ld.List.ID
share.Hash = utils.MakeRandomString(40) share.Hash = utils.MakeRandomString(40)
if _, err := x.Insert(share); err != nil { if _, err := s.Insert(share); err != nil {
return err return err
} }
} }

View file

@ -29,6 +29,8 @@ func TestListDuplicate(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
s := db.NewSession()
defer s.Close()
u := &user.User{ u := &user.User{
ID: 1, ID: 1,
@ -38,10 +40,10 @@ func TestListDuplicate(t *testing.T) {
ListID: 1, ListID: 1,
NamespaceID: 1, NamespaceID: 1,
} }
can, err := l.CanCreate(u) can, err := l.CanCreate(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
err = l.Create(u) err = l.Create(s, u)
assert.NoError(t, err) assert.NoError(t, err)
// To make this test 100% useful, it would need to assert a lot more stuff, but it is good enough for now. // To make this test 100% useful, it would need to assert a lot more stuff, but it is good enough for now.
// Also, we're lacking utility functions to do all needed assertions. // Also, we're lacking utility functions to do all needed assertions.

View file

@ -20,10 +20,11 @@ import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// CanWrite return whether the user can write on that list or not // CanWrite return whether the user can write on that list or not
func (l *List) CanWrite(a web.Auth) (bool, error) { func (l *List) CanWrite(s *xorm.Session, a web.Auth) (bool, error) {
// The favorite list can't be edited // The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID { if l.ID == FavoritesPseudoList.ID {
@ -31,15 +32,14 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
} }
// Get the list and check the right // Get the list and check the right
originalList := &List{ID: l.ID} originalList, err := GetListSimpleByID(s, l.ID)
err := originalList.GetSimpleByID()
if err != nil { if err != nil {
return false, err return false, err
} }
// We put the result of the is archived check in a separate variable to be able to return it later without // We put the result of the is archived check in a separate variable to be able to return it later without
// needing to recheck it again // needing to recheck it again
errIsArchived := originalList.CheckIsArchived() errIsArchived := originalList.CheckIsArchived(s)
var canWrite bool var canWrite bool
@ -59,7 +59,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
return canWrite, errIsArchived return canWrite, errIsArchived
} }
canWrite, _, err = originalList.checkRight(a, RightWrite, RightAdmin) canWrite, _, err = originalList.checkRight(s, a, RightWrite, RightAdmin)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -67,7 +67,7 @@ func (l *List) CanWrite(a web.Auth) (bool, error) {
} }
// CanRead checks if a user has read access to a list // CanRead checks if a user has read access to a list
func (l *List) CanRead(a web.Auth) (bool, int, error) { func (l *List) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// The favorite list needs a special treatment // The favorite list needs a special treatment
if l.ID == FavoritesPseudoList.ID { if l.ID == FavoritesPseudoList.ID {
@ -84,14 +84,18 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) {
// Saved Filter Lists need a special case // Saved Filter Lists need a special case
if getSavedFilterIDFromListID(l.ID) > 0 { if getSavedFilterIDFromListID(l.ID) > 0 {
sf := &SavedFilter{ID: getSavedFilterIDFromListID(l.ID)} sf := &SavedFilter{ID: getSavedFilterIDFromListID(l.ID)}
return sf.CanRead(a) return sf.CanRead(s, a)
} }
// Check if the user is either owner or can read // Check if the user is either owner or can read
if err := l.GetSimpleByID(); err != nil { var err error
originalList, err := GetListSimpleByID(s, l.ID)
if err != nil {
return false, 0, err return false, 0, err
} }
*l = *originalList
// Check if we're dealing with a share auth // Check if we're dealing with a share auth
shareAuth, ok := a.(*LinkSharing) shareAuth, ok := a.(*LinkSharing)
if ok { if ok {
@ -102,16 +106,16 @@ func (l *List) CanRead(a web.Auth) (bool, int, error) {
if l.isOwner(&user.User{ID: a.GetID()}) { if l.isOwner(&user.User{ID: a.GetID()}) {
return true, int(RightAdmin), nil return true, int(RightAdmin), nil
} }
return l.checkRight(a, RightRead, RightWrite, RightAdmin) return l.checkRight(s, a, RightRead, RightWrite, RightAdmin)
} }
// CanUpdate checks if the user can update a list // CanUpdate checks if the user can update a list
func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) { func (l *List) CanUpdate(s *xorm.Session, a web.Auth) (canUpdate bool, err error) {
// The favorite list can't be edited // The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID { if l.ID == FavoritesPseudoList.ID {
return false, nil return false, nil
} }
canUpdate, err = l.CanWrite(a) canUpdate, err = l.CanWrite(s, a)
// If the list is archived and the user tries to un-archive it, let the request through // If the list is archived and the user tries to un-archive it, let the request through
if IsErrListIsArchived(err) && !l.IsArchived { if IsErrListIsArchived(err) && !l.IsArchived {
err = nil err = nil
@ -120,26 +124,25 @@ func (l *List) CanUpdate(a web.Auth) (canUpdate bool, err error) {
} }
// CanDelete checks if the user can delete a list // CanDelete checks if the user can delete a list
func (l *List) CanDelete(a web.Auth) (bool, error) { func (l *List) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return l.IsAdmin(a) return l.IsAdmin(s, a)
} }
// CanCreate checks if the user can create a list // CanCreate checks if the user can create a list
func (l *List) CanCreate(a web.Auth) (bool, error) { func (l *List) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// A user can create a list if they have write access to the namespace // A user can create a list if they have write access to the namespace
n := &Namespace{ID: l.NamespaceID} n := &Namespace{ID: l.NamespaceID}
return n.CanWrite(a) return n.CanWrite(s, a)
} }
// IsAdmin returns whether the user has admin rights on the list or not // IsAdmin returns whether the user has admin rights on the list or not
func (l *List) IsAdmin(a web.Auth) (bool, error) { func (l *List) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// The favorite list can't be edited // The favorite list can't be edited
if l.ID == FavoritesPseudoList.ID { if l.ID == FavoritesPseudoList.ID {
return false, nil return false, nil
} }
originalList := &List{ID: l.ID} originalList, err := GetListSimpleByID(s, l.ID)
err := originalList.GetSimpleByID()
if err != nil { if err != nil {
return false, err return false, err
} }
@ -156,7 +159,7 @@ func (l *List) IsAdmin(a web.Auth) (bool, error) {
if originalList.isOwner(&user.User{ID: a.GetID()}) { if originalList.isOwner(&user.User{ID: a.GetID()}) {
return true, nil return true, nil
} }
is, _, err := originalList.checkRight(a, RightAdmin) is, _, err := originalList.checkRight(s, a, RightAdmin)
return is, err return is, err
} }
@ -166,7 +169,7 @@ func (l *List) isOwner(u *user.User) bool {
} }
// Checks n different rights for any given user // Checks n different rights for any given user
func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) { func (l *List) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) {
/* /*
The following loop creates an sql condition like this one: The following loop creates an sql condition like this one:
@ -218,7 +221,7 @@ func (l *List) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
r := &allListRights{} r := &allListRights{}
var maxRight = 0 var maxRight = 0
exists, err := x. exists, err := s.
Table("list"). Table("list").
Alias("l"). Alias("l").
// User stuff // User stuff

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// TeamList defines the relation between a team and a list // TeamList defines the relation between a team and a list
@ -68,7 +69,7 @@ type TeamWithRight struct {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/teams [put] // @Router /lists/{id}/teams [put]
func (tl *TeamList) Create(a web.Auth) (err error) { func (tl *TeamList) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the rights are valid // Check if the rights are valid
if err = tl.Right.isValid(); err != nil { if err = tl.Right.isValid(); err != nil {
@ -76,19 +77,19 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
} }
// Check if the team exists // Check if the team exists
_, err = GetTeamByID(tl.TeamID) _, err = GetTeamByID(s, tl.TeamID)
if err != nil { if err != nil {
return return
} }
// Check if the list exists // Check if the list exists
l := &List{ID: tl.ListID} l, err := GetListSimpleByID(s, tl.ListID)
if err := l.GetSimpleByID(); err != nil { if err != nil {
return err return err
} }
// Check if the team is already on the list // Check if the team is already on the list
exists, err := x.Where("team_id = ?", tl.TeamID). exists, err := s.Where("team_id = ?", tl.TeamID).
And("list_id = ?", tl.ListID). And("list_id = ?", tl.ListID).
Get(&TeamList{}) Get(&TeamList{})
if err != nil { if err != nil {
@ -99,12 +100,12 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
} }
// Insert the new team // Insert the new team
_, err = x.Insert(tl) _, err = s.Insert(tl)
if err != nil { if err != nil {
return err return err
} }
err = updateListLastUpdated(l) err = updateListLastUpdated(s, l)
return return
} }
@ -121,16 +122,17 @@ func (tl *TeamList) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "Team or list does not exist." // @Failure 404 {object} web.HTTPError "Team or list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/teams/{teamID} [delete] // @Router /lists/{listID}/teams/{teamID} [delete]
func (tl *TeamList) Delete() (err error) { func (tl *TeamList) Delete(s *xorm.Session) (err error) {
// Check if the team exists // Check if the team exists
_, err = GetTeamByID(tl.TeamID) _, err = GetTeamByID(s, tl.TeamID)
if err != nil { if err != nil {
return return
} }
// Check if the team has access to the list // Check if the team has access to the list
has, err := x.Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID). has, err := s.
Where("team_id = ? AND list_id = ?", tl.TeamID, tl.ListID).
Get(&TeamList{}) Get(&TeamList{})
if err != nil { if err != nil {
return return
@ -140,14 +142,14 @@ func (tl *TeamList) Delete() (err error) {
} }
// Delete the relation // Delete the relation
_, err = x.Where("team_id = ?", tl.TeamID). _, err = s.Where("team_id = ?", tl.TeamID).
And("list_id = ?", tl.ListID). And("list_id = ?", tl.ListID).
Delete(TeamList{}) Delete(TeamList{})
if err != nil { if err != nil {
return err return err
} }
err = updateListLastUpdated(&List{ID: tl.ListID}) err = updateListLastUpdated(s, &List{ID: tl.ListID})
return return
} }
@ -166,10 +168,10 @@ func (tl *TeamList) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the list." // @Failure 403 {object} web.HTTPError "No right to see the list."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/teams [get] // @Router /lists/{id}/teams [get]
func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { func (tl *TeamList) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// Check if the user can read the namespace // Check if the user can read the namespace
l := &List{ID: tl.ListID} l := &List{ID: tl.ListID}
canRead, _, err := l.CanRead(a) canRead, _, err := l.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -181,7 +183,7 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get the teams // Get the teams
all := []*TeamWithRight{} all := []*TeamWithRight{}
query := x. query := s.
Table("teams"). Table("teams").
Join("INNER", "team_list", "team_id = teams.id"). Join("INNER", "team_list", "team_id = teams.id").
Where("team_list.list_id = ?", tl.ListID). Where("team_list.list_id = ?", tl.ListID).
@ -199,12 +201,12 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
teams = append(teams, &t.Team) teams = append(teams, &t.Team)
} }
err = addMoreInfoToTeams(teams) err = addMoreInfoToTeams(s, teams)
if err != nil { if err != nil {
return return
} }
totalItems, err = x. totalItems, err = s.
Table("teams"). Table("teams").
Join("INNER", "team_list", "team_id = teams.id"). Join("INNER", "team_list", "team_id = teams.id").
Where("team_list.list_id = ?", tl.ListID). Where("team_list.list_id = ?", tl.ListID).
@ -232,14 +234,14 @@ func (tl *TeamList) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "Team or list does not exist." // @Failure 404 {object} web.HTTPError "Team or list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/teams/{teamID} [post] // @Router /lists/{listID}/teams/{teamID} [post]
func (tl *TeamList) Update() (err error) { func (tl *TeamList) Update(s *xorm.Session) (err error) {
// Check if the right is valid // Check if the right is valid
if err := tl.Right.isValid(); err != nil { if err := tl.Right.isValid(); err != nil {
return err return err
} }
_, err = x. _, err = s.
Where("list_id = ? AND team_id = ?", tl.ListID, tl.TeamID). Where("list_id = ? AND team_id = ?", tl.ListID, tl.TeamID).
Cols("right"). Cols("right").
Update(tl) Update(tl)
@ -247,6 +249,6 @@ func (tl *TeamList) Update() (err error) {
return err return err
} }
err = updateListLastUpdated(&List{ID: tl.ListID}) err = updateListLastUpdated(s, &List{ID: tl.ListID})
return return
} }

View file

@ -18,29 +18,30 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if the user can create a team <-> list relation // CanCreate checks if the user can create a team <-> list relation
func (tl *TeamList) CanCreate(a web.Auth) (bool, error) { func (tl *TeamList) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(a) return tl.canDoTeamList(s, a)
} }
// CanDelete checks if the user can delete a team <-> list relation // CanDelete checks if the user can delete a team <-> list relation
func (tl *TeamList) CanDelete(a web.Auth) (bool, error) { func (tl *TeamList) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(a) return tl.canDoTeamList(s, a)
} }
// CanUpdate checks if the user can update a team <-> list relation // CanUpdate checks if the user can update a team <-> list relation
func (tl *TeamList) CanUpdate(a web.Auth) (bool, error) { func (tl *TeamList) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return tl.canDoTeamList(a) return tl.canDoTeamList(s, a)
} }
func (tl *TeamList) canDoTeamList(a web.Auth) (bool, error) { func (tl *TeamList) canDoTeamList(s *xorm.Session, a web.Auth) (bool, error) {
// Link shares aren't allowed to do anything // Link shares aren't allowed to do anything
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
l := List{ID: tl.ListID} l := List{ID: tl.ListID}
return l.IsAdmin(a) return l.IsAdmin(s, a)
} }

View file

@ -37,20 +37,24 @@ func TestTeamList_ReadAll(t *testing.T) {
ListID: 3, ListID: 3,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
teams, _, _, err := tl.ReadAll(u, "", 1, 50) s := db.NewSession()
teams, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice) assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
s := reflect.ValueOf(teams) ts := reflect.ValueOf(teams)
assert.Equal(t, s.Len(), 1) assert.Equal(t, ts.Len(), 1)
_ = s.Close()
}) })
t.Run("nonexistant list", func(t *testing.T) { t.Run("nonexistant list", func(t *testing.T) {
tl := TeamList{ tl := TeamList{
ListID: 99999, ListID: 99999,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50) s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err)) assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
}) })
t.Run("namespace owner", func(t *testing.T) { t.Run("namespace owner", func(t *testing.T) {
tl := TeamList{ tl := TeamList{
@ -59,8 +63,10 @@ func TestTeamList_ReadAll(t *testing.T) {
Right: RightAdmin, Right: RightAdmin,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50) s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err) assert.NoError(t, err)
_ = s.Close()
}) })
t.Run("no access", func(t *testing.T) { t.Run("no access", func(t *testing.T) {
tl := TeamList{ tl := TeamList{
@ -69,9 +75,11 @@ func TestTeamList_ReadAll(t *testing.T) {
Right: RightAdmin, Right: RightAdmin,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
_, _, _, err := tl.ReadAll(u, "", 1, 50) s := db.NewSession()
_, _, _, err := tl.ReadAll(s, u, "", 1, 50)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNeedToHaveListReadAccess(err)) assert.True(t, IsErrNeedToHaveListReadAccess(err))
_ = s.Close()
}) })
} }
@ -79,14 +87,17 @@ func TestTeamList_Create(t *testing.T) {
u := &user.User{ID: 1} u := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 1, ListID: 1,
Right: RightAdmin, Right: RightAdmin,
} }
allowed, _ := tl.CanCreate(u) allowed, _ := tl.CanCreate(s, u)
assert.True(t, allowed) assert.True(t, allowed)
err := tl.Create(u) err := tl.Create(s, u)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "team_list", map[string]interface{}{ db.AssertExists(t, "team_list", map[string]interface{}{
"team_id": 1, "team_id": 1,
@ -96,56 +107,67 @@ func TestTeamList_Create(t *testing.T) {
}) })
t.Run("team already has access", func(t *testing.T) { t.Run("team already has access", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 3, ListID: 3,
Right: RightAdmin, Right: RightAdmin,
} }
err := tl.Create(u) err := tl.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamAlreadyHasAccess(err)) assert.True(t, IsErrTeamAlreadyHasAccess(err))
_ = s.Close()
}) })
t.Run("wrong rights", func(t *testing.T) { t.Run("wrong rights", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 1, ListID: 1,
Right: RightUnknown, Right: RightUnknown,
} }
err := tl.Create(u) err := tl.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrInvalidRight(err)) assert.True(t, IsErrInvalidRight(err))
_ = s.Close()
}) })
t.Run("nonexistant team", func(t *testing.T) { t.Run("nonexistant team", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 9999, TeamID: 9999,
ListID: 1, ListID: 1,
} }
err := tl.Create(u) err := tl.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexistant list", func(t *testing.T) { t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 9999, ListID: 9999,
} }
err := tl.Create(u) err := tl.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err)) assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
}) })
} }
func TestTeamList_Delete(t *testing.T) { func TestTeamList_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 3, ListID: 3,
} }
err := tl.Delete() err := tl.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertMissing(t, "team_list", map[string]interface{}{ db.AssertMissing(t, "team_list", map[string]interface{}{
"team_id": 1, "team_id": 1,
@ -154,23 +176,27 @@ func TestTeamList_Delete(t *testing.T) {
}) })
t.Run("nonexistant team", func(t *testing.T) { t.Run("nonexistant team", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 9999, TeamID: 9999,
ListID: 1, ListID: 1,
} }
err := tl.Delete() err := tl.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexistant list", func(t *testing.T) { t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := TeamList{ tl := TeamList{
TeamID: 1, TeamID: 1,
ListID: 9999, ListID: 9999,
} }
err := tl.Delete() err := tl.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotHaveAccessToList(err)) assert.True(t, IsErrTeamDoesNotHaveAccessToList(err))
_ = s.Close()
}) })
} }
@ -229,6 +255,7 @@ func TestTeamList_Update(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := &TeamList{ tl := &TeamList{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -240,13 +267,15 @@ func TestTeamList_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := tl.Update() err := tl.Update(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("TeamList.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("TeamList.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("TeamList.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("TeamList.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "team_list", map[string]interface{}{ db.AssertExists(t, "team_list", map[string]interface{}{
"list_id": tt.fields.ListID, "list_id": tt.fields.ListID,

View file

@ -35,12 +35,15 @@ func TestList_CreateOrUpdate(t *testing.T) {
t.Run("create", func(t *testing.T) { t.Run("create", func(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
Title: "test", Title: "test",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
NamespaceID: 1, NamespaceID: 1,
} }
err := list.Create(usr) err := list.Create(s, usr)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{ db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID, "id": list.ID,
@ -51,49 +54,56 @@ func TestList_CreateOrUpdate(t *testing.T) {
}) })
t.Run("nonexistant namespace", func(t *testing.T) { t.Run("nonexistant namespace", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
Title: "test", Title: "test",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
NamespaceID: 999999, NamespaceID: 999999,
} }
err := list.Create(s, usr)
err := list.Create(usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexistant owner", func(t *testing.T) { t.Run("nonexistant owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
usr := &user.User{ID: 9482385} usr := &user.User{ID: 9482385}
list := List{ list := List{
Title: "test", Title: "test",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
NamespaceID: 1, NamespaceID: 1,
} }
err := list.Create(usr) err := list.Create(s, usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
}) })
t.Run("existing identifier", func(t *testing.T) { t.Run("existing identifier", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
Title: "test", Title: "test",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
Identifier: "test1", Identifier: "test1",
NamespaceID: 1, NamespaceID: 1,
} }
err := list.Create(s, usr)
err := list.Create(usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListIdentifierIsNotUnique(err)) assert.True(t, IsErrListIdentifierIsNotUnique(err))
_ = s.Close()
}) })
t.Run("non ascii characters", func(t *testing.T) { t.Run("non ascii characters", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
Title: "приффки фсем", Title: "приффки фсем",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
NamespaceID: 1, NamespaceID: 1,
} }
err := list.Create(usr) err := list.Create(s, usr)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{ db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID, "id": list.ID,
@ -107,6 +117,7 @@ func TestList_CreateOrUpdate(t *testing.T) {
t.Run("update", func(t *testing.T) { t.Run("update", func(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
ID: 1, ID: 1,
Title: "test", Title: "test",
@ -114,7 +125,9 @@ func TestList_CreateOrUpdate(t *testing.T) {
NamespaceID: 1, NamespaceID: 1,
} }
list.Description = "Lorem Ipsum dolor sit amet." list.Description = "Lorem Ipsum dolor sit amet."
err := list.Update() err := list.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "list", map[string]interface{}{ db.AssertExists(t, "list", map[string]interface{}{
"id": list.ID, "id": list.ID,
@ -125,37 +138,43 @@ func TestList_CreateOrUpdate(t *testing.T) {
}) })
t.Run("nonexistant", func(t *testing.T) { t.Run("nonexistant", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
ID: 99999999, ID: 99999999,
Title: "test", Title: "test",
} }
err := list.Update() err := list.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err)) assert.True(t, IsErrListDoesNotExist(err))
_ = s.Close()
}) })
t.Run("existing identifier", func(t *testing.T) { t.Run("existing identifier", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
Title: "test", Title: "test",
Description: "Lorem Ipsum", Description: "Lorem Ipsum",
Identifier: "test1", Identifier: "test1",
NamespaceID: 1, NamespaceID: 1,
} }
err := list.Create(s, usr)
err := list.Create(usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListIdentifierIsNotUnique(err)) assert.True(t, IsErrListIdentifierIsNotUnique(err))
_ = s.Close()
}) })
}) })
} }
func TestList_Delete(t *testing.T) { func TestList_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
list := List{ list := List{
ID: 1, ID: 1,
} }
err := list.Delete() err := list.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertMissing(t, "list", map[string]interface{}{ db.AssertMissing(t, "list", map[string]interface{}{
"id": 1, "id": 1,
@ -165,30 +184,34 @@ func TestList_Delete(t *testing.T) {
func TestList_ReadAll(t *testing.T) { func TestList_ReadAll(t *testing.T) {
t.Run("all in namespace", func(t *testing.T) { t.Run("all in namespace", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
// Get all lists for our namespace // Get all lists for our namespace
lists, err := GetListsByNamespaceID(1, &user.User{}) lists, err := GetListsByNamespaceID(s, 1, &user.User{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, len(lists), 2) assert.Equal(t, len(lists), 2)
_ = s.Close()
}) })
t.Run("all lists for user", func(t *testing.T) { t.Run("all lists for user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
u := &user.User{ID: 1} u := &user.User{ID: 1}
list := List{} list := List{}
lists3, _, _, err := list.ReadAll(u, "", 1, 50) lists3, _, _, err := list.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(lists3).Kind(), reflect.Slice) assert.Equal(t, reflect.TypeOf(lists3).Kind(), reflect.Slice)
s := reflect.ValueOf(lists3) ls := reflect.ValueOf(lists3)
assert.Equal(t, 16, s.Len()) assert.Equal(t, 16, ls.Len())
_ = s.Close()
}) })
t.Run("lists for nonexistant user", func(t *testing.T) { t.Run("lists for nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
usr := &user.User{ID: 999999} usr := &user.User{ID: 999999}
list := List{} list := List{}
_, _, _, err := list.ReadAll(usr, "", 1, 50) _, _, _, err := list.ReadAll(s, usr, "", 1, 50)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
}) })
} }

View file

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// ListUser represents a list <-> user relation // ListUser represents a list <-> user relation
@ -71,7 +72,7 @@ type UserWithRight struct {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [put] // @Router /lists/{id}/users [put]
func (lu *ListUser) Create(a web.Auth) (err error) { func (lu *ListUser) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the right is valid // Check if the right is valid
if err := lu.Right.isValid(); err != nil { if err := lu.Right.isValid(); err != nil {
@ -79,17 +80,17 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
} }
// Check if the list exists // Check if the list exists
l := &List{ID: lu.ListID} l, err := GetListSimpleByID(s, lu.ListID)
if err = l.GetSimpleByID(); err != nil { if err != nil {
return return
} }
// Check if the user exists // Check if the user exists
user, err := user.GetUserByUsername(lu.Username) u, err := user.GetUserByUsername(s, lu.Username)
if err != nil { if err != nil {
return err return err
} }
lu.UserID = user.ID lu.UserID = u.ID
// Check if the user already has access or is owner of that list // Check if the user already has access or is owner of that list
// We explicitly DONT check for teams here // We explicitly DONT check for teams here
@ -97,7 +98,7 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID} return ErrUserAlreadyHasAccess{UserID: lu.UserID, ListID: lu.ListID}
} }
exist, err := x.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{}) exist, err := s.Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).Get(&ListUser{})
if err != nil { if err != nil {
return return
} }
@ -106,12 +107,12 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
} }
// Insert user <-> list relation // Insert user <-> list relation
_, err = x.Insert(lu) _, err = s.Insert(lu)
if err != nil { if err != nil {
return err return err
} }
err = updateListLastUpdated(l) err = updateListLastUpdated(s, l)
return return
} }
@ -128,17 +129,18 @@ func (lu *ListUser) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "user or list does not exist." // @Failure 404 {object} web.HTTPError "user or list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [delete] // @Router /lists/{listID}/users/{userID} [delete]
func (lu *ListUser) Delete() (err error) { func (lu *ListUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists // Check if the user exists
user, err := user.GetUserByUsername(lu.Username) u, err := user.GetUserByUsername(s, lu.Username)
if err != nil { if err != nil {
return return
} }
lu.UserID = user.ID lu.UserID = u.ID
// Check if the user has access to the list // Check if the user has access to the list
has, err := x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). has, err := s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Get(&ListUser{}) Get(&ListUser{})
if err != nil { if err != nil {
return return
@ -147,13 +149,14 @@ func (lu *ListUser) Delete() (err error) {
return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID} return ErrUserDoesNotHaveAccessToList{ListID: lu.ListID, UserID: lu.UserID}
} }
_, err = x.Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID). _, err = s.
Where("user_id = ? AND list_id = ?", lu.UserID, lu.ListID).
Delete(&ListUser{}) Delete(&ListUser{})
if err != nil { if err != nil {
return err return err
} }
err = updateListLastUpdated(&List{ID: lu.ListID}) err = updateListLastUpdated(s, &List{ID: lu.ListID})
return return
} }
@ -172,10 +175,10 @@ func (lu *ListUser) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the list." // @Failure 403 {object} web.HTTPError "No right to see the list."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/users [get] // @Router /lists/{id}/users [get]
func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (lu *ListUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the list // Check if the user has access to the list
l := &List{ID: lu.ListID} l := &List{ID: lu.ListID}
canRead, _, err := l.CanRead(a) canRead, _, err := l.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -187,7 +190,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all users // Get all users
all := []*UserWithRight{} all := []*UserWithRight{}
query := x. query := s.
Join("INNER", "users_list", "user_id = users.id"). Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID). Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%") Where("users.username LIKE ?", "%"+search+"%")
@ -204,7 +207,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
u.Email = "" u.Email = ""
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Join("INNER", "users_list", "user_id = users.id"). Join("INNER", "users_list", "user_id = users.id").
Where("users_list.list_id = ?", lu.ListID). Where("users_list.list_id = ?", lu.ListID).
Where("users.username LIKE ?", "%"+search+"%"). Where("users.username LIKE ?", "%"+search+"%").
@ -228,7 +231,7 @@ func (lu *ListUser) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 404 {object} web.HTTPError "User or list does not exist." // @Failure 404 {object} web.HTTPError "User or list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/users/{userID} [post] // @Router /lists/{listID}/users/{userID} [post]
func (lu *ListUser) Update() (err error) { func (lu *ListUser) Update(s *xorm.Session) (err error) {
// Check if the right is valid // Check if the right is valid
if err := lu.Right.isValid(); err != nil { if err := lu.Right.isValid(); err != nil {
@ -236,13 +239,13 @@ func (lu *ListUser) Update() (err error) {
} }
// Check if the user exists // Check if the user exists
u, err := user.GetUserByUsername(lu.Username) u, err := user.GetUserByUsername(s, lu.Username)
if err != nil { if err != nil {
return err return err
} }
lu.UserID = u.ID lu.UserID = u.ID
_, err = x. _, err = s.
Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID). Where("list_id = ? AND user_id = ?", lu.ListID, lu.UserID).
Cols("right"). Cols("right").
Update(lu) Update(lu)
@ -250,6 +253,6 @@ func (lu *ListUser) Update() (err error) {
return err return err
} }
err = updateListLastUpdated(&List{ID: lu.ListID}) err = updateListLastUpdated(s, &List{ID: lu.ListID})
return return
} }

View file

@ -18,24 +18,25 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if the user can create a new user <-> list relation // CanCreate checks if the user can create a new user <-> list relation
func (lu *ListUser) CanCreate(a web.Auth) (bool, error) { func (lu *ListUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(a) return lu.canDoListUser(s, a)
} }
// CanDelete checks if the user can delete a user <-> list relation // CanDelete checks if the user can delete a user <-> list relation
func (lu *ListUser) CanDelete(a web.Auth) (bool, error) { func (lu *ListUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(a) return lu.canDoListUser(s, a)
} }
// CanUpdate checks if the user can update a user <-> list relation // CanUpdate checks if the user can update a user <-> list relation
func (lu *ListUser) CanUpdate(a web.Auth) (bool, error) { func (lu *ListUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return lu.canDoListUser(a) return lu.canDoListUser(s, a)
} }
func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) { func (lu *ListUser) canDoListUser(s *xorm.Session, a web.Auth) (bool, error) {
// Link shares aren't allowed to do anything // Link shares aren't allowed to do anything
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
@ -43,5 +44,5 @@ func (lu *ListUser) canDoListUser(a web.Auth) (bool, error) {
// Get the list and check if the user has write access on it // Get the list and check if the user has write access on it
l := List{ID: lu.ListID} l := List{ID: lu.ListID}
return l.IsAdmin(a) return l.IsAdmin(s, a)
} }

View file

@ -80,6 +80,7 @@ func TestListUser_CanDoSomething(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{ lu := &ListUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -91,15 +92,16 @@ func TestListUser_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
if got, _ := lu.CanCreate(tt.args.a); got != tt.want["CanCreate"] { if got, _ := lu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) t.Errorf("ListUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
} }
if got, _ := lu.CanDelete(tt.args.a); got != tt.want["CanDelete"] { if got, _ := lu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) t.Errorf("ListUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
} }
if got, _ := lu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { if got, _ := lu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) t.Errorf("ListUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
} }
_ = s.Close()
}) })
} }
} }

View file

@ -24,9 +24,9 @@ import (
"code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"gopkg.in/d4l3k/messagediff.v1"
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/stretchr/testify/assert"
"gopkg.in/d4l3k/messagediff.v1"
) )
func TestListUser_Create(t *testing.T) { func TestListUser_Create(t *testing.T) {
@ -108,6 +108,7 @@ func TestListUser_Create(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
ul := &ListUser{ ul := &ListUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -120,13 +121,17 @@ func TestListUser_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := ul.Create(tt.args.a) err := ul.Create(s, tt.args.a)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListUser.Create() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("ListUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{ db.AssertExists(t, "users_list", map[string]interface{}{
"user_id": ul.UserID, "user_id": ul.UserID,
@ -212,6 +217,7 @@ func TestListUser_ReadAll(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
ul := &ListUser{ ul := &ListUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -223,7 +229,7 @@ func TestListUser_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
got, _, _, err := ul.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) got, _, _, err := ul.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -233,6 +239,7 @@ func TestListUser_ReadAll(t *testing.T) {
if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal { if diff, equal := messagediff.PrettyDiff(got, tt.want); !equal {
t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff) t.Errorf("ListUser.ReadAll() = %v, want %v, diff: %v", got, tt.want, diff)
} }
_ = s.Close()
}) })
} }
} }
@ -292,6 +299,7 @@ func TestListUser_Update(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{ lu := &ListUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -303,13 +311,17 @@ func TestListUser_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := lu.Update() err := lu.Update(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListUser.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("ListUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "users_list", map[string]interface{}{ db.AssertExists(t, "users_list", map[string]interface{}{
"list_id": tt.fields.ListID, "list_id": tt.fields.ListID,
@ -369,6 +381,7 @@ func TestListUser_Delete(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
lu := &ListUser{ lu := &ListUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -380,13 +393,17 @@ func TestListUser_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := lu.Delete() err := lu.Delete(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("ListUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertMissing(t, "users_list", map[string]interface{}{ db.AssertMissing(t, "users_list", map[string]interface{}{
"user_id": tt.fields.UserID, "user_id": tt.fields.UserID,

View file

@ -23,12 +23,11 @@ import (
"time" "time"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/imdario/mergo"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// Namespace holds informations about a namespace // Namespace holds informations about a namespace
@ -95,55 +94,48 @@ func (Namespace) TableName() string {
} }
// GetSimpleByID gets a namespace without things like the owner, it more or less only checks if it exists. // GetSimpleByID gets a namespace without things like the owner, it more or less only checks if it exists.
func (n *Namespace) GetSimpleByID() (err error) { func getNamespaceSimpleByID(s *xorm.Session, id int64) (namespace *Namespace, err error) {
if n.ID == 0 { if id == 0 {
return ErrNamespaceDoesNotExist{ID: n.ID} return nil, ErrNamespaceDoesNotExist{ID: id}
} }
// Get the namesapce with shared lists // Get the namesapce with shared lists
if n.ID == -1 { if id == -1 {
*n = SharedListsPseudoNamespace return &SharedListsPseudoNamespace, nil
return
} }
if n.ID == FavoritesPseudoNamespace.ID { if id == FavoritesPseudoNamespace.ID {
*n = FavoritesPseudoNamespace return &FavoritesPseudoNamespace, nil
return
} }
namespaceFromDB := &Namespace{} namespace = &Namespace{}
exists, err := x.Where("id = ?", n.ID).Get(namespaceFromDB)
exists, err := s.Where("id = ?", id).Get(namespace)
if err != nil { if err != nil {
return return
} }
if !exists { if !exists {
return ErrNamespaceDoesNotExist{ID: n.ID} return nil, ErrNamespaceDoesNotExist{ID: id}
} }
// We don't want to override the provided user struct because this would break updating, so we have to merge it
if err := mergo.Merge(namespaceFromDB, n, mergo.WithOverride); err != nil {
return err
}
*n = *namespaceFromDB
return return
} }
// GetNamespaceByID returns a namespace object by its ID // GetNamespaceByID returns a namespace object by its ID
func GetNamespaceByID(id int64) (namespace Namespace, err error) { func GetNamespaceByID(s *xorm.Session, id int64) (namespace *Namespace, err error) {
namespace = Namespace{ID: id} namespace, err = getNamespaceSimpleByID(s, id)
err = namespace.GetSimpleByID()
if err != nil { if err != nil {
return return
} }
// Get the namespace Owner // Get the namespace Owner
namespace.Owner, err = user.GetUserByID(namespace.OwnerID) namespace.Owner, err = user.GetUserByID(s, namespace.OwnerID)
return return
} }
// CheckIsArchived returns an ErrNamespaceIsArchived if the namepace is archived. // CheckIsArchived returns an ErrNamespaceIsArchived if the namepace is archived.
func (n *Namespace) CheckIsArchived() error { func (n *Namespace) CheckIsArchived(s *xorm.Session) error {
exists, err := x. exists, err := s.
Where("id = ? AND is_archived = true", n.ID). Where("id = ? AND is_archived = true", n.ID).
Exist(&Namespace{}) Exist(&Namespace{})
if err != nil { if err != nil {
@ -167,8 +159,12 @@ func (n *Namespace) CheckIsArchived() error {
// @Failure 403 {object} web.HTTPError "The user does not have access to that namespace." // @Failure 403 {object} web.HTTPError "The user does not have access to that namespace."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id} [get] // @Router /namespaces/{id} [get]
func (n *Namespace) ReadOne() (err error) { func (n *Namespace) ReadOne(s *xorm.Session) (err error) {
*n, err = GetNamespaceByID(n.ID) nn, err := GetNamespaceByID(s, n.ID)
if err != nil {
return err
}
*n = *nn
return return
} }
@ -207,7 +203,7 @@ func makeNamespaceSliceFromMap(namespaces map[int64]*NamespaceWithLists, userMap
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces [get] // @Router /namespaces [get]
//nolint:gocyclo //nolint:gocyclo
func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (n *Namespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{} return nil, 0, 0, ErrGenericForbidden{}
} }
@ -249,7 +245,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
} }
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
query := x.Select("namespaces.*"). query := s.Select("namespaces.*").
Table("namespaces"). Table("namespaces").
Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id"). Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id").
Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id"). Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id").
@ -268,7 +264,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
return nil, 0, 0, err return nil, 0, 0, err
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Table("namespaces"). Table("namespaces").
Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id"). Join("LEFT", "team_namespaces", "namespaces.id = team_namespaces.namespace_id").
Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id"). Join("LEFT", "team_members", "team_members.team_id = team_namespaces.team_id").
@ -294,7 +290,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all owners // Get all owners
userMap := make(map[int64]*user.User) userMap := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&userMap) err = s.In("id", userIDs).Find(&userMap)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -306,7 +302,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all lists // Get all lists
lists := []*List{} lists := []*List{}
listQuery := x. listQuery := s.
In("namespace_id", namespaceids) In("namespace_id", namespaceids)
if !n.IsArchived { if !n.IsArchived {
@ -330,7 +326,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Get all lists individually shared with our user (not via a namespace) // Get all lists individually shared with our user (not via a namespace)
individualLists := []*List{} individualLists := []*List{}
iListQuery := x.Select("l.*"). iListQuery := s.Select("l.*").
Table("list"). Table("list").
Alias("l"). Alias("l").
Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id").
@ -360,7 +356,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
} }
// More details for the lists // More details for the lists
err = AddListDetails(lists) err = addListDetails(s, lists)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -386,7 +382,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// Check if we have any favorites or favorited lists and remove the favorites namespace from the list if not // Check if we have any favorites or favorited lists and remove the favorites namespace from the list if not
var favoriteCount int64 var favoriteCount int64
favoriteCount, err = x. favoriteCount, err = s.
Join("INNER", "list", "tasks.list_id = list.id"). Join("INNER", "list", "tasks.list_id = list.id").
Join("INNER", "namespaces", "list.namespace_id = namespaces.id"). Join("INNER", "namespaces", "list.namespace_id = namespaces.id").
Where(builder.And(builder.Eq{"tasks.is_favorite": true}, builder.In("namespaces.id", namespaceids))). Where(builder.And(builder.Eq{"tasks.is_favorite": true}, builder.In("namespaces.id", namespaceids))).
@ -413,7 +409,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
///////////////// /////////////////
// Saved Filters // Saved Filters
savedFilters, err := getSavedFiltersForUser(a) savedFilters, err := getSavedFiltersForUser(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -457,7 +453,7 @@ func (n *Namespace) ReadAll(a web.Auth, search string, page int, perPage int) (r
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces [put] // @Router /namespaces [put]
func (n *Namespace) Create(a web.Auth) (err error) { func (n *Namespace) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if we have at least a name // Check if we have at least a name
if n.Title == "" { if n.Title == "" {
return ErrNamespaceNameCannotBeEmpty{NamespaceID: 0, UserID: a.GetID()} return ErrNamespaceNameCannotBeEmpty{NamespaceID: 0, UserID: a.GetID()}
@ -465,14 +461,14 @@ func (n *Namespace) Create(a web.Auth) (err error) {
n.ID = 0 // This would otherwise prevent the creation of new lists after one was created n.ID = 0 // This would otherwise prevent the creation of new lists after one was created
// Check if the User exists // Check if the User exists
n.Owner, err = user.GetUserByID(a.GetID()) n.Owner, err = user.GetUserByID(s, a.GetID())
if err != nil { if err != nil {
return return
} }
n.OwnerID = n.Owner.ID n.OwnerID = n.Owner.ID
// Insert // Insert
if _, err = x.Insert(n); err != nil { if _, err = s.Insert(n); err != nil {
return err return err
} }
@ -482,12 +478,12 @@ func (n *Namespace) Create(a web.Auth) (err error) {
// CreateNewNamespaceForUser creates a new namespace for a user. To prevent import cycles, we can't do that // CreateNewNamespaceForUser creates a new namespace for a user. To prevent import cycles, we can't do that
// directly in the user.Create function. // directly in the user.Create function.
func CreateNewNamespaceForUser(user *user.User) (err error) { func CreateNewNamespaceForUser(s *xorm.Session, user *user.User) (err error) {
newN := &Namespace{ newN := &Namespace{
Title: user.Username, Title: user.Username,
Description: user.Username + "'s namespace.", Description: user.Username + "'s namespace.",
} }
return newN.Create(user) return newN.Create(s, user)
} }
// Delete deletes a namespace // Delete deletes a namespace
@ -502,22 +498,22 @@ func CreateNewNamespaceForUser(user *user.User) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id} [delete] // @Router /namespaces/{id} [delete]
func (n *Namespace) Delete() (err error) { func (n *Namespace) Delete(s *xorm.Session) (err error) {
// Check if the namespace exists // Check if the namespace exists
_, err = GetNamespaceByID(n.ID) _, err = GetNamespaceByID(s, n.ID)
if err != nil { if err != nil {
return return
} }
// Delete the namespace // Delete the namespace
_, err = x.ID(n.ID).Delete(&Namespace{}) _, err = s.ID(n.ID).Delete(&Namespace{})
if err != nil { if err != nil {
return return
} }
// Delete all lists with their tasks // Delete all lists with their tasks
lists, err := GetListsByNamespaceID(n.ID, &user.User{}) lists, err := GetListsByNamespaceID(s, n.ID, &user.User{})
if err != nil { if err != nil {
return return
} }
@ -530,13 +526,13 @@ func (n *Namespace) Delete() (err error) {
} }
// Delete tasks // Delete tasks
_, err = x.In("list_id", listIDs).Delete(&Task{}) _, err = s.In("list_id", listIDs).Delete(&Task{})
if err != nil { if err != nil {
return return
} }
// Delete the lists // Delete the lists
_, err = x.In("id", listIDs).Delete(&List{}) _, err = s.In("id", listIDs).Delete(&List{})
if err != nil { if err != nil {
return return
} }
@ -560,14 +556,14 @@ func (n *Namespace) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespace/{id} [post] // @Router /namespace/{id} [post]
func (n *Namespace) Update() (err error) { func (n *Namespace) Update(s *xorm.Session) (err error) {
// Check if we have at least a name // Check if we have at least a name
if n.Title == "" { if n.Title == "" {
return ErrNamespaceNameCannotBeEmpty{NamespaceID: n.ID} return ErrNamespaceNameCannotBeEmpty{NamespaceID: n.ID}
} }
// Check if the namespace exists // Check if the namespace exists
currentNamespace, err := GetNamespaceByID(n.ID) currentNamespace, err := GetNamespaceByID(s, n.ID)
if err != nil { if err != nil {
return return
} }
@ -581,7 +577,7 @@ func (n *Namespace) Update() (err error) {
if n.Owner != nil { if n.Owner != nil {
n.OwnerID = n.Owner.ID n.OwnerID = n.Owner.ID
if currentNamespace.OwnerID != n.OwnerID { if currentNamespace.OwnerID != n.OwnerID {
n.Owner, err = user.GetUserByID(n.OwnerID) n.Owner, err = user.GetUserByID(s, n.OwnerID)
if err != nil { if err != nil {
return return
} }
@ -599,7 +595,7 @@ func (n *Namespace) Update() (err error) {
} }
// Do the actual update // Do the actual update
_, err = x. _, err = s.
ID(currentNamespace.ID). ID(currentNamespace.ID).
Cols(colsToUpdate...). Cols(colsToUpdate...).
Update(n) Update(n)

View file

@ -19,37 +19,38 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// CanWrite checks if a user has write access to a namespace // CanWrite checks if a user has write access to a namespace
func (n *Namespace) CanWrite(a web.Auth) (bool, error) { func (n *Namespace) CanWrite(s *xorm.Session, a web.Auth) (bool, error) {
can, _, err := n.checkRight(a, RightWrite, RightAdmin) can, _, err := n.checkRight(s, a, RightWrite, RightAdmin)
return can, err return can, err
} }
// IsAdmin returns true or false if the user is admin on that namespace or not // IsAdmin returns true or false if the user is admin on that namespace or not
func (n *Namespace) IsAdmin(a web.Auth) (bool, error) { func (n *Namespace) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
is, _, err := n.checkRight(a, RightAdmin) is, _, err := n.checkRight(s, a, RightAdmin)
return is, err return is, err
} }
// CanRead checks if a user has read access to that namespace // CanRead checks if a user has read access to that namespace
func (n *Namespace) CanRead(a web.Auth) (bool, int, error) { func (n *Namespace) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
return n.checkRight(a, RightRead, RightWrite, RightAdmin) return n.checkRight(s, a, RightRead, RightWrite, RightAdmin)
} }
// CanUpdate checks if the user can update the namespace // CanUpdate checks if the user can update the namespace
func (n *Namespace) CanUpdate(a web.Auth) (bool, error) { func (n *Namespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }
// CanDelete checks if the user can delete a namespace // CanDelete checks if the user can delete a namespace
func (n *Namespace) CanDelete(a web.Auth) (bool, error) { func (n *Namespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }
// CanCreate checks if the user can create a new namespace // CanCreate checks if the user can create a new namespace
func (n *Namespace) CanCreate(a web.Auth) (bool, error) { func (n *Namespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
@ -58,7 +59,7 @@ func (n *Namespace) CanCreate(a web.Auth) (bool, error) {
return true, nil return true, nil
} }
func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) { func (n *Namespace) checkRight(s *xorm.Session, a web.Auth, rights ...Right) (bool, int, error) {
// If the auth is a link share, don't do anything // If the auth is a link share, don't do anything
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
@ -66,13 +67,12 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
} }
// Get the namespace and check the right // Get the namespace and check the right
nn := &Namespace{ID: n.ID} nn, err := getNamespaceSimpleByID(s, n.ID)
err := nn.GetSimpleByID()
if err != nil { if err != nil {
return false, 0, err return false, 0, err
} }
if a.GetID() == n.OwnerID { if a.GetID() == nn.OwnerID {
return true, int(RightAdmin), nil return true, int(RightAdmin), nil
} }
@ -113,7 +113,8 @@ func (n *Namespace) checkRight(a web.Auth, rights ...Right) (bool, int, error) {
var maxRights = 0 var maxRights = 0
r := &allRights{} r := &allRights{}
exists, err := x.Select("*"). exists, err := s.
Select("*").
Table("namespaces"). Table("namespaces").
// User stuff // User stuff
Join("LEFT", "users_namespace", "users_namespace.namespace_id = namespaces.id"). Join("LEFT", "users_namespace", "users_namespace.namespace_id = namespaces.id").

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// TeamNamespace defines the relationship between a Team and a Namespace // TeamNamespace defines the relationship between a Team and a Namespace
@ -62,7 +63,7 @@ func (TeamNamespace) TableName() string {
// @Failure 403 {object} web.HTTPError "The team does not have access to the namespace" // @Failure 403 {object} web.HTTPError "The team does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/teams [put] // @Router /namespaces/{id}/teams [put]
func (tn *TeamNamespace) Create(a web.Auth) (err error) { func (tn *TeamNamespace) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the rights are valid // Check if the rights are valid
if err = tn.Right.isValid(); err != nil { if err = tn.Right.isValid(); err != nil {
@ -70,19 +71,20 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
} }
// Check if the team exists // Check if the team exists
_, err = GetTeamByID(tn.TeamID) _, err = GetTeamByID(s, tn.TeamID)
if err != nil { if err != nil {
return return
} }
// Check if the namespace exists // Check if the namespace exists
_, err = GetNamespaceByID(tn.NamespaceID) _, err = GetNamespaceByID(s, tn.NamespaceID)
if err != nil { if err != nil {
return return
} }
// Check if the team already has access to the namespace // Check if the team already has access to the namespace
exists, err := x.Where("team_id = ?", tn.TeamID). exists, err := s.
Where("team_id = ?", tn.TeamID).
And("namespace_id = ?", tn.NamespaceID). And("namespace_id = ?", tn.NamespaceID).
Get(&TeamNamespace{}) Get(&TeamNamespace{})
if err != nil { if err != nil {
@ -93,7 +95,7 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
} }
// Insert the new team // Insert the new team
_, err = x.Insert(tn) _, err = s.Insert(tn)
return return
} }
@ -110,16 +112,17 @@ func (tn *TeamNamespace) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "team or namespace does not exist." // @Failure 404 {object} web.HTTPError "team or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/teams/{teamID} [delete] // @Router /namespaces/{namespaceID}/teams/{teamID} [delete]
func (tn *TeamNamespace) Delete() (err error) { func (tn *TeamNamespace) Delete(s *xorm.Session) (err error) {
// Check if the team exists // Check if the team exists
_, err = GetTeamByID(tn.TeamID) _, err = GetTeamByID(s, tn.TeamID)
if err != nil { if err != nil {
return return
} }
// Check if the team has access to the namespace // Check if the team has access to the namespace
has, err := x.Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID). has, err := s.
Where("team_id = ? AND namespace_id = ?", tn.TeamID, tn.NamespaceID).
Get(&TeamNamespace{}) Get(&TeamNamespace{})
if err != nil { if err != nil {
return return
@ -129,7 +132,8 @@ func (tn *TeamNamespace) Delete() (err error) {
} }
// Delete the relation // Delete the relation
_, err = x.Where("team_id = ?", tn.TeamID). _, err = s.
Where("team_id = ?", tn.TeamID).
And("namespace_id = ?", tn.NamespaceID). And("namespace_id = ?", tn.NamespaceID).
Delete(TeamNamespace{}) Delete(TeamNamespace{})
@ -151,10 +155,10 @@ func (tn *TeamNamespace) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the namespace." // @Failure 403 {object} web.HTTPError "No right to see the namespace."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/teams [get] // @Router /namespaces/{id}/teams [get]
func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (tn *TeamNamespace) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user can read the namespace // Check if the user can read the namespace
n := Namespace{ID: tn.NamespaceID} n := Namespace{ID: tn.NamespaceID}
canRead, _, err := n.CanRead(a) canRead, _, err := n.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -167,7 +171,8 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
query := x.Table("teams"). query := s.
Table("teams").
Join("INNER", "team_namespaces", "team_id = teams.id"). Join("INNER", "team_namespaces", "team_id = teams.id").
Where("team_namespaces.namespace_id = ?", tn.NamespaceID). Where("team_namespaces.namespace_id = ?", tn.NamespaceID).
Where("teams.name LIKE ?", "%"+search+"%") Where("teams.name LIKE ?", "%"+search+"%")
@ -184,12 +189,13 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
teams = append(teams, &t.Team) teams = append(teams, &t.Team)
} }
err = addMoreInfoToTeams(teams) err = addMoreInfoToTeams(s, teams)
if err != nil { if err != nil {
return return
} }
numberOfTotalItems, err = x.Table("teams"). numberOfTotalItems, err = s.
Table("teams").
Join("INNER", "team_namespaces", "team_id = teams.id"). Join("INNER", "team_namespaces", "team_id = teams.id").
Where("team_namespaces.namespace_id = ?", tn.NamespaceID). Where("team_namespaces.namespace_id = ?", tn.NamespaceID).
Where("teams.name LIKE ?", "%"+search+"%"). Where("teams.name LIKE ?", "%"+search+"%").
@ -213,14 +219,14 @@ func (tn *TeamNamespace) ReadAll(a web.Auth, search string, page int, perPage in
// @Failure 404 {object} web.HTTPError "Team or namespace does not exist." // @Failure 404 {object} web.HTTPError "Team or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/teams/{teamID} [post] // @Router /namespaces/{namespaceID}/teams/{teamID} [post]
func (tn *TeamNamespace) Update() (err error) { func (tn *TeamNamespace) Update(s *xorm.Session) (err error) {
// Check if the right is valid // Check if the right is valid
if err := tn.Right.isValid(); err != nil { if err := tn.Right.isValid(); err != nil {
return err return err
} }
_, err = x. _, err = s.
Where("namespace_id = ? AND team_id = ?", tn.NamespaceID, tn.TeamID). Where("namespace_id = ? AND team_id = ?", tn.NamespaceID, tn.TeamID).
Cols("right"). Cols("right").
Update(tn) Update(tn)

View file

@ -18,22 +18,23 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if one can create a new team <-> namespace relation // CanCreate checks if one can create a new team <-> namespace relation
func (tn *TeamNamespace) CanCreate(a web.Auth) (bool, error) { func (tn *TeamNamespace) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID} n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }
// CanDelete checks if a user can remove a team from a namespace. Only namespace admins can do that. // CanDelete checks if a user can remove a team from a namespace. Only namespace admins can do that.
func (tn *TeamNamespace) CanDelete(a web.Auth) (bool, error) { func (tn *TeamNamespace) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID} n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }
// CanUpdate checks if a user can update a team from a Only namespace admins can do that. // CanUpdate checks if a user can update a team from a Only namespace admins can do that.
func (tn *TeamNamespace) CanUpdate(a web.Auth) (bool, error) { func (tn *TeamNamespace) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: tn.NamespaceID} n := &Namespace{ID: tn.NamespaceID}
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }

View file

@ -80,6 +80,7 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tn := &TeamNamespace{ tn := &TeamNamespace{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -91,15 +92,16 @@ func TestTeamNamespace_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
if got, _ := tn.CanCreate(tt.args.a); got != tt.want["CanCreate"] { if got, _ := tn.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("TeamNamespace.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) t.Errorf("TeamNamespace.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
} }
if got, _ := tn.CanDelete(tt.args.a); got != tt.want["CanDelete"] { if got, _ := tn.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("TeamNamespace.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) t.Errorf("TeamNamespace.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
} }
if got, _ := tn.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { if got, _ := tn.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("TeamNamespace.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) t.Errorf("TeamNamespace.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
} }
_ = s.Close()
}) })
} }
} }

View file

@ -36,29 +36,35 @@ func TestTeamNamespace_ReadAll(t *testing.T) {
NamespaceID: 3, NamespaceID: 3,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
teams, _, _, err := tn.ReadAll(u, "", 1, 50) s := db.NewSession()
teams, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice) assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
s := reflect.ValueOf(teams) ts := reflect.ValueOf(teams)
assert.Equal(t, s.Len(), 2) assert.Equal(t, ts.Len(), 2)
_ = s.Close()
}) })
t.Run("nonexistant namespace", func(t *testing.T) { t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
NamespaceID: 9999, NamespaceID: 9999,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
_, _, _, err := tn.ReadAll(u, "", 1, 50) s := db.NewSession()
_, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
t.Run("no right for namespace", func(t *testing.T) { t.Run("no right for namespace", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
NamespaceID: 17, NamespaceID: 17,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
_, _, _, err := tn.ReadAll(u, "", 1, 50) s := db.NewSession()
_, _, _, err := tn.ReadAll(s, u, "", 1, 50)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNeedToHaveNamespaceReadAccess(err)) assert.True(t, IsErrNeedToHaveNamespaceReadAccess(err))
_ = s.Close()
}) })
} }
@ -72,10 +78,15 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightAdmin, Right: RightAdmin,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
allowed, _ := tn.CanCreate(u) s := db.NewSession()
allowed, _ := tn.CanCreate(s, u)
assert.True(t, allowed) assert.True(t, allowed)
err := tn.Create(u) err := tn.Create(s, u)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_namespaces", map[string]interface{}{ db.AssertExists(t, "team_namespaces", map[string]interface{}{
"team_id": 1, "team_id": 1,
"namespace_id": 1, "namespace_id": 1,
@ -89,9 +100,11 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightRead, Right: RightRead,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Create(u) s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamAlreadyHasAccess(err)) assert.True(t, IsErrTeamAlreadyHasAccess(err))
_ = s.Close()
}) })
t.Run("invalid team right", func(t *testing.T) { t.Run("invalid team right", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
@ -100,9 +113,11 @@ func TestTeamNamespace_Create(t *testing.T) {
Right: RightUnknown, Right: RightUnknown,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Create(u) s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrInvalidRight(err)) assert.True(t, IsErrInvalidRight(err))
_ = s.Close()
}) })
t.Run("nonexistant team", func(t *testing.T) { t.Run("nonexistant team", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
@ -110,9 +125,11 @@ func TestTeamNamespace_Create(t *testing.T) {
NamespaceID: 1, NamespaceID: 1,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Create(u) s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexistant namespace", func(t *testing.T) { t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
@ -120,9 +137,11 @@ func TestTeamNamespace_Create(t *testing.T) {
NamespaceID: 9999, NamespaceID: 9999,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Create(u) s := db.NewSession()
err := tn.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
} }
@ -135,10 +154,14 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 9, NamespaceID: 9,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
allowed, _ := tn.CanDelete(u) s := db.NewSession()
allowed, _ := tn.CanDelete(s, u)
assert.True(t, allowed) assert.True(t, allowed)
err := tn.Delete() err := tn.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "team_namespaces", map[string]interface{}{ db.AssertMissing(t, "team_namespaces", map[string]interface{}{
"team_id": 7, "team_id": 7,
"namespace_id": 9, "namespace_id": 9,
@ -150,9 +173,11 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 3, NamespaceID: 3,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Delete() s := db.NewSession()
err := tn.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexistant namespace", func(t *testing.T) { t.Run("nonexistant namespace", func(t *testing.T) {
tn := TeamNamespace{ tn := TeamNamespace{
@ -160,9 +185,11 @@ func TestTeamNamespace_Delete(t *testing.T) {
NamespaceID: 9999, NamespaceID: 9999,
} }
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := tn.Delete() s := db.NewSession()
err := tn.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotHaveAccessToNamespace(err)) assert.True(t, IsErrTeamDoesNotHaveAccessToNamespace(err))
_ = s.Close()
}) })
} }
@ -221,6 +248,7 @@ func TestTeamNamespace_Update(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
tl := &TeamNamespace{ tl := &TeamNamespace{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -232,13 +260,17 @@ func TestTeamNamespace_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := tl.Update() err := tl.Update(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("TeamNamespace.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("TeamNamespace.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("TeamNamespace.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("TeamNamespace.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "team_namespaces", map[string]interface{}{ db.AssertExists(t, "team_namespaces", map[string]interface{}{
"team_id": tt.fields.TeamID, "team_id": tt.fields.TeamID,

View file

@ -36,8 +36,12 @@ func TestNamespace_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := dummynamespace.Create(user1) s := db.NewSession()
err := dummynamespace.Create(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "namespaces", map[string]interface{}{ db.AssertExists(t, "namespaces", map[string]interface{}{
"title": "Test", "title": "Test",
"description": "Lorem Ipsum", "description": "Lorem Ipsum",
@ -45,18 +49,22 @@ func TestNamespace_Create(t *testing.T) {
}) })
t.Run("no title", func(t *testing.T) { t.Run("no title", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n2 := Namespace{} n2 := Namespace{}
err := n2.Create(user1) err := n2.Create(s, user1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceNameCannotBeEmpty(err)) assert.True(t, IsErrNamespaceNameCannotBeEmpty(err))
_ = s.Close()
}) })
t.Run("nonexistant user", func(t *testing.T) { t.Run("nonexistant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
nUser := &user.User{ID: 9482385} nUser := &user.User{ID: 9482385}
dnsp2 := dummynamespace dnsp2 := dummynamespace
err := dnsp2.Create(nUser) err := dnsp2.Create(s, nUser)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
}) })
} }
@ -64,28 +72,36 @@ func TestNamespace_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
n := &Namespace{ID: 1} n := &Namespace{ID: 1}
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := n.ReadOne() s := db.NewSession()
err := n.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, n.Title, "testnamespace") assert.Equal(t, n.Title, "testnamespace")
_ = s.Close()
}) })
t.Run("nonexistant", func(t *testing.T) { t.Run("nonexistant", func(t *testing.T) {
n := &Namespace{ID: 99999} n := &Namespace{ID: 99999}
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
err := n.ReadOne() s := db.NewSession()
err := n.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
} }
func TestNamespace_Update(t *testing.T) { func TestNamespace_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 1, ID: 1,
Title: "Lorem Ipsum", Title: "Lorem Ipsum",
} }
err := n.Update() err := n.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "namespaces", map[string]interface{}{ db.AssertExists(t, "namespaces", map[string]interface{}{
"id": 1, "id": 1,
"title": "Lorem Ipsum", "title": "Lorem Ipsum",
@ -93,56 +109,68 @@ func TestNamespace_Update(t *testing.T) {
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 99999, ID: 99999,
Title: "Lorem Ipsum", Title: "Lorem Ipsum",
} }
err := n.Update() err := n.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
t.Run("nonexisting owner", func(t *testing.T) { t.Run("nonexisting owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 1, ID: 1,
Title: "Lorem Ipsum", Title: "Lorem Ipsum",
Owner: &user.User{ID: 99999}, Owner: &user.User{ID: 99999},
} }
err := n.Update() err := n.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
_ = s.Close()
}) })
t.Run("no title", func(t *testing.T) { t.Run("no title", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 1, ID: 1,
} }
err := n.Update() err := n.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceNameCannotBeEmpty(err)) assert.True(t, IsErrNamespaceNameCannotBeEmpty(err))
_ = s.Close()
}) })
} }
func TestNamespace_Delete(t *testing.T) { func TestNamespace_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 1, ID: 1,
} }
err := n.Delete() err := n.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "namespaces", map[string]interface{}{ db.AssertMissing(t, "namespaces", map[string]interface{}{
"id": 1, "id": 1,
}) })
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
n := &Namespace{ n := &Namespace{
ID: 9999, ID: 9999,
} }
err := n.Delete() err := n.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrNamespaceDoesNotExist(err)) assert.True(t, IsErrNamespaceDoesNotExist(err))
_ = s.Close()
}) })
} }
@ -152,9 +180,12 @@ func TestNamespace_ReadAll(t *testing.T) {
user11 := &user.User{ID: 11} user11 := &user.User{ID: 11}
user12 := &user.User{ID: 12} user12 := &user.User{ID: 12}
s := db.NewSession()
defer s.Close()
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
n := &Namespace{} n := &Namespace{}
nn, _, _, err := n.ReadAll(user1, "", 1, -1) nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
assert.NoError(t, err) assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces) assert.NotNil(t, namespaces)
@ -174,7 +205,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{ n := &Namespace{
NamespacesOnly: true, NamespacesOnly: true,
} }
nn, _, _, err := n.ReadAll(user1, "", 1, -1) nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
assert.NoError(t, err) assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces) assert.NotNil(t, namespaces)
@ -188,7 +219,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{ n := &Namespace{
NamespacesOnly: true, NamespacesOnly: true,
} }
nn, _, _, err := n.ReadAll(user7, "13,14", 1, -1) nn, _, _, err := n.ReadAll(s, user7, "13,14", 1, -1)
assert.NoError(t, err) assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces) assert.NotNil(t, namespaces)
@ -200,7 +231,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{ n := &Namespace{
NamespacesOnly: true, NamespacesOnly: true,
} }
nn, _, _, err := n.ReadAll(user1, "1,w", 1, -1) nn, _, _, err := n.ReadAll(s, user1, "1,w", 1, -1)
assert.NoError(t, err) assert.NoError(t, err)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NotNil(t, namespaces) assert.NotNil(t, namespaces)
@ -211,7 +242,7 @@ func TestNamespace_ReadAll(t *testing.T) {
n := &Namespace{ n := &Namespace{
IsArchived: true, IsArchived: true,
} }
nn, _, _, err := n.ReadAll(user1, "", 1, -1) nn, _, _, err := n.ReadAll(s, user1, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, namespaces) assert.NotNil(t, namespaces)
@ -222,7 +253,7 @@ func TestNamespace_ReadAll(t *testing.T) {
}) })
t.Run("no favorites", func(t *testing.T) { t.Run("no favorites", func(t *testing.T) {
n := &Namespace{} n := &Namespace{}
nn, _, _, err := n.ReadAll(user11, "", 1, -1) nn, _, _, err := n.ReadAll(s, user11, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err) assert.NoError(t, err)
// Assert the first namespace is not the favorites namespace // Assert the first namespace is not the favorites namespace
@ -230,7 +261,7 @@ func TestNamespace_ReadAll(t *testing.T) {
}) })
t.Run("no favorite tasks but namespace", func(t *testing.T) { t.Run("no favorite tasks but namespace", func(t *testing.T) {
n := &Namespace{} n := &Namespace{}
nn, _, _, err := n.ReadAll(user12, "", 1, -1) nn, _, _, err := n.ReadAll(s, user12, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err) assert.NoError(t, err)
// Assert the first namespace is the favorites namespace and contains lists // Assert the first namespace is the favorites namespace and contains lists
@ -239,7 +270,7 @@ func TestNamespace_ReadAll(t *testing.T) {
}) })
t.Run("no saved filters", func(t *testing.T) { t.Run("no saved filters", func(t *testing.T) {
n := &Namespace{} n := &Namespace{}
nn, _, _, err := n.ReadAll(user11, "", 1, -1) nn, _, _, err := n.ReadAll(s, user11, "", 1, -1)
namespaces := nn.([]*NamespaceWithLists) namespaces := nn.([]*NamespaceWithLists)
assert.NoError(t, err) assert.NoError(t, err)
// Assert the first namespace is not the favorites namespace // Assert the first namespace is not the favorites namespace

View file

@ -21,6 +21,7 @@ import (
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// NamespaceUser represents a namespace <-> user relation // NamespaceUser represents a namespace <-> user relation
@ -64,7 +65,7 @@ func (NamespaceUser) TableName() string {
// @Failure 403 {object} web.HTTPError "The user does not have access to the namespace" // @Failure 403 {object} web.HTTPError "The user does not have access to the namespace"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/users [put] // @Router /namespaces/{id}/users [put]
func (nu *NamespaceUser) Create(a web.Auth) (err error) { func (nu *NamespaceUser) Create(s *xorm.Session, a web.Auth) (err error) {
// Reset the id // Reset the id
nu.ID = 0 nu.ID = 0
@ -74,13 +75,13 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
} }
// Check if the namespace exists // Check if the namespace exists
l, err := GetNamespaceByID(nu.NamespaceID) l, err := GetNamespaceByID(s, nu.NamespaceID)
if err != nil { if err != nil {
return return
} }
// Check if the user exists // Check if the user exists
user, err := user2.GetUserByUsername(nu.Username) user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil { if err != nil {
return err return err
} }
@ -92,7 +93,9 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
return ErrUserAlreadyHasNamespaceAccess{UserID: nu.UserID, NamespaceID: nu.NamespaceID} return ErrUserAlreadyHasNamespaceAccess{UserID: nu.UserID, NamespaceID: nu.NamespaceID}
} }
exist, err := x.Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).Get(&NamespaceUser{}) exist, err := s.
Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).
Get(&NamespaceUser{})
if err != nil { if err != nil {
return return
} }
@ -101,7 +104,7 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
} }
// Insert user <-> namespace relation // Insert user <-> namespace relation
_, err = x.Insert(nu) _, err = s.Insert(nu)
return return
} }
@ -119,17 +122,18 @@ func (nu *NamespaceUser) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "user or namespace does not exist." // @Failure 404 {object} web.HTTPError "user or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/users/{userID} [delete] // @Router /namespaces/{namespaceID}/users/{userID} [delete]
func (nu *NamespaceUser) Delete() (err error) { func (nu *NamespaceUser) Delete(s *xorm.Session) (err error) {
// Check if the user exists // Check if the user exists
user, err := user2.GetUserByUsername(nu.Username) user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil { if err != nil {
return return
} }
nu.UserID = user.ID nu.UserID = user.ID
// Check if the user has access to the namespace // Check if the user has access to the namespace
has, err := x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). has, err := s.
Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
Get(&NamespaceUser{}) Get(&NamespaceUser{})
if err != nil { if err != nil {
return return
@ -138,7 +142,8 @@ func (nu *NamespaceUser) Delete() (err error) {
return ErrUserDoesNotHaveAccessToNamespace{NamespaceID: nu.NamespaceID, UserID: nu.UserID} return ErrUserDoesNotHaveAccessToNamespace{NamespaceID: nu.NamespaceID, UserID: nu.UserID}
} }
_, err = x.Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID). _, err = s.
Where("user_id = ? AND namespace_id = ?", nu.UserID, nu.NamespaceID).
Delete(&NamespaceUser{}) Delete(&NamespaceUser{})
return return
} }
@ -158,10 +163,10 @@ func (nu *NamespaceUser) Delete() (err error) {
// @Failure 403 {object} web.HTTPError "No right to see the namespace." // @Failure 403 {object} web.HTTPError "No right to see the namespace."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/users [get] // @Router /namespaces/{id}/users [get]
func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (nu *NamespaceUser) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the namespace // Check if the user has access to the namespace
l := Namespace{ID: nu.NamespaceID} l := Namespace{ID: nu.NamespaceID}
canRead, _, err := l.CanRead(a) canRead, _, err := l.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -174,7 +179,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
query := x. query := s.
Join("INNER", "users_namespace", "user_id = users.id"). Join("INNER", "users_namespace", "user_id = users.id").
Where("users_namespace.namespace_id = ?", nu.NamespaceID). Where("users_namespace.namespace_id = ?", nu.NamespaceID).
Where("users.username LIKE ?", "%"+search+"%") Where("users.username LIKE ?", "%"+search+"%")
@ -191,7 +196,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
u.Email = "" u.Email = ""
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Join("INNER", "users_namespace", "user_id = users.id"). Join("INNER", "users_namespace", "user_id = users.id").
Where("users_namespace.namespace_id = ?", nu.NamespaceID). Where("users_namespace.namespace_id = ?", nu.NamespaceID).
Where("users.username LIKE ?", "%"+search+"%"). Where("users.username LIKE ?", "%"+search+"%").
@ -215,7 +220,7 @@ func (nu *NamespaceUser) ReadAll(a web.Auth, search string, page int, perPage in
// @Failure 404 {object} web.HTTPError "User or namespace does not exist." // @Failure 404 {object} web.HTTPError "User or namespace does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{namespaceID}/users/{userID} [post] // @Router /namespaces/{namespaceID}/users/{userID} [post]
func (nu *NamespaceUser) Update() (err error) { func (nu *NamespaceUser) Update(s *xorm.Session) (err error) {
// Check if the right is valid // Check if the right is valid
if err := nu.Right.isValid(); err != nil { if err := nu.Right.isValid(); err != nil {
@ -223,13 +228,13 @@ func (nu *NamespaceUser) Update() (err error) {
} }
// Check if the user exists // Check if the user exists
user, err := user2.GetUserByUsername(nu.Username) user, err := user2.GetUserByUsername(s, nu.Username)
if err != nil { if err != nil {
return err return err
} }
nu.UserID = user.ID nu.UserID = user.ID
_, err = x. _, err = s.
Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID). Where("namespace_id = ? AND user_id = ?", nu.NamespaceID, nu.UserID).
Cols("right"). Cols("right").
Update(nu) Update(nu)

View file

@ -18,24 +18,25 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if the user can create a new user <-> namespace relation // CanCreate checks if the user can create a new user <-> namespace relation
func (nu *NamespaceUser) CanCreate(a web.Auth) (bool, error) { func (nu *NamespaceUser) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a) return nu.canDoNamespaceUser(s, a)
} }
// CanDelete checks if the user can delete a user <-> namespace relation // CanDelete checks if the user can delete a user <-> namespace relation
func (nu *NamespaceUser) CanDelete(a web.Auth) (bool, error) { func (nu *NamespaceUser) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a) return nu.canDoNamespaceUser(s, a)
} }
// CanUpdate checks if the user can update a user <-> namespace relation // CanUpdate checks if the user can update a user <-> namespace relation
func (nu *NamespaceUser) CanUpdate(a web.Auth) (bool, error) { func (nu *NamespaceUser) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return nu.canDoNamespaceUser(a) return nu.canDoNamespaceUser(s, a)
} }
func (nu *NamespaceUser) canDoNamespaceUser(a web.Auth) (bool, error) { func (nu *NamespaceUser) canDoNamespaceUser(s *xorm.Session, a web.Auth) (bool, error) {
n := &Namespace{ID: nu.NamespaceID} n := &Namespace{ID: nu.NamespaceID}
return n.IsAdmin(a) return n.IsAdmin(s, a)
} }

View file

@ -80,6 +80,8 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
nu := &NamespaceUser{ nu := &NamespaceUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -91,13 +93,13 @@ func TestNamespaceUser_CanDoSomething(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
if got, _ := nu.CanCreate(tt.args.a); got != tt.want["CanCreate"] { if got, _ := nu.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] {
t.Errorf("NamespaceUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) t.Errorf("NamespaceUser.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
} }
if got, _ := nu.CanDelete(tt.args.a); got != tt.want["CanDelete"] { if got, _ := nu.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("NamespaceUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) t.Errorf("NamespaceUser.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
} }
if got, _ := nu.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { if got, _ := nu.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("NamespaceUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) t.Errorf("NamespaceUser.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
} }
}) })

View file

@ -25,6 +25,7 @@ import (
"code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"github.com/stretchr/testify/assert"
"gopkg.in/d4l3k/messagediff.v1" "gopkg.in/d4l3k/messagediff.v1"
) )
@ -108,6 +109,7 @@ func TestNamespaceUser_Create(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
un := &NamespaceUser{ un := &NamespaceUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -119,13 +121,16 @@ func TestNamespaceUser_Create(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := un.Create(tt.args.a) err := un.Create(s, tt.args.a)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Create() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NamespaceUser.Create() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("NamespaceUser.Create() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "users_namespace", map[string]interface{}{ db.AssertExists(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID, "user_id": tt.fields.UserID,
@ -211,6 +216,8 @@ func TestNamespaceUser_ReadAll(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
un := &NamespaceUser{ un := &NamespaceUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -222,7 +229,7 @@ func TestNamespaceUser_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
got, _, _, err := un.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) got, _, _, err := un.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NamespaceUser.ReadAll() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -296,6 +303,7 @@ func TestNamespaceUser_Update(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
nu := &NamespaceUser{ nu := &NamespaceUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -307,13 +315,16 @@ func TestNamespaceUser_Update(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := nu.Update() err := nu.Update(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NamespaceUser.Update() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("NamespaceUser.Update() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertExists(t, "users_namespace", map[string]interface{}{ db.AssertExists(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID, "user_id": tt.fields.UserID,
@ -373,6 +384,7 @@ func TestNamespaceUser_Delete(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
nu := &NamespaceUser{ nu := &NamespaceUser{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -384,13 +396,16 @@ func TestNamespaceUser_Delete(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
err := nu.Delete() err := nu.Delete(s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NamespaceUser.Delete() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NamespaceUser.Delete() error = %v, wantErr %v", err, tt.wantErr)
} }
if (err != nil) && tt.wantErr && !tt.errType(err) { if (err != nil) && tt.wantErr && !tt.errType(err) {
t.Errorf("NamespaceUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name()) t.Errorf("NamespaceUser.Delete() Wrong error type! Error = %v, want = %v", err, runtime.FuncForPC(reflect.ValueOf(tt.errType).Pointer()).Name())
} }
err = s.Commit()
assert.NoError(t, err)
if !tt.wantErr { if !tt.wantErr {
db.AssertMissing(t, "users_namespace", map[string]interface{}{ db.AssertMissing(t, "users_namespace", map[string]interface{}{
"user_id": tt.fields.UserID, "user_id": tt.fields.UserID,

View file

@ -21,6 +21,7 @@ import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// SavedFilter represents a saved bunch of filters // SavedFilter represents a saved bunch of filters
@ -48,14 +49,14 @@ type SavedFilter struct {
} }
// TableName returns a better table name for saved filters // TableName returns a better table name for saved filters
func (s *SavedFilter) TableName() string { func (sf *SavedFilter) TableName() string {
return "saved_filters" return "saved_filters"
} }
func (s *SavedFilter) getTaskCollection() *TaskCollection { func (sf *SavedFilter) getTaskCollection() *TaskCollection {
// We're resetting the listID to return tasks from all lists // We're resetting the listID to return tasks from all lists
s.Filters.ListID = 0 sf.Filters.ListID = 0
return s.Filters return sf.Filters
} }
// Returns the saved filter ID from a list ID. Will not check if the filter actually exists. // Returns the saved filter ID from a list ID. Will not check if the filter actually exists.
@ -79,13 +80,13 @@ func getListIDFromSavedFilterID(filterID int64) (listID int64) {
return return
} }
func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) { func getSavedFiltersForUser(s *xorm.Session, auth web.Auth) (filters []*SavedFilter, err error) {
// Link shares can't view or modify saved filters, therefore we can error out right away // Link shares can't view or modify saved filters, therefore we can error out right away
if _, is := auth.(*LinkSharing); is { if _, is := auth.(*LinkSharing); is {
return nil, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID()} return nil, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID()}
} }
err = x.Where("owner_id = ?", auth.GetID()).Find(&filters) err = s.Where("owner_id = ?", auth.GetID()).Find(&filters)
return return
} }
@ -100,17 +101,17 @@ func getSavedFiltersForUser(auth web.Auth) (filters []*SavedFilter, err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter." // @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /filters [put] // @Router /filters [put]
func (s *SavedFilter) Create(auth web.Auth) error { func (sf *SavedFilter) Create(s *xorm.Session, auth web.Auth) error {
s.OwnerID = auth.GetID() sf.OwnerID = auth.GetID()
_, err := x.Insert(s) _, err := s.Insert(sf)
return err return err
} }
func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) { func getSavedFilterSimpleByID(s *xorm.Session, id int64) (sf *SavedFilter, err error) {
s = &SavedFilter{} sf = &SavedFilter{}
exists, err := x. exists, err := s.
Where("id = ?", id). Where("id = ?", id).
Get(s) Get(sf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,10 +133,10 @@ func getSavedFilterSimpleByID(id int64) (s *SavedFilter, err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter." // @Failure 403 {object} web.HTTPError "The user does not have access to that saved filter."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [get] // @Router /filters/{id} [get]
func (s *SavedFilter) ReadOne() error { func (sf *SavedFilter) ReadOne(s *xorm.Session) error {
// s already contains almost the full saved filter from the rights check, we only need to add the user // s already contains almost the full saved filter from the rights check, we only need to add the user
u, err := user.GetUserByID(s.OwnerID) u, err := user.GetUserByID(s, sf.OwnerID)
s.Owner = u sf.Owner = u
return err return err
} }
@ -152,15 +153,15 @@ func (s *SavedFilter) ReadOne() error {
// @Failure 404 {object} web.HTTPError "The saved filter does not exist." // @Failure 404 {object} web.HTTPError "The saved filter does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [post] // @Router /filters/{id} [post]
func (s *SavedFilter) Update() error { func (sf *SavedFilter) Update(s *xorm.Session) error {
_, err := x. _, err := s.
Where("id = ?", s.ID). Where("id = ?", sf.ID).
Cols( Cols(
"title", "title",
"description", "description",
"filters", "filters",
). ).
Update(s) Update(sf)
return err return err
} }
@ -177,7 +178,9 @@ func (s *SavedFilter) Update() error {
// @Failure 404 {object} web.HTTPError "The saved filter does not exist." // @Failure 404 {object} web.HTTPError "The saved filter does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /filters/{id} [delete] // @Router /filters/{id} [delete]
func (s *SavedFilter) Delete() error { func (sf *SavedFilter) Delete(s *xorm.Session) error {
_, err := x.Where("id = ?", s.ID).Delete(s) _, err := s.
Where("id = ?", sf.ID).
Delete(sf)
return err return err
} }

View file

@ -16,28 +16,31 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if a user has the right to read a saved filter // CanRead checks if a user has the right to read a saved filter
func (s *SavedFilter) CanRead(auth web.Auth) (bool, int, error) { func (sf *SavedFilter) CanRead(s *xorm.Session, auth web.Auth) (bool, int, error) {
can, err := s.canDoFilter(auth) can, err := sf.canDoFilter(s, auth)
return can, int(RightAdmin), err return can, int(RightAdmin), err
} }
// CanDelete checks if a user has the right to delete a saved filter // CanDelete checks if a user has the right to delete a saved filter
func (s *SavedFilter) CanDelete(auth web.Auth) (bool, error) { func (sf *SavedFilter) CanDelete(s *xorm.Session, auth web.Auth) (bool, error) {
return s.canDoFilter(auth) return sf.canDoFilter(s, auth)
} }
// CanUpdate checks if a user has the right to update a saved filter // CanUpdate checks if a user has the right to update a saved filter
func (s *SavedFilter) CanUpdate(auth web.Auth) (bool, error) { func (sf *SavedFilter) CanUpdate(s *xorm.Session, auth web.Auth) (bool, error) {
// A normal check would replace the passed struct which in our case would override the values we want to update. // A normal check would replace the passed struct which in our case would override the values we want to update.
sf := &SavedFilter{ID: s.ID} sff := &SavedFilter{ID: sf.ID}
return sf.canDoFilter(auth) return sff.canDoFilter(s, auth)
} }
// CanCreate checks if a user has the right to update a saved filter // CanCreate checks if a user has the right to update a saved filter
func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) { func (sf *SavedFilter) CanCreate(s *xorm.Session, auth web.Auth) (bool, error) {
if _, is := auth.(*LinkSharing); is { if _, is := auth.(*LinkSharing); is {
return false, nil return false, nil
} }
@ -46,23 +49,23 @@ func (s *SavedFilter) CanCreate(auth web.Auth) (bool, error) {
} }
// Helper function to check saved filter rights sind they all have the same logic // Helper function to check saved filter rights sind they all have the same logic
func (s *SavedFilter) canDoFilter(auth web.Auth) (can bool, err error) { func (sf *SavedFilter) canDoFilter(s *xorm.Session, auth web.Auth) (can bool, err error) {
// Link shares can't view or modify saved filters, therefore we can error out right away // Link shares can't view or modify saved filters, therefore we can error out right away
if _, is := auth.(*LinkSharing); is { if _, is := auth.(*LinkSharing); is {
return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: s.ID} return false, ErrSavedFilterNotAvailableForLinkShare{LinkShareID: auth.GetID(), SavedFilterID: sf.ID}
} }
sf, err := getSavedFilterSimpleByID(s.ID) sff, err := getSavedFilterSimpleByID(s, sf.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
// Only owners are allowed to do something with a saved filter // Only owners are allowed to do something with a saved filter
if sf.OwnerID != auth.GetID() { if sff.OwnerID != auth.GetID() {
return false, nil return false, nil
} }
*s = *sf *sf = *sff
return true, nil return true, nil
} }

View file

@ -45,6 +45,9 @@ func TestSavedFilter_getFilterIDFromListID(t *testing.T) {
func TestSavedFilter_Create(t *testing.T) { func TestSavedFilter_Create(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
Title: "test", Title: "test",
Description: "Lorem Ipsum dolor sit amet", Description: "Lorem Ipsum dolor sit amet",
@ -52,9 +55,11 @@ func TestSavedFilter_Create(t *testing.T) {
} }
u := &user.User{ID: 1} u := &user.User{ID: 1}
err := sf.Create(u) err := sf.Create(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, u.ID, sf.OwnerID) assert.Equal(t, u.ID, sf.OwnerID)
err = s.Commit()
assert.NoError(t, err)
vals := map[string]interface{}{ vals := map[string]interface{}{
"title": "'test'", "title": "'test'",
"description": "'Lorem Ipsum dolor sit amet'", "description": "'Lorem Ipsum dolor sit amet'",
@ -62,7 +67,7 @@ func TestSavedFilter_Create(t *testing.T) {
"owner_id": 1, "owner_id": 1,
} }
// Postgres can't compare json values directly, see https://dba.stackexchange.com/a/106290/210721 // Postgres can't compare json values directly, see https://dba.stackexchange.com/a/106290/210721
if x.Dialect().URI().DBType == schemas.POSTGRES { if db.Type() == schemas.POSTGRES {
vals["filters::jsonb"] = vals["filters"].(string) + "::jsonb" vals["filters::jsonb"] = vals["filters"].(string) + "::jsonb"
delete(vals, "filters") delete(vals, "filters")
} }
@ -72,26 +77,34 @@ func TestSavedFilter_Create(t *testing.T) {
func TestSavedFilter_ReadOne(t *testing.T) { func TestSavedFilter_ReadOne(t *testing.T) {
user1 := &user.User{ID: 1} user1 := &user.User{ID: 1}
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
} }
// canRead pre-populates the struct // canRead pre-populates the struct
_, _, err := sf.CanRead(user1) _, _, err := sf.CanRead(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
err = sf.ReadOne() err = sf.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, sf.Owner) assert.NotNil(t, sf.Owner)
} }
func TestSavedFilter_Update(t *testing.T) { func TestSavedFilter_Update(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "NewTitle", Title: "NewTitle",
Description: "", // Explicitly reset the description Description: "", // Explicitly reset the description
Filters: &TaskCollection{}, Filters: &TaskCollection{},
} }
err := sf.Update() err := sf.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "saved_filters", map[string]interface{}{ db.AssertExists(t, "saved_filters", map[string]interface{}{
"id": 1, "id": 1,
@ -102,10 +115,15 @@ func TestSavedFilter_Update(t *testing.T) {
func TestSavedFilter_Delete(t *testing.T) { func TestSavedFilter_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
} }
err := sf.Delete() err := sf.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertMissing(t, "saved_filters", map[string]interface{}{ db.AssertMissing(t, "saved_filters", map[string]interface{}{
"id": 1, "id": 1,
@ -120,50 +138,65 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("create", func(t *testing.T) { t.Run("create", func(t *testing.T) {
// Should always be true // Should always be true
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
can, err := (&SavedFilter{}).CanCreate(user1) s := db.NewSession()
defer s.Close()
can, err := (&SavedFilter{}).CanCreate(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("read", func(t *testing.T) { t.Run("read", func(t *testing.T) {
t.Run("owner", func(t *testing.T) { t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, max, err := sf.CanRead(user1) can, max, err := sf.CanRead(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int(RightAdmin), max) assert.Equal(t, int(RightAdmin), max)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("not owner", func(t *testing.T) { t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, _, err := sf.CanRead(user2) can, _, err := sf.CanRead(s, user2)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 9999, ID: 9999,
Title: "Lorem", Title: "Lorem",
} }
can, _, err := sf.CanRead(user1) can, _, err := sf.CanRead(s, user1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can) assert.False(t, can)
}) })
t.Run("link share", func(t *testing.T) { t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, _, err := sf.CanRead(ls) can, _, err := sf.CanRead(s, ls)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can) assert.False(t, can)
@ -172,42 +205,54 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("update", func(t *testing.T) { t.Run("update", func(t *testing.T) {
t.Run("owner", func(t *testing.T) { t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanUpdate(user1) can, err := sf.CanUpdate(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("not owner", func(t *testing.T) { t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanUpdate(user2) can, err := sf.CanUpdate(s, user2)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 9999, ID: 9999,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanUpdate(user1) can, err := sf.CanUpdate(s, user1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can) assert.False(t, can)
}) })
t.Run("link share", func(t *testing.T) { t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanUpdate(ls) can, err := sf.CanUpdate(s, ls)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can) assert.False(t, can)
@ -216,40 +261,52 @@ func TestSavedFilter_Rights(t *testing.T) {
t.Run("delete", func(t *testing.T) { t.Run("delete", func(t *testing.T) {
t.Run("owner", func(t *testing.T) { t.Run("owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
} }
can, err := sf.CanDelete(user1) can, err := sf.CanDelete(s, user1)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("not owner", func(t *testing.T) { t.Run("not owner", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
} }
can, err := sf.CanDelete(user2) can, err := sf.CanDelete(s, user2)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 9999, ID: 9999,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanDelete(user1) can, err := sf.CanDelete(s, user1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterDoesNotExist(err)) assert.True(t, IsErrSavedFilterDoesNotExist(err))
assert.False(t, can) assert.False(t, can)
}) })
t.Run("link share", func(t *testing.T) { t.Run("link share", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sf := &SavedFilter{ sf := &SavedFilter{
ID: 1, ID: 1,
Title: "Lorem", Title: "Lorem",
} }
can, err := sf.CanDelete(ls) can, err := sf.CanDelete(s, ls)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err)) assert.True(t, IsErrSavedFilterNotAvailableForLinkShare(err))
assert.False(t, can) assert.False(t, can)

View file

@ -46,9 +46,9 @@ type TaskAssigneeWithUser struct {
user.User `xorm:"extends"` user.User `xorm:"extends"`
} }
func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) { func getRawTaskAssigneesForTasks(s *xorm.Session, taskIDs []int64) (taskAssignees []*TaskAssigneeWithUser, err error) {
taskAssignees = []*TaskAssigneeWithUser{} taskAssignees = []*TaskAssigneeWithUser{}
err = x.Table("task_assignees"). err = s.Table("task_assignees").
Select("task_id, users.*"). Select("task_id, users.*").
In("task_id", taskIDs). In("task_id", taskIDs).
Join("INNER", "users", "task_assignees.user_id = users.id"). Join("INNER", "users", "task_assignees.user_id = users.id").
@ -60,7 +60,7 @@ func getRawTaskAssigneesForTasks(taskIDs []int64) (taskAssignees []*TaskAssignee
func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err error) { func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err error) {
// Load the current assignees // Load the current assignees
currentAssignees, err := getRawTaskAssigneesForTasks([]int64{t.ID}) currentAssignees, err := getRawTaskAssigneesForTasks(s, []int64{t.ID})
if err != nil { if err != nil {
return err return err
} }
@ -118,8 +118,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
} }
// Get the list to perform later checks // Get the list to perform later checks
list := List{ID: t.ListID} list, err := GetListSimpleByID(s, t.ListID)
err = list.GetSimpleByID()
if err != nil { if err != nil {
return return
} }
@ -133,7 +132,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
} }
// Add the new assignee // Add the new assignee
err = t.addNewAssigneeByID(u.ID, &list) err = t.addNewAssigneeByID(s, u.ID, list)
if err != nil { if err != nil {
return err return err
} }
@ -141,7 +140,7 @@ func (t *Task) updateTaskAssignees(s *xorm.Session, assignees []*user.User) (err
t.setTaskAssignees(assignees) t.setTaskAssignees(assignees)
err = updateListLastUpdated(&List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -167,13 +166,13 @@ func (t *Task) setTaskAssignees(assignees []*user.User) {
// @Failure 403 {object} web.HTTPError "Not allowed to delete the assignee." // @Failure 403 {object} web.HTTPError "Not allowed to delete the assignee."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees/{userID} [delete] // @Router /tasks/{taskID}/assignees/{userID} [delete]
func (la *TaskAssginee) Delete() (err error) { func (la *TaskAssginee) Delete(s *xorm.Session) (err error) {
_, err = x.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID}) _, err = s.Delete(&TaskAssginee{TaskID: la.TaskID, UserID: la.UserID})
if err != nil { if err != nil {
return err return err
} }
err = updateListByTaskID(la.TaskID) err = updateListByTaskID(s, la.TaskID)
return return
} }
@ -190,25 +189,25 @@ func (la *TaskAssginee) Delete() (err error) {
// @Failure 400 {object} web.HTTPError "Invalid assignee object provided." // @Failure 400 {object} web.HTTPError "Invalid assignee object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees [put] // @Router /tasks/{taskID}/assignees [put]
func (la *TaskAssginee) Create(a web.Auth) (err error) { func (la *TaskAssginee) Create(s *xorm.Session, a web.Auth) (err error) {
// Get the list to perform later checks // Get the list to perform later checks
list, err := GetListSimplByTaskID(la.TaskID) list, err := GetListSimplByTaskID(s, la.TaskID)
if err != nil { if err != nil {
return return
} }
task := &Task{ID: la.TaskID} task := &Task{ID: la.TaskID}
return task.addNewAssigneeByID(la.UserID, list) return task.addNewAssigneeByID(s, la.UserID, list)
} }
func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) { func (t *Task) addNewAssigneeByID(s *xorm.Session, newAssigneeID int64, list *List) (err error) {
// Check if the user exists and has access to the list // Check if the user exists and has access to the list
newAssignee, err := user.GetUserByID(newAssigneeID) newAssignee, err := user.GetUserByID(s, newAssigneeID)
if err != nil { if err != nil {
return err return err
} }
canRead, _, err := list.CanRead(newAssignee) canRead, _, err := list.CanRead(s, newAssignee)
if err != nil { if err != nil {
return err return err
} }
@ -216,7 +215,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
return ErrUserDoesNotHaveAccessToList{list.ID, newAssigneeID} return ErrUserDoesNotHaveAccessToList{list.ID, newAssigneeID}
} }
_, err = x.Insert(TaskAssginee{ _, err = s.Insert(TaskAssginee{
TaskID: t.ID, TaskID: t.ID,
UserID: newAssigneeID, UserID: newAssigneeID,
}) })
@ -224,7 +223,7 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
return err return err
} }
err = updateListLastUpdated(&List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -242,13 +241,13 @@ func (t *Task) addNewAssigneeByID(newAssigneeID int64, list *List) (err error) {
// @Success 200 {array} user.User "The assignees" // @Success 200 {array} user.User "The assignees"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees [get] // @Router /tasks/{taskID}/assignees [get]
func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (la *TaskAssginee) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
task, err := GetListSimplByTaskID(la.TaskID) task, err := GetListSimplByTaskID(s, la.TaskID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
can, _, err := task.CanRead(a) can, _, err := task.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -258,7 +257,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
var taskAssignees []*user.User var taskAssignees []*user.User
query := x.Table("task_assignees"). query := s.Table("task_assignees").
Select("users.*"). Select("users.*").
Join("INNER", "users", "task_assignees.user_id = users.id"). Join("INNER", "users", "task_assignees.user_id = users.id").
Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%") Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%")
@ -270,7 +269,7 @@ func (la *TaskAssginee) ReadAll(a web.Auth, search string, page int, perPage int
return nil, 0, 0, err return nil, 0, 0, err
} }
numberOfTotalItems, err = x.Table("task_assignees"). numberOfTotalItems, err = s.Table("task_assignees").
Select("users.*"). Select("users.*").
Join("INNER", "users", "task_assignees.user_id = users.id"). Join("INNER", "users", "task_assignees.user_id = users.id").
Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%"). Where("task_id = ? AND users.username LIKE ?", la.TaskID, "%"+search+"%").
@ -301,14 +300,12 @@ type BulkAssignees struct {
// @Failure 400 {object} web.HTTPError "Invalid assignee object provided." // @Failure 400 {object} web.HTTPError "Invalid assignee object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/assignees/bulk [post] // @Router /tasks/{taskID}/assignees/bulk [post]
func (ba *BulkAssignees) Create(a web.Auth) (err error) { func (ba *BulkAssignees) Create(s *xorm.Session, a web.Auth) (err error) {
s := x.NewSession() task, err := GetTaskByIDSimple(s, ba.TaskID)
task, err := GetTaskByIDSimple(ba.TaskID)
if err != nil { if err != nil {
return return
} }
assignees, err := getRawTaskAssigneesForTasks([]int64{task.ID}) assignees, err := getRawTaskAssigneesForTasks(s, []int64{task.ID})
if err != nil { if err != nil {
return err return err
} }
@ -317,10 +314,5 @@ func (ba *BulkAssignees) Create(a web.Auth) (err error) {
} }
err = task.updateTaskAssignees(s, ba.Assignees) err = task.updateTaskAssignees(s, ba.Assignees)
if err != nil { return
_ = s.Rollback()
return err
}
return s.Commit()
} }

View file

@ -18,28 +18,29 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if a user can add a new assignee // CanCreate checks if a user can add a new assignee
func (la *TaskAssginee) CanCreate(a web.Auth) (bool, error) { func (la *TaskAssginee) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(la.TaskID, a) return canDoTaskAssingee(s, la.TaskID, a)
} }
// CanCreate checks if a user can add a new assignee // CanCreate checks if a user can add a new assignee
func (ba *BulkAssignees) CanCreate(a web.Auth) (bool, error) { func (ba *BulkAssignees) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(ba.TaskID, a) return canDoTaskAssingee(s, ba.TaskID, a)
} }
// CanDelete checks if a user can delete an assignee // CanDelete checks if a user can delete an assignee
func (la *TaskAssginee) CanDelete(a web.Auth) (bool, error) { func (la *TaskAssginee) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return canDoTaskAssingee(la.TaskID, a) return canDoTaskAssingee(s, la.TaskID, a)
} }
func canDoTaskAssingee(taskID int64, a web.Auth) (bool, error) { func canDoTaskAssingee(s *xorm.Session, taskID int64, a web.Auth) (bool, error) {
// Check if the current user can edit the list // Check if the current user can edit the list
list, err := GetListSimplByTaskID(taskID) list, err := GetListSimplByTaskID(s, taskID)
if err != nil { if err != nil {
return false, err return false, err
} }
return list.CanUpdate(a) return list.CanUpdate(s, a)
} }

View file

@ -23,6 +23,7 @@ import (
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// TaskAttachment is the definition of a task attachment // TaskAttachment is the definition of a task attachment
@ -49,7 +50,7 @@ func (TaskAttachment) TableName() string {
// NewAttachment creates a new task attachment // NewAttachment creates a new task attachment
// Note: I'm not sure if only accepting an io.ReadCloser and not an afero.File or os.File instead is a good way of doing things. // Note: I'm not sure if only accepting an io.ReadCloser and not an afero.File or os.File instead is a good way of doing things.
func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsize uint64, a web.Auth) error { func (ta *TaskAttachment) NewAttachment(s *xorm.Session, f io.ReadCloser, realname string, realsize uint64, a web.Auth) error {
// Store the file // Store the file
file, err := files.Create(f, realname, realsize, a) file, err := files.Create(f, realname, realsize, a)
@ -64,7 +65,7 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi
// Add an entry to the db // Add an entry to the db
ta.FileID = file.ID ta.FileID = file.ID
ta.CreatedByID = a.GetID() ta.CreatedByID = a.GetID()
_, err = x.Insert(ta) _, err = s.Insert(ta)
if err != nil { if err != nil {
// remove the uploaded file if adding it to the db fails // remove the uploaded file if adding it to the db fails
if err2 := file.Delete(); err2 != nil { if err2 := file.Delete(); err2 != nil {
@ -77,8 +78,8 @@ func (ta *TaskAttachment) NewAttachment(f io.ReadCloser, realname string, realsi
} }
// ReadOne returns a task attachment // ReadOne returns a task attachment
func (ta *TaskAttachment) ReadOne() (err error) { func (ta *TaskAttachment) ReadOne(s *xorm.Session) (err error) {
exists, err := x.Where("id = ?", ta.ID).Get(ta) exists, err := s.Where("id = ?", ta.ID).Get(ta)
if err != nil { if err != nil {
return return
} }
@ -110,12 +111,12 @@ func (ta *TaskAttachment) ReadOne() (err error) {
// @Failure 404 {object} models.Message "The task does not exist." // @Failure 404 {object} models.Message "The task does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id}/attachments [get] // @Router /tasks/{id}/attachments [get]
func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (ta *TaskAttachment) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
attachments := []*TaskAttachment{} attachments := []*TaskAttachment{}
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
query := x. query := s.
Where("task_id = ?", ta.TaskID) Where("task_id = ?", ta.TaskID)
if limit > 0 { if limit > 0 {
query = query.Limit(limit, start) query = query.Limit(limit, start)
@ -133,13 +134,13 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
} }
fs := make(map[int64]*files.File) fs := make(map[int64]*files.File)
err = x.In("id", fileIDs).Find(&fs) err = s.In("id", fileIDs).Find(&fs)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
us := make(map[int64]*user.User) us := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&us) err = s.In("id", userIDs).Find(&us)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -153,7 +154,7 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
r.CreatedBy = us[r.CreatedByID] r.CreatedBy = us[r.CreatedByID]
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Where("task_id = ?", ta.TaskID). Where("task_id = ?", ta.TaskID).
Count(&TaskAttachment{}) Count(&TaskAttachment{})
return attachments, len(attachments), numberOfTotalItems, err return attachments, len(attachments), numberOfTotalItems, err
@ -173,15 +174,17 @@ func (ta *TaskAttachment) ReadAll(a web.Auth, search string, page int, perPage i
// @Failure 404 {object} models.Message "The task does not exist." // @Failure 404 {object} models.Message "The task does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id}/attachments/{attachmentID} [delete] // @Router /tasks/{id}/attachments/{attachmentID} [delete]
func (ta *TaskAttachment) Delete() error { func (ta *TaskAttachment) Delete(s *xorm.Session) error {
// Load the attachment // Load the attachment
err := ta.ReadOne() err := ta.ReadOne(s)
if err != nil && !files.IsErrFileDoesNotExist(err) { if err != nil && !files.IsErrFileDoesNotExist(err) {
return err return err
} }
// Delete it // Delete it
_, err = x.Where("task_id = ? AND id = ?", ta.TaskID, ta.ID).Delete(ta) _, err = s.
Where("task_id = ? AND id = ?", ta.TaskID, ta.ID).
Delete(ta)
if err != nil { if err != nil {
return err return err
} }
@ -195,9 +198,9 @@ func (ta *TaskAttachment) Delete() error {
return err return err
} }
func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment, err error) { func getTaskAttachmentsByTaskIDs(s *xorm.Session, taskIDs []int64) (attachments []*TaskAttachment, err error) {
attachments = []*TaskAttachment{} attachments = []*TaskAttachment{}
err = x. err = s.
In("task_id", taskIDs). In("task_id", taskIDs).
Find(&attachments) Find(&attachments)
if err != nil { if err != nil {
@ -213,13 +216,13 @@ func getTaskAttachmentsByTaskIDs(taskIDs []int64) (attachments []*TaskAttachment
// Get all files // Get all files
fs := make(map[int64]*files.File) fs := make(map[int64]*files.File)
err = x.In("id", fileIDs).Find(&fs) err = s.In("id", fileIDs).Find(&fs)
if err != nil { if err != nil {
return return
} }
users := make(map[int64]*user.User) users := make(map[int64]*user.User)
err = x.In("id", userIDs).Find(&users) err = s.In("id", userIDs).Find(&users)
if err != nil { if err != nil {
return return
} }

View file

@ -16,25 +16,28 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if the user can see an attachment // CanRead checks if the user can see an attachment
func (ta *TaskAttachment) CanRead(a web.Auth) (bool, int, error) { func (ta *TaskAttachment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
t := &Task{ID: ta.TaskID} t := &Task{ID: ta.TaskID}
return t.CanRead(a) return t.CanRead(s, a)
} }
// CanDelete checks if the user can delete an attachment // CanDelete checks if the user can delete an attachment
func (ta *TaskAttachment) CanDelete(a web.Auth) (bool, error) { func (ta *TaskAttachment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
t := &Task{ID: ta.TaskID} t := &Task{ID: ta.TaskID}
return t.CanWrite(a) return t.CanWrite(s, a)
} }
// CanCreate checks if the user can create an attachment // CanCreate checks if the user can create an attachment
func (ta *TaskAttachment) CanCreate(a web.Auth) (bool, error) { func (ta *TaskAttachment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
t, err := GetTaskByIDSimple(ta.TaskID) t, err := GetTaskByIDSimple(s, ta.TaskID)
if err != nil { if err != nil {
return false, err return false, err
} }
return t.CanCreate(a) return t.CanCreate(s, a)
} }

View file

@ -33,11 +33,14 @@ import (
func TestTaskAttachment_ReadOne(t *testing.T) { func TestTaskAttachment_ReadOne(t *testing.T) {
t.Run("Normal File", func(t *testing.T) { t.Run("Normal File", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{ ta := &TaskAttachment{
ID: 1, ID: 1,
} }
err := ta.ReadOne() err := ta.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, ta.File) assert.NotNil(t, ta.File)
assert.True(t, ta.File.ID == ta.FileID && ta.FileID != 0) assert.True(t, ta.File.ID == ta.FileID && ta.FileID != 0)
@ -54,21 +57,27 @@ func TestTaskAttachment_ReadOne(t *testing.T) {
}) })
t.Run("Nonexisting Attachment", func(t *testing.T) { t.Run("Nonexisting Attachment", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{ ta := &TaskAttachment{
ID: 9999, ID: 9999,
} }
err := ta.ReadOne() err := ta.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskAttachmentDoesNotExist(err)) assert.True(t, IsErrTaskAttachmentDoesNotExist(err))
}) })
t.Run("Existing Attachment, Nonexisting File", func(t *testing.T) { t.Run("Existing Attachment, Nonexisting File", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{ ta := &TaskAttachment{
ID: 2, ID: 2,
} }
err := ta.ReadOne() err := ta.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.EqualError(t, err, "file 9999 does not exist") assert.EqualError(t, err, "file 9999 does not exist")
}) })
@ -94,6 +103,9 @@ func (t *testfile) Close() error {
func TestTaskAttachment_NewAttachment(t *testing.T) { func TestTaskAttachment_NewAttachment(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
// Assert the file is being stored correctly // Assert the file is being stored correctly
ta := TaskAttachment{ ta := TaskAttachment{
@ -104,7 +116,7 @@ func TestTaskAttachment_NewAttachment(t *testing.T) {
} }
testuser := &user.User{ID: 1} testuser := &user.User{ID: 1}
err := ta.NewAttachment(tf, "testfile", 100, testuser) err := ta.NewAttachment(s, tf, "testfile", 100, testuser)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, 0, ta.FileID) assert.NotEqual(t, 0, ta.FileID)
_, err = files.FileStat("files/" + strconv.FormatInt(ta.FileID, 10)) _, err = files.FileStat("files/" + strconv.FormatInt(ta.FileID, 10))
@ -125,9 +137,12 @@ func TestTaskAttachment_NewAttachment(t *testing.T) {
func TestTaskAttachment_ReadAll(t *testing.T) { func TestTaskAttachment_ReadAll(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{TaskID: 1} ta := &TaskAttachment{TaskID: 1}
as, _, _, err := ta.ReadAll(&user.User{ID: 1}, "", 0, 50) as, _, _, err := ta.ReadAll(s, &user.User{ID: 1}, "", 0, 50)
attachments, _ := as.([]*TaskAttachment) attachments, _ := as.([]*TaskAttachment)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, attachments, 2) assert.Len(t, attachments, 2)
@ -136,10 +151,13 @@ func TestTaskAttachment_ReadAll(t *testing.T) {
func TestTaskAttachment_Delete(t *testing.T) { func TestTaskAttachment_Delete(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
t.Run("Normal", func(t *testing.T) { t.Run("Normal", func(t *testing.T) {
ta := &TaskAttachment{ID: 1} ta := &TaskAttachment{ID: 1}
err := ta.Delete() err := ta.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
// Check if the file itself was deleted // Check if the file itself was deleted
_, err = files.FileStat("/1") // The new file has the id 2 since it's the second attachment _, err = files.FileStat("/1") // The new file has the id 2 since it's the second attachment
@ -148,14 +166,14 @@ func TestTaskAttachment_Delete(t *testing.T) {
t.Run("Nonexisting", func(t *testing.T) { t.Run("Nonexisting", func(t *testing.T) {
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{ID: 9999} ta := &TaskAttachment{ID: 9999}
err := ta.Delete() err := ta.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskAttachmentDoesNotExist(err)) assert.True(t, IsErrTaskAttachmentDoesNotExist(err))
}) })
t.Run("Existing attachment, nonexisting file", func(t *testing.T) { t.Run("Existing attachment, nonexisting file", func(t *testing.T) {
files.InitTestFileFixtures(t) files.InitTestFileFixtures(t)
ta := &TaskAttachment{ID: 2} ta := &TaskAttachment{ID: 2}
err := ta.Delete() err := ta.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -165,15 +183,21 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Read", func(t *testing.T) { t.Run("Can Read", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1} ta := &TaskAttachment{TaskID: 1}
can, _, err := ta.CanRead(u) can, _, err := ta.CanRead(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("Forbidden", func(t *testing.T) { t.Run("Forbidden", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14} ta := &TaskAttachment{TaskID: 14}
can, _, err := ta.CanRead(u) can, _, err := ta.CanRead(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
@ -181,22 +205,31 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Delete", func(t *testing.T) { t.Run("Can Delete", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1} ta := &TaskAttachment{TaskID: 1}
can, err := ta.CanDelete(u) can, err := ta.CanDelete(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("Forbidden, no access", func(t *testing.T) { t.Run("Forbidden, no access", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14} ta := &TaskAttachment{TaskID: 14}
can, err := ta.CanDelete(u) can, err := ta.CanDelete(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("Forbidden, shared read only", func(t *testing.T) { t.Run("Forbidden, shared read only", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 15} ta := &TaskAttachment{TaskID: 15}
can, err := ta.CanDelete(u) can, err := ta.CanDelete(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
@ -204,22 +237,31 @@ func TestTaskAttachment_Rights(t *testing.T) {
t.Run("Can Create", func(t *testing.T) { t.Run("Can Create", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 1} ta := &TaskAttachment{TaskID: 1}
can, err := ta.CanCreate(u) can, err := ta.CanCreate(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("Forbidden, no access", func(t *testing.T) { t.Run("Forbidden, no access", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 14} ta := &TaskAttachment{TaskID: 14}
can, err := ta.CanCreate(u) can, err := ta.CanCreate(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("Forbidden, shared read only", func(t *testing.T) { t.Run("Forbidden, shared read only", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
ta := &TaskAttachment{TaskID: 15} ta := &TaskAttachment{TaskID: 15}
can, err := ta.CanCreate(u) can, err := ta.CanCreate(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })

View file

@ -20,6 +20,7 @@ package models
import ( import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// TaskCollection is a struct used to hold filter details and not clutter the Task struct with information not related to actual tasks. // TaskCollection is a struct used to hold filter details and not clutter the Task struct with information not related to actual tasks.
@ -100,17 +101,17 @@ func validateTaskField(fieldName string) error {
// @Success 200 {array} models.Task "The tasks" // @Success 200 {array} models.Task "The tasks"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{listID}/tasks [get] // @Router /lists/{listID}/tasks [get]
func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { func (tf *TaskCollection) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
// If the list id is < -1 this means we're dealing with a saved filter - in that case we get and populate the filter // If the list id is < -1 this means we're dealing with a saved filter - in that case we get and populate the filter
// -1 is the favorites list which works as intended // -1 is the favorites list which works as intended
if tf.ListID < -1 { if tf.ListID < -1 {
s, err := getSavedFilterSimpleByID(getSavedFilterIDFromListID(tf.ListID)) sf, err := getSavedFilterSimpleByID(s, getSavedFilterIDFromListID(tf.ListID))
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return s.getTaskCollection().ReadAll(a, search, page, perPage) return sf.getTaskCollection().ReadAll(s, a, search, page, perPage)
} }
if len(tf.SortByArr) > 0 { if len(tf.SortByArr) > 0 {
@ -156,28 +157,30 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i
shareAuth, is := a.(*LinkSharing) shareAuth, is := a.(*LinkSharing)
if is { if is {
list := &List{ID: shareAuth.ListID} list, err := GetListSimpleByID(s, shareAuth.ListID)
err := list.GetSimpleByID()
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return getTasksForLists([]*List{list}, a, taskopts) return getTasksForLists(s, []*List{list}, a, taskopts)
} }
// If the list ID is not set, we get all tasks for the user. // If the list ID is not set, we get all tasks for the user.
// This allows to use this function in Task.ReadAll with a possibility to deprecate the latter at some point. // This allows to use this function in Task.ReadAll with a possibility to deprecate the latter at some point.
if tf.ListID == 0 { if tf.ListID == 0 {
tf.Lists, _, _, err = getRawListsForUser(&listOptions{ tf.Lists, _, _, err = getRawListsForUser(
s,
&listOptions{
user: &user.User{ID: a.GetID()}, user: &user.User{ID: a.GetID()},
page: -1, page: -1,
}) },
)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
} else { } else {
// Check the list exists and the user has acess on it // Check the list exists and the user has acess on it
list := &List{ID: tf.ListID} list := &List{ID: tf.ListID}
canRead, _, err := list.CanRead(a) canRead, _, err := list.CanRead(s, a)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -187,5 +190,5 @@ func (tf *TaskCollection) ReadAll(a web.Auth, search string, page int, perPage i
tf.Lists = []*List{{ID: tf.ListID}} tf.Lists = []*List{{ID: tf.ListID}}
} }
return getTasksForLists(tf.Lists, a, taskopts) return getTasksForLists(s, tf.Lists, a, taskopts)
} }

View file

@ -986,6 +986,8 @@ func TestTaskCollection_ReadAll(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
lt := &TaskCollection{ lt := &TaskCollection{
ListID: tt.fields.ListID, ListID: tt.fields.ListID,
@ -1000,7 +1002,7 @@ func TestTaskCollection_ReadAll(t *testing.T) {
CRUDable: tt.fields.CRUDable, CRUDable: tt.fields.CRUDable,
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
got, _, _, err := lt.ReadAll(tt.args.a, tt.args.search, tt.args.page, 50) got, _, _, err := lt.ReadAll(s, tt.args.a, tt.args.search, tt.args.page, 50)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Test %s, Task.ReadAll() error = %v, wantErr %v", tt.name, err, tt.wantErr) t.Errorf("Test %s, Task.ReadAll() error = %v, wantErr %v", tt.name, err, tt.wantErr)
return return

View file

@ -17,28 +17,31 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanRead checks if a user can read a comment // CanRead checks if a user can read a comment
func (tc *TaskComment) CanRead(a web.Auth) (bool, int, error) { func (tc *TaskComment) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
t := Task{ID: tc.TaskID} t := Task{ID: tc.TaskID}
return t.CanRead(a) return t.CanRead(s, a)
} }
// CanDelete checks if a user can delete a comment // CanDelete checks if a user can delete a comment
func (tc *TaskComment) CanDelete(a web.Auth) (bool, error) { func (tc *TaskComment) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID} t := Task{ID: tc.TaskID}
return t.CanWrite(a) return t.CanWrite(s, a)
} }
// CanUpdate checks if a user can update a comment // CanUpdate checks if a user can update a comment
func (tc *TaskComment) CanUpdate(a web.Auth) (bool, error) { func (tc *TaskComment) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID} t := Task{ID: tc.TaskID}
return t.CanWrite(a) return t.CanWrite(s, a)
} }
// CanCreate checks if a user can create a new comment // CanCreate checks if a user can create a new comment
func (tc *TaskComment) CanCreate(a web.Auth) (bool, error) { func (tc *TaskComment) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
t := Task{ID: tc.TaskID} t := Task{ID: tc.TaskID}
return t.CanWrite(a) return t.CanWrite(s, a)
} }

View file

@ -20,6 +20,8 @@ package models
import ( import (
"time" "time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
) )
@ -57,19 +59,19 @@ func (tc *TaskComment) TableName() string {
// @Failure 400 {object} web.HTTPError "Invalid task comment object provided." // @Failure 400 {object} web.HTTPError "Invalid task comment object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments [put] // @Router /tasks/{taskID}/comments [put]
func (tc *TaskComment) Create(a web.Auth) (err error) { func (tc *TaskComment) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the task exists // Check if the task exists
_, err = GetTaskSimple(&Task{ID: tc.TaskID}) _, err = GetTaskSimple(s, &Task{ID: tc.TaskID})
if err != nil { if err != nil {
return err return err
} }
tc.AuthorID = a.GetID() tc.AuthorID = a.GetID()
_, err = x.Insert(tc) _, err = s.Insert(tc)
if err != nil { if err != nil {
return return
} }
tc.Author, err = user.GetUserByID(a.GetID()) tc.Author, err = user.GetUserByID(s, a.GetID())
return return
} }
@ -87,8 +89,11 @@ func (tc *TaskComment) Create(a web.Auth) (err error) {
// @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [delete] // @Router /tasks/{taskID}/comments/{commentID} [delete]
func (tc *TaskComment) Delete() error { func (tc *TaskComment) Delete(s *xorm.Session) error {
deleted, err := x.ID(tc.ID).NoAutoCondition().Delete(tc) deleted, err := s.
ID(tc.ID).
NoAutoCondition().
Delete(tc)
if deleted == 0 { if deleted == 0 {
return ErrTaskCommentDoesNotExist{ID: tc.ID} return ErrTaskCommentDoesNotExist{ID: tc.ID}
} }
@ -109,8 +114,11 @@ func (tc *TaskComment) Delete() error {
// @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [post] // @Router /tasks/{taskID}/comments/{commentID} [post]
func (tc *TaskComment) Update() error { func (tc *TaskComment) Update(s *xorm.Session) error {
updated, err := x.ID(tc.ID).Cols("comment").Update(tc) updated, err := s.
ID(tc.ID).
Cols("comment").
Update(tc)
if updated == 0 { if updated == 0 {
return ErrTaskCommentDoesNotExist{ID: tc.ID} return ErrTaskCommentDoesNotExist{ID: tc.ID}
} }
@ -131,8 +139,8 @@ func (tc *TaskComment) Update() error {
// @Failure 404 {object} web.HTTPError "The task comment was not found." // @Failure 404 {object} web.HTTPError "The task comment was not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments/{commentID} [get] // @Router /tasks/{taskID}/comments/{commentID} [get]
func (tc *TaskComment) ReadOne() (err error) { func (tc *TaskComment) ReadOne(s *xorm.Session) (err error) {
exists, err := x.Get(tc) exists, err := s.Get(tc)
if err != nil { if err != nil {
return return
} }
@ -145,7 +153,7 @@ func (tc *TaskComment) ReadOne() (err error) {
// Get the author // Get the author
author := &user.User{} author := &user.User{}
_, err = x. _, err = s.
Where("id = ?", tc.AuthorID). Where("id = ?", tc.AuthorID).
Get(author) Get(author)
tc.Author = author tc.Author = author
@ -163,10 +171,10 @@ func (tc *TaskComment) ReadOne() (err error) {
// @Success 200 {array} models.TaskComment "The array with all task comments" // @Success 200 {array} models.TaskComment "The array with all task comments"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/comments [get] // @Router /tasks/{taskID}/comments [get]
func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (tc *TaskComment) ReadAll(s *xorm.Session, auth web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
// Check if the user has access to the task // Check if the user has access to the task
canRead, _, err := tc.CanRead(auth) canRead, _, err := tc.CanRead(s, auth)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -184,7 +192,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
comments := []*TaskComment{} comments := []*TaskComment{}
query := x. query := s.
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
Join("LEFT", "users", "users.id = task_comments.author_id") Join("LEFT", "users", "users.id = task_comments.author_id")
if limit > 0 { if limit > 0 {
@ -197,7 +205,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
// Get all authors // Get all authors
authors := make(map[int64]*user.User) authors := make(map[int64]*user.User)
err = x. err = s.
Select("users.*"). Select("users.*").
Table("task_comments"). Table("task_comments").
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
@ -211,7 +219,7 @@ func (tc *TaskComment) ReadAll(auth web.Auth, search string, page int, perPage i
comment.Author = authors[comment.AuthorID] comment.Author = authors[comment.AuthorID]
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%"). Where("task_id = ? AND comment like ?", tc.TaskID, "%"+search+"%").
Count(&TaskCommentWithAuthor{}) Count(&TaskCommentWithAuthor{})
return comments, len(comments), numberOfTotalItems, err return comments, len(comments), numberOfTotalItems, err

View file

@ -28,14 +28,20 @@ func TestTaskComment_Create(t *testing.T) {
u := &user.User{ID: 1} u := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ tc := &TaskComment{
Comment: "test", Comment: "test",
TaskID: 1, TaskID: 1,
} }
err := tc.Create(u) err := tc.Create(s, u)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "test", tc.Comment) assert.Equal(t, "test", tc.Comment)
assert.Equal(t, int64(1), tc.Author.ID) assert.Equal(t, int64(1), tc.Author.ID)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_comments", map[string]interface{}{ db.AssertExists(t, "task_comments", map[string]interface{}{
"id": tc.ID, "id": tc.ID,
"author_id": u.ID, "author_id": u.ID,
@ -45,11 +51,14 @@ func TestTaskComment_Create(t *testing.T) {
}) })
t.Run("nonexisting task", func(t *testing.T) { t.Run("nonexisting task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ tc := &TaskComment{
Comment: "test", Comment: "test",
TaskID: 99999, TaskID: 99999,
} }
err := tc.Create(u) err := tc.Create(s, u)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err)) assert.True(t, IsErrTaskDoesNotExist(err))
}) })
@ -58,17 +67,26 @@ func TestTaskComment_Create(t *testing.T) {
func TestTaskComment_Delete(t *testing.T) { func TestTaskComment_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 1} tc := &TaskComment{ID: 1}
err := tc.Delete() err := tc.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "task_comments", map[string]interface{}{ db.AssertMissing(t, "task_comments", map[string]interface{}{
"id": 1, "id": 1,
}) })
}) })
t.Run("nonexisting comment", func(t *testing.T) { t.Run("nonexisting comment", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 9999} tc := &TaskComment{ID: 9999}
err := tc.Delete() err := tc.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err)) assert.True(t, IsErrTaskCommentDoesNotExist(err))
}) })
@ -77,12 +95,18 @@ func TestTaskComment_Delete(t *testing.T) {
func TestTaskComment_Update(t *testing.T) { func TestTaskComment_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ tc := &TaskComment{
ID: 1, ID: 1,
Comment: "testing", Comment: "testing",
} }
err := tc.Update() err := tc.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "task_comments", map[string]interface{}{ db.AssertExists(t, "task_comments", map[string]interface{}{
"id": 1, "id": 1,
"comment": "testing", "comment": "testing",
@ -90,10 +114,13 @@ func TestTaskComment_Update(t *testing.T) {
}) })
t.Run("nonexisting comment", func(t *testing.T) { t.Run("nonexisting comment", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ tc := &TaskComment{
ID: 9999, ID: 9999,
} }
err := tc.Update() err := tc.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err)) assert.True(t, IsErrTaskCommentDoesNotExist(err))
}) })
@ -102,16 +129,22 @@ func TestTaskComment_Update(t *testing.T) {
func TestTaskComment_ReadOne(t *testing.T) { func TestTaskComment_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 1} tc := &TaskComment{ID: 1}
err := tc.ReadOne() err := tc.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "Lorem Ipsum Dolor Sit Amet", tc.Comment) assert.Equal(t, "Lorem Ipsum Dolor Sit Amet", tc.Comment)
assert.NotEmpty(t, tc.Author.ID) assert.NotEmpty(t, tc.Author.ID)
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{ID: 9999} tc := &TaskComment{ID: 9999}
err := tc.ReadOne() err := tc.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskCommentDoesNotExist(err)) assert.True(t, IsErrTaskCommentDoesNotExist(err))
}) })
@ -120,9 +153,12 @@ func TestTaskComment_ReadOne(t *testing.T) {
func TestTaskComment_ReadAll(t *testing.T) { func TestTaskComment_ReadAll(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{TaskID: 1} tc := &TaskComment{TaskID: 1}
u := &user.User{ID: 1} u := &user.User{ID: 1}
result, resultCount, total, err := tc.ReadAll(u, "", 0, -1) result, resultCount, total, err := tc.ReadAll(s, u, "", 0, -1)
resultComment := result.([]*TaskComment) resultComment := result.([]*TaskComment)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, resultCount) assert.Equal(t, 1, resultCount)
@ -133,9 +169,12 @@ func TestTaskComment_ReadAll(t *testing.T) {
}) })
t.Run("no access to task", func(t *testing.T) { t.Run("no access to task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tc := &TaskComment{TaskID: 14} tc := &TaskComment{TaskID: 14}
u := &user.User{ID: 1} u := &user.User{ID: 1}
_, _, _, err := tc.ReadAll(u, "", 0, -1) _, _, _, err := tc.ReadAll(s, u, "", 0, -1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrGenericForbidden(err)) assert.True(t, IsErrGenericForbidden(err))
}) })

View file

@ -20,6 +20,8 @@ package models
import ( import (
"time" "time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
) )
@ -117,7 +119,7 @@ type RelatedTaskMap map[RelationKind][]*Task
// @Failure 400 {object} web.HTTPError "Invalid task relation object provided." // @Failure 400 {object} web.HTTPError "Invalid task relation object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/relations [put] // @Router /tasks/{taskID}/relations [put]
func (rel *TaskRelation) Create(a web.Auth) error { func (rel *TaskRelation) Create(s *xorm.Session, a web.Auth) error {
// Check if both tasks are the same // Check if both tasks are the same
if rel.TaskID == rel.OtherTaskID { if rel.TaskID == rel.OtherTaskID {
@ -128,7 +130,7 @@ func (rel *TaskRelation) Create(a web.Auth) error {
} }
// Check if the relation already exists, in one form or the other. // Check if the relation already exists, in one form or the other.
exists, err := x. exists, err := s.
Where("(task_id = ? AND other_task_id = ? AND relation_kind = ?) OR (task_id = ? AND other_task_id = ? AND relation_kind = ?)", Where("(task_id = ? AND other_task_id = ? AND relation_kind = ?) OR (task_id = ? AND other_task_id = ? AND relation_kind = ?)",
rel.TaskID, rel.OtherTaskID, rel.RelationKind, rel.TaskID, rel.OtherTaskID, rel.RelationKind). rel.TaskID, rel.OtherTaskID, rel.RelationKind, rel.TaskID, rel.OtherTaskID, rel.RelationKind).
Exist(rel) Exist(rel)
@ -180,7 +182,7 @@ func (rel *TaskRelation) Create(a web.Auth) error {
} }
// Finally insert everything // Finally insert everything
_, err = x.Insert(&[]*TaskRelation{ _, err = s.Insert(&[]*TaskRelation{
rel, rel,
otherRelation, otherRelation,
}) })
@ -200,9 +202,9 @@ func (rel *TaskRelation) Create(a web.Auth) error {
// @Failure 404 {object} web.HTTPError "The task relation was not found." // @Failure 404 {object} web.HTTPError "The task relation was not found."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{taskID}/relations [delete] // @Router /tasks/{taskID}/relations [delete]
func (rel *TaskRelation) Delete() error { func (rel *TaskRelation) Delete(s *xorm.Session) error {
// Check if the relation exists // Check if the relation exists
exists, err := x. exists, err := s.
Cols("task_id", "other_task_id", "relation_kind"). Cols("task_id", "other_task_id", "relation_kind").
Get(rel) Get(rel)
if err != nil { if err != nil {
@ -216,6 +218,6 @@ func (rel *TaskRelation) Delete() error {
} }
} }
_, err = x.Delete(rel) _, err = s.Delete(rel)
return err return err
} }

View file

@ -17,17 +17,20 @@
package models package models
import "code.vikunja.io/web" import (
"code.vikunja.io/web"
"xorm.io/xorm"
)
// CanDelete checks if a user can delete a task relation // CanDelete checks if a user can delete a task relation
func (rel *TaskRelation) CanDelete(a web.Auth) (bool, error) { func (rel *TaskRelation) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
// A user can delete a relation if it can update the base task // A user can delete a relation if it can update the base task
baseTask := &Task{ID: rel.TaskID} baseTask := &Task{ID: rel.TaskID}
return baseTask.CanUpdate(a) return baseTask.CanUpdate(s, a)
} }
// CanCreate checks if a user can create a new relation between two relations // CanCreate checks if a user can create a new relation between two relations
func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) { func (rel *TaskRelation) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// Check if the relation kind is valid // Check if the relation kind is valid
if !rel.RelationKind.isValid() { if !rel.RelationKind.isValid() {
return false, ErrInvalidRelationKind{Kind: rel.RelationKind} return false, ErrInvalidRelationKind{Kind: rel.RelationKind}
@ -35,14 +38,14 @@ func (rel *TaskRelation) CanCreate(a web.Auth) (bool, error) {
// Needs have write access to the base task and at least read access to the other task // Needs have write access to the base task and at least read access to the other task
baseTask := &Task{ID: rel.TaskID} baseTask := &Task{ID: rel.TaskID}
has, err := baseTask.CanUpdate(a) has, err := baseTask.CanUpdate(s, a)
if err != nil || !has { if err != nil || !has {
return false, err return false, err
} }
// We explicitly don't check if the two tasks are on the same list. // We explicitly don't check if the two tasks are on the same list.
otherTask := &Task{ID: rel.OtherTaskID} otherTask := &Task{ID: rel.OtherTaskID}
has, _, err = otherTask.CanRead(a) has, _, err = otherTask.CanRead(s, a)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -28,13 +28,17 @@ import (
func TestTaskRelation_Create(t *testing.T) { func TestTaskRelation_Create(t *testing.T) {
t.Run("Normal", func(t *testing.T) { t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 2, OtherTaskID: 2,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
err := rel.Create(&user.User{ID: 1}) err := rel.Create(s, &user.User{ID: 1})
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "task_relations", map[string]interface{}{ db.AssertExists(t, "task_relations", map[string]interface{}{
"task_id": 1, "task_id": 1,
@ -45,13 +49,17 @@ func TestTaskRelation_Create(t *testing.T) {
}) })
t.Run("Two Tasks In Different Lists", func(t *testing.T) { t.Run("Two Tasks In Different Lists", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 13, OtherTaskID: 13,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
err := rel.Create(&user.User{ID: 1}) err := rel.Create(s, &user.User{ID: 1})
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "task_relations", map[string]interface{}{ db.AssertExists(t, "task_relations", map[string]interface{}{
"task_id": 1, "task_id": 1,
@ -62,24 +70,28 @@ func TestTaskRelation_Create(t *testing.T) {
}) })
t.Run("Already Existing", func(t *testing.T) { t.Run("Already Existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 29, OtherTaskID: 29,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
err := rel.Create(&user.User{ID: 1}) err := rel.Create(s, &user.User{ID: 1})
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrRelationAlreadyExists(err)) assert.True(t, IsErrRelationAlreadyExists(err))
}) })
t.Run("Same Task", func(t *testing.T) { t.Run("Same Task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 1, OtherTaskID: 1,
} }
err := rel.Create(&user.User{ID: 1}) err := rel.Create(s, &user.User{ID: 1})
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrRelationTasksCannotBeTheSame(err)) assert.True(t, IsErrRelationTasksCannotBeTheSame(err))
}) })
@ -88,13 +100,17 @@ func TestTaskRelation_Create(t *testing.T) {
func TestTaskRelation_Delete(t *testing.T) { func TestTaskRelation_Delete(t *testing.T) {
t.Run("Normal", func(t *testing.T) { t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 29, OtherTaskID: 29,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
err := rel.Delete() err := rel.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertMissing(t, "task_relations", map[string]interface{}{ db.AssertMissing(t, "task_relations", map[string]interface{}{
"task_id": 1, "task_id": 1,
@ -104,13 +120,15 @@ func TestTaskRelation_Delete(t *testing.T) {
}) })
t.Run("Not existing", func(t *testing.T) { t.Run("Not existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 9999, TaskID: 9999,
OtherTaskID: 3, OtherTaskID: 3,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
err := rel.Delete() err := rel.Delete(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrRelationDoesNotExist(err)) assert.True(t, IsErrRelationDoesNotExist(err))
}) })
@ -119,86 +137,100 @@ func TestTaskRelation_Delete(t *testing.T) {
func TestTaskRelation_CanCreate(t *testing.T) { func TestTaskRelation_CanCreate(t *testing.T) {
t.Run("Normal", func(t *testing.T) { t.Run("Normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 2, OtherTaskID: 2,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("Two tasks on different lists", func(t *testing.T) { t.Run("Two tasks on different lists", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 13, OtherTaskID: 13,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, can) assert.True(t, can)
}) })
t.Run("No update rights on base task", func(t *testing.T) { t.Run("No update rights on base task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 14, TaskID: 14,
OtherTaskID: 1, OtherTaskID: 1,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("No update rights on base task, but read rights", func(t *testing.T) { t.Run("No update rights on base task, but read rights", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 15, TaskID: 15,
OtherTaskID: 1, OtherTaskID: 1,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("No read rights on other task", func(t *testing.T) { t.Run("No read rights on other task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 14, OtherTaskID: 14,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, can) assert.False(t, can)
}) })
t.Run("Nonexisting base task", func(t *testing.T) { t.Run("Nonexisting base task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 999999, TaskID: 999999,
OtherTaskID: 1, OtherTaskID: 1,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err)) assert.True(t, IsErrTaskDoesNotExist(err))
assert.False(t, can) assert.False(t, can)
}) })
t.Run("Nonexisting other task", func(t *testing.T) { t.Run("Nonexisting other task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
rel := TaskRelation{ rel := TaskRelation{
TaskID: 1, TaskID: 1,
OtherTaskID: 999999, OtherTaskID: 999999,
RelationKind: RelationKindSubtask, RelationKind: RelationKindSubtask,
} }
can, err := rel.CanCreate(&user.User{ID: 1}) can, err := rel.CanCreate(s, &user.User{ID: 1})
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err)) assert.True(t, IsErrTaskDoesNotExist(err))
assert.False(t, can) assert.False(t, can)

View file

@ -19,6 +19,9 @@ package models
import ( import (
"time" "time"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/cron" "code.vikunja.io/api/pkg/cron"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
@ -44,10 +47,10 @@ type taskUser struct {
User *user.User `xorm:"extends"` User *user.User `xorm:"extends"`
} }
func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) { func getTaskUsersForTasks(s *xorm.Session, taskIDs []int64) (taskUsers []*taskUser, err error) {
// Get all creators of tasks // Get all creators of tasks
creators := make(map[int64]*user.User, len(taskIDs)) creators := make(map[int64]*user.User, len(taskIDs))
err = x. err = s.
Select("users.id, users.username, users.email, users.name"). Select("users.id, users.username, users.email, users.name").
Join("LEFT", "tasks", "tasks.created_by_id = users.id"). Join("LEFT", "tasks", "tasks.created_by_id = users.id").
In("tasks.id", taskIDs). In("tasks.id", taskIDs).
@ -58,13 +61,13 @@ func getTaskUsersForTasks(taskIDs []int64) (taskUsers []*taskUser, err error) {
return return
} }
assignees, err := getRawTaskAssigneesForTasks(taskIDs) assignees, err := getRawTaskAssigneesForTasks(s, taskIDs)
if err != nil { if err != nil {
return return
} }
taskMap := make(map[int64]*Task, len(taskIDs)) taskMap := make(map[int64]*Task, len(taskIDs))
err = x.In("id", taskIDs).Find(&taskMap) err = s.In("id", taskIDs).Find(&taskMap)
if err != nil { if err != nil {
return return
} }
@ -106,6 +109,8 @@ func RegisterReminderCron() {
log.Debugf("[Task Reminder Cron] Timezone is %s", tz) log.Debugf("[Task Reminder Cron] Timezone is %s", tz)
s := db.NewSession()
err := cron.Schedule("* * * * *", func() { err := cron.Schedule("* * * * *", func() {
// By default, time.Now() includes nanoseconds which we don't save. That results in getting the wrong dates, // By default, time.Now() includes nanoseconds which we don't save. That results in getting the wrong dates,
// so we make sure the time we use to get the reminders don't contain nanoseconds. // so we make sure the time we use to get the reminders don't contain nanoseconds.
@ -116,7 +121,7 @@ func RegisterReminderCron() {
log.Debugf("[Task Reminder Cron] Looking for reminders between %s and %s to send...", now, nextMinute) log.Debugf("[Task Reminder Cron] Looking for reminders between %s and %s to send...", now, nextMinute)
reminders := []*TaskReminder{} reminders := []*TaskReminder{}
err := x. err := s.
Where("reminder >= ? and reminder < ?", now.Format(dbFormat), nextMinute.Format(dbFormat)). Where("reminder >= ? and reminder < ?", now.Format(dbFormat), nextMinute.Format(dbFormat)).
Find(&reminders) Find(&reminders)
if err != nil { if err != nil {
@ -136,7 +141,7 @@ func RegisterReminderCron() {
taskIDs = append(taskIDs, r.TaskID) taskIDs = append(taskIDs, r.TaskID)
} }
users, err := getTaskUsersForTasks(taskIDs) users, err := getTaskUsersForTasks(s, taskIDs)
if err != nil { if err != nil {
log.Errorf("[Task Reminder Cron] Could not get task users to send them reminders: %s", err) log.Errorf("[Task Reminder Cron] Could not get task users to send them reminders: %s", err)
return return

View file

@ -22,6 +22,8 @@ import (
"strconv" "strconv"
"time" "time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
@ -153,7 +155,7 @@ type taskOptions struct {
// @Success 200 {array} models.Task "The tasks" // @Success 200 {array} models.Task "The tasks"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/all [get] // @Router /tasks/all [get]
func (t *Task) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) { func (t *Task) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, totalItems int64, err error) {
return nil, 0, 0, nil return nil, 0, 0, nil
} }
@ -209,7 +211,7 @@ func getFilterCondForSeparateTable(table string, concat taskFilterConcatinator,
} }
//nolint:gocyclo //nolint:gocyclo
func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { func getRawTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
// If the user does not have any lists, don't try to get any tasks // If the user does not have any lists, don't try to get any tasks
if len(lists) == 0 { if len(lists) == 0 {
@ -253,7 +255,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
// Postgres sorts by default entries with null values after ones with values. // Postgres sorts by default entries with null values after ones with values.
// To make that consistent with the sort order we have and other dbms, we're adding a separate clause here. // To make that consistent with the sort order we have and other dbms, we're adding a separate clause here.
if x.Dialect().URI().DBType == schemas.POSTGRES { if db.Type() == schemas.POSTGRES {
if param.orderBy == orderAscending { if param.orderBy == orderAscending {
orderby += " NULLS FIRST" orderby += " NULLS FIRST"
} }
@ -324,9 +326,7 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
} }
// Then return all tasks for that lists // Then return all tasks for that lists
query := x.NewSession(). var where builder.Cond
OrderBy(orderby)
queryCount := x.NewSession()
if len(opts.search) > 0 { if len(opts.search) > 0 {
// Postgres' is case sensitive by default. // Postgres' is case sensitive by default.
@ -335,11 +335,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
// See https://stackoverflow.com/q/7005302/10924593 // See https://stackoverflow.com/q/7005302/10924593
// Seems okay to use that now, we may need to find a better solution overall in the future. // Seems okay to use that now, we may need to find a better solution overall in the future.
if config.DatabaseType.GetString() == "postgres" { if config.DatabaseType.GetString() == "postgres" {
query = query.Where("title ILIKE ?", "%"+opts.search+"%") where = builder.Expr("title ILIKE ?", "%"+opts.search+"%")
queryCount = queryCount.Where("title ILIKE ?", "%"+opts.search+"%")
} else { } else {
query = query.Where("title LIKE ?", "%"+opts.search+"%") where = &builder.Like{"title", "%" + opts.search + "%"}
queryCount = queryCount.Where("title LIKE ?", "%"+opts.search+"%")
} }
} }
@ -352,10 +350,13 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
if hasFavoriteLists { if hasFavoriteLists {
// Make sure users can only see their favorites // Make sure users can only see their favorites
userLists, _, _, err := getRawListsForUser(&listOptions{ userLists, _, _, err := getRawListsForUser(
s,
&listOptions{
user: &user.User{ID: a.GetID()}, user: &user.User{ID: a.GetID()},
page: -1, page: -1,
}) },
)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -399,32 +400,31 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
filters = append(filters, cond) filters = append(filters, cond)
} }
query = query.Where(listCond) var filterCond builder.Cond
queryCount = queryCount.Where(listCond)
if len(filters) > 0 { if len(filters) > 0 {
if opts.filterConcat == filterConcatOr { if opts.filterConcat == filterConcatOr {
query = query.Where(builder.Or(filters...)) filterCond = builder.Or(filters...)
queryCount = queryCount.Where(builder.Or(filters...))
} }
if opts.filterConcat == filterConcatAnd { if opts.filterConcat == filterConcatAnd {
query = query.Where(builder.And(filters...)) filterCond = builder.And(filters...)
queryCount = queryCount.Where(builder.And(filters...))
} }
} }
limit, start := getLimitFromPageIndex(opts.page, opts.perPage) limit, start := getLimitFromPageIndex(opts.page, opts.perPage)
cond := builder.And(listCond, where, filterCond)
query := s.Where(cond)
if limit > 0 { if limit > 0 {
query = query.Limit(limit, start) query = query.Limit(limit, start)
} }
tasks = []*Task{} tasks = []*Task{}
err = query.Find(&tasks) err = query.OrderBy(orderby).Find(&tasks)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
queryCount := s.Where(cond)
totalItems, err = queryCount. totalItems, err = queryCount.
Count(&Task{}) Count(&Task{})
if err != nil { if err != nil {
@ -434,9 +434,9 @@ func getRawTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []
return tasks, len(tasks), totalItems, nil return tasks, len(tasks), totalItems, nil
} }
func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) { func getTasksForLists(s *xorm.Session, lists []*List, a web.Auth, opts *taskOptions) (tasks []*Task, resultCount int, totalItems int64, err error) {
tasks, resultCount, totalItems, err = getRawTasksForLists(lists, a, opts) tasks, resultCount, totalItems, err = getRawTasksForLists(s, lists, a, opts)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -446,7 +446,7 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta
taskMap[t.ID] = t taskMap[t.ID] = t
} }
err = addMoreInfoToTasks(taskMap) err = addMoreInfoToTasks(s, taskMap)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -455,18 +455,18 @@ func getTasksForLists(lists []*List, a web.Auth, opts *taskOptions) (tasks []*Ta
} }
// GetTaskByIDSimple returns a raw task without extra data by the task ID // GetTaskByIDSimple returns a raw task without extra data by the task ID
func GetTaskByIDSimple(taskID int64) (task Task, err error) { func GetTaskByIDSimple(s *xorm.Session, taskID int64) (task Task, err error) {
if taskID < 1 { if taskID < 1 {
return Task{}, ErrTaskDoesNotExist{taskID} return Task{}, ErrTaskDoesNotExist{taskID}
} }
return GetTaskSimple(&Task{ID: taskID}) return GetTaskSimple(s, &Task{ID: taskID})
} }
// GetTaskSimple returns a raw task without extra data // GetTaskSimple returns a raw task without extra data
func GetTaskSimple(t *Task) (task Task, err error) { func GetTaskSimple(s *xorm.Session, t *Task) (task Task, err error) {
task = *t task = *t
exists, err := x.Get(&task) exists, err := s.Get(&task)
if err != nil { if err != nil {
return Task{}, err return Task{}, err
} }
@ -478,14 +478,14 @@ func GetTaskSimple(t *Task) (task Task, err error) {
} }
// GetTasksByIDs returns all tasks for a list of ids // GetTasksByIDs returns all tasks for a list of ids
func (bt *BulkTask) GetTasksByIDs() (err error) { func (bt *BulkTask) GetTasksByIDs(s *xorm.Session) (err error) {
for _, id := range bt.IDs { for _, id := range bt.IDs {
if id < 1 { if id < 1 {
return ErrTaskDoesNotExist{id} return ErrTaskDoesNotExist{id}
} }
} }
err = x.In("id", bt.IDs).Find(&bt.Tasks) err = s.In("id", bt.IDs).Find(&bt.Tasks)
if err != nil { if err != nil {
return return
} }
@ -494,9 +494,9 @@ func (bt *BulkTask) GetTasksByIDs() (err error) {
} }
// GetTasksByUIDs gets all tasks from a bunch of uids // GetTasksByUIDs gets all tasks from a bunch of uids
func GetTasksByUIDs(uids []string) (tasks []*Task, err error) { func GetTasksByUIDs(s *xorm.Session, uids []string) (tasks []*Task, err error) {
tasks = []*Task{} tasks = []*Task{}
err = x.In("uid", uids).Find(&tasks) err = s.In("uid", uids).Find(&tasks)
if err != nil { if err != nil {
return return
} }
@ -506,13 +506,13 @@ func GetTasksByUIDs(uids []string) (tasks []*Task, err error) {
taskMap[t.ID] = t taskMap[t.ID] = t
} }
err = addMoreInfoToTasks(taskMap) err = addMoreInfoToTasks(s, taskMap)
return return
} }
func getRemindersForTasks(taskIDs []int64) (reminders []*TaskReminder, err error) { func getRemindersForTasks(s *xorm.Session, taskIDs []int64) (reminders []*TaskReminder, err error) {
reminders = []*TaskReminder{} reminders = []*TaskReminder{}
err = x.In("task_id", taskIDs).Find(&reminders) err = s.In("task_id", taskIDs).Find(&reminders)
return return
} }
@ -521,8 +521,8 @@ func (t *Task) setIdentifier(list *List) {
} }
// Get all assignees // Get all assignees
func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { func addAssigneesToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
taskAssignees, err := getRawTaskAssigneesForTasks(taskIDs) taskAssignees, err := getRawTaskAssigneesForTasks(s, taskIDs)
if err != nil { if err != nil {
return return
} }
@ -538,8 +538,8 @@ func addAssigneesToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
} }
// Get all labels for all the tasks // Get all labels for all the tasks
func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { func addLabelsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
labels, _, _, err := getLabelsByTaskIDs(&LabelByTaskIDsOptions{ labels, _, _, err := getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{
TaskIDs: taskIDs, TaskIDs: taskIDs,
Page: -1, Page: -1,
}) })
@ -556,8 +556,8 @@ func addLabelsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) {
} }
// Get task attachments // Get task attachments
func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { func addAttachmentsToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
attachments, err := getTaskAttachmentsByTaskIDs(taskIDs) attachments, err := getTaskAttachmentsByTaskIDs(s, taskIDs)
if err != nil { if err != nil {
return return
} }
@ -568,11 +568,11 @@ func addAttachmentsToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error)
return return
} }
func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, err error) { func getTaskReminderMap(s *xorm.Session, taskIDs []int64) (taskReminders map[int64][]time.Time, err error) {
taskReminders = make(map[int64][]time.Time) taskReminders = make(map[int64][]time.Time)
// Get all reminders and put them in a map to have it easier later // Get all reminders and put them in a map to have it easier later
reminders, err := getRemindersForTasks(taskIDs) reminders, err := getRemindersForTasks(s, taskIDs)
if err != nil { if err != nil {
return return
} }
@ -584,9 +584,9 @@ func getTaskReminderMap(taskIDs []int64) (taskReminders map[int64][]time.Time, e
return return
} }
func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error) { func addRelatedTasksToTasks(s *xorm.Session, taskIDs []int64, taskMap map[int64]*Task) (err error) {
relatedTasks := []*TaskRelation{} relatedTasks := []*TaskRelation{}
err = x.In("task_id", taskIDs).Find(&relatedTasks) err = s.In("task_id", taskIDs).Find(&relatedTasks)
if err != nil { if err != nil {
return return
} }
@ -597,7 +597,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error
relatedTaskIDs = append(relatedTaskIDs, rt.OtherTaskID) relatedTaskIDs = append(relatedTaskIDs, rt.OtherTaskID)
} }
fullRelatedTasks := make(map[int64]*Task) fullRelatedTasks := make(map[int64]*Task)
err = x.In("id", relatedTaskIDs).Find(&fullRelatedTasks) err = s.In("id", relatedTaskIDs).Find(&fullRelatedTasks)
if err != nil { if err != nil {
return return
} }
@ -614,7 +614,7 @@ func addRelatedTasksToTasks(taskIDs []int64, taskMap map[int64]*Task) (err error
// This function takes a map with pointers and returns a slice with pointers to tasks // This function takes a map with pointers and returns a slice with pointers to tasks
// It adds more stuff like assignees/labels/etc to a bunch of tasks // It adds more stuff like assignees/labels/etc to a bunch of tasks
func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) { func addMoreInfoToTasks(s *xorm.Session, taskMap map[int64]*Task) (err error) {
// No need to iterate over users and stuff if the list doesn't have tasks // No need to iterate over users and stuff if the list doesn't have tasks
if len(taskMap) == 0 { if len(taskMap) == 0 {
@ -631,33 +631,33 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) {
listIDs = append(listIDs, i.ListID) listIDs = append(listIDs, i.ListID)
} }
err = addAssigneesToTasks(taskIDs, taskMap) err = addAssigneesToTasks(s, taskIDs, taskMap)
if err != nil { if err != nil {
return return
} }
err = addLabelsToTasks(taskIDs, taskMap) err = addLabelsToTasks(s, taskIDs, taskMap)
if err != nil { if err != nil {
return return
} }
err = addAttachmentsToTasks(taskIDs, taskMap) err = addAttachmentsToTasks(s, taskIDs, taskMap)
if err != nil { if err != nil {
return return
} }
users, err := user.GetUsersByIDs(userIDs) users, err := user.GetUsersByIDs(s, userIDs)
if err != nil { if err != nil {
return return
} }
taskReminders, err := getTaskReminderMap(taskIDs) taskReminders, err := getTaskReminderMap(s, taskIDs)
if err != nil { if err != nil {
return err return err
} }
// Get all identifiers // Get all identifiers
lists, err := GetListsByIDs(listIDs) lists, err := GetListsByIDs(s, listIDs)
if err != nil { if err != nil {
return err return err
} }
@ -679,7 +679,7 @@ func addMoreInfoToTasks(taskMap map[int64]*Task) (err error) {
} }
// Get all related tasks // Get all related tasks
err = addRelatedTasksToTasks(taskIDs, taskMap) err = addRelatedTasksToTasks(s, taskIDs, taskMap)
return return
} }
@ -739,14 +739,8 @@ func checkBucketLimit(s *xorm.Session, t *Task, bucket *Bucket) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id} [put] // @Router /lists/{id} [put]
func (t *Task) Create(a web.Auth) (err error) { func (t *Task) Create(s *xorm.Session, a web.Auth) (err error) {
s := x.NewSession() return createTask(s, t, a, true)
err = createTask(s, t, a, true)
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
} }
func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err error) { func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err error) {
@ -759,16 +753,16 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
} }
// Check if the list exists // Check if the list exists
l := &List{ID: t.ListID} l, err := GetListSimpleByID(s, t.ListID)
if err = l.getSimpleByID(s); err != nil { if err != nil {
return return err
} }
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
// A negative user id indicates user share links // A negative user id indicates user share links
t.CreatedByID = a.GetID() * -1 t.CreatedByID = a.GetID() * -1
} else { } else {
u, err := user.GetUserByID(a.GetID()) u, err := user.GetUserByID(s, a.GetID())
if err != nil { if err != nil {
return err return err
} }
@ -834,7 +828,7 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
t.setIdentifier(l) t.setIdentifier(l)
err = updateListLastUpdatedS(s, &List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -853,21 +847,17 @@ func createTask(s *xorm.Session, t *Task, a web.Auth, updateAssignees bool) (err
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id} [post] // @Router /tasks/{id} [post]
//nolint:gocyclo //nolint:gocyclo
func (t *Task) Update() (err error) { func (t *Task) Update(s *xorm.Session) (err error) {
s := x.NewSession()
// Check if the task exists and get the old values // Check if the task exists and get the old values
ot, err := GetTaskByIDSimple(t.ID) ot, err := GetTaskByIDSimple(s, t.ID)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
// Get the reminders // Get the reminders
reminders, err := getRemindersForTasks([]int64{t.ID}) reminders, err := getRemindersForTasks(s, []int64{t.ID})
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -881,20 +871,17 @@ func (t *Task) Update() (err error) {
// Update the assignees // Update the assignees
if err := ot.updateTaskAssignees(s, t.Assignees); err != nil { if err := ot.updateTaskAssignees(s, t.Assignees); err != nil {
_ = s.Rollback()
return err return err
} }
// Update the reminders // Update the reminders
if err := ot.updateReminders(s, t.Reminders); err != nil { if err := ot.updateReminders(s, t.Reminders); err != nil {
_ = s.Rollback()
return err return err
} }
// If there is a bucket set, make sure they belong to the same list as the task // If there is a bucket set, make sure they belong to the same list as the task
err = checkBucketAndTaskBelongToSameList(s, &ot, t.BucketID) err = checkBucketAndTaskBelongToSameList(s, &ot, t.BucketID)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -923,7 +910,6 @@ func (t *Task) Update() (err error) {
if t.BucketID == 0 || (t.ListID != 0 && ot.ListID != t.ListID) { if t.BucketID == 0 || (t.ListID != 0 && ot.ListID != t.ListID) {
bucket, err = getDefaultBucket(s, t.ListID) bucket, err = getDefaultBucket(s, t.ListID)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
t.BucketID = bucket.ID t.BucketID = bucket.ID
@ -934,7 +920,6 @@ func (t *Task) Update() (err error) {
latestTask := &Task{} latestTask := &Task{}
_, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask) _, err = s.Where("list_id = ?", t.ListID).OrderBy("id desc").Get(latestTask)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
@ -946,7 +931,6 @@ func (t *Task) Update() (err error) {
// Only check the bucket limit if the task is being moved between buckets, allow reordering the task within a bucket // Only check the bucket limit if the task is being moved between buckets, allow reordering the task within a bucket
if t.BucketID != ot.BucketID { if t.BucketID != ot.BucketID {
if err := checkBucketLimit(s, t, bucket); err != nil { if err := checkBucketLimit(s, t, bucket); err != nil {
_ = s.Rollback()
return err return err
} }
} }
@ -972,7 +956,6 @@ func (t *Task) Update() (err error) {
// Which is why we merge the actual task struct with the one we got from the db // Which is why we merge the actual task struct with the one we got from the db
// The user struct overrides values in the actual one. // The user struct overrides values in the actual one.
if err := mergo.Merge(&ot, t, mergo.WithOverride); err != nil { if err := mergo.Merge(&ot, t, mergo.WithOverride); err != nil {
_ = s.Rollback()
return err return err
} }
@ -1034,7 +1017,6 @@ func (t *Task) Update() (err error) {
Update(ot) Update(ot)
*t = ot *t = ot
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
// Get the task updated timestamp in a new struct - if we'd just try to put it into t which we already have, it // Get the task updated timestamp in a new struct - if we'd just try to put it into t which we already have, it
@ -1042,17 +1024,11 @@ func (t *Task) Update() (err error) {
nt := &Task{} nt := &Task{}
_, err = s.ID(t.ID).Get(nt) _, err = s.ID(t.ID).Get(nt)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
t.Updated = nt.Updated t.Updated = nt.Updated
err = updateListLastUpdatedS(s, &List{ID: t.ListID}) return updateListLastUpdated(s, &List{ID: t.ListID})
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
} }
// This helper function updates the reminders, doneAt, start and end dates of the *old* task // This helper function updates the reminders, doneAt, start and end dates of the *old* task
@ -1174,7 +1150,7 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro
t.Reminders = nil t.Reminders = nil
} }
err = updateListLastUpdatedS(s, &List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -1190,20 +1166,20 @@ func (t *Task) updateReminders(s *xorm.Session, reminders []time.Time) (err erro
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{id} [delete] // @Router /tasks/{id} [delete]
func (t *Task) Delete() (err error) { func (t *Task) Delete(s *xorm.Session) (err error) {
if _, err = x.ID(t.ID).Delete(Task{}); err != nil { if _, err = s.ID(t.ID).Delete(Task{}); err != nil {
return err return err
} }
// Delete assignees // Delete assignees
if _, err = x.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil { if _, err = s.Where("task_id = ?", t.ID).Delete(TaskAssginee{}); err != nil {
return err return err
} }
metrics.UpdateCount(-1, metrics.TaskCountKey) metrics.UpdateCount(-1, metrics.TaskCountKey)
err = updateListLastUpdated(&List{ID: t.ListID}) err = updateListLastUpdated(s, &List{ID: t.ListID})
return return
} }
@ -1219,16 +1195,16 @@ func (t *Task) Delete() (err error) {
// @Failure 404 {object} models.Message "Task not found" // @Failure 404 {object} models.Message "Task not found"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /tasks/{ID} [get] // @Router /tasks/{ID} [get]
func (t *Task) ReadOne() (err error) { func (t *Task) ReadOne(s *xorm.Session) (err error) {
taskMap := make(map[int64]*Task, 1) taskMap := make(map[int64]*Task, 1)
taskMap[t.ID] = &Task{} taskMap[t.ID] = &Task{}
*taskMap[t.ID], err = GetTaskByIDSimple(t.ID) *taskMap[t.ID], err = GetTaskByIDSimple(s, t.ID)
if err != nil { if err != nil {
return return
} }
err = addMoreInfoToTasks(taskMap) err = addMoreInfoToTasks(s, taskMap)
if err != nil { if err != nil {
return return
} }

View file

@ -18,47 +18,48 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanDelete checks if the user can delete an task // CanDelete checks if the user can delete an task
func (t *Task) CanDelete(a web.Auth) (bool, error) { func (t *Task) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return t.canDoTask(a) return t.canDoTask(s, a)
} }
// CanUpdate determines if a user has the right to update a list task // CanUpdate determines if a user has the right to update a list task
func (t *Task) CanUpdate(a web.Auth) (bool, error) { func (t *Task) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return t.canDoTask(a) return t.canDoTask(s, a)
} }
// CanCreate determines if a user has the right to create a list task // CanCreate determines if a user has the right to create a list task
func (t *Task) CanCreate(a web.Auth) (bool, error) { func (t *Task) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
// A user can do a task if he has write acces to its list // A user can do a task if he has write acces to its list
l := &List{ID: t.ListID} l := &List{ID: t.ListID}
return l.CanWrite(a) return l.CanWrite(s, a)
} }
// CanRead determines if a user can read a task // CanRead determines if a user can read a task
func (t *Task) CanRead(a web.Auth) (canRead bool, maxRight int, err error) { func (t *Task) CanRead(s *xorm.Session, a web.Auth) (canRead bool, maxRight int, err error) {
// Get the task, error out if it doesn't exist // Get the task, error out if it doesn't exist
*t, err = GetTaskByIDSimple(t.ID) *t, err = GetTaskByIDSimple(s, t.ID)
if err != nil { if err != nil {
return return
} }
// A user can read a task if it has access to the list // A user can read a task if it has access to the list
l := &List{ID: t.ListID} l := &List{ID: t.ListID}
return l.CanRead(a) return l.CanRead(s, a)
} }
// CanWrite checks if a user has write access to a task // CanWrite checks if a user has write access to a task
func (t *Task) CanWrite(a web.Auth) (canWrite bool, err error) { func (t *Task) CanWrite(s *xorm.Session, a web.Auth) (canWrite bool, err error) {
return t.canDoTask(a) return t.canDoTask(s, a)
} }
// Helper function to check if a user can do stuff on a list task // Helper function to check if a user can do stuff on a list task
func (t *Task) canDoTask(a web.Auth) (bool, error) { func (t *Task) canDoTask(s *xorm.Session, a web.Auth) (bool, error) {
// Get the task // Get the task
ot, err := GetTaskByIDSimple(t.ID) ot, err := GetTaskByIDSimple(s, t.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -66,7 +67,7 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) {
// Check if we're moving the task into a different list to check if the user has sufficient rights for that on the new list // Check if we're moving the task into a different list to check if the user has sufficient rights for that on the new list
if t.ListID != 0 && t.ListID != ot.ListID { if t.ListID != 0 && t.ListID != ot.ListID {
newList := &List{ID: t.ListID} newList := &List{ID: t.ListID}
can, err := newList.CanWrite(a) can, err := newList.CanWrite(s, a)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -77,5 +78,5 @@ func (t *Task) canDoTask(a web.Auth) (bool, error) {
// A user can do a task if it has write acces to its list // A user can do a task if it has write acces to its list
l := &List{ID: ot.ListID} l := &List{ID: ot.ListID}
return l.CanWrite(a) return l.CanWrite(s, a)
} }

View file

@ -36,12 +36,15 @@ func TestTask_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
Title: "Lorem", Title: "Lorem",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
} }
err := task.Create(usr) err := task.Create(s, usr)
assert.NoError(t, err) assert.NoError(t, err)
// Assert getting a uid // Assert getting a uid
assert.NotEmpty(t, task.UID) assert.NotEmpty(t, task.UID)
@ -50,6 +53,9 @@ func TestTask_Create(t *testing.T) {
assert.Equal(t, int64(18), task.Index) assert.Equal(t, int64(18), task.Index)
// Assert moving it into the default bucket // Assert moving it into the default bucket
assert.Equal(t, int64(1), task.BucketID) assert.Equal(t, int64(1), task.BucketID)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "tasks", map[string]interface{}{ db.AssertExists(t, "tasks", map[string]interface{}{
"id": task.ID, "id": task.ID,
"title": "Lorem", "title": "Lorem",
@ -62,47 +68,59 @@ func TestTask_Create(t *testing.T) {
}) })
t.Run("empty title", func(t *testing.T) { t.Run("empty title", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
Title: "", Title: "",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
} }
err := task.Create(usr) err := task.Create(s, usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskCannotBeEmpty(err)) assert.True(t, IsErrTaskCannotBeEmpty(err))
}) })
t.Run("nonexistant list", func(t *testing.T) { t.Run("nonexistant list", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
Title: "Test", Title: "Test",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 9999999, ListID: 9999999,
} }
err := task.Create(usr) err := task.Create(s, usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrListDoesNotExist(err)) assert.True(t, IsErrListDoesNotExist(err))
}) })
t.Run("noneixtant user", func(t *testing.T) { t.Run("noneixtant user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
nUser := &user.User{ID: 99999999} nUser := &user.User{ID: 99999999}
task := &Task{ task := &Task{
Title: "Test", Title: "Test",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
} }
err := task.Create(nUser) err := task.Create(s, nUser)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
}) })
t.Run("full bucket", func(t *testing.T) { t.Run("full bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
Title: "Lorem", Title: "Lorem",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
} }
err := task.Create(usr) err := task.Create(s, usr)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrBucketLimitExceeded(err)) assert.True(t, IsErrBucketLimitExceeded(err))
}) })
@ -111,14 +129,20 @@ func TestTask_Create(t *testing.T) {
func TestTask_Update(t *testing.T) { func TestTask_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
ID: 1, ID: 1,
Title: "test10000", Title: "test10000",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
} }
err := task.Update() err := task.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "tasks", map[string]interface{}{ db.AssertExists(t, "tasks", map[string]interface{}{
"id": 1, "id": 1,
"title": "test10000", "title": "test10000",
@ -128,18 +152,24 @@ func TestTask_Update(t *testing.T) {
}) })
t.Run("nonexistant task", func(t *testing.T) { t.Run("nonexistant task", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
ID: 9999999, ID: 9999999,
Title: "test10000", Title: "test10000",
Description: "Lorem Ipsum Dolor", Description: "Lorem Ipsum Dolor",
ListID: 1, ListID: 1,
} }
err := task.Update() err := task.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err)) assert.True(t, IsErrTaskDoesNotExist(err))
}) })
t.Run("full bucket", func(t *testing.T) { t.Run("full bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
ID: 1, ID: 1,
Title: "test10000", Title: "test10000",
@ -147,12 +177,15 @@ func TestTask_Update(t *testing.T) {
ListID: 1, ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
} }
err := task.Update() err := task.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrBucketLimitExceeded(err)) assert.True(t, IsErrBucketLimitExceeded(err))
}) })
t.Run("full bucket but not changing the bucket", func(t *testing.T) { t.Run("full bucket but not changing the bucket", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
ID: 4, ID: 4,
Title: "test10000", Title: "test10000",
@ -161,7 +194,7 @@ func TestTask_Update(t *testing.T) {
ListID: 1, ListID: 1,
BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3 BucketID: 2, // Bucket 2 already has 3 tasks and a limit of 3
} }
err := task.Update() err := task.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -169,11 +202,17 @@ func TestTask_Update(t *testing.T) {
func TestTask_Delete(t *testing.T) { func TestTask_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ task := &Task{
ID: 1, ID: 1,
} }
err := task.Delete() err := task.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "tasks", map[string]interface{}{ db.AssertMissing(t, "tasks", map[string]interface{}{
"id": 1, "id": 1,
}) })
@ -183,6 +222,9 @@ func TestTask_Delete(t *testing.T) {
func TestUpdateDone(t *testing.T) { func TestUpdateDone(t *testing.T) {
t.Run("marking a task as done", func(t *testing.T) { t.Run("marking a task as done", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
oldTask := &Task{Done: false} oldTask := &Task{Done: false}
newTask := &Task{Done: true} newTask := &Task{Done: true}
updateDone(oldTask, newTask) updateDone(oldTask, newTask)
@ -190,6 +232,9 @@ func TestUpdateDone(t *testing.T) {
}) })
t.Run("unmarking a task as done", func(t *testing.T) { t.Run("unmarking a task as done", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
oldTask := &Task{Done: true} oldTask := &Task{Done: true}
newTask := &Task{Done: false} newTask := &Task{Done: false}
updateDone(oldTask, newTask) updateDone(oldTask, newTask)
@ -397,15 +442,21 @@ func TestUpdateDone(t *testing.T) {
func TestTask_ReadOne(t *testing.T) { func TestTask_ReadOne(t *testing.T) {
t.Run("default", func(t *testing.T) { t.Run("default", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ID: 1} task := &Task{ID: 1}
err := task.ReadOne() err := task.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "task #1", task.Title) assert.Equal(t, "task #1", task.Title)
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
task := &Task{ID: 99999} task := &Task{ID: 99999}
err := task.ReadOne() err := task.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTaskDoesNotExist(err)) assert.True(t, IsErrTaskDoesNotExist(err))
}) })

View file

@ -19,6 +19,7 @@ package models
import ( import (
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// Create implements the create method to assign a user to a team // Create implements the create method to assign a user to a team
@ -35,23 +36,24 @@ import (
// @Failure 403 {object} web.HTTPError "The user does not have access to the team" // @Failure 403 {object} web.HTTPError "The user does not have access to the team"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members [put] // @Router /teams/{id}/members [put]
func (tm *TeamMember) Create(a web.Auth) (err error) { func (tm *TeamMember) Create(s *xorm.Session, a web.Auth) (err error) {
// Check if the team extst // Check if the team extst
_, err = GetTeamByID(tm.TeamID) _, err = GetTeamByID(s, tm.TeamID)
if err != nil { if err != nil {
return return
} }
// Check if the user exists // Check if the user exists
user, err := user2.GetUserByUsername(tm.Username) user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil { if err != nil {
return return
} }
tm.UserID = user.ID tm.UserID = user.ID
// Check if that user is already part of the team // Check if that user is already part of the team
exists, err := x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). exists, err := s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Get(&TeamMember{}) Get(&TeamMember{})
if err != nil { if err != nil {
return return
@ -61,7 +63,7 @@ func (tm *TeamMember) Create(a web.Auth) (err error) {
} }
// Insert the user // Insert the user
_, err = x.Insert(tm) _, err = s.Insert(tm)
return return
} }
@ -76,9 +78,9 @@ func (tm *TeamMember) Create(a web.Auth) (err error) {
// @Success 200 {object} models.Message "The user was successfully removed from the team." // @Success 200 {object} models.Message "The user was successfully removed from the team."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members/{userID} [delete] // @Router /teams/{id}/members/{userID} [delete]
func (tm *TeamMember) Delete() (err error) { func (tm *TeamMember) Delete(s *xorm.Session) (err error) {
total, err := x.Where("team_id = ?", tm.TeamID).Count(&TeamMember{}) total, err := s.Where("team_id = ?", tm.TeamID).Count(&TeamMember{})
if err != nil { if err != nil {
return return
} }
@ -87,13 +89,13 @@ func (tm *TeamMember) Delete() (err error) {
} }
// Find the numeric user id // Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username) user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil { if err != nil {
return return
} }
tm.UserID = user.ID tm.UserID = user.ID
_, err = x.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{}) _, err = s.Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).Delete(&TeamMember{})
return return
} }
@ -108,9 +110,9 @@ func (tm *TeamMember) Delete() (err error) {
// @Success 200 {object} models.Message "The member right was successfully changed." // @Success 200 {object} models.Message "The member right was successfully changed."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id}/members/{userID}/admin [post] // @Router /teams/{id}/members/{userID}/admin [post]
func (tm *TeamMember) Update() (err error) { func (tm *TeamMember) Update(s *xorm.Session) (err error) {
// Find the numeric user id // Find the numeric user id
user, err := user2.GetUserByUsername(tm.Username) user, err := user2.GetUserByUsername(s, tm.Username)
if err != nil { if err != nil {
return return
} }
@ -118,7 +120,7 @@ func (tm *TeamMember) Update() (err error) {
// Get the full member object and change the admin right // Get the full member object and change the admin right
ttm := &TeamMember{} ttm := &TeamMember{}
_, err = x. _, err = s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Get(ttm) Get(ttm)
if err != nil { if err != nil {
@ -127,7 +129,7 @@ func (tm *TeamMember) Update() (err error) {
ttm.Admin = !ttm.Admin ttm.Admin = !ttm.Admin
// Do the update // Do the update
_, err = x. _, err = s.
Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID). Where("team_id = ? AND user_id = ?", tm.TeamID, tm.UserID).
Cols("admin"). Cols("admin").
Update(ttm) Update(ttm)

View file

@ -18,32 +18,34 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if the user can add a new tem member // CanCreate checks if the user can add a new tem member
func (tm *TeamMember) CanCreate(a web.Auth) (bool, error) { func (tm *TeamMember) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(a) return tm.IsAdmin(s, a)
} }
// CanDelete checks if the user can delete a new team member // CanDelete checks if the user can delete a new team member
func (tm *TeamMember) CanDelete(a web.Auth) (bool, error) { func (tm *TeamMember) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(a) return tm.IsAdmin(s, a)
} }
// CanUpdate checks if the user can modify a team member's right // CanUpdate checks if the user can modify a team member's right
func (tm *TeamMember) CanUpdate(a web.Auth) (bool, error) { func (tm *TeamMember) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return tm.IsAdmin(a) return tm.IsAdmin(s, a)
} }
// IsAdmin checks if the user is team admin // IsAdmin checks if the user is team admin
func (tm *TeamMember) IsAdmin(a web.Auth) (bool, error) { func (tm *TeamMember) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// Don't allow anything if we're dealing with a list share here // Don't allow anything if we're dealing with a list share here
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
// A user can add a member to a team if he is admin of that team // A user can add a member to a team if he is admin of that team
exists, err := x.Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true). exists, err := s.
Where("user_id = ? AND team_id = ? AND admin = ?", a.GetID(), tm.TeamID, true).
Get(&TeamMember{}) Get(&TeamMember{})
return exists, err return exists, err
} }

View file

@ -32,12 +32,18 @@ func TestTeamMember_Create(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "user3", Username: "user3",
} }
err := tm.Create(doer) err := tm.Create(s, doer)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{ db.AssertExists(t, "team_members", map[string]interface{}{
"id": tm.ID, "id": tm.ID,
"team_id": 1, "team_id": 1,
@ -46,31 +52,40 @@ func TestTeamMember_Create(t *testing.T) {
}) })
t.Run("already existing", func(t *testing.T) { t.Run("already existing", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "user1", Username: "user1",
} }
err := tm.Create(doer) err := tm.Create(s, doer)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrUserIsMemberOfTeam(err)) assert.True(t, IsErrUserIsMemberOfTeam(err))
}) })
t.Run("nonexisting user", func(t *testing.T) { t.Run("nonexisting user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "nonexistinguser", Username: "nonexistinguser",
} }
err := tm.Create(doer) err := tm.Create(s, doer)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, user.IsErrUserDoesNotExist(err)) assert.True(t, user.IsErrUserDoesNotExist(err))
}) })
t.Run("nonexisting team", func(t *testing.T) { t.Run("nonexisting team", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 9999999, TeamID: 9999999,
Username: "user1", Username: "user1",
} }
err := tm.Create(doer) err := tm.Create(s, doer)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
}) })
@ -79,12 +94,18 @@ func TestTeamMember_Create(t *testing.T) {
func TestTeamMember_Delete(t *testing.T) { func TestTeamMember_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "user1", Username: "user1",
} }
err := tm.Delete() err := tm.Delete(s)
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertMissing(t, "team_members", map[string]interface{}{ db.AssertMissing(t, "team_members", map[string]interface{}{
"team_id": 1, "team_id": 1,
"user_id": 1, "user_id": 1,
@ -95,14 +116,20 @@ func TestTeamMember_Delete(t *testing.T) {
func TestTeamMember_Update(t *testing.T) { func TestTeamMember_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "user1", Username: "user1",
Admin: true, Admin: true,
} }
err := tm.Update() err := tm.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, tm.Admin) // Since this endpoint toggles the right, we should get a false for admin back. assert.False(t, tm.Admin) // Since this endpoint toggles the right, we should get a false for admin back.
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{ db.AssertExists(t, "team_members", map[string]interface{}{
"team_id": 1, "team_id": 1,
"user_id": 1, "user_id": 1,
@ -113,14 +140,20 @@ func TestTeamMember_Update(t *testing.T) {
// should ignore what was passed. // should ignore what was passed.
t.Run("explicitly false in payload", func(t *testing.T) { t.Run("explicitly false in payload", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &TeamMember{ tm := &TeamMember{
TeamID: 1, TeamID: 1,
Username: "user1", Username: "user1",
Admin: true, Admin: true,
} }
err := tm.Update() err := tm.Update(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, tm.Admin) assert.False(t, tm.Admin)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "team_members", map[string]interface{}{ db.AssertExists(t, "team_members", map[string]interface{}{
"team_id": 1, "team_id": 1,
"user_id": 1, "user_id": 1,

View file

@ -19,6 +19,8 @@ package models
import ( import (
"time" "time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/metrics" "code.vikunja.io/api/pkg/metrics"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web" "code.vikunja.io/web"
@ -54,10 +56,6 @@ func (Team) TableName() string {
return "teams" return "teams"
} }
// AfterLoad gets the created by user object
func (t *Team) AfterLoad() {
}
// TeamMember defines the relationship between a user and a team // TeamMember defines the relationship between a user and a team
type TeamMember struct { type TeamMember struct {
// The unique, numeric id of this team member relation. // The unique, numeric id of this team member relation.
@ -92,14 +90,14 @@ type TeamUser struct {
} }
// GetTeamByID gets a team by its ID // GetTeamByID gets a team by its ID
func GetTeamByID(id int64) (team *Team, err error) { func GetTeamByID(s *xorm.Session, id int64) (team *Team, err error) {
if id < 1 { if id < 1 {
return team, ErrTeamDoesNotExist{id} return team, ErrTeamDoesNotExist{id}
} }
t := Team{} t := Team{}
exists, err := x. exists, err := s.
Where("id = ?", id). Where("id = ?", id).
Get(&t) Get(&t)
if err != nil { if err != nil {
@ -110,7 +108,7 @@ func GetTeamByID(id int64) (team *Team, err error) {
} }
teamSlice := []*Team{&t} teamSlice := []*Team{&t}
err = addMoreInfoToTeams(teamSlice) err = addMoreInfoToTeams(s, teamSlice)
if err != nil { if err != nil {
return return
} }
@ -120,7 +118,7 @@ func GetTeamByID(id int64) (team *Team, err error) {
return return
} }
func addMoreInfoToTeams(teams []*Team) (err error) { func addMoreInfoToTeams(s *xorm.Session, teams []*Team) (err error) {
// Put the teams in a map to make assigning more info to it more efficient // Put the teams in a map to make assigning more info to it more efficient
teamMap := make(map[int64]*Team, len(teams)) teamMap := make(map[int64]*Team, len(teams))
var teamIDs []int64 var teamIDs []int64
@ -133,7 +131,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) {
// Get all owners and team members // Get all owners and team members
users := make(map[int64]*TeamUser) users := make(map[int64]*TeamUser)
err = x.Select("*"). err = s.
Select("*").
Table("users"). Table("users").
Join("LEFT", "team_members", "team_members.user_id = users.id"). Join("LEFT", "team_members", "team_members.user_id = users.id").
Join("LEFT", "teams", "team_members.team_id = teams.id"). Join("LEFT", "teams", "team_members.team_id = teams.id").
@ -178,8 +177,8 @@ func addMoreInfoToTeams(teams []*Team) (err error) {
// @Failure 403 {object} web.HTTPError "The user does not have access to the team" // @Failure 403 {object} web.HTTPError "The user does not have access to the team"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [get] // @Router /teams/{id} [get]
func (t *Team) ReadOne() (err error) { func (t *Team) ReadOne(s *xorm.Session) (err error) {
team, err := GetTeamByID(t.ID) team, err := GetTeamByID(s, t.ID)
if team != nil { if team != nil {
*t = *team *t = *team
} }
@ -199,7 +198,7 @@ func (t *Team) ReadOne() (err error) {
// @Success 200 {array} models.Team "The teams." // @Success 200 {array} models.Team "The teams."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams [get] // @Router /teams [get]
func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) { func (t *Team) ReadAll(s *xorm.Session, a web.Auth, search string, page int, perPage int) (result interface{}, resultCount int, numberOfTotalItems int64, err error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return nil, 0, 0, ErrGenericForbidden{} return nil, 0, 0, ErrGenericForbidden{}
} }
@ -207,7 +206,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
limit, start := getLimitFromPageIndex(page, perPage) limit, start := getLimitFromPageIndex(page, perPage)
all := []*Team{} all := []*Team{}
query := x.Select("teams.*"). query := s.Select("teams.*").
Table("teams"). Table("teams").
Join("INNER", "team_members", "team_members.team_id = teams.id"). Join("INNER", "team_members", "team_members.team_id = teams.id").
Where("team_members.user_id = ?", a.GetID()). Where("team_members.user_id = ?", a.GetID()).
@ -220,12 +219,12 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
return nil, 0, 0, err return nil, 0, 0, err
} }
err = addMoreInfoToTeams(all) err = addMoreInfoToTeams(s, all)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
numberOfTotalItems, err = x. numberOfTotalItems, err = s.
Table("teams"). Table("teams").
Join("INNER", "team_members", "team_members.team_id = teams.id"). Join("INNER", "team_members", "team_members.team_id = teams.id").
Where("team_members.user_id = ?", a.GetID()). Where("team_members.user_id = ?", a.GetID()).
@ -246,7 +245,7 @@ func (t *Team) ReadAll(a web.Auth, search string, page int, perPage int) (result
// @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams [put] // @Router /teams [put]
func (t *Team) Create(a web.Auth) (err error) { func (t *Team) Create(s *xorm.Session, a web.Auth) (err error) {
doer, err := user.GetFromAuth(a) doer, err := user.GetFromAuth(a)
if err != nil { if err != nil {
return err return err
@ -260,14 +259,14 @@ func (t *Team) Create(a web.Auth) (err error) {
t.CreatedByID = doer.ID t.CreatedByID = doer.ID
t.CreatedBy = doer t.CreatedBy = doer
_, err = x.Insert(t) _, err = s.Insert(t)
if err != nil { if err != nil {
return return
} }
// Insert the current user as member and admin // Insert the current user as member and admin
tm := TeamMember{TeamID: t.ID, Username: doer.Username, Admin: true} tm := TeamMember{TeamID: t.ID, Username: doer.Username, Admin: true}
if err = tm.Create(doer); err != nil { if err = tm.Create(s, doer); err != nil {
return err return err
} }
@ -286,28 +285,28 @@ func (t *Team) Create(a web.Auth) (err error) {
// @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [delete] // @Router /teams/{id} [delete]
func (t *Team) Delete() (err error) { func (t *Team) Delete(s *xorm.Session) (err error) {
// Delete the team // Delete the team
_, err = x.ID(t.ID).Delete(&Team{}) _, err = s.ID(t.ID).Delete(&Team{})
if err != nil { if err != nil {
return return
} }
// Delete team members // Delete team members
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamMember{}) _, err = s.Where("team_id = ?", t.ID).Delete(&TeamMember{})
if err != nil { if err != nil {
return return
} }
// Delete team <-> namespace relations // Delete team <-> namespace relations
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamNamespace{}) _, err = s.Where("team_id = ?", t.ID).Delete(&TeamNamespace{})
if err != nil { if err != nil {
return return
} }
// Delete team <-> lists relations // Delete team <-> lists relations
_, err = x.Where("team_id = ?", t.ID).Delete(&TeamList{}) _, err = s.Where("team_id = ?", t.ID).Delete(&TeamList{})
if err != nil { if err != nil {
return return
} }
@ -329,25 +328,25 @@ func (t *Team) Delete() (err error) {
// @Failure 400 {object} web.HTTPError "Invalid team object provided." // @Failure 400 {object} web.HTTPError "Invalid team object provided."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /teams/{id} [post] // @Router /teams/{id} [post]
func (t *Team) Update() (err error) { func (t *Team) Update(s *xorm.Session) (err error) {
// Check if we have a name // Check if we have a name
if t.Name == "" { if t.Name == "" {
return ErrTeamNameCannotBeEmpty{} return ErrTeamNameCannotBeEmpty{}
} }
// Check if the team exists // Check if the team exists
_, err = GetTeamByID(t.ID) _, err = GetTeamByID(s, t.ID)
if err != nil { if err != nil {
return return
} }
_, err = x.ID(t.ID).Update(t) _, err = s.ID(t.ID).Update(t)
if err != nil { if err != nil {
return return
} }
// Get the newly updated team // Get the newly updated team
team, err := GetTeamByID(t.ID) team, err := GetTeamByID(s, t.ID)
if team != nil { if team != nil {
*t = *team *t = *team
} }

View file

@ -18,10 +18,11 @@ package models
import ( import (
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// CanCreate checks if the user can create a new team // CanCreate checks if the user can create a new team
func (t *Team) CanCreate(a web.Auth) (bool, error) { func (t *Team) CanCreate(s *xorm.Session, a web.Auth) (bool, error) {
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
@ -31,39 +32,40 @@ func (t *Team) CanCreate(a web.Auth) (bool, error) {
} }
// CanUpdate checks if the user can update a team // CanUpdate checks if the user can update a team
func (t *Team) CanUpdate(a web.Auth) (bool, error) { func (t *Team) CanUpdate(s *xorm.Session, a web.Auth) (bool, error) {
return t.IsAdmin(a) return t.IsAdmin(s, a)
} }
// CanDelete checks if a user can delete a team // CanDelete checks if a user can delete a team
func (t *Team) CanDelete(a web.Auth) (bool, error) { func (t *Team) CanDelete(s *xorm.Session, a web.Auth) (bool, error) {
return t.IsAdmin(a) return t.IsAdmin(s, a)
} }
// IsAdmin returns true when the user is admin of a team // IsAdmin returns true when the user is admin of a team
func (t *Team) IsAdmin(a web.Auth) (bool, error) { func (t *Team) IsAdmin(s *xorm.Session, a web.Auth) (bool, error) {
// Don't do anything if we're deadling with a link share auth here // Don't do anything if we're deadling with a link share auth here
if _, is := a.(*LinkSharing); is { if _, is := a.(*LinkSharing); is {
return false, nil return false, nil
} }
// Check if the team exists to be able to return a proper error message if not // Check if the team exists to be able to return a proper error message if not
_, err := GetTeamByID(t.ID) _, err := GetTeamByID(s, t.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
return x.Where("team_id = ?", t.ID). return s.Where("team_id = ?", t.ID).
And("user_id = ?", a.GetID()). And("user_id = ?", a.GetID()).
And("admin = ?", true). And("admin = ?", true).
Get(&TeamMember{}) Get(&TeamMember{})
} }
// CanRead returns true if the user has read access to the team // CanRead returns true if the user has read access to the team
func (t *Team) CanRead(a web.Auth) (bool, int, error) { func (t *Team) CanRead(s *xorm.Session, a web.Auth) (bool, int, error) {
// Check if the user is in the team // Check if the user is in the team
tm := &TeamMember{} tm := &TeamMember{}
can, err := x.Where("team_id = ?", t.ID). can, err := s.
Where("team_id = ?", t.ID).
And("user_id = ?", a.GetID()). And("user_id = ?", a.GetID()).
Get(tm) Get(tm)

View file

@ -82,6 +82,8 @@ func TestTeam_CanDoSomething(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
tm := &Team{ tm := &Team{
ID: tt.fields.ID, ID: tt.fields.ID,
@ -96,19 +98,19 @@ func TestTeam_CanDoSomething(t *testing.T) {
Rights: tt.fields.Rights, Rights: tt.fields.Rights,
} }
if got, _ := tm.CanCreate(tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true if got, _ := tm.CanCreate(s, tt.args.a); got != tt.want["CanCreate"] { // CanCreate is currently always true
t.Errorf("Team.CanCreate() = %v, want %v", got, tt.want["CanCreate"]) t.Errorf("Team.CanCreate() = %v, want %v", got, tt.want["CanCreate"])
} }
if got, _ := tm.CanDelete(tt.args.a); got != tt.want["CanDelete"] { if got, _ := tm.CanDelete(s, tt.args.a); got != tt.want["CanDelete"] {
t.Errorf("Team.CanDelete() = %v, want %v", got, tt.want["CanDelete"]) t.Errorf("Team.CanDelete() = %v, want %v", got, tt.want["CanDelete"])
} }
if got, _ := tm.CanUpdate(tt.args.a); got != tt.want["CanUpdate"] { if got, _ := tm.CanUpdate(s, tt.args.a); got != tt.want["CanUpdate"] {
t.Errorf("Team.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"]) t.Errorf("Team.CanUpdate() = %v, want %v", got, tt.want["CanUpdate"])
} }
if got, _, _ := tm.CanRead(tt.args.a); got != tt.want["CanRead"] { if got, _, _ := tm.CanRead(s, tt.args.a); got != tt.want["CanRead"] {
t.Errorf("Team.CanRead() = %v, want %v", got, tt.want["CanRead"]) t.Errorf("Team.CanRead() = %v, want %v", got, tt.want["CanRead"])
} }
if got, _ := tm.IsAdmin(tt.args.a); got != tt.want["IsAdmin"] { if got, _ := tm.IsAdmin(s, tt.args.a); got != tt.want["IsAdmin"] {
t.Errorf("Team.IsAdmin() = %v, want %v", got, tt.want["IsAdmin"]) t.Errorf("Team.IsAdmin() = %v, want %v", got, tt.want["IsAdmin"])
} }
}) })

View file

@ -32,11 +32,16 @@ func TestTeam_Create(t *testing.T) {
} }
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ team := &Team{
Name: "Testteam293", Name: "Testteam293",
Description: "Lorem Ispum", Description: "Lorem Ispum",
} }
err := team.Create(doer) err := team.Create(s, doer)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "teams", map[string]interface{}{ db.AssertExists(t, "teams", map[string]interface{}{
"id": team.ID, "id": team.ID,
@ -46,8 +51,11 @@ func TestTeam_Create(t *testing.T) {
}) })
t.Run("empty name", func(t *testing.T) { t.Run("empty name", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{} team := &Team{}
err := team.Create(doer) err := team.Create(s, doer)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamNameCannotBeEmpty(err)) assert.True(t, IsErrTeamNameCannotBeEmpty(err))
}) })
@ -56,8 +64,11 @@ func TestTeam_Create(t *testing.T) {
func TestTeam_ReadOne(t *testing.T) { func TestTeam_ReadOne(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: 1} team := &Team{ID: 1}
err := team.ReadOne() err := team.ReadOne(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "testteam1", team.Name) assert.Equal(t, "testteam1", team.Name)
assert.Equal(t, "Lorem Ipsum", team.Description) assert.Equal(t, "Lorem Ipsum", team.Description)
@ -66,15 +77,21 @@ func TestTeam_ReadOne(t *testing.T) {
}) })
t.Run("invalid id", func(t *testing.T) { t.Run("invalid id", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: -1} team := &Team{ID: -1}
err := team.ReadOne() err := team.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ID: 99999} team := &Team{ID: 99999}
err := team.ReadOne() err := team.ReadOne(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
}) })
@ -83,23 +100,31 @@ func TestTeam_ReadOne(t *testing.T) {
func TestTeam_ReadAll(t *testing.T) { func TestTeam_ReadAll(t *testing.T) {
doer := &user.User{ID: 1} doer := &user.User{ID: 1}
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
s := db.NewSession()
defer s.Close()
team := &Team{} team := &Team{}
ts, _, _, err := team.ReadAll(doer, "", 1, 50) teams, _, _, err := team.ReadAll(s, doer, "", 1, 50)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, reflect.TypeOf(ts).Kind(), reflect.Slice) assert.Equal(t, reflect.TypeOf(teams).Kind(), reflect.Slice)
s := reflect.ValueOf(ts) ts := reflect.ValueOf(teams)
assert.Equal(t, 8, s.Len()) assert.Equal(t, 8, ts.Len())
}) })
} }
func TestTeam_Update(t *testing.T) { func TestTeam_Update(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ team := &Team{
ID: 1, ID: 1,
Name: "SomethingNew", Name: "SomethingNew",
} }
err := team.Update() err := team.Update(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertExists(t, "teams", map[string]interface{}{ db.AssertExists(t, "teams", map[string]interface{}{
"id": team.ID, "id": team.ID,
@ -108,21 +133,27 @@ func TestTeam_Update(t *testing.T) {
}) })
t.Run("empty name", func(t *testing.T) { t.Run("empty name", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ team := &Team{
ID: 1, ID: 1,
Name: "", Name: "",
} }
err := team.Update() err := team.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamNameCannotBeEmpty(err)) assert.True(t, IsErrTeamNameCannotBeEmpty(err))
}) })
t.Run("nonexisting", func(t *testing.T) { t.Run("nonexisting", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ team := &Team{
ID: 9999, ID: 9999,
Name: "SomethingNew", Name: "SomethingNew",
} }
err := team.Update() err := team.Update(s)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, IsErrTeamDoesNotExist(err)) assert.True(t, IsErrTeamDoesNotExist(err))
}) })
@ -131,10 +162,15 @@ func TestTeam_Update(t *testing.T) {
func TestTeam_Delete(t *testing.T) { func TestTeam_Delete(t *testing.T) {
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
team := &Team{ team := &Team{
ID: 1, ID: 1,
} }
err := team.Delete() err := team.Delete(s)
assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err) assert.NoError(t, err)
db.AssertMissing(t, "teams", map[string]interface{}{ db.AssertMissing(t, "teams", map[string]interface{}{
"id": 1, "id": 1,

View file

@ -16,7 +16,10 @@
package models package models
import "code.vikunja.io/api/pkg/files" import (
"code.vikunja.io/api/pkg/files"
"xorm.io/xorm"
)
// Unsplash requires us to do pingbacks to their site and also name the image author. // Unsplash requires us to do pingbacks to their site and also name the image author.
// To do this properly, we need to save these information somewhere. // To do this properly, we need to save these information somewhere.
@ -36,15 +39,15 @@ func (u *UnsplashPhoto) TableName() string {
} }
// Save persists an unsplash photo to the db // Save persists an unsplash photo to the db
func (u *UnsplashPhoto) Save() error { func (u *UnsplashPhoto) Save(s *xorm.Session) error {
_, err := x.Insert(u) _, err := s.Insert(u)
return err return err
} }
// GetUnsplashPhotoByFileID returns an unsplash photo by its saved file id // GetUnsplashPhotoByFileID returns an unsplash photo by its saved file id
func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) { func GetUnsplashPhotoByFileID(s *xorm.Session, fileID int64) (u *UnsplashPhoto, err error) {
u = &UnsplashPhoto{} u = &UnsplashPhoto{}
exists, err := x.Where("file_id = ?", fileID).Get(u) exists, err := s.Where("file_id = ?", fileID).Get(u)
if err != nil { if err != nil {
return return
} }
@ -55,10 +58,10 @@ func GetUnsplashPhotoByFileID(fileID int64) (u *UnsplashPhoto, err error) {
} }
// RemoveUnsplashPhoto removes an unsplash photo from the db // RemoveUnsplashPhoto removes an unsplash photo from the db
func RemoveUnsplashPhoto(fileID int64) (err error) { func RemoveUnsplashPhoto(s *xorm.Session, fileID int64) (err error) {
// This is intentionally "fire and forget" which is why we don't check if we have an // This is intentionally "fire and forget" which is why we don't check if we have an
// unsplash entry for that file at all. If there is one, it will be deleted. // unsplash entry for that file at all. If there is one, it will be deleted.
// We do this to keep the function simple. // We do this to keep the function simple.
_, err = x.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{}) _, err = s.Where("file_id = ?", fileID).Delete(&UnsplashPhoto{})
return return
} }

View file

@ -20,6 +20,7 @@ package models
import ( import (
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm"
) )
// ListUIDs hold all kinds of user IDs from accounts who have somehow access to a list // ListUIDs hold all kinds of user IDs from accounts who have somehow access to a list
@ -33,11 +34,11 @@ type ListUIDs struct {
} }
// ListUsersFromList returns a list with all users who have access to a list, regardless of the method which gave them access // ListUsersFromList returns a list with all users who have access to a list, regardless of the method which gave them access
func ListUsersFromList(l *List, search string) (users []*user.User, err error) { func ListUsersFromList(s *xorm.Session, l *List, search string) (users []*user.User, err error) {
userids := []*ListUIDs{} userids := []*ListUIDs{}
err = x. err = s.
Select(`l.owner_id as listOwner, Select(`l.owner_id as listOwner,
un.user_id as unID, un.user_id as unID,
ul.user_id as ulID, ul.user_id as ulID,
@ -97,7 +98,7 @@ func ListUsersFromList(l *List, search string) (users []*user.User, err error) {
} }
// Get all users // Get all users
err = x. err = s.
Table("users"). Table("users").
Select("*"). Select("*").
In("id", uids). In("id", uids).

View file

@ -201,8 +201,10 @@ func TestListUsersFromList(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
gotUsers, err := ListUsersFromList(tt.args.l, tt.args.search) gotUsers, err := ListUsersFromList(s, tt.args.l, tt.args.search)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ListUsersFromList() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListUsersFromList() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -23,6 +23,9 @@ import (
"net/http" "net/http"
"time" "time"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/modules/auth"
@ -130,8 +133,17 @@ func HandleCallback(c echo.Context) error {
return err return err
} }
s := db.NewSession()
defer s.Close()
// Check if we have seen this user before // Check if we have seen this user before
u, err := getOrCreateUser(cl, idToken.Issuer, idToken.Subject) u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject)
if err != nil {
_ = s.Rollback()
return err
}
err = s.Commit()
if err != nil { if err != nil {
return err return err
} }
@ -140,9 +152,9 @@ func HandleCallback(c echo.Context) error {
return auth.NewUserAuthTokenResponse(u, c) return auth.NewUserAuthTokenResponse(u, c)
} }
func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err error) { func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) {
// Check if the user exists for that issuer and subject // Check if the user exists for that issuer and subject
u, err = user.GetUserWithEmail(&user.User{ u, err = user.GetUserWithEmail(s, &user.User{
Issuer: issuer, Issuer: issuer,
Subject: subject, Subject: subject,
}) })
@ -165,7 +177,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
uu.Username = petname.Generate(3, "-") uu.Username = petname.Generate(3, "-")
} }
u, err = user.CreateUser(uu) u, err = user.CreateUser(s, uu)
if err != nil && !user.IsErrUsernameExists(err) { if err != nil && !user.IsErrUsernameExists(err) {
return nil, err return nil, err
} }
@ -173,14 +185,14 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
// If their preferred username is already taken, create some random one from the email and subject // If their preferred username is already taken, create some random one from the email and subject
if user.IsErrUsernameExists(err) { if user.IsErrUsernameExists(err) {
uu.Username = petname.Generate(3, "-") uu.Username = petname.Generate(3, "-")
u, err = user.CreateUser(uu) u, err = user.CreateUser(s, uu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// And create its namespace // And create its namespace
err = models.CreateNewNamespaceForUser(u) err = models.CreateNewNamespaceForUser(s, u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,7 +208,7 @@ func getOrCreateUser(cl *claims, issuer, subject string) (u *user.User, err erro
if cl.Name != u.Name { if cl.Name != u.Name {
u.Name = cl.Name u.Name = cl.Name
} }
u, err = user.UpdateUser(&user.User{ u, err = user.UpdateUser(s, &user.User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Name: u.Name, Name: u.Name,

View file

@ -26,12 +26,18 @@ import (
func TestGetOrCreateUser(t *testing.T) { func TestGetOrCreateUser(t *testing.T) {
t.Run("new user", func(t *testing.T) { t.Run("new user", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{ cl := &claims{
Email: "test@example.com", Email: "test@example.com",
PreferredUsername: "someUserWhoDoesNotExistYet", PreferredUsername: "someUserWhoDoesNotExistYet",
} }
u, err := getOrCreateUser(cl, "https://some.issuer", "12345") u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{ db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID, "id": u.ID,
"email": cl.Email, "email": cl.Email,
@ -40,13 +46,19 @@ func TestGetOrCreateUser(t *testing.T) {
}) })
t.Run("new user, no username provided", func(t *testing.T) { t.Run("new user, no username provided", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{ cl := &claims{
Email: "test@example.com", Email: "test@example.com",
PreferredUsername: "", PreferredUsername: "",
} }
u, err := getOrCreateUser(cl, "https://some.issuer", "12345") u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, u.Username) assert.NotEmpty(t, u.Username)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{ db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID, "id": u.ID,
"email": cl.Email, "email": cl.Email,
@ -54,19 +66,28 @@ func TestGetOrCreateUser(t *testing.T) {
}) })
t.Run("new user, no email address", func(t *testing.T) { t.Run("new user, no email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{ cl := &claims{
Email: "", Email: "",
} }
_, err := getOrCreateUser(cl, "https://some.issuer", "12345") _, err := getOrCreateUser(s, cl, "https://some.issuer", "12345")
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("existing user, different email address", func(t *testing.T) { t.Run("existing user, different email address", func(t *testing.T) {
db.LoadAndAssertFixtures(t) db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
cl := &claims{ cl := &claims{
Email: "other-email-address@some.service.com", Email: "other-email-address@some.service.com",
} }
u, err := getOrCreateUser(cl, "https://some.service.com", "12345") u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345")
assert.NoError(t, err) assert.NoError(t, err)
err = s.Commit()
assert.NoError(t, err)
db.AssertExists(t, "users", map[string]interface{}{ db.AssertExists(t, "users", map[string]interface{}{
"id": u.ID, "id": u.ID,
"email": cl.Email, "email": cl.Email,

View file

@ -19,6 +19,7 @@ package background
import ( import (
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/web" "code.vikunja.io/web"
"xorm.io/xorm"
) )
// Image represents an image which can be used as a list background // Image represents an image which can be used as a list background
@ -33,7 +34,7 @@ type Image struct {
// Provider represents something that is able to get a list of images and set one of them as background // Provider represents something that is able to get a list of images and set one of them as background
type Provider interface { type Provider interface {
// Search is used to either return a pre-defined list of Image or let the user search for an image // Search is used to either return a pre-defined list of Image or let the user search for an image
Search(search string, page int64) (result []*Image, err error) Search(s *xorm.Session, search string, page int64) (result []*Image, err error)
// Set sets an image which was most likely previously obtained by Search as list background // Set sets an image which was most likely previously obtained by Search as list background
Set(image *Image, list *models.List, auth web.Auth) (err error) Set(s *xorm.Session, image *Image, list *models.List, auth web.Auth) (err error)
} }

View file

@ -22,6 +22,9 @@ import (
"strconv" "strconv"
"strings" "strings"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
@ -59,8 +62,17 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error {
} }
} }
result, err := p.Search(search, page) s := db.NewSession()
defer s.Close()
result, err := p.Search(s, search, page)
if err != nil { if err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error())
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error()) return echo.NewHTTPError(http.StatusBadRequest, "An error occurred: "+err.Error())
} }
@ -68,7 +80,7 @@ func (bp *BackgroundProvider) SearchBackgrounds(c echo.Context) error {
} }
// This function does all kinds of preparations for setting and uploading a background // This function does all kinds of preparations for setting and uploading a background
func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *models.List, auth web.Auth, err error) { func (bp *BackgroundProvider) setBackgroundPreparations(s *xorm.Session, c echo.Context) (list *models.List, auth web.Auth, err error) {
auth, err = auth2.GetAuthFromClaims(c) auth, err = auth2.GetAuthFromClaims(c)
if err != nil { if err != nil {
return nil, nil, echo.NewHTTPError(http.StatusBadRequest, "Invalid auth token: "+err.Error()) return nil, nil, echo.NewHTTPError(http.StatusBadRequest, "Invalid auth token: "+err.Error())
@ -81,7 +93,7 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m
// Check if the user has the right to change the list background // Check if the user has the right to change the list background
list = &models.List{ID: listID} list = &models.List{ID: listID}
can, err := list.CanUpdate(auth) can, err := list.CanUpdate(s, auth)
if err != nil { if err != nil {
return return
} }
@ -90,14 +102,18 @@ func (bp *BackgroundProvider) setBackgroundPreparations(c echo.Context) (list *m
return list, auth, models.ErrGenericForbidden{} return list, auth, models.ErrGenericForbidden{}
} }
// Load the list // Load the list
err = list.GetSimpleByID() list, err = models.GetListSimpleByID(s, list.ID)
return return
} }
// SetBackground sets an Image as list background // SetBackground sets an Image as list background
func (bp *BackgroundProvider) SetBackground(c echo.Context) error { func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
list, auth, err := bp.setBackgroundPreparations(c) s := db.NewSession()
defer s.Close()
list, auth, err := bp.setBackgroundPreparations(s, c)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -106,11 +122,13 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
image := &background.Image{} image := &background.Image{}
err = c.Bind(image) err = c.Bind(image)
if err != nil { if err != nil {
_ = s.Rollback()
return echo.NewHTTPError(http.StatusBadRequest, "No or invalid model provided: "+err.Error()) return echo.NewHTTPError(http.StatusBadRequest, "No or invalid model provided: "+err.Error())
} }
err = p.Set(image, list, auth) err = p.Set(s, image, list, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
return c.JSON(http.StatusOK, list) return c.JSON(http.StatusOK, list)
@ -118,8 +136,12 @@ func (bp *BackgroundProvider) SetBackground(c echo.Context) error {
// UploadBackground uploads a background and passes the id of the uploaded file as an Image to the Set function of the BackgroundProvider. // UploadBackground uploads a background and passes the id of the uploaded file as an Image to the Set function of the BackgroundProvider.
func (bp *BackgroundProvider) UploadBackground(c echo.Context) error { func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
list, auth, err := bp.setBackgroundPreparations(c) s := db.NewSession()
defer s.Close()
list, auth, err := bp.setBackgroundPreparations(s, c)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -128,10 +150,12 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Get + upload the image // Get + upload the image
file, err := c.FormFile("background") file, err := c.FormFile("background")
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
src, err := file.Open() src, err := file.Open()
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
defer src.Close() defer src.Close()
@ -139,9 +163,11 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Validate we're dealing with an image // Validate we're dealing with an image
mime, err := mimetype.DetectReader(src) mime, err := mimetype.DetectReader(src)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !strings.HasPrefix(mime.String(), "image") { if !strings.HasPrefix(mime.String(), "image") {
_ = s.Rollback()
return c.JSON(http.StatusBadRequest, models.Message{Message: "Uploaded file is no image."}) return c.JSON(http.StatusBadRequest, models.Message{Message: "Uploaded file is no image."})
} }
_, _ = src.Seek(0, io.SeekStart) _, _ = src.Seek(0, io.SeekStart)
@ -149,6 +175,7 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
// Save the file // Save the file
f, err := files.CreateWithMime(src, file.Filename, uint64(file.Size), auth, mime.String()) f, err := files.CreateWithMime(src, file.Filename, uint64(file.Size), auth, mime.String())
if err != nil { if err != nil {
_ = s.Rollback()
if files.IsErrFileIsTooLarge(err) { if files.IsErrFileIsTooLarge(err) {
return echo.ErrBadRequest return echo.ErrBadRequest
} }
@ -158,10 +185,17 @@ func (bp *BackgroundProvider) UploadBackground(c echo.Context) error {
image := &background.Image{ID: strconv.FormatInt(f.ID, 10)} image := &background.Image{ID: strconv.FormatInt(f.ID, 10)}
err = p.Set(image, list, auth) err = p.Set(s, image, list, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, list) return c.JSON(http.StatusOK, list)
} }
@ -190,17 +224,23 @@ func GetListBackground(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid list ID: "+err.Error()) return echo.NewHTTPError(http.StatusBadRequest, "Invalid list ID: "+err.Error())
} }
s := db.NewSession()
defer s.Close()
// Check if a background for this list exists + Rights // Check if a background for this list exists + Rights
list := &models.List{ID: listID} list := &models.List{ID: listID}
can, _, err := list.CanRead(auth) can, _, err := list.CanRead(s, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !can { if !can {
_ = s.Rollback()
log.Infof("Tried to get list background of list %d while not having the rights for it (User: %v)", listID, auth) log.Infof("Tried to get list background of list %d while not having the rights for it (User: %v)", listID, auth)
return echo.NewHTTPError(http.StatusForbidden) return echo.NewHTTPError(http.StatusForbidden)
} }
if list.BackgroundFileID == 0 { if list.BackgroundFileID == 0 {
_ = s.Rollback()
return echo.NotFoundHandler(c) return echo.NotFoundHandler(c)
} }
@ -209,13 +249,19 @@ func GetListBackground(c echo.Context) error {
ID: list.BackgroundFileID, ID: list.BackgroundFileID,
} }
if err := bgFile.LoadFileByID(); err != nil { if err := bgFile.LoadFileByID(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
// Unsplash requires pingbacks as per their api usage guidelines. // Unsplash requires pingbacks as per their api usage guidelines.
// To do this in a privacy-preserving manner, we do the ping from inside of Vikunja to not expose any user details. // To do this in a privacy-preserving manner, we do the ping from inside of Vikunja to not expose any user details.
// FIXME: This should use an event once we have events // FIXME: This should use an event once we have events
unsplash.Pingback(bgFile) unsplash.Pingback(s, bgFile)
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Serve the file // Serve the file
return c.Stream(http.StatusOK, "image/jpg", bgFile.File) return c.Stream(http.StatusOK, "image/jpg", bgFile.File)

View file

@ -26,6 +26,8 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
@ -150,7 +152,7 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
// @Success 200 {array} background.Image "An array with photos" // @Success 200 {array} background.Image "An array with photos"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /backgrounds/unsplash/search [get] // @Router /backgrounds/unsplash/search [get]
func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) { func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) {
// If we don't have a search query, return results from the unsplash featured collection // If we don't have a search query, return results from the unsplash featured collection
if search == "" { if search == "" {
@ -243,7 +245,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image
// @Failure 403 {object} web.HTTPError "The user does not have access to the list" // @Failure 403 {object} web.HTTPError "The user does not have access to the list"
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/backgrounds/unsplash [post] // @Router /lists/{id}/backgrounds/unsplash [post]
func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) { func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) {
// Find the photo // Find the photo
photo, err := getUnsplashPhotoInfoByID(image.ID) photo, err := getUnsplashPhotoInfoByID(image.ID)
@ -292,7 +294,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
return err return err
} }
if err := models.RemoveUnsplashPhoto(list.BackgroundFileID); err != nil { if err := models.RemoveUnsplashPhoto(s, list.BackgroundFileID); err != nil {
return err return err
} }
} }
@ -304,7 +306,7 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
Author: photo.User.Username, Author: photo.User.Username,
AuthorName: photo.User.Name, AuthorName: photo.User.Name,
} }
err = unsplashPhoto.Save() err = unsplashPhoto.Save(s)
if err != nil { if err != nil {
return return
} }
@ -315,13 +317,13 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
list.BackgroundInformation = unsplashPhoto list.BackgroundInformation = unsplashPhoto
// Set it as the list background // Set it as the list background
return models.SetListBackground(list.ID, file) return models.SetListBackground(s, list.ID, file)
} }
// Pingback pings the unsplash api if an unsplash photo has been accessed. // Pingback pings the unsplash api if an unsplash photo has been accessed.
func Pingback(f *files.File) { func Pingback(s *xorm.Session, f *files.File) {
// Check if the file is actually downloaded from unsplash // Check if the file is actually downloaded from unsplash
unsplashPhoto, err := models.GetUnsplashPhotoByFileID(f.ID) unsplashPhoto, err := models.GetUnsplashPhotoByFileID(s, f.ID)
if err != nil { if err != nil {
if files.IsErrFileIsNotUnsplashFile(err) { if files.IsErrFileIsNotUnsplashFile(err) {
return return

View file

@ -19,6 +19,8 @@ package upload
import ( import (
"strconv" "strconv"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/background" "code.vikunja.io/api/pkg/modules/background"
@ -30,7 +32,7 @@ type Provider struct {
} }
// Search is only used to implement the interface // Search is only used to implement the interface
func (p *Provider) Search(search string, page int64) (result []*background.Image, err error) { func (p *Provider) Search(s *xorm.Session, search string, page int64) (result []*background.Image, err error) {
return return
} }
@ -50,7 +52,7 @@ func (p *Provider) Search(search string, page int64) (result []*background.Image
// @Failure 404 {object} models.Message "The list does not exist." // @Failure 404 {object} models.Message "The list does not exist."
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /lists/{id}/backgrounds/upload [put] // @Router /lists/{id}/backgrounds/upload [put]
func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth) (err error) { func (p *Provider) Set(s *xorm.Session, image *background.Image, list *models.List, auth web.Auth) (err error) {
// Remove the old background if one exists // Remove the old background if one exists
if list.BackgroundFileID != 0 { if list.BackgroundFileID != 0 {
file := files.File{ID: list.BackgroundFileID} file := files.File{ID: list.BackgroundFileID}
@ -67,5 +69,5 @@ func (p *Provider) Set(image *background.Image, list *models.List, auth web.Auth
list.BackgroundInformation = &models.ListBackgroundType{Type: models.ListBackgroundUpload} list.BackgroundInformation = &models.ListBackgroundType{Type: models.ListBackgroundUpload}
return models.SetListBackground(list.ID, file) return models.SetListBackground(s, list.ID, file)
} }

View file

@ -20,6 +20,8 @@ import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
@ -34,10 +36,14 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
labels := make(map[string]*models.Label) labels := make(map[string]*models.Label)
s := db.NewSession()
defer s.Close()
// Create all namespaces // Create all namespaces
for _, n := range str { for _, n := range str {
err = n.Create(user) err = n.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -54,8 +60,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
needsDefaultBucket := false needsDefaultBucket := false
l.NamespaceID = n.ID l.NamespaceID = n.ID
err = l.Create(user) err = l.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -67,11 +74,13 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
file, err := files.Create(backgroundFile, "", uint64(backgroundFile.Len()), user) file, err := files.Create(backgroundFile, "", uint64(backgroundFile.Len()), user)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
err = models.SetListBackground(l.ID, file) err = models.SetListBackground(s, l.ID, file)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
@ -87,8 +96,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
oldID := bucket.ID oldID := bucket.ID
bucket.ID = 0 // We want a new id bucket.ID = 0 // We want a new id
bucket.ListID = l.ID bucket.ListID = l.ID
err = bucket.Create(user) err = bucket.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
buckets[oldID] = bucket buckets[oldID] = bucket
@ -111,8 +121,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
} }
t.ListID = l.ID t.ListID = l.ID
err = t.Create(user) err = t.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -132,8 +143,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
// First create the related tasks if they do not exist // First create the related tasks if they do not exist
if rt.ID == 0 { if rt.ID == 0 {
rt.ListID = t.ListID rt.ListID = t.ListID
err = rt.Create(user) err = rt.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
log.Debugf("[creating structure] Created related task %d", rt.ID) log.Debugf("[creating structure] Created related task %d", rt.ID)
@ -145,8 +157,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
OtherTaskID: rt.ID, OtherTaskID: rt.ID,
RelationKind: kind, RelationKind: kind,
} }
err = taskRel.Create(user) err = taskRel.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -164,8 +177,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
if len(a.File.FileContent) > 0 { if len(a.File.FileContent) > 0 {
a.TaskID = t.ID a.TaskID = t.ID
fr := ioutil.NopCloser(bytes.NewReader(a.File.FileContent)) fr := ioutil.NopCloser(bytes.NewReader(a.File.FileContent))
err = a.NewAttachment(fr, a.File.Name, a.File.Size, user) err = a.NewAttachment(s, fr, a.File.Name, a.File.Size, user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
log.Debugf("[creating structure] Created new attachment %d", a.ID) log.Debugf("[creating structure] Created new attachment %d", a.ID)
@ -180,8 +194,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
var exists bool var exists bool
lb, exists = labels[label.Title+label.HexColor] lb, exists = labels[label.Title+label.HexColor]
if !exists { if !exists {
err = label.Create(user) err = label.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
log.Debugf("[creating structure] Created new label %d", label.ID) log.Debugf("[creating structure] Created new label %d", label.ID)
@ -193,8 +208,9 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
LabelID: lb.ID, LabelID: lb.ID,
TaskID: t.ID, TaskID: t.ID,
} }
err = lt.Create(user) err = lt.Create(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
log.Debugf("[creating structure] Associated task %d with label %d", t.ID, lb.ID) log.Debugf("[creating structure] Associated task %d with label %d", t.ID, lb.ID)
@ -204,13 +220,15 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
// All tasks brought their own bucket with them, therefore the newly created default bucket is just extra space // All tasks brought their own bucket with them, therefore the newly created default bucket is just extra space
if !needsDefaultBucket { if !needsDefaultBucket {
b := &models.Bucket{ListID: l.ID} b := &models.Bucket{ListID: l.ID}
bucketsIn, _, _, err := b.ReadAll(user, "", 1, 1) bucketsIn, _, _, err := b.ReadAll(s, user, "", 1, 1)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
buckets := bucketsIn.([]*models.Bucket) buckets := bucketsIn.([]*models.Bucket)
err = buckets[0].Delete() err = buckets[0].Delete(s)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
} }
@ -222,5 +240,5 @@ func InsertFromStructure(str []*models.NamespaceWithLists, user *user.User) (err
log.Debugf("[creating structure] Done inserting new task structure") log.Debugf("[creating structure] Done inserting new task structure")
return nil return s.Commit()
} }

View file

@ -19,20 +19,10 @@ package migration
import ( import (
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"xorm.io/xorm"
) )
var x *xorm.Engine
// InitDB sets up the database connection to use in this module // InitDB sets up the database connection to use in this module
func InitDB() (err error) { func InitDB() (err error) {
x, err = db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache // Cache
if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" { if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" {
db.RegisterTableStructsForCache(GetTables()) db.RegisterTableStructsForCache(GetTables())

View file

@ -19,6 +19,7 @@ package migration
import ( import (
"time" "time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
) )
@ -37,17 +38,26 @@ func (s *Status) TableName() string {
// SetMigrationStatus sets the migration status for a user // SetMigrationStatus sets the migration status for a user
func SetMigrationStatus(m Migrator, u *user.User) (err error) { func SetMigrationStatus(m Migrator, u *user.User) (err error) {
s := db.NewSession()
defer s.Close()
status := &Status{ status := &Status{
UserID: u.ID, UserID: u.ID,
MigratorName: m.Name(), MigratorName: m.Name(),
} }
_, err = x.Insert(status) _, err = s.Insert(status)
return return
} }
// GetMigrationStatus returns the migration status for a migration and a user // GetMigrationStatus returns the migration status for a migration and a user
func GetMigrationStatus(m Migrator, u *user.User) (status *Status, err error) { func GetMigrationStatus(m Migrator, u *user.User) (status *Status, err error) {
s := db.NewSession()
defer s.Close()
status = &Status{} status = &Status{}
_, err = x.Where("user_id = ? and migrator_name = ?", u.ID, m.Name()).Desc("id").Get(status) _, err = s.
Where("user_id = ? and migrator_name = ?", u.ID, m.Name()).
Desc("id").
Get(status)
return return
} }

View file

@ -17,6 +17,7 @@
package v1 package v1
import ( import (
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/files"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
@ -56,8 +57,11 @@ func GetAvatar(c echo.Context) error {
// Get the username // Get the username
username := c.Param("username") username := c.Param("username")
s := db.NewSession()
defer s.Close()
// Get the user // Get the user
u, err := user.GetUserWithEmail(&user.User{Username: username}) u, err := user.GetUserWithEmail(s, &user.User{Username: username})
if err != nil { if err != nil {
log.Errorf("Error getting user for avatar: %v", err) log.Errorf("Error getting user for avatar: %v", err)
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
@ -113,22 +117,28 @@ func GetAvatar(c echo.Context) error {
// @Router /user/settings/avatar/upload [put] // @Router /user/settings/avatar/upload [put]
func UploadAvatar(c echo.Context) (err error) { func UploadAvatar(c echo.Context) (err error) {
s := db.NewSession()
defer s.Close()
uc, err := user.GetCurrentUser(c) uc, err := user.GetCurrentUser(c)
if err != nil { if err != nil {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
u, err := user.GetUserByID(uc.ID) u, err := user.GetUserByID(s, uc.ID)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
// Get + upload the image // Get + upload the image
file, err := c.FormFile("avatar") file, err := c.FormFile("avatar")
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
src, err := file.Open() src, err := file.Open()
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
defer src.Close() defer src.Close()
@ -136,6 +146,7 @@ func UploadAvatar(c echo.Context) (err error) {
// Validate we're dealing with an image // Validate we're dealing with an image
mime, err := mimetype.DetectReader(src) mime, err := mimetype.DetectReader(src)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !strings.HasPrefix(mime.String(), "image") { if !strings.HasPrefix(mime.String(), "image") {
@ -148,6 +159,7 @@ func UploadAvatar(c echo.Context) (err error) {
f := &files.File{ID: u.AvatarFileID} f := &files.File{ID: u.AvatarFileID}
if err := f.Delete(); err != nil { if err := f.Delete(); err != nil {
if !files.IsErrFileDoesNotExist(err) { if !files.IsErrFileDoesNotExist(err) {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
} }
@ -157,11 +169,13 @@ func UploadAvatar(c echo.Context) (err error) {
// Resize the new file to a max height of 1024 // Resize the new file to a max height of 1024
img, _, err := image.Decode(src) img, _, err := image.Decode(src)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
resizedImg := imaging.Resize(img, 0, 1024, imaging.Lanczos) resizedImg := imaging.Resize(img, 0, 1024, imaging.Lanczos)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := png.Encode(buf, resizedImg); err != nil { if err := png.Encode(buf, resizedImg); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -170,6 +184,7 @@ func UploadAvatar(c echo.Context) (err error) {
// Save the file // Save the file
f, err := files.CreateWithMime(buf, file.Filename, uint64(file.Size), u, "image/png") f, err := files.CreateWithMime(buf, file.Filename, uint64(file.Size), u, "image/png")
if err != nil { if err != nil {
_ = s.Rollback()
if files.IsErrFileIsTooLarge(err) { if files.IsErrFileIsTooLarge(err) {
return echo.ErrBadRequest return echo.ErrBadRequest
} }
@ -180,7 +195,13 @@ func UploadAvatar(c echo.Context) (err error) {
u.AvatarFileID = f.ID u.AvatarFileID = f.ID
u.AvatarProvider = "upload" u.AvatarProvider = "upload"
if _, err := user.UpdateUser(u); err != nil { if _, err := user.UpdateUser(s, u); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -45,8 +47,18 @@ type LinkShareToken struct {
// @Router /shares/{share}/auth [post] // @Router /shares/{share}/auth [post]
func AuthenticateLinkShare(c echo.Context) error { func AuthenticateLinkShare(c echo.Context) error {
hash := c.Param("share") hash := c.Param("share")
share, err := models.GetLinkShareByHash(hash)
s := db.NewSession()
defer s.Close()
share, err := models.GetLinkShareByHash(s, hash)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -20,6 +20,9 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"code.vikunja.io/api/pkg/db"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -41,8 +44,11 @@ import (
// @Failure 500 {object} models.Message "Internal error" // @Failure 500 {object} models.Message "Internal error"
// @Router /namespaces/{id}/lists [get] // @Router /namespaces/{id}/lists [get]
func GetListsByNamespaceID(c echo.Context) error { func GetListsByNamespaceID(c echo.Context) error {
s := db.NewSession()
defer s.Close()
// Get our namespace // Get our namespace
namespace, err := getNamespace(c) namespace, err := getNamespace(s, c)
if err != nil { if err != nil {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -53,14 +59,14 @@ func GetListsByNamespaceID(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
lists, err := models.GetListsByNamespaceID(namespace.ID, doer) lists, err := models.GetListsByNamespaceID(s, namespace.ID, doer)
if err != nil { if err != nil {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
return c.JSON(http.StatusOK, lists) return c.JSON(http.StatusOK, lists)
} }
func getNamespace(c echo.Context) (namespace *models.Namespace, err error) { func getNamespace(s *xorm.Session, c echo.Context) (namespace *models.Namespace, err error) {
// Check if we have our ID // Check if we have our ID
id := c.Param("namespace") id := c.Param("namespace")
// Make int // Make int
@ -75,12 +81,12 @@ func getNamespace(c echo.Context) (namespace *models.Namespace, err error) {
} }
// Check if the user has acces to that namespace // Check if the user has acces to that namespace
user, err := user.GetCurrentUser(c) u, err := user.GetCurrentUser(c)
if err != nil { if err != nil {
return return
} }
namespace = &models.Namespace{ID: namespaceID} namespace = &models.Namespace{ID: namespaceID}
canRead, _, err := namespace.CanRead(user) canRead, _, err := namespace.CanRead(s, u)
if err != nil { if err != nil {
return namespace, err return namespace, err
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/modules/auth"
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
@ -45,27 +47,38 @@ func Login(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "Please provide a username and password."}) return c.JSON(http.StatusBadRequest, models.Message{Message: "Please provide a username and password."})
} }
s := db.NewSession()
defer s.Close()
// Check user // Check user
user, err := user2.CheckUserCredentials(&u) user, err := user2.CheckUserCredentials(s, &u)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
totpEnabled, err := user2.TOTPEnabledForUser(user) totpEnabled, err := user2.TOTPEnabledForUser(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if totpEnabled { if totpEnabled {
_, err = user2.ValidateTOTPPasscode(&user2.TOTPPasscode{ _, err = user2.ValidateTOTPPasscode(s, &user2.TOTPPasscode{
User: user, User: user,
Passcode: u.TOTPPasscode, Passcode: u.TOTPPasscode,
}) })
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
} }
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
// Create token // Create token
return auth.NewUserAuthTokenResponse(user, c) return auth.NewUserAuthTokenResponse(user, c)
} }
@ -82,18 +95,23 @@ func Login(c echo.Context) error {
// @Router /user/token [post] // @Router /user/token [post]
func RenewToken(c echo.Context) (err error) { func RenewToken(c echo.Context) (err error) {
s := db.NewSession()
defer s.Close()
jwtinf := c.Get("user").(*jwt.Token) jwtinf := c.Get("user").(*jwt.Token)
claims := jwtinf.Claims.(jwt.MapClaims) claims := jwtinf.Claims.(jwt.MapClaims)
typ := int(claims["type"].(float64)) typ := int(claims["type"].(float64))
if typ == auth.AuthTypeLinkShare { if typ == auth.AuthTypeLinkShare {
share := &models.LinkSharing{} share := &models.LinkSharing{}
share.ID = int64(claims["id"].(float64)) share.ID = int64(claims["id"].(float64))
err := share.ReadOne() err := share.ReadOne(s)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
t, err := auth.NewLinkShareJWTAuthtoken(share) t, err := auth.NewLinkShareJWTAuthtoken(share)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
return c.JSON(http.StatusOK, auth.Token{Token: t}) return c.JSON(http.StatusOK, auth.Token{Token: t})
@ -101,11 +119,18 @@ func RenewToken(c echo.Context) (err error) {
u, err := user2.GetUserFromClaims(claims) u, err := user2.GetUserFromClaims(claims)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
auth2 "code.vikunja.io/api/pkg/modules/auth" auth2 "code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -52,8 +54,12 @@ func UploadTaskAttachment(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
can, err := taskAttachment.CanCreate(auth) s := db.NewSession()
defer s.Close()
can, err := taskAttachment.CanCreate(s, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !can { if !can {
@ -63,6 +69,7 @@ func UploadTaskAttachment(c echo.Context) error {
// Multipart form // Multipart form
form, err := c.MultipartForm() form, err := c.MultipartForm()
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -85,7 +92,7 @@ func UploadTaskAttachment(c echo.Context) error {
} }
defer f.Close() defer f.Close()
err = ta.NewAttachment(f, file.Filename, uint64(file.Size), auth) err = ta.NewAttachment(s, f, file.Filename, uint64(file.Size), auth)
if err != nil { if err != nil {
r.Errors = append(r.Errors, handler.HandleHTTPError(err, c)) r.Errors = append(r.Errors, handler.HandleHTTPError(err, c))
continue continue
@ -93,6 +100,11 @@ func UploadTaskAttachment(c echo.Context) error {
r.Success = append(r.Success, ta) r.Success = append(r.Success, ta)
} }
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
return c.JSON(http.StatusOK, r) return c.JSON(http.StatusOK, r)
} }
@ -121,8 +133,13 @@ func GetTaskAttachment(c echo.Context) error {
if err != nil { if err != nil {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
can, _, err := taskAttachment.CanRead(auth)
s := db.NewSession()
defer s.Close()
can, _, err := taskAttachment.CanRead(s, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !can { if !can {
@ -130,14 +147,21 @@ func GetTaskAttachment(c echo.Context) error {
} }
// Get the attachment incl file // Get the attachment incl file
err = taskAttachment.ReadOne() err = taskAttachment.ReadOne(s)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
// Open an send the file to the client // Open an send the file to the client
err = taskAttachment.File.LoadFileByID() err = taskAttachment.File.LoadFileByID()
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -43,8 +45,17 @@ func UserConfirmEmail(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No token provided.") return echo.NewHTTPError(http.StatusBadRequest, "No token provided.")
} }
err := user.ConfirmEmail(&emailConfirm) s := db.NewSession()
defer s.Close()
err := user.ConfirmEmail(s, &emailConfirm)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -20,6 +20,8 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
auth2 "code.vikunja.io/api/pkg/modules/auth" auth2 "code.vikunja.io/api/pkg/modules/auth"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
@ -40,9 +42,19 @@ import (
// @Failure 500 {object} models.Message "Internal server error." // @Failure 500 {object} models.Message "Internal server error."
// @Router /users [get] // @Router /users [get]
func UserList(c echo.Context) error { func UserList(c echo.Context) error {
s := c.QueryParam("s") search := c.QueryParam("s")
users, err := user.ListUsers(s)
s := db.NewSession()
defer s.Close()
users, err := user.ListUsers(s, search)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -80,17 +92,27 @@ func ListUsersForList(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
canRead, _, err := list.CanRead(auth) s := db.NewSession()
defer s.Close()
canRead, _, err := list.CanRead(s, auth)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
if !canRead { if !canRead {
return echo.ErrForbidden return echo.ErrForbidden
} }
s := c.QueryParam("s") search := c.QueryParam("s")
users, err := models.ListUsersFromList(&list, s) users, err := models.ListUsersFromList(s, &list, search)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -43,8 +45,17 @@ func UserResetPassword(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "No password provided.") return echo.NewHTTPError(http.StatusBadRequest, "No password provided.")
} }
err := user.ResetPassword(&pwReset) s := db.NewSession()
defer s.Close()
err := user.ResetPassword(s, &pwReset)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -73,8 +84,17 @@ func UserRequestResetPasswordToken(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err) return echo.NewHTTPError(http.StatusBadRequest, err)
} }
err := user.RequestUserPasswordResetTokenByEmail(&pwTokenReset) s := db.NewSession()
defer s.Close()
err := user.RequestUserPasswordResetTokenByEmail(s, &pwTokenReset)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
@ -50,15 +52,25 @@ func RegisterUser(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "No or invalid user model provided."}) return c.JSON(http.StatusBadRequest, models.Message{Message: "No or invalid user model provided."})
} }
s := db.NewSession()
defer s.Close()
// Insert the user // Insert the user
newUser, err := user.CreateUser(datUser.APIFormat()) newUser, err := user.CreateUser(s, datUser.APIFormat())
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
// Add its namespace // Add its namespace
err = models.CreateNewNamespaceForUser(newUser) err = models.CreateNewNamespaceForUser(s, newUser)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -57,8 +59,17 @@ func GetUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -91,15 +102,25 @@ func ChangeUserAvatarProvider(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user.AvatarProvider = uap.AvatarProvider user.AvatarProvider = uap.AvatarProvider
_, err = user2.UpdateUser(user) _, err = user2.UpdateUser(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -129,16 +150,26 @@ func UpdateGeneralUserSettings(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user, err := user2.GetUserWithEmail(&user2.User{ID: u.ID}) s := db.NewSession()
defer s.Close()
user, err := user2.GetUserWithEmail(s, &user2.User{ID: u.ID})
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
user.Name = us.Name user.Name = us.Name
user.EmailRemindersEnabled = us.EmailRemindersEnabled user.EmailRemindersEnabled = us.EmailRemindersEnabled
_, err = user2.UpdateUser(user) _, err = user2.UpdateUser(s, user)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -41,8 +43,17 @@ func UserShow(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "Error getting current user.") return echo.NewHTTPError(http.StatusInternalServerError, "Error getting current user.")
} }
user, err := user2.GetUserByID(userInfos.ID) s := db.NewSession()
defer s.Close()
user, err := user2.GetUserByID(s, userInfos.ID)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -22,6 +22,8 @@ import (
"image/jpeg" "image/jpeg"
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
@ -47,8 +49,17 @@ func UserTOTPEnroll(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
t, err := user.EnrollTOTP(u) s := db.NewSession()
defer s.Close()
t, err := user.EnrollTOTP(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -86,8 +97,17 @@ func UserTOTPEnable(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.") return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.")
} }
err = user.EnableTOTP(passcode) s := db.NewSession()
defer s.Close()
err = user.EnableTOTP(s, passcode)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -122,18 +142,29 @@ func UserTOTPDisable(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
u, err = user.GetUserByID(u.ID) s := db.NewSession()
defer s.Close()
u, err = user.GetUserByID(s, u.ID)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
err = user.CheckUserPassword(u, login.Password) err = user.CheckUserPassword(u, login.Password)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
err = user.DisableTOTP(u) err = user.DisableTOTP(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -156,14 +187,24 @@ func UserTOTPQrCode(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
qrcode, err := user.GetTOTPQrCodeForUser(u) s := db.NewSession()
defer s.Close()
qrcode, err := user.GetTOTPQrCodeForUser(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
buff := &bytes.Buffer{} buff := &bytes.Buffer{}
err = jpeg.Encode(buff, qrcode, nil) err = jpeg.Encode(buff, qrcode, nil)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
@ -186,8 +227,17 @@ func UserTOTP(c echo.Context) error {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
t, err := user.GetTOTPForUser(u) s := db.NewSession()
defer s.Close()
t, err := user.GetTOTPForUser(s, u)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -20,6 +20,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
@ -56,16 +58,26 @@ func UpdateUserEmail(c echo.Context) (err error) {
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
emailUpdate.User, err = user.CheckUserCredentials(&user.Login{ s := db.NewSession()
defer s.Close()
emailUpdate.User, err = user.CheckUserCredentials(s, &user.Login{
Username: emailUpdate.User.Username, Username: emailUpdate.User.Username,
Password: emailUpdate.Password, Password: emailUpdate.Password,
}) })
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
err = user.UpdateEmail(emailUpdate) err = user.UpdateEmail(s, emailUpdate)
if err != nil { if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -19,6 +19,8 @@ package v1
import ( import (
"net/http" "net/http"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/user"
"code.vikunja.io/web/handler" "code.vikunja.io/web/handler"
@ -61,13 +63,23 @@ func UserChangePassword(c echo.Context) error {
return handler.HandleHTTPError(user.ErrEmptyOldPassword{}, c) return handler.HandleHTTPError(user.ErrEmptyOldPassword{}, c)
} }
s := db.NewSession()
defer s.Close()
// Check the current password // Check the current password
if _, err = user.CheckUserCredentials(&user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil { if _, err = user.CheckUserCredentials(s, &user.Login{Username: doer.Username, Password: newPW.OldPassword}); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }
// Update the password // Update the password
if err = user.UpdateUserPassword(doer, newPW.NewPassword); err != nil { if err = user.UpdateUserPassword(s, doer, newPW.NewPassword); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c)
}
if err := s.Commit(); err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)
} }

View file

@ -21,6 +21,8 @@ import (
"strings" "strings"
"time" "time"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
user2 "code.vikunja.io/api/pkg/user" user2 "code.vikunja.io/api/pkg/user"
@ -90,9 +92,16 @@ func (vcls *VikunjaCaldavListStorage) GetResources(rpath string, withChildren bo
return []data.Resource{r}, nil return []data.Resource{r}, nil
} }
s := db.NewSession()
defer s.Close()
// Otherwise get all lists // Otherwise get all lists
thelists, _, _, err := vcls.list.ReadAll(vcls.user, "", -1, 50) thelists, _, _, err := vcls.list.ReadAll(s, vcls.user, "", -1, 50)
if err != nil { if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err return nil, err
} }
lists := thelists.([]*models.List) lists := thelists.([]*models.List)
@ -125,10 +134,17 @@ func (vcls *VikunjaCaldavListStorage) GetResourcesByList(rpaths []string) ([]dat
uids = append(uids, string(uid[:endlen])) uids = append(uids, string(uid[:endlen]))
} }
s := db.NewSession()
defer s.Close()
// GetTasksByUIDs... // GetTasksByUIDs...
// Parse these into ressources... // Parse these into ressources...
tasks, err := models.GetTasksByUIDs(uids) tasks, err := models.GetTasksByUIDs(s, uids)
if err != nil { if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err return nil, err
} }
@ -187,15 +203,22 @@ func (vcls *VikunjaCaldavListStorage) GetResource(rpath string) (*data.Resource,
// If the task is not nil, we need to get the task and not the list // If the task is not nil, we need to get the task and not the list
if vcls.task != nil { if vcls.task != nil {
s := db.NewSession()
defer s.Close()
// save and override the updated unix date to not break any later etag checks // save and override the updated unix date to not break any later etag checks
updated := vcls.task.Updated updated := vcls.task.Updated
task, err := models.GetTaskSimple(&models.Task{ID: vcls.task.ID, UID: vcls.task.UID}) task, err := models.GetTaskSimple(s, &models.Task{ID: vcls.task.ID, UID: vcls.task.UID})
if err != nil { if err != nil {
_ = s.Rollback()
if models.IsErrTaskDoesNotExist(err) { if models.IsErrTaskDoesNotExist(err) {
return nil, false, errs.ResourceNotFoundError return nil, false, errs.ResourceNotFoundError
} }
return nil, false, err return nil, false, err
} }
if err := s.Commit(); err != nil {
return nil, false, err
}
vcls.task = &task vcls.task = &task
if updated.Unix() > 0 { if updated.Unix() > 0 {
@ -230,6 +253,9 @@ func (vcls *VikunjaCaldavListStorage) GetShallowResource(rpath string) (*data.Re
// CreateResource creates a new resource // CreateResource creates a new resource
func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*data.Resource, error) { func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*data.Resource, error) {
s := db.NewSession()
defer s.Close()
vTask, err := parseTaskFromVTODO(content) vTask, err := parseTaskFromVTODO(content)
if err != nil { if err != nil {
return nil, err return nil, err
@ -238,7 +264,7 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da
vTask.ListID = vcls.list.ID vTask.ListID = vcls.list.ID
// Check the rights // Check the rights
canCreate, err := vTask.CanCreate(vcls.user) canCreate, err := vTask.CanCreate(s, vcls.user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -247,8 +273,13 @@ func (vcls *VikunjaCaldavListStorage) CreateResource(rpath, content string) (*da
} }
// Create the task // Create the task
err = vTask.Create(vcls.user) err = vTask.Create(s, vcls.user)
if err != nil { if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err return nil, err
} }
@ -272,18 +303,28 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da
// At this point, we already have the right task in vcls.task, so we can use that ID directly // At this point, we already have the right task in vcls.task, so we can use that ID directly
vTask.ID = vcls.task.ID vTask.ID = vcls.task.ID
s := db.NewSession()
defer s.Close()
// Check the rights // Check the rights
canUpdate, err := vTask.CanUpdate(vcls.user) canUpdate, err := vTask.CanUpdate(s, vcls.user)
if err != nil { if err != nil {
_ = s.Rollback()
return nil, err return nil, err
} }
if !canUpdate { if !canUpdate {
_ = s.Rollback()
return nil, errs.ForbiddenError return nil, errs.ForbiddenError
} }
// Update the task // Update the task
err = vTask.Update() err = vTask.Update(s)
if err != nil { if err != nil {
_ = s.Rollback()
return nil, err
}
if err := s.Commit(); err != nil {
return nil, err return nil, err
} }
@ -299,9 +340,13 @@ func (vcls *VikunjaCaldavListStorage) UpdateResource(rpath, content string) (*da
// DeleteResource deletes a resource // DeleteResource deletes a resource
func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error { func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error {
if vcls.task != nil { if vcls.task != nil {
s := db.NewSession()
defer s.Close()
// Check the rights // Check the rights
canDelete, err := vcls.task.CanDelete(vcls.user) canDelete, err := vcls.task.CanDelete(s, vcls.user)
if err != nil { if err != nil {
_ = s.Rollback()
return err return err
} }
if !canDelete { if !canDelete {
@ -309,7 +354,13 @@ func (vcls *VikunjaCaldavListStorage) DeleteResource(rpath string) error {
} }
// Delete it // Delete it
return vcls.task.Delete() err = vcls.task.Delete(s)
if err != nil {
_ = s.Rollback()
return err
}
return s.Commit()
} }
return nil return nil
@ -385,16 +436,22 @@ func (vlra *VikunjaListResourceAdapter) GetModTime() time.Time {
} }
func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr VikunjaListResourceAdapter, err error) { func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr VikunjaListResourceAdapter, err error) {
can, _, err := vcls.list.CanRead(vcls.user) s := db.NewSession()
defer s.Close()
can, _, err := vcls.list.CanRead(s, vcls.user)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
if !can { if !can {
_ = s.Rollback()
log.Errorf("User %v tried to access a caldav resource (List %v) which they are not allowed to access", vcls.user.Username, vcls.list.ID) log.Errorf("User %v tried to access a caldav resource (List %v) which they are not allowed to access", vcls.user.Username, vcls.list.ID)
return rr, models.ErrUserDoesNotHaveAccessToList{ListID: vcls.list.ID} return rr, models.ErrUserDoesNotHaveAccessToList{ListID: vcls.list.ID}
} }
err = vcls.list.ReadOne() err = vcls.list.ReadOne(s)
if err != nil { if err != nil {
_ = s.Rollback()
return return
} }
@ -403,8 +460,9 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi
tk := models.TaskCollection{ tk := models.TaskCollection{
ListID: vcls.list.ID, ListID: vcls.list.ID,
} }
iface, _, _, err := tk.ReadAll(vcls.user, "", 1, 1000) iface, _, _, err := tk.ReadAll(s, vcls.user, "", 1, 1000)
if err != nil { if err != nil {
_ = s.Rollback()
return rr, err return rr, err
} }
tasks, ok := iface.([]*models.Task) tasks, ok := iface.([]*models.Task)
@ -416,6 +474,10 @@ func (vcls *VikunjaCaldavListStorage) getListRessource(isCollection bool) (rr Vi
vcls.list.Tasks = tasks vcls.list.Tasks = tasks
} }
if err := s.Commit(); err != nil {
return rr, err
}
rr = VikunjaListResourceAdapter{ rr = VikunjaListResourceAdapter{
list: vcls.list, list: vcls.list,
listTasks: listTasks, listTasks: listTasks,

View file

@ -50,11 +50,8 @@ import (
"strings" "strings"
"time" "time"
microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo"
"code.vikunja.io/api/pkg/modules/migration/trello"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/modules/auth" "code.vikunja.io/api/pkg/modules/auth"
@ -65,7 +62,9 @@ import (
"code.vikunja.io/api/pkg/modules/background/upload" "code.vikunja.io/api/pkg/modules/background/upload"
"code.vikunja.io/api/pkg/modules/migration" "code.vikunja.io/api/pkg/modules/migration"
migrationHandler "code.vikunja.io/api/pkg/modules/migration/handler" migrationHandler "code.vikunja.io/api/pkg/modules/migration/handler"
microsofttodo "code.vikunja.io/api/pkg/modules/migration/microsoft-todo"
"code.vikunja.io/api/pkg/modules/migration/todoist" "code.vikunja.io/api/pkg/modules/migration/todoist"
"code.vikunja.io/api/pkg/modules/migration/trello"
"code.vikunja.io/api/pkg/modules/migration/wunderlist" "code.vikunja.io/api/pkg/modules/migration/wunderlist"
apiv1 "code.vikunja.io/api/pkg/routes/api/v1" apiv1 "code.vikunja.io/api/pkg/routes/api/v1"
"code.vikunja.io/api/pkg/routes/caldav" "code.vikunja.io/api/pkg/routes/caldav"
@ -175,6 +174,7 @@ func NewEcho() *echo.Echo {
}) })
handler.SetLoggingProvider(log.GetLogger()) handler.SetLoggingProvider(log.GetLogger())
handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt()) handler.SetMaxItemsPerPage(config.ServiceMaxItemsPerPage.GetInt())
handler.SetSessionFactory(db.NewSession)
return e return e
} }
@ -601,11 +601,19 @@ func caldavBasicAuth(username, password string, c echo.Context) (bool, error) {
Username: username, Username: username,
Password: password, Password: password,
} }
u, err := user.CheckUserCredentials(creds) s := db.NewSession()
defer s.Close()
u, err := user.CheckUserCredentials(s, creds)
if err != nil { if err != nil {
_ = s.Rollback()
log.Errorf("Error during basic auth for caldav: %v", err) log.Errorf("Error during basic auth for caldav: %v", err)
return false, nil return false, nil
} }
if err := s.Commit(); err != nil {
return false, err
}
// Save the user in echo context for later use // Save the user in echo context for later use
c.Set("userBasicAuth", u) c.Set("userBasicAuth", u)
return true, nil return true, nil

View file

@ -20,20 +20,10 @@ package user
import ( import (
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/db"
"code.vikunja.io/api/pkg/log"
"xorm.io/xorm"
) )
var x *xorm.Engine
// InitDB sets up the database connection to use in this module // InitDB sets up the database connection to use in this module
func InitDB() (err error) { func InitDB() (err error) {
x, err = db.CreateDBEngine()
if err != nil {
log.Criticalf("Could not connect to db: %v", err.Error())
return
}
// Cache // Cache
if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" { if config.CacheEnabled.GetBool() && config.CacheType.GetString() == "redis" {
db.RegisterTableStructsForCache(GetTables()) db.RegisterTableStructsForCache(GetTables())

View file

@ -24,8 +24,7 @@ import (
// InitTests handles the actual bootstrapping of the test env // InitTests handles the actual bootstrapping of the test env
func InitTests() { func InitTests() {
var err error x, err := db.CreateTestEngine()
x, err = db.CreateTestEngine()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -19,6 +19,8 @@ package user
import ( import (
"image" "image"
"xorm.io/xorm"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"github.com/pquerna/otp" "github.com/pquerna/otp"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@ -47,19 +49,19 @@ type TOTPPasscode struct {
} }
// TOTPEnabledForUser checks if totp is enabled for a user - not if it is activated, use GetTOTPForUser to check that. // TOTPEnabledForUser checks if totp is enabled for a user - not if it is activated, use GetTOTPForUser to check that.
func TOTPEnabledForUser(user *User) (bool, error) { func TOTPEnabledForUser(s *xorm.Session, user *User) (bool, error) {
if !config.ServiceEnableTotp.GetBool() { if !config.ServiceEnableTotp.GetBool() {
return false, nil return false, nil
} }
t := &TOTP{} t := &TOTP{}
_, err := x.Where("user_id = ?", user.ID).Get(t) _, err := s.Where("user_id = ?", user.ID).Get(t)
return t.Enabled, err return t.Enabled, err
} }
// GetTOTPForUser returns the current state of totp settings for the user. // GetTOTPForUser returns the current state of totp settings for the user.
func GetTOTPForUser(user *User) (t *TOTP, err error) { func GetTOTPForUser(s *xorm.Session, user *User) (t *TOTP, err error) {
t = &TOTP{} t = &TOTP{}
exists, err := x.Where("user_id = ?", user.ID).Get(t) exists, err := s.Where("user_id = ?", user.ID).Get(t)
if err != nil { if err != nil {
return return
} }
@ -71,8 +73,8 @@ func GetTOTPForUser(user *User) (t *TOTP, err error) {
} }
// EnrollTOTP creates a new TOTP entry for the user - it does not enable it yet. // EnrollTOTP creates a new TOTP entry for the user - it does not enable it yet.
func EnrollTOTP(user *User) (t *TOTP, err error) { func EnrollTOTP(s *xorm.Session, user *User) (t *TOTP, err error) {
isEnrolled, err := x.Where("user_id = ?", user.ID).Exist(&TOTP{}) isEnrolled, err := s.Where("user_id = ?", user.ID).Exist(&TOTP{})
if err != nil { if err != nil {
return return
} }
@ -94,18 +96,18 @@ func EnrollTOTP(user *User) (t *TOTP, err error) {
Enabled: false, Enabled: false,
URL: key.URL(), URL: key.URL(),
} }
_, err = x.Insert(t) _, err = s.Insert(t)
return return
} }
// EnableTOTP enables totp for a user. The provided passcode is used to verify the user has a working totp setup. // EnableTOTP enables totp for a user. The provided passcode is used to verify the user has a working totp setup.
func EnableTOTP(passcode *TOTPPasscode) (err error) { func EnableTOTP(s *xorm.Session, passcode *TOTPPasscode) (err error) {
t, err := ValidateTOTPPasscode(passcode) t, err := ValidateTOTPPasscode(s, passcode)
if err != nil { if err != nil {
return return
} }
_, err = x. _, err = s.
Where("id = ?", t.ID). Where("id = ?", t.ID).
Cols("enabled"). Cols("enabled").
Update(&TOTP{Enabled: true}) Update(&TOTP{Enabled: true})
@ -113,14 +115,16 @@ func EnableTOTP(passcode *TOTPPasscode) (err error) {
} }
// DisableTOTP removes all totp settings for a user. // DisableTOTP removes all totp settings for a user.
func DisableTOTP(user *User) (err error) { func DisableTOTP(s *xorm.Session, user *User) (err error) {
_, err = x.Where("user_id = ?", user.ID).Delete(&TOTP{}) _, err = s.
Where("user_id = ?", user.ID).
Delete(&TOTP{})
return return
} }
// ValidateTOTPPasscode validated totp codes of users. // ValidateTOTPPasscode validated totp codes of users.
func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) { func ValidateTOTPPasscode(s *xorm.Session, passcode *TOTPPasscode) (t *TOTP, err error) {
t, err = GetTOTPForUser(passcode.User) t, err = GetTOTPForUser(s, passcode.User)
if err != nil { if err != nil {
return return
} }
@ -133,8 +137,8 @@ func ValidateTOTPPasscode(passcode *TOTPPasscode) (t *TOTP, err error) {
} }
// GetTOTPQrCodeForUser returns a qrcode for a user's totp setting // GetTOTPQrCodeForUser returns a qrcode for a user's totp setting
func GetTOTPQrCodeForUser(user *User) (qrcode image.Image, err error) { func GetTOTPQrCodeForUser(s *xorm.Session, user *User) (qrcode image.Image, err error) {
t, err := GetTOTPForUser(user) t, err := GetTOTPForUser(s, user)
if err != nil { if err != nil {
return return
} }

View file

@ -20,6 +20,7 @@ import (
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/mail" "code.vikunja.io/api/pkg/mail"
"code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/utils"
"xorm.io/xorm"
) )
// EmailUpdate is the data structure to update a user's email address // EmailUpdate is the data structure to update a user's email address
@ -32,11 +33,11 @@ type EmailUpdate struct {
} }
// UpdateEmail lets a user update their email address // UpdateEmail lets a user update their email address
func UpdateEmail(update *EmailUpdate) (err error) { func UpdateEmail(s *xorm.Session, update *EmailUpdate) (err error) {
// Check the email is not already used // Check the email is not already used
user := &User{} user := &User{}
has, err := x.Where("email = ?", update.NewEmail).Get(user) has, err := s.Where("email = ?", update.NewEmail).Get(user)
if err != nil { if err != nil {
return return
} }
@ -46,7 +47,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
} }
// Set the user as unconfirmed and the new email address // Set the user as unconfirmed and the new email address
update.User, err = GetUserWithEmail(&User{ID: update.User.ID}) update.User, err = GetUserWithEmail(s, &User{ID: update.User.ID})
if err != nil { if err != nil {
return return
} }
@ -54,7 +55,7 @@ func UpdateEmail(update *EmailUpdate) (err error) {
update.User.IsActive = false update.User.IsActive = false
update.User.Email = update.NewEmail update.User.Email = update.NewEmail
update.User.EmailConfirmToken = utils.MakeRandomString(64) update.User.EmailConfirmToken = utils.MakeRandomString(64)
_, err = x. _, err = s.
Where("id = ?", update.User.ID). Where("id = ?", update.User.ID).
Cols("email", "is_active", "email_confirm_token"). Cols("email", "is_active", "email_confirm_token").
Update(update.User) Update(update.User)

Some files were not shown because too many files have changed in this diff Show more