diff --git a/web/service/inbound.go b/web/service/inbound.go index b00473e2..b116514d 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -19,6 +19,10 @@ type InboundService struct { xrayApi xray.XrayAPI } +const ( + safeBatchSize = 500 +) + func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) { db := database.GetDB() var inbounds []*model.Inbound @@ -789,6 +793,11 @@ func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic return nil } +type newExpiryTime struct { + Email string + NewExpiryTime int64 +} + func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) { if len(traffics) == 0 { // Empty onlineUsers @@ -805,9 +814,18 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr emails = append(emails, traffic.Email) } dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) - err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error - if err != nil { - return err + for i := 0; i < len(emails); i += safeBatchSize { + end := i + safeBatchSize + if end > len(emails) { + end = len(emails) + } + + batchClientTraffics := make([]*xray.ClientTraffic, 0, end-i) + err = tx.Model(xray.ClientTraffic{}).Where("email IN ?", emails[i:end]).Find(&batchClientTraffics).Error + if err != nil { + return err + } + dbClientTraffics = append(dbClientTraffics, batchClientTraffics...) } // Avoid empty slice error @@ -815,7 +833,20 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr return nil } - dbClientTraffics, err = s.adjustTraffics(tx, dbClientTraffics) + inboundExpiryTimeMap := make(map[int][]newExpiryTime, 0) + for index, t := range dbClientTraffics { + if t.ExpiryTime < 0 { + newClientExpiryTime := (time.Now().Unix() * 1000) - int64(t.ExpiryTime) + newExpiryTime := newExpiryTime{ + Email: t.Email, + NewExpiryTime: newClientExpiryTime, + } + inboundExpiryTimeMap[t.InboundId] = append(inboundExpiryTimeMap[t.InboundId], newExpiryTime) + dbClientTraffics[index].ExpiryTime = newClientExpiryTime + } + } + + err = s.adjustTraffics(tx, inboundExpiryTimeMap) if err != nil { return err } @@ -836,66 +867,90 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr } // Set onlineUsers - p.SetOnlineClients(onlineClients) + if p != nil { + p.SetOnlineClients(onlineClients) + } - err = tx.Save(dbClientTraffics).Error - if err != nil { - logger.Warning("AddClientTraffic update data ", err) + for i := 0; i < len(dbClientTraffics); i += safeBatchSize { + end := i + safeBatchSize + if end > len(dbClientTraffics) { + end = len(dbClientTraffics) + } + + err = tx.Save(dbClientTraffics[i:end]).Error + if err != nil { + logger.Warning("AddClientTraffic update data ", err) + return nil + } } return nil } -func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.ClientTraffic) ([]*xray.ClientTraffic, error) { - inboundIds := make([]int, 0, len(dbClientTraffics)) - for _, dbClientTraffic := range dbClientTraffics { - if dbClientTraffic.ExpiryTime < 0 { - inboundIds = append(inboundIds, dbClientTraffic.InboundId) +func (s *InboundService) adjustTraffics(tx *gorm.DB, inboundExpiryTimeMap map[int][]newExpiryTime) error { + if len(inboundExpiryTimeMap) == 0 { + return nil + } + + inboundIds := make([]int, 0) + for inId := range inboundExpiryTimeMap { + inboundIds = append(inboundIds, inId) + } + inbounds := make([]*model.Inbound, 0, len(inboundIds)) + for i := 0; i < len(inboundIds); i += safeBatchSize { + end := i + safeBatchSize + if end > len(inboundIds) { + end = len(inboundIds) + } + + batchInbounds := make([]*model.Inbound, 0, end-i) + err := tx.Model(model.Inbound{}).Where("id IN ?", inboundIds[i:end]).Find(&batchInbounds).Error + if err != nil { + return err + } + inbounds = append(inbounds, batchInbounds...) + } + for inbound_index := range inbounds { + settings := map[string]interface{}{} + json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings) + clients, ok := settings["clients"].([]interface{}) + inbEmails := inboundExpiryTimeMap[inbounds[inbound_index].Id] + if ok { + var newClients []interface{} + for client_index := range clients { + c := clients[client_index].(map[string]interface{}) + for index := range inbEmails { + if c["email"] == inbEmails[index].Email { + c["expiryTime"] = inbEmails[index].NewExpiryTime + break + } + } + newClients = append(newClients, interface{}(c)) + } + settings["clients"] = newClients + modifiedSettings, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + + inbounds[inbound_index].Settings = string(modifiedSettings) } } - if len(inboundIds) > 0 { - var inbounds []*model.Inbound - err := tx.Model(model.Inbound{}).Where("id IN (?)", inboundIds).Find(&inbounds).Error - if err != nil { - return nil, err + for i := 0; i < len(inbounds); i += safeBatchSize { + end := i + safeBatchSize + if end > len(inbounds) { + end = len(inbounds) } - for inbound_index := range inbounds { - settings := map[string]interface{}{} - json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings) - clients, ok := settings["clients"].([]interface{}) - if ok { - var newClients []interface{} - for client_index := range clients { - c := clients[client_index].(map[string]interface{}) - for traffic_index := range dbClientTraffics { - if dbClientTraffics[traffic_index].ExpiryTime < 0 && c["email"] == dbClientTraffics[traffic_index].Email { - oldExpiryTime := c["expiryTime"].(float64) - newExpiryTime := (time.Now().Unix() * 1000) - int64(oldExpiryTime) - c["expiryTime"] = newExpiryTime - dbClientTraffics[traffic_index].ExpiryTime = newExpiryTime - break - } - } - newClients = append(newClients, interface{}(c)) - } - settings["clients"] = newClients - modifiedSettings, err := json.MarshalIndent(settings, "", " ") - if err != nil { - return nil, err - } - inbounds[inbound_index].Settings = string(modifiedSettings) - } - } - err = tx.Save(inbounds).Error + err := tx.Save(inbounds[i:end]).Error if err != nil { logger.Warning("AddClientTraffic update inbounds ", err) - logger.Error(inbounds) + break } } - return dbClientTraffics, nil + return nil } func (s *InboundService) autoRenewClients(tx *gorm.DB) (bool, int64, error) { diff --git a/web/service/inbound_add_client_traffic_test.go b/web/service/inbound_add_client_traffic_test.go new file mode 100644 index 00000000..dbfc7000 --- /dev/null +++ b/web/service/inbound_add_client_traffic_test.go @@ -0,0 +1,74 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/alireza0/x-ui/xray" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestAddClientTrafficHandlesLargeBatch(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("open db: %v", err) + } + + if err := db.AutoMigrate(&xray.ClientTraffic{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + + const clientCount = 4000 + + seed := make([]*xray.ClientTraffic, 0, clientCount) + updates := make([]*xray.ClientTraffic, 0, clientCount) + for i := 0; i < clientCount; i++ { + email := fmt.Sprintf("user-%05d@example.com", i) + seed = append(seed, &xray.ClientTraffic{ + InboundId: 1, + Enable: true, + Email: email, + }) + updates = append(updates, &xray.ClientTraffic{ + Email: email, + Up: 1, + Down: 2, + }) + } + + if err := db.CreateInBatches(seed, 100).Error; err != nil { + t.Fatalf("seed traffic: %v", err) + } + + tx := db.Begin() + if tx.Error != nil { + t.Fatalf("begin tx: %v", tx.Error) + } + + s := &InboundService{} + if err := s.addClientTraffic(tx, updates); err != nil { + tx.Rollback() + t.Fatalf("add client traffic: %v", err) + } + if err := tx.Commit().Error; err != nil { + t.Fatalf("commit tx: %v", err) + } + + var totalUp int64 + if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(up), 0)").Scan(&totalUp).Error; err != nil { + t.Fatalf("sum up: %v", err) + } + + var totalDown int64 + if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(down), 0)").Scan(&totalDown).Error; err != nil { + t.Fatalf("sum down: %v", err) + } + + if totalUp != clientCount { + t.Fatalf("unexpected total up: got %d want %d", totalUp, clientCount) + } + if totalDown != 2*clientCount { + t.Fatalf("unexpected total down: got %d want %d", totalDown, 2*clientCount) + } +}