Refactor & fix storing struct-values in redis keyvalue
This commit is contained in:
parent
df45675df3
commit
d48aa101cf
10 changed files with 117 additions and 59 deletions
|
@ -91,12 +91,8 @@ func SetUserActive(a web.Auth) (err error) {
|
||||||
|
|
||||||
// getActiveUsers returns the active users from redis
|
// getActiveUsers returns the active users from redis
|
||||||
func getActiveUsers() (users activeUsersMap, err error) {
|
func getActiveUsers() (users activeUsersMap, err error) {
|
||||||
u, _, err := keyvalue.Get(ActiveUsersKey)
|
users = activeUsersMap{}
|
||||||
if err != nil {
|
_, err = keyvalue.GetWithValue(ActiveUsersKey, &users)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users = u.(activeUsersMap)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package metrics
|
package metrics
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"code.vikunja.io/api/pkg/log"
|
"code.vikunja.io/api/pkg/log"
|
||||||
"code.vikunja.io/api/pkg/modules/keyvalue"
|
"code.vikunja.io/api/pkg/modules/keyvalue"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
@ -132,7 +134,11 @@ func GetCount(key string) (count int64, err error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s, is := cnt.(string); is {
|
||||||
|
count, err = strconv.ParseInt(s, 10, 64)
|
||||||
|
} else {
|
||||||
count = cnt.(int64)
|
count = cnt.(int64)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,7 @@ type Provider struct {
|
||||||
AuthURL string `json:"auth_url"`
|
AuthURL string `json:"auth_url"`
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
ClientSecret string `json:"-"`
|
ClientSecret string `json:"-"`
|
||||||
OpenIDProvider *oidc.Provider `json:"-"`
|
openIDProvider *oidc.Provider
|
||||||
Oauth2Config *oauth2.Config `json:"-"`
|
Oauth2Config *oauth2.Config `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,11 @@ func init() {
|
||||||
rand.Seed(time.Now().UTC().UnixNano())
|
rand.Seed(time.Now().UTC().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) setOicdProvider() (err error) {
|
||||||
|
p.openIDProvider, err = oidc.NewProvider(context.Background(), p.AuthURL)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// HandleCallback handles the auth request callback after redirecting from the provider with an auth code
|
// HandleCallback handles the auth request callback after redirecting from the provider with an auth code
|
||||||
// @Summary Authenticate a user with OpenID Connect
|
// @Summary Authenticate a user with OpenID Connect
|
||||||
// @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in.
|
// @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in.
|
||||||
|
@ -122,7 +127,7 @@ func HandleCallback(c echo.Context) error {
|
||||||
return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"})
|
return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"})
|
||||||
}
|
}
|
||||||
|
|
||||||
verifier := provider.OpenIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID})
|
verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID})
|
||||||
|
|
||||||
// Parse and verify ID Token payload.
|
// Parse and verify ID Token payload.
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
|
@ -140,7 +145,7 @@ func HandleCallback(c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" {
|
if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" {
|
||||||
info, err := provider.OpenIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token))
|
info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err)
|
log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err)
|
||||||
return handler.HandleHTTPError(err, c)
|
return handler.HandleHTTPError(err, c)
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package openid
|
package openid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -36,7 +35,8 @@ func GetAllProviders() (providers []*Provider, err error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ps, exists, err := keyvalue.Get("openid_providers")
|
providers = []*Provider{}
|
||||||
|
exists, err := keyvalue.GetWithValue("openid_providers", &providers)
|
||||||
if !exists {
|
if !exists {
|
||||||
rawProviders := config.AuthOpenIDProviders.Get()
|
rawProviders := config.AuthOpenIDProviders.Get()
|
||||||
if rawProviders == nil {
|
if rawProviders == nil {
|
||||||
|
@ -68,31 +68,30 @@ func GetAllProviders() (providers []*Provider, err error) {
|
||||||
err = keyvalue.Put("openid_providers", providers)
|
err = keyvalue.Put("openid_providers", providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ps != nil {
|
|
||||||
return ps.([]*Provider), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvider retrieves a provider from keyvalue
|
// GetProvider retrieves a provider from keyvalue
|
||||||
func GetProvider(key string) (provider *Provider, err error) {
|
func GetProvider(key string) (provider *Provider, err error) {
|
||||||
var p interface{}
|
provider = &Provider{}
|
||||||
p, exists, err := keyvalue.Get("openid_provider_" + key)
|
exists, err := keyvalue.GetWithValue("openid_provider_"+key, provider)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if !exists {
|
if !exists {
|
||||||
_, err = GetAllProviders() // This will put all providers in cache
|
_, err = GetAllProviders() // This will put all providers in cache
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p, _, err = keyvalue.Get("openid_provider_" + key)
|
_, err = keyvalue.GetWithValue("openid_provider_"+key, provider)
|
||||||
}
|
if err != nil {
|
||||||
|
|
||||||
if p != nil {
|
|
||||||
return p.(*Provider), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = provider.setOicdProvider()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getKeyFromName(name string) string {
|
func getKeyFromName(name string) string {
|
||||||
|
@ -100,7 +99,7 @@ func getKeyFromName(name string) string {
|
||||||
return reg.ReplaceAllString(strings.ToLower(name), "")
|
return reg.ReplaceAllString(strings.ToLower(name), "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
|
func getProviderFromMap(pi map[interface{}]interface{}) (provider *Provider, err error) {
|
||||||
name, is := pi["name"].(string)
|
name, is := pi["name"].(string)
|
||||||
if !is {
|
if !is {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -108,7 +107,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
|
||||||
|
|
||||||
k := getKeyFromName(name)
|
k := getKeyFromName(name)
|
||||||
|
|
||||||
provider := &Provider{
|
provider = &Provider{
|
||||||
Name: pi["name"].(string),
|
Name: pi["name"].(string),
|
||||||
Key: k,
|
Key: k,
|
||||||
AuthURL: pi["authurl"].(string),
|
AuthURL: pi["authurl"].(string),
|
||||||
|
@ -122,10 +121,9 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
|
||||||
provider.ClientID = pi["clientid"].(string)
|
provider.ClientID = pi["clientid"].(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
err = provider.setOicdProvider()
|
||||||
provider.OpenIDProvider, err = oidc.NewProvider(context.Background(), provider.AuthURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return provider, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
provider.Oauth2Config = &oauth2.Config{
|
provider.Oauth2Config = &oauth2.Config{
|
||||||
|
@ -134,7 +132,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
|
||||||
RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k,
|
RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k,
|
||||||
|
|
||||||
// Discovery returns the OAuth2 endpoints.
|
// Discovery returns the OAuth2 endpoints.
|
||||||
Endpoint: provider.OpenIDProvider.Endpoint(),
|
Endpoint: provider.openIDProvider.Endpoint(),
|
||||||
|
|
||||||
// "openid" is a required scope for OpenID Connect flows.
|
// "openid" is a required scope for OpenID Connect flows.
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
@ -142,5 +140,5 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
|
||||||
|
|
||||||
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
|
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
|
||||||
|
|
||||||
return provider, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,7 +127,8 @@ func getCacheKey(prefix string, keys ...int64) string {
|
||||||
func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
|
func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
|
||||||
cacheKey := getCacheKey("full", u.ID)
|
cacheKey := getCacheKey("full", u.ID)
|
||||||
|
|
||||||
a, exists, err := keyvalue.Get(cacheKey)
|
fullSizeAvatar = &image.RGBA64{}
|
||||||
|
exists, err := keyvalue.GetWithValue(cacheKey, fullSizeAvatar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -145,8 +146,6 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
fullSizeAvatar = a.(*image.RGBA64)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fullSizeAvatar, nil
|
return fullSizeAvatar, nil
|
||||||
|
@ -156,7 +155,7 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) {
|
||||||
func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType string, err error) {
|
func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType string, err error) {
|
||||||
cacheKey := getCacheKey("resized", u.ID, size)
|
cacheKey := getCacheKey("resized", u.ID, size)
|
||||||
|
|
||||||
a, exists, err := keyvalue.Get(cacheKey)
|
exists, err := keyvalue.GetWithValue(cacheKey, &avatar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -180,7 +179,6 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
avatar = a.([]byte)
|
|
||||||
log.Debugf("Serving initials avatar for user %d and size %d from cache", u.ID, size)
|
log.Debugf("Serving initials avatar for user %d and size %d from cache", u.ID, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,22 +39,17 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType
|
||||||
|
|
||||||
cacheKey := "avatar_upload_" + strconv.Itoa(int(u.ID))
|
cacheKey := "avatar_upload_" + strconv.Itoa(int(u.ID))
|
||||||
|
|
||||||
ai, exists, err := keyvalue.Get(cacheKey)
|
var cached map[int64][]byte
|
||||||
|
exists, err := keyvalue.GetWithValue(cacheKey, &cached)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var cached map[int64][]byte
|
|
||||||
|
|
||||||
if ai != nil {
|
|
||||||
cached = ai.(map[int64][]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
// Nothing ever cached for this user so we need to create the size map to avoid panics
|
// Nothing ever cached for this user so we need to create the size map to avoid panics
|
||||||
cached = make(map[int64][]byte)
|
cached = make(map[int64][]byte)
|
||||||
} else {
|
} else {
|
||||||
a := ai.(map[int64][]byte)
|
a := cached
|
||||||
if a != nil && a[size] != nil {
|
if a != nil && a[size] != nil {
|
||||||
log.Debugf("Serving uploaded avatar for user %d and size %d from cache.", u.ID, size)
|
log.Debugf("Serving uploaded avatar for user %d and size %d from cache.", u.ID, size)
|
||||||
return a[size], "", nil
|
return a[size], "", nil
|
||||||
|
|
|
@ -122,7 +122,8 @@ func getImageID(fullURL string) string {
|
||||||
// Gets an unsplash photo either from cache or directly from the unsplash api
|
// Gets an unsplash photo either from cache or directly from the unsplash api
|
||||||
func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
|
func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
|
||||||
|
|
||||||
p, exists, err := keyvalue.Get(cachePrefix + photoID)
|
photo = &Photo{}
|
||||||
|
exists, err := keyvalue.GetWithValue(cachePrefix+photoID, photo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -134,8 +135,6 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
photo = p.(*Photo)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
Put(key string, value interface{}) (err error)
|
Put(key string, value interface{}) (err error)
|
||||||
Get(key string) (value interface{}, exists bool, err error)
|
Get(key string) (value interface{}, exists bool, err error)
|
||||||
|
GetWithValue(key string, value interface{}) (exists bool, err error)
|
||||||
Del(key string) (err error)
|
Del(key string) (err error)
|
||||||
IncrBy(key string, update int64) (err error)
|
IncrBy(key string, update int64) (err error)
|
||||||
DecrBy(key string, update int64) (err error)
|
DecrBy(key string, update int64) (err error)
|
||||||
|
@ -55,6 +56,10 @@ func Get(key string) (value interface{}, exists bool, err error) {
|
||||||
return store.Get(key)
|
return store.Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetWithValue(key string, value interface{}) (exists bool, err error) {
|
||||||
|
return store.GetWithValue(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
// Del removes a save value from a storage backend
|
// Del removes a save value from a storage backend
|
||||||
func Del(key string) (err error) {
|
func Del(key string) (err error) {
|
||||||
return store.Del(key)
|
return store.Del(key)
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
e "code.vikunja.io/api/pkg/modules/keyvalue/error"
|
e "code.vikunja.io/api/pkg/modules/keyvalue/error"
|
||||||
|
@ -52,6 +53,21 @@ func (s *Storage) Get(key string) (value interface{}, exists bool, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) {
|
||||||
|
v, exists, err := s.Get(key)
|
||||||
|
if !exists {
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.ValueOf(value)
|
||||||
|
if val.Kind() != reflect.Ptr {
|
||||||
|
panic("some: check must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
val.Elem().Set(reflect.ValueOf(v))
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|
||||||
// Del removes a saved value from a memory storage
|
// Del removes a saved value from a memory storage
|
||||||
func (s *Storage) Del(key string) (err error) {
|
func (s *Storage) Del(key string) (err error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
|
|
|
@ -17,8 +17,10 @@
|
||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/gob"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"code.vikunja.io/api/pkg/red"
|
"code.vikunja.io/api/pkg/red"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
|
@ -40,23 +42,61 @@ func NewStorage() *Storage {
|
||||||
|
|
||||||
// Put puts a value into redis
|
// Put puts a value into redis
|
||||||
func (s *Storage) Put(key string, value interface{}) (err error) {
|
func (s *Storage) Put(key string, value interface{}) (err error) {
|
||||||
v, err := json.Marshal(value)
|
|
||||||
|
var v interface{}
|
||||||
|
|
||||||
|
switch value.(type) {
|
||||||
|
case int:
|
||||||
|
v = value
|
||||||
|
case int8:
|
||||||
|
v = value
|
||||||
|
case int16:
|
||||||
|
v = value
|
||||||
|
case int32:
|
||||||
|
v = value
|
||||||
|
case int64:
|
||||||
|
v = value
|
||||||
|
default:
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buf)
|
||||||
|
err = enc.Encode(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return s.client.Set(context.Background(), key, buf.Bytes(), 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
return s.client.Set(context.Background(), key, v, 0).Err()
|
return s.client.Set(context.Background(), key, v, 0).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a saved value from redis
|
// Get retrieves a saved value from redis
|
||||||
func (s *Storage) Get(key string) (value interface{}, exists bool, err error) {
|
func (s *Storage) Get(key string) (value interface{}, exists bool, err error) {
|
||||||
|
value, err = s.client.Get(context.Background(), key).Result()
|
||||||
|
if err != nil && errors.Is(err, redis.Nil) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return value, true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) {
|
||||||
b, err := s.client.Get(context.Background(), key).Bytes()
|
b, err := s.client.Get(context.Background(), key).Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(b, value)
|
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, err = buf.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dec := gob.NewDecoder(&buf)
|
||||||
|
err = dec.Decode(value)
|
||||||
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Del removed a value from redis
|
// Del removed a value from redis
|
||||||
|
|
Loading…
Reference in a new issue