diff --git a/web/service/inbound.go b/web/service/inbound.go index b00473e2..02311cdb 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -19,6 +19,13 @@ type InboundService struct { xrayApi xray.XrayAPI } +const ( + // Keep query variables below SQLite's classic 999 limit. + safeSQLVariablesPerQuery = 900 + // Save in small chunks so row-column placeholders stay under SQL var limits. + safeSaveBatchSize = 50 +) + func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) { db := database.GetDB() var inbounds []*model.Inbound @@ -805,9 +812,18 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr emails = append(emails, traffic.Email) } dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) - err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error - if err != nil { - return err + for start := 0; start < len(emails); start += safeSQLVariablesPerQuery { + end := start + safeSQLVariablesPerQuery + if end > len(emails) { + end = len(emails) + } + + batchClientTraffics := make([]*xray.ClientTraffic, 0, end-start) + err = tx.Model(xray.ClientTraffic{}).Where("email IN ?", emails[start:end]).Find(&batchClientTraffics).Error + if err != nil { + return err + } + dbClientTraffics = append(dbClientTraffics, batchClientTraffics...) } // Avoid empty slice error @@ -836,11 +852,21 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr } // Set onlineUsers - p.SetOnlineClients(onlineClients) + if p != nil { + p.SetOnlineClients(onlineClients) + } - err = tx.Save(dbClientTraffics).Error - if err != nil { - logger.Warning("AddClientTraffic update data ", err) + for start := 0; start < len(dbClientTraffics); start += safeSaveBatchSize { + end := start + safeSaveBatchSize + if end > len(dbClientTraffics) { + end = len(dbClientTraffics) + } + + err = tx.Save(dbClientTraffics[start:end]).Error + if err != nil { + logger.Warning("AddClientTraffic update data ", err) + return nil + } } return nil @@ -848,17 +874,31 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.ClientTraffic) ([]*xray.ClientTraffic, error) { inboundIds := make([]int, 0, len(dbClientTraffics)) + seenInboundIds := make(map[int]struct{}, len(dbClientTraffics)) for _, dbClientTraffic := range dbClientTraffics { if dbClientTraffic.ExpiryTime < 0 { + if _, seen := seenInboundIds[dbClientTraffic.InboundId]; seen { + continue + } inboundIds = append(inboundIds, dbClientTraffic.InboundId) + seenInboundIds[dbClientTraffic.InboundId] = struct{}{} } } if len(inboundIds) > 0 { - var inbounds []*model.Inbound - err := tx.Model(model.Inbound{}).Where("id IN (?)", inboundIds).Find(&inbounds).Error - if err != nil { - return nil, err + inbounds := make([]*model.Inbound, 0, len(inboundIds)) + for start := 0; start < len(inboundIds); start += safeSQLVariablesPerQuery { + end := start + safeSQLVariablesPerQuery + if end > len(inboundIds) { + end = len(inboundIds) + } + + batchInbounds := make([]*model.Inbound, 0, end-start) + err := tx.Model(model.Inbound{}).Where("id IN ?", inboundIds[start:end]).Find(&batchInbounds).Error + if err != nil { + return nil, err + } + inbounds = append(inbounds, batchInbounds...) } for inbound_index := range inbounds { settings := map[string]interface{}{} @@ -888,10 +928,19 @@ func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.Cl inbounds[inbound_index].Settings = string(modifiedSettings) } } - err = tx.Save(inbounds).Error - if err != nil { - logger.Warning("AddClientTraffic update inbounds ", err) - logger.Error(inbounds) + + for start := 0; start < len(inbounds); start += safeSaveBatchSize { + end := start + safeSaveBatchSize + if end > len(inbounds) { + end = len(inbounds) + } + + err := tx.Save(inbounds[start:end]).Error + if err != nil { + logger.Warning("AddClientTraffic update inbounds ", err) + logger.Error(inbounds[start:end]) + break + } } } diff --git a/web/service/inbound_add_client_traffic_test.go b/web/service/inbound_add_client_traffic_test.go new file mode 100644 index 00000000..dbfc7000 --- /dev/null +++ b/web/service/inbound_add_client_traffic_test.go @@ -0,0 +1,74 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/alireza0/x-ui/xray" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestAddClientTrafficHandlesLargeBatch(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("open db: %v", err) + } + + if err := db.AutoMigrate(&xray.ClientTraffic{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + + const clientCount = 4000 + + seed := make([]*xray.ClientTraffic, 0, clientCount) + updates := make([]*xray.ClientTraffic, 0, clientCount) + for i := 0; i < clientCount; i++ { + email := fmt.Sprintf("user-%05d@example.com", i) + seed = append(seed, &xray.ClientTraffic{ + InboundId: 1, + Enable: true, + Email: email, + }) + updates = append(updates, &xray.ClientTraffic{ + Email: email, + Up: 1, + Down: 2, + }) + } + + if err := db.CreateInBatches(seed, 100).Error; err != nil { + t.Fatalf("seed traffic: %v", err) + } + + tx := db.Begin() + if tx.Error != nil { + t.Fatalf("begin tx: %v", tx.Error) + } + + s := &InboundService{} + if err := s.addClientTraffic(tx, updates); err != nil { + tx.Rollback() + t.Fatalf("add client traffic: %v", err) + } + if err := tx.Commit().Error; err != nil { + t.Fatalf("commit tx: %v", err) + } + + var totalUp int64 + if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(up), 0)").Scan(&totalUp).Error; err != nil { + t.Fatalf("sum up: %v", err) + } + + var totalDown int64 + if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(down), 0)").Scan(&totalDown).Error; err != nil { + t.Fatalf("sum down: %v", err) + } + + if totalUp != clientCount { + t.Fatalf("unexpected total up: got %d want %d", totalUp, clientCount) + } + if totalDown != 2*clientCount { + t.Fatalf("unexpected total down: got %d want %d", totalDown, 2*clientCount) + } +}