package handlers import ( "context" "fmt" "html/template" "net/http" "sync" "sync/atomic" "time" "github.com/gorilla/websocket" "github.com/markbates/goth/gothic" ) type Handler struct { Template template.Template } type Message struct { Id string Action string Message string } type PageData struct { Username string } type Client struct { conn *websocket.Conn send chan []byte hub *hub ctx context.Context cancel context.CancelFunc id uint64 mu sync.Mutex // Protect concurrent operations } type hub struct { clients sync.Map broadcast chan []byte register chan *Client unregister chan *Client // Message pool for reusing byte slices messagePool sync.Pool } var Hub = &hub{ broadcast: make(chan []byte, 1024), register: make(chan *Client, 256), unregister: make(chan *Client, 256), } var upgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, } var clientIDCounter uint64 const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = 54 * time.Second maxMessageSize = 512 ) func init() { Hub.messagePool.New = func() interface{} { return make([]byte, 0, 1024) } } func (h *hub) Run() { for { select { case client := <-h.register: h.clients.Store(client.id, client) fmt.Printf("Client connected. ID: %d\n", client.id) case client := <-h.unregister: if _, loaded := h.clients.LoadAndDelete(client.id); loaded { close(client.send) fmt.Printf("Client disconnected. ID: %d\n", client.id) } case message := <-h.broadcast: // Single-threaded broadcasting to avoid race conditions h.clients.Range(func(key, value interface{}) bool { client := value.(*Client) select { case client.send <- message: // Message sent successfully default: // Client is slow, remove it h.clients.Delete(client.id) close(client.send) fmt.Printf("Slow client removed. ID: %d\n", client.id) } return true }) } } } func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() c.cancel() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { select { case <-c.ctx.Done(): return default: _, messageBytes, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { fmt.Printf("WebSocket error for client %d: %v\n", c.id, err) } return } // Create a new message buffer instead of reusing from pool message := make([]byte, len(messageBytes)) copy(message, messageBytes) // Non-blocking broadcast select { case c.hub.broadcast <- message: // Message will be sent to all clients default: // Broadcast buffer full, drop message fmt.Printf("Broadcast buffer full, dropping message\n") } } } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case <-c.ctx.Done(): return case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } // Write the message if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { fmt.Printf("Write error for client %d: %v\n", c.id, err) return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func (h *Handler) Home(w http.ResponseWriter, r *http.Request) { session, err := gothic.Store.Get(r, "user-session") if err != nil { http.Error(w, "Error retrieving session for welcome page", http.StatusInternalServerError) return } username, ok := session.Values["user_name"].(string) var pagedata PageData if ok { pagedata.Username = username } else { pagedata.Username = "" } err = h.Template.ExecuteTemplate(w, "index.html", &pagedata) if err != nil { http.Error(w, "Template rendering error", http.StatusInternalServerError) } } func (h *Handler) WsHandler(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { fmt.Println("upgrade error:", err) return } ctx, cancel := context.WithCancel(context.Background()) client := &Client{ conn: conn, send: make(chan []byte, 256), hub: Hub, ctx: ctx, cancel: cancel, id: atomic.AddUint64(&clientIDCounter, 1), } Hub.register <- client // Start pumps in separate goroutines go client.writePump() go client.readPump() // Wait for completion <-ctx.Done() }