From 7141050f8b8cbb0d59008507529e3482c553edb1 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 2 Jul 2020 21:16:39 +0200 Subject: [PATCH] Make sure the metrics map accesses only happen explicitly --- pkg/metrics/active_users.go | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/pkg/metrics/active_users.go b/pkg/metrics/active_users.go index 2d07f5a5..a232ad2b 100644 --- a/pkg/metrics/active_users.go +++ b/pkg/metrics/active_users.go @@ -23,6 +23,7 @@ import ( "encoding/gob" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "sync" "time" ) @@ -39,13 +40,19 @@ type ActiveUser struct { } // ActiveUsersMap is the type used to save active users -type ActiveUsersMap map[int64]*ActiveUser +type ActiveUsers struct { + users map[int64]*ActiveUser + mutex *sync.Mutex +} // activeUsers holds a map with all active users -var activeUsers ActiveUsersMap +var activeUsers *ActiveUsers func init() { - activeUsers = make(ActiveUsersMap) + activeUsers = &ActiveUsers{ + users: make(map[int64]*ActiveUser), + mutex: &sync.Mutex{}, + } promauto.NewGaugeFunc(prometheus.GaugeOpts{ Name: "vikunja_active_users", @@ -56,8 +63,11 @@ func init() { if err != nil { log.Error(err.Error()) } + if allActiveUsers == nil { + return 0 + } activeUsersCount := 0 - for _, u := range allActiveUsers { + for _, u := range allActiveUsers.users { if time.Since(u.LastSeen) < SecondsUntilInactive*time.Second { activeUsersCount++ } @@ -68,15 +78,17 @@ func init() { // SetUserActive sets a user as active and pushes it to redis func SetUserActive(a web.Auth) (err error) { - activeUsers[a.GetID()] = &ActiveUser{ + activeUsers.mutex.Lock() + activeUsers.users[a.GetID()] = &ActiveUser{ UserID: a.GetID(), LastSeen: time.Now(), } + activeUsers.mutex.Unlock() return PushActiveUsers() } // GetActiveUsers returns the active users from redis -func GetActiveUsers() (users ActiveUsersMap, err error) { +func GetActiveUsers() (users *ActiveUsers, err error) { activeUsersR, err := r.Get(ActiveUsersKey).Bytes() if err != nil { @@ -102,7 +114,9 @@ func GetActiveUsers() (users ActiveUsersMap, err error) { func PushActiveUsers() (err error) { var b bytes.Buffer e := gob.NewEncoder(&b) - if err := e.Encode(activeUsers); err != nil { + activeUsers.mutex.Lock() + defer activeUsers.mutex.Unlock() + if err := e.Encode(activeUsers.users); err != nil { return err }