diff --git a/sub/sub.go b/sub/sub.go index be541ed2..b642f7f2 100644 --- a/sub/sub.go +++ b/sub/sub.go @@ -7,10 +7,10 @@ import ( "net" "net/http" "strconv" - "strings" "x-ui/config" "x-ui/logger" "x-ui/util/common" + "x-ui/web/middleware" "x-ui/web/network" "x-ui/web/service" @@ -58,18 +58,7 @@ func (s *Server) initRouter() (*gin.Engine, error) { } if subDomain != "" { - validateDomain := func(c *gin.Context) { - host := strings.Split(c.Request.Host, ":")[0] - - if host != subDomain { - c.AbortWithStatus(http.StatusForbidden) - return - } - - c.Next() - } - - engine.Use(validateDomain) + engine.Use(middleware.DomainValidatorMiddleware(subDomain)) } g := engine.Group(subPath) @@ -116,11 +105,13 @@ func (s *Server) Start() (err error) { if err != nil { return err } + listenAddr := net.JoinHostPort(listen, strconv.Itoa(port)) listener, err := net.Listen("tcp", listenAddr) if err != nil { return err } + if certFile != "" || keyFile != "" { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { diff --git a/web/web.go b/web/web.go index a8c9a533..02e858bc 100644 --- a/web/web.go +++ b/web/web.go @@ -19,6 +19,7 @@ import ( "x-ui/web/controller" "x-ui/web/job" "x-ui/web/locale" + "x-ui/web/middleware" "x-ui/web/network" "x-ui/web/service" @@ -155,6 +156,15 @@ func (s *Server) initRouter() (*gin.Engine, error) { engine := gin.Default() + webDomain, err := s.settingService.GetWebDomain() + if err != nil { + return nil, err + } + + if webDomain != "" { + engine.Use(middleware.DomainValidatorMiddleware(webDomain)) + } + secret, err := s.settingService.GetSecret() if err != nil { return nil, err