Add rate limit by ip for non-authenticated routes (#127)
Add rate limit by ip for non-authenticated routes Signed-off-by: kolaente <k@knt.li> Co-authored-by: kolaente <k@knt.li> Reviewed-on: https://kolaente.dev/vikunja/api/pulls/127
This commit is contained in:
parent
a464d1760c
commit
2abb858859
2 changed files with 21 additions and 16 deletions
|
@ -32,11 +32,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// RateLimit is the rate limit middleware
|
// RateLimit is the rate limit middleware
|
||||||
func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
|
func RateLimit(rateLimiter *limiter.Limiter, rateLimitKind string) echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) (err error) {
|
return func(c echo.Context) (err error) {
|
||||||
var rateLimitKey string
|
var rateLimitKey string
|
||||||
switch config.RateLimitKind.GetString() {
|
switch rateLimitKind {
|
||||||
case "ip":
|
case "ip":
|
||||||
rateLimitKey = c.RealIP()
|
rateLimitKey = c.RealIP()
|
||||||
case "user":
|
case "user":
|
||||||
|
@ -46,7 +46,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
|
||||||
}
|
}
|
||||||
rateLimitKey = "user_" + strconv.FormatInt(auth.GetID(), 10)
|
rateLimitKey = "user_" + strconv.FormatInt(auth.GetID(), 10)
|
||||||
default:
|
default:
|
||||||
log.Errorf("Unknown rate limit kind configured: %s", config.RateLimitKind.GetString())
|
log.Errorf("Unknown rate limit kind configured: %s", rateLimitKind)
|
||||||
}
|
}
|
||||||
limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey)
|
limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -74,7 +74,7 @@ func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRateLimit(a *echo.Group) {
|
func setupRateLimit(a *echo.Group, rateLimitKind string) {
|
||||||
if config.RateLimitEnabled.GetBool() {
|
if config.RateLimitEnabled.GetBool() {
|
||||||
rate := limiter.Rate{
|
rate := limiter.Rate{
|
||||||
Period: config.RateLimitPeriod.GetDuration() * time.Second,
|
Period: config.RateLimitPeriod.GetDuration() * time.Second,
|
||||||
|
@ -98,6 +98,6 @@ func setupRateLimit(a *echo.Group) {
|
||||||
}
|
}
|
||||||
rateLimiter := limiter.New(store, rate)
|
rateLimiter := limiter.New(store, rate)
|
||||||
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
|
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
|
||||||
a.Use(RateLimit(rateLimiter))
|
a.Use(RateLimit(rateLimiter, rateLimitKind))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,26 +159,31 @@ func RegisterRoutes(e *echo.Echo) {
|
||||||
|
|
||||||
func registerAPIRoutes(a *echo.Group) {
|
func registerAPIRoutes(a *echo.Group) {
|
||||||
|
|
||||||
|
// This is the group with no auth
|
||||||
|
// It is its own group to be able to rate limit this based on different heuristics
|
||||||
|
n := a.Group("")
|
||||||
|
setupRateLimit(n, "ip")
|
||||||
|
|
||||||
// Docs
|
// Docs
|
||||||
a.GET("/docs.json", apiv1.DocsJSON)
|
n.GET("/docs.json", apiv1.DocsJSON)
|
||||||
a.GET("/docs", apiv1.RedocUI)
|
n.GET("/docs", apiv1.RedocUI)
|
||||||
|
|
||||||
// Prometheus endpoint
|
// Prometheus endpoint
|
||||||
setupMetrics(a)
|
setupMetrics(n)
|
||||||
|
|
||||||
// User stuff
|
// User stuff
|
||||||
a.POST("/login", apiv1.Login)
|
n.POST("/login", apiv1.Login)
|
||||||
a.POST("/register", apiv1.RegisterUser)
|
n.POST("/register", apiv1.RegisterUser)
|
||||||
a.POST("/user/password/token", apiv1.UserRequestResetPasswordToken)
|
n.POST("/user/password/token", apiv1.UserRequestResetPasswordToken)
|
||||||
a.POST("/user/password/reset", apiv1.UserResetPassword)
|
n.POST("/user/password/reset", apiv1.UserResetPassword)
|
||||||
a.POST("/user/confirm", apiv1.UserConfirmEmail)
|
n.POST("/user/confirm", apiv1.UserConfirmEmail)
|
||||||
|
|
||||||
// Info endpoint
|
// Info endpoint
|
||||||
a.GET("/info", apiv1.Info)
|
n.GET("/info", apiv1.Info)
|
||||||
|
|
||||||
// Link share auth
|
// Link share auth
|
||||||
if config.ServiceEnableLinkSharing.GetBool() {
|
if config.ServiceEnableLinkSharing.GetBool() {
|
||||||
a.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare)
|
n.POST("/shares/:share/auth", apiv1.AuthenticateLinkShare)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== Routes with Authetication =====
|
// ===== Routes with Authetication =====
|
||||||
|
@ -186,7 +191,7 @@ func registerAPIRoutes(a *echo.Group) {
|
||||||
a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString())))
|
a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString())))
|
||||||
|
|
||||||
// Rate limit
|
// Rate limit
|
||||||
setupRateLimit(a)
|
setupRateLimit(a, config.RateLimitKind.GetString())
|
||||||
|
|
||||||
// Middleware to collect metrics
|
// Middleware to collect metrics
|
||||||
setupMetricsMiddleware(a)
|
setupMetricsMiddleware(a)
|
||||||
|
|
Loading…
Reference in a new issue