From 1329b2bb61daf5310229107d146557f4441d2d7f Mon Sep 17 00:00:00 2001 From: Alireza Ahmadi Date: Thu, 22 Feb 2024 14:18:31 +0100 Subject: [PATCH] support multi address-port #997 --- database/model/model.go | 2 +- web/controller/inbound.go | 12 ++++++++++-- web/service/inbound.go | 32 +++++++++++++++++++++++++++----- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/database/model/model.go b/database/model/model.go index e2559d40..f7cdf322 100644 --- a/database/model/model.go +++ b/database/model/model.go @@ -36,7 +36,7 @@ type Inbound struct { // config part Listen string `json:"listen" form:"listen"` - Port int `json:"port" form:"port" gorm:"unique"` + Port int `json:"port" form:"port"` Protocol Protocol `json:"protocol" form:"protocol"` Settings string `json:"settings" form:"settings"` StreamSettings string `json:"streamSettings" form:"streamSettings"` diff --git a/web/controller/inbound.go b/web/controller/inbound.go index 02573bd7..4be7b454 100644 --- a/web/controller/inbound.go +++ b/web/controller/inbound.go @@ -82,7 +82,11 @@ func (a *InboundController) addInbound(c *gin.Context) { } user := session.GetLoginUser(c) inbound.UserId = user.Id - inbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + if inbound.Listen == "" || inbound.Listen == "0.0.0.0" || inbound.Listen == "::" || inbound.Listen == "::0" { + inbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + } else { + inbound.Tag = fmt.Sprintf("inbound-%v:%v", inbound.Listen, inbound.Port) + } needRestart := false inbound, needRestart, err = a.inboundService.AddInbound(inbound) @@ -266,7 +270,11 @@ func (a *InboundController) importInbound(c *gin.Context) { user := session.GetLoginUser(c) inbound.Id = 0 inbound.UserId = user.Id - inbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + if inbound.Listen == "" || inbound.Listen == "0.0.0.0" || inbound.Listen == "::" || inbound.Listen == "::0" { + inbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + } else { + inbound.Tag = fmt.Sprintf("inbound-%v:%v", inbound.Listen, inbound.Port) + } for index := range inbound.ClientStats { inbound.ClientStats[index].Id = 0 diff --git a/web/service/inbound.go b/web/service/inbound.go index 0bedaf1b..1acb2523 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -38,9 +38,25 @@ func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) { return inbounds, nil } -func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) { +func (s *InboundService) checkPortExist(listen string, port int, ignoreId int) (bool, error) { db := database.GetDB() - db = db.Model(model.Inbound{}).Where("port = ?", port) + if listen == "" || listen == "0.0.0.0" || listen == "::" || listen == "::0" { + db = db.Model(model.Inbound{}).Where("port = ?", port) + } else { + db = db.Model(model.Inbound{}). + Where("port = ?", port). + Where( + db.Model(model.Inbound{}).Where( + "listen = ?", listen, + ).Or( + "listen = \"\"", + ).Or( + "listen = \"0.0.0.0\"", + ).Or( + "listen = \"::\"", + ).Or( + "listen = \"::0\"")) + } if ignoreId > 0 { db = db.Where("id != ?", ignoreId) } @@ -135,7 +151,7 @@ func (s *InboundService) checkEmailExistForInbound(inbound *model.Inbound) (stri } func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) { - exist, err := s.checkPortExist(inbound.Port, 0) + exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0) if err != nil { return inbound, false, err } @@ -238,7 +254,7 @@ func (s *InboundService) GetInbound(id int) (*model.Inbound, error) { } func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) { - exist, err := s.checkPortExist(inbound.Port, inbound.Id) + exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id) if err != nil { return inbound, false, err } @@ -281,7 +297,11 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, oldInbound.Settings = inbound.Settings oldInbound.StreamSettings = inbound.StreamSettings oldInbound.Sniffing = inbound.Sniffing - oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + if inbound.Listen == "" || inbound.Listen == "0.0.0.0" || inbound.Listen == "::" || inbound.Listen == "::0" { + oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + } else { + oldInbound.Tag = fmt.Sprintf("inbound-%v:%v", inbound.Listen, inbound.Port) + } needRestart := false s.xrayApi.Init(p.GetAPIPort()) @@ -1309,6 +1329,8 @@ func (s *InboundService) MigrationRequirements() { } }() + db.Migrator().DropIndex(&model.Inbound{}, "port") + // Fix inbounds based problems var inbounds []*model.Inbound err = tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error