diff --git a/pkg/routes/rate_limit.go b/pkg/routes/rate_limit.go index 02002731..75993cb7 100644 --- a/pkg/routes/rate_limit.go +++ b/pkg/routes/rate_limit.go @@ -74,29 +74,33 @@ func RateLimit(rateLimiter *limiter.Limiter, rateLimitKind string) echo.Middlewa } } +func createRateLimiter(rate limiter.Rate) *limiter.Limiter { + var store limiter.Store + var err error + switch config.RateLimitStore.GetString() { + case "memory": + store = memory.NewStore() + case "redis": + if !config.RedisEnabled.GetBool() { + log.Fatal("Redis is configured for rate limiting, but not enabled!") + } + store, err = redis.NewStore(red.GetRedis()) + if err != nil { + log.Fatalf("Error while creating rate limit redis store: %s", err) + } + default: + log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString()) + } + return limiter.New(store, rate) +} + func setupRateLimit(a *echo.Group, rateLimitKind string) { if config.RateLimitEnabled.GetBool() { rate := limiter.Rate{ Period: config.RateLimitPeriod.GetDuration() * time.Second, Limit: config.RateLimitLimit.GetInt64(), } - var store limiter.Store - var err error - switch config.RateLimitStore.GetString() { - case "memory": - store = memory.NewStore() - case "redis": - if !config.RedisEnabled.GetBool() { - log.Fatal("Redis is configured for rate limiting, but not enabled!") - } - store, err = redis.NewStore(red.GetRedis()) - if err != nil { - log.Fatalf("Error while creating rate limit redis store: %s", err) - } - default: - log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString()) - } - rateLimiter := limiter.New(store, rate) + rateLimiter := createRateLimiter(rate) log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period) a.Use(RateLimit(rateLimiter, rateLimitKind)) } diff --git a/pkg/routes/routes.go b/pkg/routes/routes.go index f0bac083..044f3a30 100644 --- a/pkg/routes/routes.go +++ b/pkg/routes/routes.go @@ -52,6 +52,8 @@ import ( "strings" "time" + "github.com/ulule/limiter/v3" + vikunja_file "code.vikunja.io/api/pkg/modules/migration/vikunja-file" "code.vikunja.io/api/pkg/config" @@ -235,17 +237,26 @@ func registerAPIRoutes(a *echo.Group) { // Prometheus endpoint setupMetrics(n) + // Separate route for unauthenticated routes to enable rate limits for it + ur := a.Group("") + rate := limiter.Rate{ + Period: 60 * time.Second, + Limit: 10, + } + rateLimiter := createRateLimiter(rate) + ur.Use(RateLimit(rateLimiter, "ip")) + if config.AuthLocalEnabled.GetBool() { // User stuff - n.POST("/login", apiv1.Login) - n.POST("/register", apiv1.RegisterUser) - n.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) - n.POST("/user/password/reset", apiv1.UserResetPassword) - n.POST("/user/confirm", apiv1.UserConfirmEmail) + ur.POST("/login", apiv1.Login) + ur.POST("/register", apiv1.RegisterUser) + ur.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) + ur.POST("/user/password/reset", apiv1.UserResetPassword) + ur.POST("/user/confirm", apiv1.UserConfirmEmail) } if config.AuthOpenIDEnabled.GetBool() { - n.POST("/auth/openid/:provider/callback", openid.HandleCallback) + ur.POST("/auth/openid/:provider/callback", openid.HandleCallback) } // Testing @@ -261,7 +272,7 @@ func registerAPIRoutes(a *echo.Group) { // Link share auth if config.ServiceEnableLinkSharing.GetBool() { - n.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) + ur.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) } // ===== Routes with Authetication =====