diff --git a/pkg/routes/rate_limit.go b/pkg/routes/rate_limit.go index 17558a11..8809f571 100644 --- a/pkg/routes/rate_limit.go +++ b/pkg/routes/rate_limit.go @@ -32,11 +32,11 @@ import ( ) // RateLimit is the rate limit middleware -func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc { +func RateLimit(rateLimiter *limiter.Limiter, rateLimitKind string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { var rateLimitKey string - switch config.RateLimitKind.GetString() { + switch rateLimitKind { case "ip": rateLimitKey = c.RealIP() case "user": @@ -46,7 +46,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc { } rateLimitKey = "user_" + strconv.FormatInt(auth.GetID(), 10) default: - log.Errorf("Unknown rate limit kind configured: %s", config.RateLimitKind.GetString()) + log.Errorf("Unknown rate limit kind configured: %s", rateLimitKind) } limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey) if err != nil { @@ -74,7 +74,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc { } } -func setupRateLimit(a *echo.Group) { +func setupRateLimit(a *echo.Group, rateLimitKind string) { if config.RateLimitEnabled.GetBool() { rate := limiter.Rate{ Period: config.RateLimitPeriod.GetDuration() * time.Second, @@ -98,6 +98,6 @@ func setupRateLimit(a *echo.Group) { } rateLimiter := limiter.New(store, rate) log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period) - a.Use(RateLimit(rateLimiter)) + a.Use(RateLimit(rateLimiter, rateLimitKind)) } } diff --git a/pkg/routes/routes.go b/pkg/routes/routes.go index aab84105..1172d410 100644 --- a/pkg/routes/routes.go +++ b/pkg/routes/routes.go @@ -159,26 +159,31 @@ func RegisterRoutes(e *echo.Echo) { func registerAPIRoutes(a *echo.Group) { + // This is the group with no auth + // It is its own group to be able to rate limit this based on different heuristics + n := a.Group("") + setupRateLimit(n, "ip") + // Docs - a.GET("/docs.json", apiv1.DocsJSON) - a.GET("/docs", apiv1.RedocUI) + n.GET("/docs.json", apiv1.DocsJSON) + n.GET("/docs", apiv1.RedocUI) // Prometheus endpoint - setupMetrics(a) + setupMetrics(n) // User stuff - a.POST("/login", apiv1.Login) - a.POST("/register", apiv1.RegisterUser) - a.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) - a.POST("/user/password/reset", apiv1.UserResetPassword) - a.POST("/user/confirm", apiv1.UserConfirmEmail) + n.POST("/login", apiv1.Login) + n.POST("/register", apiv1.RegisterUser) + n.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) + n.POST("/user/password/reset", apiv1.UserResetPassword) + n.POST("/user/confirm", apiv1.UserConfirmEmail) // Info endpoint - a.GET("/info", apiv1.Info) + n.GET("/info", apiv1.Info) // Link share auth if config.ServiceEnableLinkSharing.GetBool() { - a.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) + n.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) } // ===== Routes with Authetication ===== @@ -186,7 +191,7 @@ func registerAPIRoutes(a *echo.Group) { a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString()))) // Rate limit - setupRateLimit(a) + setupRateLimit(a, config.RateLimitKind.GetString()) // Middleware to collect metrics setupMetricsMiddleware(a)