mirror of
https://github.com/alireza0/x-ui.git
synced 2026-03-13 21:13:09 +00:00
Merge pull request #1622 from shayan775/fix-AddClientTraffic
The too many SQL variables path is now batched.
This commit is contained in:
@@ -19,6 +19,10 @@ type InboundService struct {
|
||||
xrayApi xray.XrayAPI
|
||||
}
|
||||
|
||||
const (
|
||||
safeBatchSize = 500
|
||||
)
|
||||
|
||||
func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
|
||||
db := database.GetDB()
|
||||
var inbounds []*model.Inbound
|
||||
@@ -789,6 +793,11 @@ func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic
|
||||
return nil
|
||||
}
|
||||
|
||||
type newExpiryTime struct {
|
||||
Email string
|
||||
NewExpiryTime int64
|
||||
}
|
||||
|
||||
func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) {
|
||||
if len(traffics) == 0 {
|
||||
// Empty onlineUsers
|
||||
@@ -805,9 +814,18 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr
|
||||
emails = append(emails, traffic.Email)
|
||||
}
|
||||
dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics))
|
||||
err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
|
||||
if err != nil {
|
||||
return err
|
||||
for i := 0; i < len(emails); i += safeBatchSize {
|
||||
end := i + safeBatchSize
|
||||
if end > len(emails) {
|
||||
end = len(emails)
|
||||
}
|
||||
|
||||
batchClientTraffics := make([]*xray.ClientTraffic, 0, end-i)
|
||||
err = tx.Model(xray.ClientTraffic{}).Where("email IN ?", emails[i:end]).Find(&batchClientTraffics).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dbClientTraffics = append(dbClientTraffics, batchClientTraffics...)
|
||||
}
|
||||
|
||||
// Avoid empty slice error
|
||||
@@ -815,7 +833,20 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr
|
||||
return nil
|
||||
}
|
||||
|
||||
dbClientTraffics, err = s.adjustTraffics(tx, dbClientTraffics)
|
||||
inboundExpiryTimeMap := make(map[int][]newExpiryTime, 0)
|
||||
for index, t := range dbClientTraffics {
|
||||
if t.ExpiryTime < 0 {
|
||||
newClientExpiryTime := (time.Now().Unix() * 1000) - int64(t.ExpiryTime)
|
||||
newExpiryTime := newExpiryTime{
|
||||
Email: t.Email,
|
||||
NewExpiryTime: newClientExpiryTime,
|
||||
}
|
||||
inboundExpiryTimeMap[t.InboundId] = append(inboundExpiryTimeMap[t.InboundId], newExpiryTime)
|
||||
dbClientTraffics[index].ExpiryTime = newClientExpiryTime
|
||||
}
|
||||
}
|
||||
|
||||
err = s.adjustTraffics(tx, inboundExpiryTimeMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -836,66 +867,90 @@ func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTr
|
||||
}
|
||||
|
||||
// Set onlineUsers
|
||||
p.SetOnlineClients(onlineClients)
|
||||
if p != nil {
|
||||
p.SetOnlineClients(onlineClients)
|
||||
}
|
||||
|
||||
err = tx.Save(dbClientTraffics).Error
|
||||
if err != nil {
|
||||
logger.Warning("AddClientTraffic update data ", err)
|
||||
for i := 0; i < len(dbClientTraffics); i += safeBatchSize {
|
||||
end := i + safeBatchSize
|
||||
if end > len(dbClientTraffics) {
|
||||
end = len(dbClientTraffics)
|
||||
}
|
||||
|
||||
err = tx.Save(dbClientTraffics[i:end]).Error
|
||||
if err != nil {
|
||||
logger.Warning("AddClientTraffic update data ", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.ClientTraffic) ([]*xray.ClientTraffic, error) {
|
||||
inboundIds := make([]int, 0, len(dbClientTraffics))
|
||||
for _, dbClientTraffic := range dbClientTraffics {
|
||||
if dbClientTraffic.ExpiryTime < 0 {
|
||||
inboundIds = append(inboundIds, dbClientTraffic.InboundId)
|
||||
func (s *InboundService) adjustTraffics(tx *gorm.DB, inboundExpiryTimeMap map[int][]newExpiryTime) error {
|
||||
if len(inboundExpiryTimeMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
inboundIds := make([]int, 0)
|
||||
for inId := range inboundExpiryTimeMap {
|
||||
inboundIds = append(inboundIds, inId)
|
||||
}
|
||||
inbounds := make([]*model.Inbound, 0, len(inboundIds))
|
||||
for i := 0; i < len(inboundIds); i += safeBatchSize {
|
||||
end := i + safeBatchSize
|
||||
if end > len(inboundIds) {
|
||||
end = len(inboundIds)
|
||||
}
|
||||
|
||||
batchInbounds := make([]*model.Inbound, 0, end-i)
|
||||
err := tx.Model(model.Inbound{}).Where("id IN ?", inboundIds[i:end]).Find(&batchInbounds).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
inbounds = append(inbounds, batchInbounds...)
|
||||
}
|
||||
for inbound_index := range inbounds {
|
||||
settings := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
|
||||
clients, ok := settings["clients"].([]interface{})
|
||||
inbEmails := inboundExpiryTimeMap[inbounds[inbound_index].Id]
|
||||
if ok {
|
||||
var newClients []interface{}
|
||||
for client_index := range clients {
|
||||
c := clients[client_index].(map[string]interface{})
|
||||
for index := range inbEmails {
|
||||
if c["email"] == inbEmails[index].Email {
|
||||
c["expiryTime"] = inbEmails[index].NewExpiryTime
|
||||
break
|
||||
}
|
||||
}
|
||||
newClients = append(newClients, interface{}(c))
|
||||
}
|
||||
settings["clients"] = newClients
|
||||
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inbounds[inbound_index].Settings = string(modifiedSettings)
|
||||
}
|
||||
}
|
||||
|
||||
if len(inboundIds) > 0 {
|
||||
var inbounds []*model.Inbound
|
||||
err := tx.Model(model.Inbound{}).Where("id IN (?)", inboundIds).Find(&inbounds).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for i := 0; i < len(inbounds); i += safeBatchSize {
|
||||
end := i + safeBatchSize
|
||||
if end > len(inbounds) {
|
||||
end = len(inbounds)
|
||||
}
|
||||
for inbound_index := range inbounds {
|
||||
settings := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
|
||||
clients, ok := settings["clients"].([]interface{})
|
||||
if ok {
|
||||
var newClients []interface{}
|
||||
for client_index := range clients {
|
||||
c := clients[client_index].(map[string]interface{})
|
||||
for traffic_index := range dbClientTraffics {
|
||||
if dbClientTraffics[traffic_index].ExpiryTime < 0 && c["email"] == dbClientTraffics[traffic_index].Email {
|
||||
oldExpiryTime := c["expiryTime"].(float64)
|
||||
newExpiryTime := (time.Now().Unix() * 1000) - int64(oldExpiryTime)
|
||||
c["expiryTime"] = newExpiryTime
|
||||
dbClientTraffics[traffic_index].ExpiryTime = newExpiryTime
|
||||
break
|
||||
}
|
||||
}
|
||||
newClients = append(newClients, interface{}(c))
|
||||
}
|
||||
settings["clients"] = newClients
|
||||
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inbounds[inbound_index].Settings = string(modifiedSettings)
|
||||
}
|
||||
}
|
||||
err = tx.Save(inbounds).Error
|
||||
err := tx.Save(inbounds[i:end]).Error
|
||||
if err != nil {
|
||||
logger.Warning("AddClientTraffic update inbounds ", err)
|
||||
logger.Error(inbounds)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return dbClientTraffics, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundService) autoRenewClients(tx *gorm.DB) (bool, int64, error) {
|
||||
|
||||
74
web/service/inbound_add_client_traffic_test.go
Normal file
74
web/service/inbound_add_client_traffic_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/alireza0/x-ui/xray"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestAddClientTrafficHandlesLargeBatch(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&xray.ClientTraffic{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
|
||||
const clientCount = 4000
|
||||
|
||||
seed := make([]*xray.ClientTraffic, 0, clientCount)
|
||||
updates := make([]*xray.ClientTraffic, 0, clientCount)
|
||||
for i := 0; i < clientCount; i++ {
|
||||
email := fmt.Sprintf("user-%05d@example.com", i)
|
||||
seed = append(seed, &xray.ClientTraffic{
|
||||
InboundId: 1,
|
||||
Enable: true,
|
||||
Email: email,
|
||||
})
|
||||
updates = append(updates, &xray.ClientTraffic{
|
||||
Email: email,
|
||||
Up: 1,
|
||||
Down: 2,
|
||||
})
|
||||
}
|
||||
|
||||
if err := db.CreateInBatches(seed, 100).Error; err != nil {
|
||||
t.Fatalf("seed traffic: %v", err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatalf("begin tx: %v", tx.Error)
|
||||
}
|
||||
|
||||
s := &InboundService{}
|
||||
if err := s.addClientTraffic(tx, updates); err != nil {
|
||||
tx.Rollback()
|
||||
t.Fatalf("add client traffic: %v", err)
|
||||
}
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Fatalf("commit tx: %v", err)
|
||||
}
|
||||
|
||||
var totalUp int64
|
||||
if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(up), 0)").Scan(&totalUp).Error; err != nil {
|
||||
t.Fatalf("sum up: %v", err)
|
||||
}
|
||||
|
||||
var totalDown int64
|
||||
if err := db.Model(&xray.ClientTraffic{}).Select("COALESCE(SUM(down), 0)").Scan(&totalDown).Error; err != nil {
|
||||
t.Fatalf("sum down: %v", err)
|
||||
}
|
||||
|
||||
if totalUp != clientCount {
|
||||
t.Fatalf("unexpected total up: got %d want %d", totalUp, clientCount)
|
||||
}
|
||||
if totalDown != 2*clientCount {
|
||||
t.Fatalf("unexpected total down: got %d want %d", totalDown, 2*clientCount)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user