go fmt
This commit is contained in:
@@ -272,7 +272,6 @@ subscribeLoop:
|
||||
TsNs: logEntry.TsNs,
|
||||
}
|
||||
|
||||
|
||||
if err := stream.Send(&mq_pb.SubscribeMessageResponse{Message: &mq_pb.SubscribeMessageResponse_Data{
|
||||
Data: dataMsg,
|
||||
}}); err != nil {
|
||||
|
||||
@@ -103,15 +103,15 @@ func TestIncrementalCooperativeAssignmentStrategy_RebalanceWithRevocation(t *tes
|
||||
t.Errorf("Expected member-2 to have 0 partitions during revocation, got %d", len(member2Assignments))
|
||||
}
|
||||
|
||||
t.Logf("Revocation phase - Member-1: %d partitions, Member-2: %d partitions",
|
||||
t.Logf("Revocation phase - Member-1: %d partitions, Member-2: %d partitions",
|
||||
len(member1Assignments), len(member2Assignments))
|
||||
|
||||
// Simulate time passing and second call (should move to assignment phase)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
|
||||
// Force move to assignment phase by setting timeout to 0
|
||||
state.RevocationTimeout = 0
|
||||
|
||||
|
||||
assignments2 := strategy.Assign(members, topicPartitions)
|
||||
|
||||
// Should complete rebalance
|
||||
@@ -136,7 +136,7 @@ func TestIncrementalCooperativeAssignmentStrategy_RebalanceWithRevocation(t *tes
|
||||
t.Errorf("Expected 4 total partitions after rebalance, got %d", totalFinalPartitions)
|
||||
}
|
||||
|
||||
t.Logf("Final assignment - Member-1: %d partitions, Member-2: %d partitions",
|
||||
t.Logf("Final assignment - Member-1: %d partitions, Member-2: %d partitions",
|
||||
len(member1FinalAssignments), len(member2FinalAssignments))
|
||||
}
|
||||
|
||||
@@ -239,7 +239,7 @@ func TestIncrementalCooperativeAssignmentStrategy_MultipleTopics(t *testing.T) {
|
||||
t.Errorf("Expected partition %s to be assigned", expected)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Debug: Print all assigned partitions
|
||||
t.Logf("All assigned partitions: %v", allAssignedPartitions)
|
||||
}
|
||||
@@ -390,7 +390,7 @@ func TestIncrementalCooperativeAssignmentStrategy_StateTransitions(t *testing.T)
|
||||
// Force timeout to move to assignment phase
|
||||
state.RevocationTimeout = 0
|
||||
strategy.Assign(members, topicPartitions)
|
||||
|
||||
|
||||
// Should complete and return to None
|
||||
state = strategy.GetRebalanceState()
|
||||
if state.Phase != RebalancePhaseNone {
|
||||
|
||||
@@ -24,12 +24,12 @@ func (rtm *RebalanceTimeoutManager) CheckRebalanceTimeouts() {
|
||||
|
||||
for _, group := range rtm.coordinator.groups {
|
||||
group.Mu.Lock()
|
||||
|
||||
|
||||
// Only check timeouts for groups in rebalancing states
|
||||
if group.State == GroupStatePreparingRebalance || group.State == GroupStateCompletingRebalance {
|
||||
rtm.checkGroupRebalanceTimeout(group, now)
|
||||
}
|
||||
|
||||
|
||||
group.Mu.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func (rtm *RebalanceTimeoutManager) CheckRebalanceTimeouts() {
|
||||
// checkGroupRebalanceTimeout checks and handles rebalance timeout for a specific group
|
||||
func (rtm *RebalanceTimeoutManager) checkGroupRebalanceTimeout(group *ConsumerGroup, now time.Time) {
|
||||
expiredMembers := make([]string, 0)
|
||||
|
||||
|
||||
for memberID, member := range group.Members {
|
||||
// Check if member has exceeded its rebalance timeout
|
||||
rebalanceTimeout := time.Duration(member.RebalanceTimeout) * time.Millisecond
|
||||
@@ -45,21 +45,21 @@ func (rtm *RebalanceTimeoutManager) checkGroupRebalanceTimeout(group *ConsumerGr
|
||||
// Use default rebalance timeout if not specified
|
||||
rebalanceTimeout = time.Duration(rtm.coordinator.rebalanceTimeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
|
||||
// For members in pending state during rebalance, check against join time
|
||||
if member.State == MemberStatePending {
|
||||
if now.Sub(member.JoinedAt) > rebalanceTimeout {
|
||||
expiredMembers = append(expiredMembers, memberID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Also check session timeout as a fallback
|
||||
sessionTimeout := time.Duration(member.SessionTimeout) * time.Millisecond
|
||||
if now.Sub(member.LastHeartbeat) > sessionTimeout {
|
||||
expiredMembers = append(expiredMembers, memberID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remove expired members and trigger rebalance if necessary
|
||||
if len(expiredMembers) > 0 {
|
||||
rtm.evictExpiredMembers(group, expiredMembers)
|
||||
@@ -70,13 +70,13 @@ func (rtm *RebalanceTimeoutManager) checkGroupRebalanceTimeout(group *ConsumerGr
|
||||
func (rtm *RebalanceTimeoutManager) evictExpiredMembers(group *ConsumerGroup, expiredMembers []string) {
|
||||
for _, memberID := range expiredMembers {
|
||||
delete(group.Members, memberID)
|
||||
|
||||
|
||||
// If the leader was evicted, clear leader
|
||||
if group.Leader == memberID {
|
||||
group.Leader = ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Update group state based on remaining members
|
||||
if len(group.Members) == 0 {
|
||||
group.State = GroupStateEmpty
|
||||
@@ -92,18 +92,18 @@ func (rtm *RebalanceTimeoutManager) evictExpiredMembers(group *ConsumerGroup, ex
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Reset to preparing rebalance to restart the process
|
||||
group.State = GroupStatePreparingRebalance
|
||||
group.Generation++
|
||||
|
||||
|
||||
// Mark remaining members as pending
|
||||
for _, member := range group.Members {
|
||||
member.State = MemberStatePending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
group.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (rtm *RebalanceTimeoutManager) IsRebalanceStuck(group *ConsumerGroup, maxRe
|
||||
if group.State != GroupStatePreparingRebalance && group.State != GroupStateCompletingRebalance {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
return time.Since(group.LastActivity) > maxRebalanceDuration
|
||||
}
|
||||
|
||||
@@ -120,14 +120,14 @@ func (rtm *RebalanceTimeoutManager) IsRebalanceStuck(group *ConsumerGroup, maxRe
|
||||
func (rtm *RebalanceTimeoutManager) ForceCompleteRebalance(group *ConsumerGroup) {
|
||||
group.Mu.Lock()
|
||||
defer group.Mu.Unlock()
|
||||
|
||||
|
||||
// If stuck in preparing rebalance, move to completing
|
||||
if group.State == GroupStatePreparingRebalance {
|
||||
group.State = GroupStateCompletingRebalance
|
||||
group.LastActivity = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// If stuck in completing rebalance, force to stable
|
||||
if group.State == GroupStateCompletingRebalance {
|
||||
group.State = GroupStateStable
|
||||
@@ -145,21 +145,21 @@ func (rtm *RebalanceTimeoutManager) GetRebalanceStatus(groupID string) *Rebalanc
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
defer group.Mu.RUnlock()
|
||||
|
||||
|
||||
status := &RebalanceStatus{
|
||||
GroupID: groupID,
|
||||
State: group.State,
|
||||
Generation: group.Generation,
|
||||
MemberCount: len(group.Members),
|
||||
Leader: group.Leader,
|
||||
LastActivity: group.LastActivity,
|
||||
IsRebalancing: group.State == GroupStatePreparingRebalance || group.State == GroupStateCompletingRebalance,
|
||||
GroupID: groupID,
|
||||
State: group.State,
|
||||
Generation: group.Generation,
|
||||
MemberCount: len(group.Members),
|
||||
Leader: group.Leader,
|
||||
LastActivity: group.LastActivity,
|
||||
IsRebalancing: group.State == GroupStatePreparingRebalance || group.State == GroupStateCompletingRebalance,
|
||||
RebalanceDuration: time.Since(group.LastActivity),
|
||||
}
|
||||
|
||||
|
||||
// Calculate member timeout status
|
||||
now := time.Now()
|
||||
for memberID, member := range group.Members {
|
||||
@@ -171,48 +171,48 @@ func (rtm *RebalanceTimeoutManager) GetRebalanceStatus(groupID string) *Rebalanc
|
||||
SessionTimeout: time.Duration(member.SessionTimeout) * time.Millisecond,
|
||||
RebalanceTimeout: time.Duration(member.RebalanceTimeout) * time.Millisecond,
|
||||
}
|
||||
|
||||
|
||||
// Calculate time until session timeout
|
||||
sessionTimeRemaining := memberStatus.SessionTimeout - now.Sub(member.LastHeartbeat)
|
||||
if sessionTimeRemaining < 0 {
|
||||
sessionTimeRemaining = 0
|
||||
}
|
||||
memberStatus.SessionTimeRemaining = sessionTimeRemaining
|
||||
|
||||
|
||||
// Calculate time until rebalance timeout
|
||||
rebalanceTimeRemaining := memberStatus.RebalanceTimeout - now.Sub(member.JoinedAt)
|
||||
if rebalanceTimeRemaining < 0 {
|
||||
rebalanceTimeRemaining = 0
|
||||
}
|
||||
memberStatus.RebalanceTimeRemaining = rebalanceTimeRemaining
|
||||
|
||||
|
||||
status.Members = append(status.Members, memberStatus)
|
||||
}
|
||||
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// RebalanceStatus represents the current status of a group's rebalance
|
||||
type RebalanceStatus struct {
|
||||
GroupID string `json:"group_id"`
|
||||
State GroupState `json:"state"`
|
||||
Generation int32 `json:"generation"`
|
||||
MemberCount int `json:"member_count"`
|
||||
Leader string `json:"leader"`
|
||||
LastActivity time.Time `json:"last_activity"`
|
||||
IsRebalancing bool `json:"is_rebalancing"`
|
||||
RebalanceDuration time.Duration `json:"rebalance_duration"`
|
||||
Members []MemberTimeoutStatus `json:"members"`
|
||||
GroupID string `json:"group_id"`
|
||||
State GroupState `json:"state"`
|
||||
Generation int32 `json:"generation"`
|
||||
MemberCount int `json:"member_count"`
|
||||
Leader string `json:"leader"`
|
||||
LastActivity time.Time `json:"last_activity"`
|
||||
IsRebalancing bool `json:"is_rebalancing"`
|
||||
RebalanceDuration time.Duration `json:"rebalance_duration"`
|
||||
Members []MemberTimeoutStatus `json:"members"`
|
||||
}
|
||||
|
||||
// MemberTimeoutStatus represents timeout status for a group member
|
||||
type MemberTimeoutStatus struct {
|
||||
MemberID string `json:"member_id"`
|
||||
State MemberState `json:"state"`
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
SessionTimeout time.Duration `json:"session_timeout"`
|
||||
RebalanceTimeout time.Duration `json:"rebalance_timeout"`
|
||||
SessionTimeRemaining time.Duration `json:"session_time_remaining"`
|
||||
RebalanceTimeRemaining time.Duration `json:"rebalance_time_remaining"`
|
||||
MemberID string `json:"member_id"`
|
||||
State MemberState `json:"state"`
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
SessionTimeout time.Duration `json:"session_timeout"`
|
||||
RebalanceTimeout time.Duration `json:"rebalance_timeout"`
|
||||
SessionTimeRemaining time.Duration `json:"session_time_remaining"`
|
||||
RebalanceTimeRemaining time.Duration `json:"rebalance_time_remaining"`
|
||||
}
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
func TestRebalanceTimeoutManager_CheckRebalanceTimeouts(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Create a group with a member that has a short rebalance timeout
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
|
||||
|
||||
member := &GroupMember{
|
||||
ID: "member1",
|
||||
ClientID: "client1",
|
||||
@@ -27,15 +27,15 @@ func TestRebalanceTimeoutManager_CheckRebalanceTimeouts(t *testing.T) {
|
||||
}
|
||||
group.Members["member1"] = member
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Check timeouts - member should be evicted
|
||||
rtm.CheckRebalanceTimeouts()
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if len(group.Members) != 0 {
|
||||
t.Errorf("Expected member to be evicted due to rebalance timeout, but %d members remain", len(group.Members))
|
||||
}
|
||||
|
||||
|
||||
if group.State != GroupStateEmpty {
|
||||
t.Errorf("Expected group state to be Empty after member eviction, got %s", group.State.String())
|
||||
}
|
||||
@@ -45,18 +45,18 @@ func TestRebalanceTimeoutManager_CheckRebalanceTimeouts(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_SessionTimeoutFallback(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Create a group with a member that has exceeded session timeout
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
|
||||
|
||||
member := &GroupMember{
|
||||
ID: "member1",
|
||||
ClientID: "client1",
|
||||
SessionTimeout: 1000, // 1 second
|
||||
SessionTimeout: 1000, // 1 second
|
||||
RebalanceTimeout: 30000, // 30 seconds
|
||||
State: MemberStatePending,
|
||||
LastHeartbeat: time.Now().Add(-2 * time.Second), // Last heartbeat 2 seconds ago
|
||||
@@ -64,10 +64,10 @@ func TestRebalanceTimeoutManager_SessionTimeoutFallback(t *testing.T) {
|
||||
}
|
||||
group.Members["member1"] = member
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Check timeouts - member should be evicted due to session timeout
|
||||
rtm.CheckRebalanceTimeouts()
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if len(group.Members) != 0 {
|
||||
t.Errorf("Expected member to be evicted due to session timeout, but %d members remain", len(group.Members))
|
||||
@@ -78,15 +78,15 @@ func TestRebalanceTimeoutManager_SessionTimeoutFallback(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_LeaderEviction(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Create a group with leader and another member
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
group.Leader = "member1"
|
||||
|
||||
|
||||
// Leader with expired rebalance timeout
|
||||
leader := &GroupMember{
|
||||
ID: "member1",
|
||||
@@ -98,7 +98,7 @@ func TestRebalanceTimeoutManager_LeaderEviction(t *testing.T) {
|
||||
JoinedAt: time.Now().Add(-2 * time.Second),
|
||||
}
|
||||
group.Members["member1"] = leader
|
||||
|
||||
|
||||
// Another member that's still valid
|
||||
member2 := &GroupMember{
|
||||
ID: "member2",
|
||||
@@ -111,19 +111,19 @@ func TestRebalanceTimeoutManager_LeaderEviction(t *testing.T) {
|
||||
}
|
||||
group.Members["member2"] = member2
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Check timeouts - leader should be evicted, new leader selected
|
||||
rtm.CheckRebalanceTimeouts()
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if len(group.Members) != 1 {
|
||||
t.Errorf("Expected 1 member to remain after leader eviction, got %d", len(group.Members))
|
||||
}
|
||||
|
||||
|
||||
if group.Leader != "member2" {
|
||||
t.Errorf("Expected member2 to become new leader, got %s", group.Leader)
|
||||
}
|
||||
|
||||
|
||||
if group.State != GroupStatePreparingRebalance {
|
||||
t.Errorf("Expected group to restart rebalancing after leader eviction, got %s", group.State.String())
|
||||
}
|
||||
@@ -133,37 +133,37 @@ func TestRebalanceTimeoutManager_LeaderEviction(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_IsRebalanceStuck(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Create a group that's been rebalancing for a while
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
group.LastActivity = time.Now().Add(-15 * time.Minute) // 15 minutes ago
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Check if rebalance is stuck (max 10 minutes)
|
||||
maxDuration := 10 * time.Minute
|
||||
if !rtm.IsRebalanceStuck(group, maxDuration) {
|
||||
t.Error("Expected rebalance to be detected as stuck")
|
||||
}
|
||||
|
||||
|
||||
// Test with a group that's not stuck
|
||||
group.Mu.Lock()
|
||||
group.LastActivity = time.Now().Add(-5 * time.Minute) // 5 minutes ago
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
if rtm.IsRebalanceStuck(group, maxDuration) {
|
||||
t.Error("Expected rebalance to not be detected as stuck")
|
||||
}
|
||||
|
||||
|
||||
// Test with stable group (should not be stuck)
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStateStable
|
||||
group.LastActivity = time.Now().Add(-15 * time.Minute)
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
if rtm.IsRebalanceStuck(group, maxDuration) {
|
||||
t.Error("Stable group should not be detected as stuck")
|
||||
}
|
||||
@@ -172,37 +172,37 @@ func TestRebalanceTimeoutManager_IsRebalanceStuck(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_ForceCompleteRebalance(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Test forcing completion from PreparingRebalance
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
|
||||
|
||||
member := &GroupMember{
|
||||
ID: "member1",
|
||||
State: MemberStatePending,
|
||||
}
|
||||
group.Members["member1"] = member
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
rtm.ForceCompleteRebalance(group)
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if group.State != GroupStateCompletingRebalance {
|
||||
t.Errorf("Expected group state to be CompletingRebalance, got %s", group.State.String())
|
||||
}
|
||||
group.Mu.RUnlock()
|
||||
|
||||
|
||||
// Test forcing completion from CompletingRebalance
|
||||
rtm.ForceCompleteRebalance(group)
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if group.State != GroupStateStable {
|
||||
t.Errorf("Expected group state to be Stable, got %s", group.State.String())
|
||||
}
|
||||
|
||||
|
||||
if member.State != MemberStateStable {
|
||||
t.Errorf("Expected member state to be Stable, got %s", member.State.String())
|
||||
}
|
||||
@@ -212,15 +212,15 @@ func TestRebalanceTimeoutManager_ForceCompleteRebalance(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_GetRebalanceStatus(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Test with non-existent group
|
||||
status := rtm.GetRebalanceStatus("non-existent")
|
||||
if status != nil {
|
||||
t.Error("Expected nil status for non-existent group")
|
||||
}
|
||||
|
||||
|
||||
// Create a group with members
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
@@ -228,7 +228,7 @@ func TestRebalanceTimeoutManager_GetRebalanceStatus(t *testing.T) {
|
||||
group.Generation = 5
|
||||
group.Leader = "member1"
|
||||
group.LastActivity = time.Now().Add(-2 * time.Minute)
|
||||
|
||||
|
||||
member1 := &GroupMember{
|
||||
ID: "member1",
|
||||
State: MemberStatePending,
|
||||
@@ -238,7 +238,7 @@ func TestRebalanceTimeoutManager_GetRebalanceStatus(t *testing.T) {
|
||||
RebalanceTimeout: 300000, // 5 minutes
|
||||
}
|
||||
group.Members["member1"] = member1
|
||||
|
||||
|
||||
member2 := &GroupMember{
|
||||
ID: "member2",
|
||||
State: MemberStatePending,
|
||||
@@ -249,48 +249,48 @@ func TestRebalanceTimeoutManager_GetRebalanceStatus(t *testing.T) {
|
||||
}
|
||||
group.Members["member2"] = member2
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Get status
|
||||
status = rtm.GetRebalanceStatus("test-group")
|
||||
|
||||
|
||||
if status == nil {
|
||||
t.Fatal("Expected non-nil status")
|
||||
}
|
||||
|
||||
|
||||
if status.GroupID != "test-group" {
|
||||
t.Errorf("Expected group ID 'test-group', got %s", status.GroupID)
|
||||
}
|
||||
|
||||
|
||||
if status.State != GroupStatePreparingRebalance {
|
||||
t.Errorf("Expected state PreparingRebalance, got %s", status.State.String())
|
||||
}
|
||||
|
||||
|
||||
if status.Generation != 5 {
|
||||
t.Errorf("Expected generation 5, got %d", status.Generation)
|
||||
}
|
||||
|
||||
|
||||
if status.MemberCount != 2 {
|
||||
t.Errorf("Expected 2 members, got %d", status.MemberCount)
|
||||
}
|
||||
|
||||
|
||||
if status.Leader != "member1" {
|
||||
t.Errorf("Expected leader 'member1', got %s", status.Leader)
|
||||
}
|
||||
|
||||
|
||||
if !status.IsRebalancing {
|
||||
t.Error("Expected IsRebalancing to be true")
|
||||
}
|
||||
|
||||
|
||||
if len(status.Members) != 2 {
|
||||
t.Errorf("Expected 2 member statuses, got %d", len(status.Members))
|
||||
}
|
||||
|
||||
|
||||
// Check member timeout calculations
|
||||
for _, memberStatus := range status.Members {
|
||||
if memberStatus.SessionTimeRemaining < 0 {
|
||||
t.Errorf("Session time remaining should not be negative for member %s", memberStatus.MemberID)
|
||||
}
|
||||
|
||||
|
||||
if memberStatus.RebalanceTimeRemaining < 0 {
|
||||
t.Errorf("Rebalance time remaining should not be negative for member %s", memberStatus.MemberID)
|
||||
}
|
||||
@@ -300,14 +300,14 @@ func TestRebalanceTimeoutManager_GetRebalanceStatus(t *testing.T) {
|
||||
func TestRebalanceTimeoutManager_DefaultRebalanceTimeout(t *testing.T) {
|
||||
coordinator := NewGroupCoordinator()
|
||||
defer coordinator.Close()
|
||||
|
||||
|
||||
rtm := coordinator.rebalanceTimeoutManager
|
||||
|
||||
|
||||
// Create a group with a member that has no rebalance timeout set (0)
|
||||
group := coordinator.GetOrCreateGroup("test-group")
|
||||
group.Mu.Lock()
|
||||
group.State = GroupStatePreparingRebalance
|
||||
|
||||
|
||||
member := &GroupMember{
|
||||
ID: "member1",
|
||||
ClientID: "client1",
|
||||
@@ -319,10 +319,10 @@ func TestRebalanceTimeoutManager_DefaultRebalanceTimeout(t *testing.T) {
|
||||
}
|
||||
group.Members["member1"] = member
|
||||
group.Mu.Unlock()
|
||||
|
||||
|
||||
// Default rebalance timeout is 5 minutes (300000ms), so member should be evicted
|
||||
rtm.CheckRebalanceTimeouts()
|
||||
|
||||
|
||||
group.Mu.RLock()
|
||||
if len(group.Members) != 0 {
|
||||
t.Errorf("Expected member to be evicted using default rebalance timeout, but %d members remain", len(group.Members))
|
||||
|
||||
@@ -142,4 +142,3 @@ func (m *MemoryStorage) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -206,4 +206,3 @@ func TestMemoryStorageOverwrite(t *testing.T) {
|
||||
assert.Equal(t, int64(20), offset)
|
||||
assert.Equal(t, "meta2", metadata)
|
||||
}
|
||||
|
||||
|
||||
@@ -56,4 +56,3 @@ var (
|
||||
ErrInvalidPartition = fmt.Errorf("invalid partition")
|
||||
ErrStorageClosed = fmt.Errorf("storage is closed")
|
||||
)
|
||||
|
||||
|
||||
@@ -121,7 +121,6 @@ func (m *mockSeaweedMQHandler) ProduceRecord(ctx context.Context, topicName stri
|
||||
offset := m.offsets[topicName][partitionID]
|
||||
m.offsets[topicName][partitionID]++
|
||||
|
||||
|
||||
// Store record
|
||||
record := &mockRecord{
|
||||
key: key,
|
||||
|
||||
@@ -9,5 +9,3 @@ package kafka
|
||||
// - offset/: Offset management
|
||||
// - schema/: Schema registry integration
|
||||
// - consumer/: Consumer group coordination
|
||||
|
||||
|
||||
|
||||
@@ -51,5 +51,3 @@ func GetRangeSize() int32 {
|
||||
func GetMaxKafkaPartitions() int32 {
|
||||
return int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ func (h *Handler) handleDescribeCluster(correlationID uint32, apiVersion uint16,
|
||||
// Tagged fields at end of request
|
||||
// (We don't parse them, just skip)
|
||||
|
||||
|
||||
// Build response
|
||||
response := make([]byte, 0, 256)
|
||||
|
||||
@@ -109,6 +108,5 @@ func (h *Handler) handleDescribeCluster(correlationID uint32, apiVersion uint16,
|
||||
// Response-level tagged fields (flexible response)
|
||||
response = append(response, 0x00) // Empty tagged fields
|
||||
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -268,7 +268,6 @@ func parseCompactString(data []byte) ([]byte, int) {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
|
||||
if actualLength == 0 {
|
||||
// Empty string (length was 1)
|
||||
return []byte{}, consumed
|
||||
|
||||
@@ -107,13 +107,13 @@ func (h *Handler) describeGroup(groupID string) DescribeGroupsGroup {
|
||||
}
|
||||
|
||||
return DescribeGroupsGroup{
|
||||
ErrorCode: 0,
|
||||
GroupID: groupID,
|
||||
State: stateStr,
|
||||
ProtocolType: "consumer", // Default protocol type
|
||||
Protocol: group.Protocol,
|
||||
Members: members,
|
||||
AuthorizedOps: []int32{}, // Empty for now
|
||||
ErrorCode: 0,
|
||||
GroupID: groupID,
|
||||
State: stateStr,
|
||||
ProtocolType: "consumer", // Default protocol type
|
||||
Protocol: group.Protocol,
|
||||
Members: members,
|
||||
AuthorizedOps: []int32{}, // Empty for now
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,8 +175,8 @@ func (h *Handler) listAllGroups(statesFilter []string) []ListGroupsGroup {
|
||||
// Request/Response structures
|
||||
|
||||
type DescribeGroupsRequest struct {
|
||||
GroupIDs []string
|
||||
IncludeAuthorizedOps bool
|
||||
GroupIDs []string
|
||||
IncludeAuthorizedOps bool
|
||||
}
|
||||
|
||||
type DescribeGroupsResponse struct {
|
||||
|
||||
@@ -661,7 +661,7 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
|
||||
return
|
||||
}
|
||||
// Removed V(4) logging from hot path - only log errors and important events
|
||||
|
||||
|
||||
// Wrap request processing with panic recovery to prevent deadlocks
|
||||
// If processRequestSync panics, we MUST still send a response to avoid blocking the response writer
|
||||
var response []byte
|
||||
@@ -881,7 +881,6 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
|
||||
return fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Parse at least the basic header to get API key and correlation ID
|
||||
if len(messageBuf) < 8 {
|
||||
return fmt.Errorf("message too short")
|
||||
@@ -890,7 +889,7 @@ func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
|
||||
apiKey := binary.BigEndian.Uint16(messageBuf[0:2])
|
||||
apiVersion := binary.BigEndian.Uint16(messageBuf[2:4])
|
||||
correlationID := binary.BigEndian.Uint32(messageBuf[4:8])
|
||||
|
||||
|
||||
// Validate API version against what we support
|
||||
if err := h.validateAPIVersion(apiKey, apiVersion); err != nil {
|
||||
glog.Errorf("API VERSION VALIDATION FAILED: Key=%d (%s), Version=%d, error=%v", apiKey, getAPIName(APIKey(apiKey)), apiVersion, err)
|
||||
@@ -1050,7 +1049,6 @@ func (h *Handler) processRequestSync(req *kafkaRequest) ([]byte, error) {
|
||||
requestStart := time.Now()
|
||||
apiName := getAPIName(APIKey(req.apiKey))
|
||||
|
||||
|
||||
// Only log high-volume requests at V(2), not V(4)
|
||||
if glog.V(2) {
|
||||
glog.V(2).Infof("[API] %s (key=%d, ver=%d, corr=%d)",
|
||||
@@ -1589,15 +1587,15 @@ func (h *Handler) HandleMetadataV2(correlationID uint32, requestBody []byte) ([]
|
||||
for partitionID := int32(0); partitionID < partitionCount; partitionID++ {
|
||||
binary.Write(&buf, binary.BigEndian, int16(0)) // ErrorCode
|
||||
binary.Write(&buf, binary.BigEndian, partitionID) // PartitionIndex
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
|
||||
// ReplicaNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 replica
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
|
||||
// IsrNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 ISR node
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1716,15 +1714,15 @@ func (h *Handler) HandleMetadataV3V4(correlationID uint32, requestBody []byte) (
|
||||
for partitionID := int32(0); partitionID < partitionCount; partitionID++ {
|
||||
binary.Write(&buf, binary.BigEndian, int16(0)) // ErrorCode
|
||||
binary.Write(&buf, binary.BigEndian, partitionID) // PartitionIndex
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
|
||||
// ReplicaNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 replica
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
|
||||
// IsrNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 ISR node
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1737,7 +1735,7 @@ func (h *Handler) HandleMetadataV3V4(correlationID uint32, requestBody []byte) (
|
||||
}
|
||||
if len(response) > 100 {
|
||||
}
|
||||
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -1828,7 +1826,6 @@ func (h *Handler) handleMetadataV5ToV8(correlationID uint32, requestBody []byte,
|
||||
// NOTE: Correlation ID is handled by writeResponseWithCorrelationID
|
||||
// Do NOT include it in the response body
|
||||
|
||||
|
||||
// ThrottleTimeMs (4 bytes) - v3+ addition
|
||||
binary.Write(&buf, binary.BigEndian, int32(0)) // No throttling
|
||||
|
||||
@@ -1896,7 +1893,7 @@ func (h *Handler) handleMetadataV5ToV8(correlationID uint32, requestBody []byte,
|
||||
for partitionID := int32(0); partitionID < partitionCount; partitionID++ {
|
||||
binary.Write(&buf, binary.BigEndian, int16(0)) // ErrorCode
|
||||
binary.Write(&buf, binary.BigEndian, partitionID) // PartitionIndex
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // LeaderID
|
||||
|
||||
// LeaderEpoch (4 bytes) - v7+ addition
|
||||
if apiVersion >= 7 {
|
||||
@@ -1905,11 +1902,11 @@ func (h *Handler) handleMetadataV5ToV8(correlationID uint32, requestBody []byte,
|
||||
|
||||
// ReplicaNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 replica
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
|
||||
// IsrNodes array (4 bytes length + nodes)
|
||||
binary.Write(&buf, binary.BigEndian, int32(1)) // 1 ISR node
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
binary.Write(&buf, binary.BigEndian, nodeID) // NodeID 1
|
||||
|
||||
// OfflineReplicas array (4 bytes length + nodes) - v5+ addition
|
||||
binary.Write(&buf, binary.BigEndian, int32(0)) // No offline replicas
|
||||
@@ -1930,7 +1927,7 @@ func (h *Handler) handleMetadataV5ToV8(correlationID uint32, requestBody []byte,
|
||||
}
|
||||
if len(response) > 100 {
|
||||
}
|
||||
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -1994,12 +1991,11 @@ func (h *Handler) handleListOffsets(correlationID uint32, apiVersion uint16, req
|
||||
// Parse minimal request to understand what's being asked (header already stripped)
|
||||
offset := 0
|
||||
|
||||
|
||||
maxBytes := len(requestBody)
|
||||
if maxBytes > 64 {
|
||||
maxBytes = 64
|
||||
}
|
||||
|
||||
|
||||
// v1+ has replica_id(4)
|
||||
if apiVersion >= 1 {
|
||||
if len(requestBody) < offset+4 {
|
||||
@@ -3930,12 +3926,11 @@ func (h *Handler) handleInitProducerId(correlationID uint32, apiVersion uint16,
|
||||
// v2+: transactional_id(NULLABLE_STRING) + transaction_timeout_ms(INT32) + producer_id(INT64) + producer_epoch(INT16)
|
||||
// v4+: Uses flexible format with tagged fields
|
||||
|
||||
|
||||
maxBytes := len(requestBody)
|
||||
if maxBytes > 64 {
|
||||
maxBytes = 64
|
||||
}
|
||||
|
||||
|
||||
offset := 0
|
||||
|
||||
// Parse transactional_id (NULLABLE_STRING or COMPACT_NULLABLE_STRING for flexible versions)
|
||||
|
||||
@@ -47,4 +47,3 @@ func (a *offsetStorageAdapter) DeleteGroup(group string) error {
|
||||
func (a *offsetStorageAdapter) Close() error {
|
||||
return a.storage.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -140,4 +140,3 @@ func TestMetadataResponseHasBrokers(t *testing.T) {
|
||||
|
||||
t.Logf("✓ Metadata response correctly has %d broker(s)", parsedCount)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,46 +7,46 @@ import (
|
||||
|
||||
func TestParseConfluentEnvelope(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expectOK bool
|
||||
expectID uint32
|
||||
name string
|
||||
input []byte
|
||||
expectOK bool
|
||||
expectID uint32
|
||||
expectFormat Format
|
||||
}{
|
||||
{
|
||||
name: "valid Avro message",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x10, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, // schema ID 1 + "Hello"
|
||||
expectOK: true,
|
||||
expectID: 1,
|
||||
name: "valid Avro message",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x10, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, // schema ID 1 + "Hello"
|
||||
expectOK: true,
|
||||
expectID: 1,
|
||||
expectFormat: FormatAvro,
|
||||
},
|
||||
{
|
||||
name: "valid message with larger schema ID",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x04, 0xd2, 0x02, 0x66, 0x6f, 0x6f}, // schema ID 1234 + "foo"
|
||||
expectOK: true,
|
||||
expectID: 1234,
|
||||
name: "valid message with larger schema ID",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x04, 0xd2, 0x02, 0x66, 0x6f, 0x6f}, // schema ID 1234 + "foo"
|
||||
expectOK: true,
|
||||
expectID: 1234,
|
||||
expectFormat: FormatAvro,
|
||||
},
|
||||
{
|
||||
name: "too short message",
|
||||
input: []byte{0x00, 0x00, 0x00},
|
||||
expectOK: false,
|
||||
name: "too short message",
|
||||
input: []byte{0x00, 0x00, 0x00},
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "no magic byte",
|
||||
input: []byte{0x01, 0x00, 0x00, 0x00, 0x01, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||
expectOK: false,
|
||||
name: "no magic byte",
|
||||
input: []byte{0x01, 0x00, 0x00, 0x00, 0x01, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
input: []byte{},
|
||||
expectOK: false,
|
||||
name: "empty message",
|
||||
input: []byte{},
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "minimal valid message",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x00, 0x01}, // schema ID 1, empty payload
|
||||
expectOK: true,
|
||||
expectID: 1,
|
||||
name: "minimal valid message",
|
||||
input: []byte{0x00, 0x00, 0x00, 0x00, 0x01}, // schema ID 1, empty payload
|
||||
expectOK: true,
|
||||
expectID: 1,
|
||||
expectFormat: FormatAvro,
|
||||
},
|
||||
}
|
||||
@@ -54,24 +54,24 @@ func TestParseConfluentEnvelope(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
envelope, ok := ParseConfluentEnvelope(tt.input)
|
||||
|
||||
|
||||
if ok != tt.expectOK {
|
||||
t.Errorf("ParseConfluentEnvelope() ok = %v, want %v", ok, tt.expectOK)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if !tt.expectOK {
|
||||
return // No need to check further if we expected failure
|
||||
}
|
||||
|
||||
|
||||
if envelope.SchemaID != tt.expectID {
|
||||
t.Errorf("ParseConfluentEnvelope() schemaID = %v, want %v", envelope.SchemaID, tt.expectID)
|
||||
}
|
||||
|
||||
|
||||
if envelope.Format != tt.expectFormat {
|
||||
t.Errorf("ParseConfluentEnvelope() format = %v, want %v", envelope.Format, tt.expectFormat)
|
||||
}
|
||||
|
||||
|
||||
// Verify payload extraction
|
||||
expectedPayloadLen := len(tt.input) - 5 // 5 bytes for magic + schema ID
|
||||
if len(envelope.Payload) != expectedPayloadLen {
|
||||
@@ -150,11 +150,11 @@ func TestExtractSchemaID(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, ok := ExtractSchemaID(tt.input)
|
||||
|
||||
|
||||
if ok != tt.expectOK {
|
||||
t.Errorf("ExtractSchemaID() ok = %v, want %v", ok, tt.expectOK)
|
||||
}
|
||||
|
||||
|
||||
if id != tt.expectID {
|
||||
t.Errorf("ExtractSchemaID() id = %v, want %v", id, tt.expectID)
|
||||
}
|
||||
@@ -200,12 +200,12 @@ func TestCreateConfluentEnvelope(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CreateConfluentEnvelope(tt.format, tt.schemaID, tt.indexes, tt.payload)
|
||||
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("CreateConfluentEnvelope() length = %v, want %v", len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
for i, b := range result {
|
||||
if b != tt.expected[i] {
|
||||
t.Errorf("CreateConfluentEnvelope() byte[%d] = %v, want %v", i, b, tt.expected[i])
|
||||
@@ -262,7 +262,7 @@ func TestEnvelopeValidate(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.envelope.Validate()
|
||||
|
||||
|
||||
if (err != nil) != tt.expectErr {
|
||||
t.Errorf("Envelope.Validate() error = %v, expectErr %v", err, tt.expectErr)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func TestEnvelopeMetadata(t *testing.T) {
|
||||
func BenchmarkParseConfluentEnvelope(b *testing.B) {
|
||||
// Create a test message
|
||||
testMsg := make([]byte, 1024)
|
||||
testMsg[0] = 0x00 // Magic byte
|
||||
testMsg[0] = 0x00 // Magic byte
|
||||
binary.BigEndian.PutUint32(testMsg[1:5], 123) // Schema ID
|
||||
// Fill rest with dummy data
|
||||
for i := 5; i < len(testMsg); i++ {
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestCreateConfluentEnvelopeWithProtobufIndexes(t *testing.T) {
|
||||
parsed, ok := ParseConfluentEnvelope(envelope)
|
||||
require.True(t, ok, "Should be able to parse envelope")
|
||||
assert.Equal(t, tc.schemaID, parsed.SchemaID)
|
||||
|
||||
|
||||
if tc.format == FormatProtobuf && len(tc.indexes) == 0 {
|
||||
// For Protobuf without indexes, payload should match
|
||||
assert.Equal(t, tc.payload, parsed.Payload, "Payload should match")
|
||||
|
||||
@@ -17,5 +17,3 @@ const (
|
||||
// Source file tracking for parquet deduplication
|
||||
ExtendedAttrSources = "sources" // JSON-encoded list of source log files
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -118,17 +118,17 @@ func (m *MigrationManager) GetCurrentVersion() (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create migrations table: %w", err)
|
||||
}
|
||||
|
||||
|
||||
var version sql.NullInt64
|
||||
err = m.db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if !version.Valid {
|
||||
return 0, nil // No migrations applied yet
|
||||
}
|
||||
|
||||
|
||||
return int(version.Int64), nil
|
||||
}
|
||||
|
||||
@@ -138,29 +138,29 @@ func (m *MigrationManager) ApplyMigrations() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
migrations := GetMigrations()
|
||||
|
||||
|
||||
for _, migration := range migrations {
|
||||
if migration.Version <= currentVersion {
|
||||
continue // Already applied
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("Applying migration %d: %s\n", migration.Version, migration.Description)
|
||||
|
||||
|
||||
// Begin transaction
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// Execute migration SQL
|
||||
_, err = tx.Exec(migration.SQL)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// Record migration as applied
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO schema_migrations (version, description, applied_at) VALUES (?, ?, ?)",
|
||||
@@ -172,16 +172,16 @@ func (m *MigrationManager) ApplyMigrations() error {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("Successfully applied migration %d\n", migration.Version)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) {
|
||||
return nil, fmt.Errorf("failed to query applied migrations: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
var migrations []AppliedMigration
|
||||
for rows.Next() {
|
||||
var migration AppliedMigration
|
||||
@@ -213,7 +213,7 @@ func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) {
|
||||
}
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
@@ -223,17 +223,17 @@ func (m *MigrationManager) ValidateSchema() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
migrations := GetMigrations()
|
||||
if len(migrations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
latestVersion := migrations[len(migrations)-1].Version
|
||||
if currentVersion < latestVersion {
|
||||
return fmt.Errorf("schema is outdated: current version %d, latest version %d", currentVersion, latestVersion)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -253,21 +253,21 @@ func getCurrentTimestamp() int64 {
|
||||
func CreateDatabase(dbPath string) (*sql.DB, error) {
|
||||
// TODO: Support different database types (PostgreSQL, MySQL, etc.)
|
||||
// ASSUMPTION: Using SQLite for now, can be extended for other databases
|
||||
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Configure SQLite for better performance
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency
|
||||
"PRAGMA synchronous=NORMAL", // Balance between safety and performance
|
||||
"PRAGMA cache_size=10000", // Increase cache size
|
||||
"PRAGMA foreign_keys=ON", // Enable foreign key constraints
|
||||
"PRAGMA temp_store=MEMORY", // Store temporary tables in memory
|
||||
"PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency
|
||||
"PRAGMA synchronous=NORMAL", // Balance between safety and performance
|
||||
"PRAGMA cache_size=10000", // Increase cache size
|
||||
"PRAGMA foreign_keys=ON", // Enable foreign key constraints
|
||||
"PRAGMA temp_store=MEMORY", // Store temporary tables in memory
|
||||
}
|
||||
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
_, err := db.Exec(pragma)
|
||||
if err != nil {
|
||||
@@ -275,7 +275,7 @@ func CreateDatabase(dbPath string) (*sql.DB, error) {
|
||||
return nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Apply migrations
|
||||
migrationManager := NewMigrationManager(db)
|
||||
err = migrationManager.ApplyMigrations()
|
||||
@@ -283,7 +283,7 @@ func CreateDatabase(dbPath string) (*sql.DB, error) {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -248,11 +248,11 @@ func TestValidateKeyColumns(t *testing.T) {
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(str, substr string) bool {
|
||||
return len(str) >= len(substr) &&
|
||||
(len(substr) == 0 || str[len(str)-len(substr):] == substr ||
|
||||
str[:len(substr)] == substr ||
|
||||
len(str) > len(substr) && (str[len(str)-len(substr)-1:len(str)-len(substr)] == " " || str[len(str)-len(substr)-1] == ' ') && str[len(str)-len(substr):] == substr ||
|
||||
findInString(str, substr))
|
||||
return len(str) >= len(substr) &&
|
||||
(len(substr) == 0 || str[len(str)-len(substr):] == substr ||
|
||||
str[:len(substr)] == substr ||
|
||||
len(str) > len(substr) && (str[len(str)-len(substr)-1:len(str)-len(substr)] == " " || str[len(str)-len(substr)-1] == ' ') && str[len(str)-len(substr):] == substr ||
|
||||
findInString(str, substr))
|
||||
}
|
||||
|
||||
func findInString(str, substr string) bool {
|
||||
|
||||
Reference in New Issue
Block a user