feat: enable rate limit for unauthenticated routes

This commit is contained in:
kolaente 2021-11-14 20:42:33 +01:00
parent da2d5e41c7
commit 093d0c65ca
No known key found for this signature in database
GPG key ID: F40E70337AB24C9B
2 changed files with 39 additions and 24 deletions

View file

@ -74,12 +74,7 @@ func RateLimit(rateLimiter *limiter.Limiter, rateLimitKind string) echo.Middlewa
} }
} }
func setupRateLimit(a *echo.Group, rateLimitKind string) { func createRateLimiter(rate limiter.Rate) *limiter.Limiter {
if config.RateLimitEnabled.GetBool() {
rate := limiter.Rate{
Period: config.RateLimitPeriod.GetDuration() * time.Second,
Limit: config.RateLimitLimit.GetInt64(),
}
var store limiter.Store var store limiter.Store
var err error var err error
switch config.RateLimitStore.GetString() { switch config.RateLimitStore.GetString() {
@ -96,7 +91,16 @@ func setupRateLimit(a *echo.Group, rateLimitKind string) {
default: default:
log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString()) log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString())
} }
rateLimiter := limiter.New(store, rate) return limiter.New(store, rate)
}
func setupRateLimit(a *echo.Group, rateLimitKind string) {
if config.RateLimitEnabled.GetBool() {
rate := limiter.Rate{
Period: config.RateLimitPeriod.GetDuration() * time.Second,
Limit: config.RateLimitLimit.GetInt64(),
}
rateLimiter := createRateLimiter(rate)
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period) log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
a.Use(RateLimit(rateLimiter, rateLimitKind)) a.Use(RateLimit(rateLimiter, rateLimitKind))
} }

View file

@ -52,6 +52,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/ulule/limiter/v3"
vikunja_file "code.vikunja.io/api/pkg/modules/migration/vikunja-file" vikunja_file "code.vikunja.io/api/pkg/modules/migration/vikunja-file"
"code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/config"
@ -235,17 +237,26 @@ func registerAPIRoutes(a *echo.Group) {
// Prometheus endpoint // Prometheus endpoint
setupMetrics(n) setupMetrics(n)
// Separate route for unauthenticated routes to enable rate limits for it
ur := a.Group("")
rate := limiter.Rate{
Period: 60 * time.Second,
Limit: 10,
}
rateLimiter := createRateLimiter(rate)
ur.Use(RateLimit(rateLimiter, "ip"))
if config.AuthLocalEnabled.GetBool() { if config.AuthLocalEnabled.GetBool() {
// User stuff // User stuff
n.POST("/login", apiv1.Login) ur.POST("/login", apiv1.Login)
n.POST("/register", apiv1.RegisterUser) ur.POST("/register", apiv1.RegisterUser)
n.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) ur.POST("/user/password/token", apiv1.UserRequestResetPasswordToken)
n.POST("/user/password/reset", apiv1.UserResetPassword) ur.POST("/user/password/reset", apiv1.UserResetPassword)
n.POST("/user/confirm", apiv1.UserConfirmEmail) ur.POST("/user/confirm", apiv1.UserConfirmEmail)
} }
if config.AuthOpenIDEnabled.GetBool() { if config.AuthOpenIDEnabled.GetBool() {
n.POST("/auth/openid/:provider/callback", openid.HandleCallback) ur.POST("/auth/openid/:provider/callback", openid.HandleCallback)
} }
// Testing // Testing
@ -261,7 +272,7 @@ func registerAPIRoutes(a *echo.Group) {
// Link share auth // Link share auth
if config.ServiceEnableLinkSharing.GetBool() { if config.ServiceEnableLinkSharing.GetBool() {
n.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) ur.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare)
} }
// ===== Routes with Authetication ===== // ===== Routes with Authetication =====