Refactor & fix storing struct-values in redis keyvalue

This commit is contained in:
kolaente 2021-05-28 10:52:32 +02:00
parent df45675df3
commit d48aa101cf
No known key found for this signature in database
GPG key ID: F40E70337AB24C9B
10 changed files with 117 additions and 59 deletions

View file

@ -91,12 +91,8 @@ func SetUserActive(a web.Auth) (err error) {
// getActiveUsers returns the active users from redis // getActiveUsers returns the active users from redis
func getActiveUsers() (users activeUsersMap, err error) { func getActiveUsers() (users activeUsersMap, err error) {
u, _, err := keyvalue.Get(ActiveUsersKey) users = activeUsersMap{}
if err != nil { _, err = keyvalue.GetWithValue(ActiveUsersKey, &users)
return nil, err
}
users = u.(activeUsersMap)
return return
} }

View file

@ -17,6 +17,8 @@
package metrics package metrics
import ( import (
"strconv"
"code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/modules/keyvalue" "code.vikunja.io/api/pkg/modules/keyvalue"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -132,7 +134,11 @@ func GetCount(key string) (count int64, err error) {
return 0, nil return 0, nil
} }
count = cnt.(int64) if s, is := cnt.(string); is {
count, err = strconv.ParseInt(s, 10, 64)
} else {
count = cnt.(int64)
}
return return
} }

View file

@ -46,12 +46,12 @@ type Callback struct {
// Provider is the structure of an OpenID Connect provider // Provider is the structure of an OpenID Connect provider
type Provider struct { type Provider struct {
Name string `json:"name"` Name string `json:"name"`
Key string `json:"key"` Key string `json:"key"`
AuthURL string `json:"auth_url"` AuthURL string `json:"auth_url"`
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
ClientSecret string `json:"-"` ClientSecret string `json:"-"`
OpenIDProvider *oidc.Provider `json:"-"` openIDProvider *oidc.Provider
Oauth2Config *oauth2.Config `json:"-"` Oauth2Config *oauth2.Config `json:"-"`
} }
@ -66,6 +66,11 @@ func init() {
rand.Seed(time.Now().UTC().UnixNano()) rand.Seed(time.Now().UTC().UnixNano())
} }
func (p *Provider) setOicdProvider() (err error) {
p.openIDProvider, err = oidc.NewProvider(context.Background(), p.AuthURL)
return err
}
// HandleCallback handles the auth request callback after redirecting from the provider with an auth code // HandleCallback handles the auth request callback after redirecting from the provider with an auth code
// @Summary Authenticate a user with OpenID Connect // @Summary Authenticate a user with OpenID Connect
// @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in. // @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in.
@ -122,7 +127,7 @@ func HandleCallback(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"}) return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"})
} }
verifier := provider.OpenIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID}) verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID})
// Parse and verify ID Token payload. // Parse and verify ID Token payload.
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
@ -140,7 +145,7 @@ func HandleCallback(c echo.Context) error {
} }
if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" { if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" {
info, err := provider.OpenIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token))
if err != nil { if err != nil {
log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err) log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err)
return handler.HandleHTTPError(err, c) return handler.HandleHTTPError(err, c)

View file

