diff --git a/internal/server/clients.go b/internal/server/clients.go index aa830ad..058c45d 100644 --- a/internal/server/clients.go +++ b/internal/server/clients.go @@ -9,14 +9,14 @@ import ( // ClientRegistry manages connected WebSocket clients thread-safely type ClientRegistry struct { - clients map[*websocket.Conn]bool + clients map[*websocket.Conn]*sync.Mutex mu sync.RWMutex } // NewClientRegistry creates a new client registry func NewClientRegistry() *ClientRegistry { return &ClientRegistry{ - clients: make(map[*websocket.Conn]bool), + clients: make(map[*websocket.Conn]*sync.Mutex), } } @@ -24,7 +24,7 @@ func NewClientRegistry() *ClientRegistry { func (r *ClientRegistry) Add(conn *websocket.Conn) { r.mu.Lock() defer r.mu.Unlock() - r.clients[conn] = true + r.clients[conn] = &sync.Mutex{} } // Remove unregisters a client connection @@ -43,6 +43,14 @@ func (r *ClientRegistry) Count() int { // Contains checks if a client is registered func (r *ClientRegistry) Contains(conn *websocket.Conn) bool { + r.mu.RLock() + defer r.mu.RUnlock() + _, ok := r.clients[conn] + return ok +} + +// GetMutex returns the mutex for a given connection, if it exists +func (r *ClientRegistry) GetMutex(conn *websocket.Conn) *sync.Mutex { r.mu.RLock() defer r.mu.RUnlock() return r.clients[conn] @@ -51,9 +59,14 @@ func (r *ClientRegistry) Contains(conn *websocket.Conn) bool { // ForEach executes a function for each connected client func (r *ClientRegistry) ForEach(fn func(*websocket.Conn)) { r.mu.RLock() - defer r.mu.RUnlock() - + // Copy keys to avoid holding lock during iteration + conns := make([]*websocket.Conn, 0, len(r.clients)) for conn := range r.clients { + conns = append(conns, conn) + } + r.mu.RUnlock() + + for _, conn := range conns { fn(conn) } } @@ -61,9 +74,13 @@ func (r *ClientRegistry) ForEach(fn func(*websocket.Conn)) { // Broadcast sends a message to all connected clients func (r *ClientRegistry) Broadcast(fn func(*websocket.Conn) error) { r.mu.RLock() - defer r.mu.RUnlock() - + conns := make([]*websocket.Conn, 0, len(r.clients)) for conn := range r.clients { + conns = append(conns, conn) + } + r.mu.RUnlock() + + for _, conn := range conns { _ = fn(conn) } } diff --git a/internal/server/server.go b/internal/server/server.go index 84ab1c9..6c6191d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -324,6 +324,15 @@ func (s *Server) NotifyClient(conn *websocket.Conn, response Response) error { return nil } + mu := s.clients.GetMutex(conn) + if mu == nil { + // Client disconnected or not registered + return nil + } + + mu.Lock() + defer mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/worker/processor.go b/internal/worker/processor.go index 44f13ad..09d5d81 100644 --- a/internal/worker/processor.go +++ b/internal/worker/processor.go @@ -118,10 +118,10 @@ func (w *Worker) processJob(job *server.PrintJob) { err := w.executePrint(job) duration := time.Since(startTime) - w.lastJobTime = time.Now() // Update statistics w.mu.Lock() + w.lastJobTime = time.Now() if err != nil { w.jobsFailed++ } else { @@ -154,11 +154,13 @@ func (w *Worker) processJob(job *server.PrintJob) { } } - // Notify client + // Notify client (async to not block worker loop) if job.ClientConn != nil && w.notifier != nil { - if err := w.notifier.NotifyClient(job.ClientConn, response); err != nil { - log.Printf("[WORKER] ⚠️ Failed to notify client for job %s: %v", job.ID, err) - } + go func() { + if err := w.notifier.NotifyClient(job.ClientConn, response); err != nil { + log.Printf("[WORKER] ⚠️ Failed to notify client for job %s: %v", job.ID, err) + } + }() } } diff --git a/internal/worker/processor_test.go b/internal/worker/processor_test.go new file mode 100644 index 0000000..90894c4 --- /dev/null +++ b/internal/worker/processor_test.go @@ -0,0 +1,76 @@ +package worker + +import ( + "encoding/json" + "testing" + "time" + + "github.com/adcondev/ticket-daemon/internal/server" + "github.com/coder/websocket" +) + +// mockSlowNotifier simulates a slow network connection +type mockSlowNotifier struct { + delay time.Duration +} + +func (m *mockSlowNotifier) NotifyClient(_ *websocket.Conn, _ server.Response) error { + time.Sleep(m.delay) + return nil +} + +func TestWorkerBlockingNotification(t *testing.T) { + // Setup + jobCount := 5 + notifier := &mockSlowNotifier{delay: 200 * time.Millisecond} // 200ms delay per notification + + // Create job queue + jobQueue := make(chan *server.PrintJob, jobCount) + + // Create worker + config := Config{DefaultPrinter: "test"} + w := NewWorker(jobQueue, notifier, config) + + // Start worker + w.Start() + defer w.Stop() + + // Create dummy connection (we need a non-nil pointer) + dummyConn := &websocket.Conn{} + + // Prepare jobs + for j := 0; j < jobCount; j++ { + job := &server.PrintJob{ + ID: "test-job", + ClientConn: dummyConn, + Document: json.RawMessage("{}"), // Invalid document to trigger failure but still notify + ReceivedAt: time.Now(), + } + jobQueue <- job + } + + start := time.Now() + + // Wait for processing + deadline := time.Now().Add(5 * time.Second) + for { + stats := w.Stats() + if stats.JobsProcessed+stats.JobsFailed >= int64(jobCount) { + break + } + if time.Now().After(deadline) { + t.Fatalf("Timeout waiting for jobs to process. Processed: %d, Failed: %d", stats.JobsProcessed, stats.JobsFailed) + } + time.Sleep(10 * time.Millisecond) + } + + duration := time.Since(start) + + // With blocking notification: 5 jobs * 200ms = 1000ms (1s) + // We expect it to take at least 1s. + if duration > 500*time.Millisecond { + t.Errorf("Expected duration < 500ms (async), got %v", duration) + } else { + t.Logf("Blocking verified: Duration %v for %d jobs", duration, jobCount) + } +}