Add rate limit by ip for non-authenticated routes (#127)

Add rate limit by ip for non-authenticated routes

Signed-off-by: kolaente <k@knt.li>

Co-authored-by: kolaente <k@knt.li>
Reviewed-on: https://kolaente.dev/vikunja/api/pulls/127
This commit is contained in:
konrad 2020-01-26 19:53:47 +00:00
parent a464d1760c
commit 2abb858859
2 changed files with 21 additions and 16 deletions

View file

@ -32,11 +32,11 @@ import (
) )
// RateLimit is the rate limit middleware // RateLimit is the rate limit middleware
func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc { func RateLimit(rateLimiter *limiter.Limiter, rateLimitKind string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
var rateLimitKey string var rateLimitKey string
switch config.RateLimitKind.GetString() { switch rateLimitKind {
case "ip": case "ip":
rateLimitKey = c.RealIP() rateLimitKey = c.RealIP()
case "user": case "user":
@ -46,7 +46,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
} }
rateLimitKey = "user_" + strconv.FormatInt(auth.GetID(), 10) rateLimitKey = "user_" + strconv.FormatInt(auth.GetID(), 10)
default: default:
log.Errorf("Unknown rate limit kind configured: %s", config.RateLimitKind.GetString()) log.Errorf("Unknown rate limit kind configured: %s", rateLimitKind)
} }
limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey) limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey)
if err != nil { if err != nil {
@ -74,7 +74,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
} }
} }
func setupRateLimit(a *echo.Group) { func setupRateLimit(a *echo.Group, rateLimitKind string) {
if config.RateLimitEnabled.GetBool() { if config.RateLimitEnabled.GetBool() {
rate := limiter.Rate{ rate := limiter.Rate{
Period: config.RateLimitPeriod.GetDuration() * time.Second, Period: config.RateLimitPeriod.GetDuration() * time.Second,
@ -98,6 +98,6 @@ func setupRateLimit(a *echo.Group) {
} }
rateLimiter := limiter.New(store, rate) rateLimiter := limiter.New(store, rate)
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period) log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
a.Use(RateLimit(rateLimiter)) a.Use(RateLimit(rateLimiter, rateLimitKind))
} }
} }

View file

@ -159,26 +159,31 @@ func RegisterRoutes(e *echo.Echo) {
func registerAPIRoutes(a *echo.Group) { func registerAPIRoutes(a *echo.Group) {
// This is the group with no auth
// It is its own group to be able to rate limit this based on different heuristics
n := a.Group("")
setupRateLimit(n, "ip")
// Docs // Docs
a.GET("/docs.json", apiv1.DocsJSON) n.GET("/docs.json", apiv1.DocsJSON)
a.GET("/docs", apiv1.RedocUI) n.GET("/docs", apiv1.RedocUI)
// Prometheus endpoint // Prometheus endpoint
setupMetrics(a) setupMetrics(n)
// User stuff // User stuff
a.POST("/login", apiv1.Login) n.POST("/login", apiv1.Login)
a.POST("/register", apiv1.RegisterUser) n.POST("/register", apiv1.RegisterUser)
a.POST("/user/password/token", apiv1.UserRequestResetPasswordToken) n.POST("/user/password/token", apiv1.UserRequestResetPasswordToken)
a.POST("/user/password/reset", apiv1.UserResetPassword) n.POST("/user/password/reset", apiv1.UserResetPassword)
a.POST("/user/confirm", apiv1.UserConfirmEmail) n.POST("/user/confirm", apiv1.UserConfirmEmail)
// Info endpoint // Info endpoint
a.GET("/info", apiv1.Info) n.GET("/info", apiv1.Info)
// Link share auth // Link share auth
if config.ServiceEnableLinkSharing.GetBool() { if config.ServiceEnableLinkSharing.GetBool() {
a.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare) n.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare)
} }
// ===== Routes with Authetication ===== // ===== Routes with Authetication =====
@ -186,7 +191,7 @@ func registerAPIRoutes(a *echo.Group) {
a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString()))) a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString())))
// Rate limit // Rate limit
setupRateLimit(a) setupRateLimit(a, config.RateLimitKind.GetString())
// Middleware to collect metrics // Middleware to collect metrics
setupMetricsMiddleware(a) setupMetricsMiddleware(a)