@ -17,7 +17,6 @@
package openid package openid
import ( import (
"context"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -36,7 +35,8 @@ func GetAllProviders() (providers []*Provider, err error) {
return nil, nil return nil, nil
} }
ps, exists, err := keyvalue.Get("openid_providers") providers = []*Provider{}
exists, err := keyvalue.GetWithValue("openid_providers", &providers)
if !exists { if !exists {
rawProviders := config.AuthOpenIDProviders.Get() rawProviders := config.AuthOpenIDProviders.Get()
if rawProviders == nil { if rawProviders == nil {
@ -68,31 +68,30 @@ func GetAllProviders() (providers []*Provider, err error) {
err = keyvalue.Put("openid_providers", providers) err = keyvalue.Put("openid_providers", providers)
} }
if ps != nil {
return ps.([]*Provider), nil
}
return return
} }
// GetProvider retrieves a provider from keyvalue // GetProvider retrieves a provider from keyvalue
func GetProvider(key string) (provider *Provider, err error) { func GetProvider(key string) (provider *Provider, err error) {
var p interface{} provider = &Provider{}
p, exists, err := keyvalue.Get("openid_provider_" + key) exists, err := keyvalue.GetWithValue("openid_provider_"+key, provider)
if err != nil {
return nil, err
}
if !exists { if !exists {
_, err = GetAllProviders() // This will put all providers in cache _, err = GetAllProviders() // This will put all providers in cache
if err != nil { if err != nil {
return nil, err return nil, err
} }
p, _, err = keyvalue.Get("openid_provider_" + key) _, err = keyvalue.GetWithValue("openid_provider_"+key, provider)
if err != nil {
return nil, err
}
} }
if p != nil { err = provider.setOicdProvider()
return p.(*Provider), nil return
}
return nil, err
} }
func getKeyFromName(name string) string { func getKeyFromName(name string) string {
@ -100,7 +99,7 @@ func getKeyFromName(name string) string {
return reg.ReplaceAllString(strings.ToLower(name), "") return reg.ReplaceAllString(strings.ToLower(name), "")
} }
func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { func getProviderFromMap(pi map[interface{}]interface{}) (provider *Provider, err error) {
name, is := pi["name"].(string) name, is := pi["name"].(string)
if !is { if !is {
return nil, nil return nil, nil
@ -108,7 +107,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
k := getKeyFromName(name) k := getKeyFromName(name)
provider := &Provider{ provider = &Provider{
Name: pi["name"].(string), Name: pi["name"].(string),
Key: k, Key: k,
AuthURL: pi["authurl"].(string), AuthURL: pi["authurl"].(string),
@ -122,10 +121,9 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
provider.ClientID = pi["clientid"].(string) provider.ClientID = pi["clientid"].(string)
} }
var err error err = provider.setOicdProvider()
provider.OpenIDProvider, err = oidc.NewProvider(context.Background(), provider.AuthURL)
if err != nil { if err != nil {
return provider, err return
} }
provider.Oauth2Config = &oauth2.Config{ provider.Oauth2Config = &oauth2.Config{
@ -134,7 +132,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k, RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k,
// Discovery returns the OAuth2 endpoints. // Discovery returns the OAuth2 endpoints.
Endpoint: provider.OpenIDProvider.Endpoint(), Endpoint: provider.openIDProvider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows. // "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
@ -142,5 +140,5 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
return provider, nil return
} }

View file

@ -127,7 +127,8 @@ func getCacheKey(prefix string, keys ...int64) string {
func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) { func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
cacheKey := getCacheKey("full", u.ID) cacheKey := getCacheKey("full", u.ID)
a, exists, err := keyvalue.Get(cacheKey) fullSizeAvatar = &image.RGBA64{}
exists, err := keyvalue.GetWithValue(cacheKey, fullSizeAvatar)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,8 +146,6 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
fullSizeAvatar = a.(*image.RGBA64)
} }
return fullSizeAvatar, nil return fullSizeAvatar, nil
@ -156,7 +155,7 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType string, err error) { func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType string, err error) {
cacheKey := getCacheKey("resized", u.ID, size) cacheKey := getCacheKey("resized", u.ID, size)
a, exists, err := keyvalue.Get(cacheKey) exists, err := keyvalue.GetWithValue(cacheKey, &avatar)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -180,7 +179,6 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType
return nil, "", err return nil, "", err
} }
} else { } else {
avatar = a.([]byte)
log.Debugf("Serving initials avatar for user %d and size %d from cache", u.ID, size) log.Debugf("Serving initials avatar for user %d and size %d from cache", u.ID, size)
} }

View file

@ -39,22 +39,17 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType
cacheKey := "avatar_upload_" + strconv.Itoa(int(u.ID)) cacheKey := "avatar_upload_" + strconv.Itoa(int(u.ID))
ai, exists, err := keyvalue.Get(cacheKey) var cached map[int64][]byte
exists, err := keyvalue.GetWithValue(cacheKey, &cached)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
var cached map[int64][]byte
if ai != nil {
cached = ai.(map[int64][]byte)
}
if !exists { if !exists {
// Nothing ever cached for this user so we need to create the size map to avoid panics // Nothing ever cached for this user so we need to create the size map to avoid panics
cached = make(map[int64][]byte) cached = make(map[int64][]byte)
} else { } else {
a := ai.(map[int64][]byte) a := cached
if a != nil && a[size] != nil { if a != nil && a[size] != nil {
log.Debugf("Serving uploaded avatar for user %d and size %d from cache.", u.ID, size) log.Debugf("Serving uploaded avatar for user %d and size %d from cache.", u.ID, size)
return a[size], "", nil return a[size], "", nil

View file

@ -122,7 +122,8 @@ func getImageID(fullURL string) string {
// Gets an unsplash photo either from cache or directly from the unsplash api // Gets an unsplash photo either from cache or directly from the unsplash api
func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) { func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
p, exists, err := keyvalue.Get(cachePrefix + photoID) photo = &Photo{}
exists, err := keyvalue.GetWithValue(cachePrefix+photoID, photo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,8 +135,6 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
if err != nil { if err != nil {
return return
} }
} else {
photo = p.(*Photo)
} }
return return
} }

View file

@ -26,6 +26,7 @@ import (
type Storage interface { type Storage interface {
Put(key string, value interface{}) (err error) Put(key string, value interface{}) (err error)
Get(key string) (value interface{}, exists bool, err error) Get(key string) (value interface{}, exists bool, err error)
GetWithValue(key string, value interface{}) (exists bool, err error)
Del(key string) (err error) Del(key string) (err error)
IncrBy(key string, update int64) (err error) IncrBy(key string, update int64) (err error)
DecrBy(key string, update int64) (err error) DecrBy(key string, update int64) (err error)
@ -55,6 +56,10 @@ func Get(key string) (value interface{}, exists bool, err error) {
return store.Get(key) return store.Get(key)
} }
func GetWithValue(key string, value interface{}) (exists bool, err error) {
return store.GetWithValue(key, value)
}
// Del removes a save value from a storage backend // Del removes a save value from a storage backend
func Del(key string) (err error) { func Del(key string) (err error) {
return store.Del(key) return store.Del(key)

View file

@ -17,6 +17,7 @@
package memory package memory
import ( import (
"reflect"
"sync" "sync"
e "code.vikunja.io/api/pkg/modules/keyvalue/error" e "code.vikunja.io/api/pkg/modules/keyvalue/error"
@ -52,6 +53,21 @@ func (s *Storage) Get(key string) (value interface{}, exists bool, err error) {
return return
} }
func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) {
v, exists, err := s.Get(key)
if !exists {
return exists, err
}
val := reflect.ValueOf(value)
if val.Kind() != reflect.Ptr {
panic("some: check must be a pointer")
}
val.Elem().Set(reflect.ValueOf(v))
return exists, err
}
// Del removes a saved value from a memory storage // Del removes a saved value from a memory storage
func (s *Storage) Del(key string) (err error) { func (s *Storage) Del(key string) (err error) {
s.mutex.Lock() s.mutex.Lock()

View file

@ -17,8 +17,10 @@
package redis package redis
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/gob"
"errors"
"code.vikunja.io/api/pkg/red" "code.vikunja.io/api/pkg/red"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
@ -40,9 +42,28 @@ func NewStorage() *Storage {
// Put puts a value into redis // Put puts a value into redis
func (s *Storage) Put(key string, value interface{}) (err error) { func (s *Storage) Put(key string, value interface{}) (err error) {
v, err := json.Marshal(value)
if err != nil { var v interface{}
return err
switch value.(type) {
case int:
v = value
case int8:
v = value
case int16:
v = value
case int32:
v = value
case int64:
v = value
default:
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err = enc.Encode(value)
if err != nil {
return err
}
return s.client.Set(context.Background(), key, buf.Bytes(), 0).Err()
} }
return s.client.Set(context.Background(), key, v, 0).Err() return s.client.Set(context.Background(), key, v, 0).Err()
@ -50,13 +71,32 @@ func (s *Storage) Put(key string, value interface{}) (err error) {
// Get retrieves a saved value from redis // Get retrieves a saved value from redis
func (s *Storage) Get(key string) (value interface{}, exists bool, err error) { func (s *Storage) Get(key string) (value interface{}, exists bool, err error) {
value, err = s.client.Get(context.Background(), key).Result()
if err != nil && errors.Is(err, redis.Nil) {
return nil, false, nil
}
return value, true, err
}
func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) {
b, err := s.client.Get(context.Background(), key).Bytes() b, err := s.client.Get(context.Background(), key).Bytes()
if err != nil { if err != nil {
return nil, false, err if errors.Is(err, redis.Nil) {
return false, nil
}
return
} }
err = json.Unmarshal(b, value) var buf bytes.Buffer
return _, err = buf.Write(b)
if err != nil {
return
}
dec := gob.NewDecoder(&buf)
err = dec.Decode(value)
return true, err
} }
// Del removed a value from redis // Del removed a value from redis