You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
4.9 KiB
238 lines
4.9 KiB
package handlers |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"html/template" |
|
"net/http" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
|
|
"github.com/gorilla/websocket" |
|
"github.com/markbates/goth/gothic" |
|
) |
|
|
|
type Handler struct { |
|
Template template.Template |
|
} |
|
|
|
type Message struct { |
|
Id string |
|
Action string |
|
Message string |
|
} |
|
|
|
type PageData struct { |
|
Username string |
|
} |
|
|
|
type Client struct { |
|
conn *websocket.Conn |
|
send chan []byte |
|
hub *hub |
|
ctx context.Context |
|
cancel context.CancelFunc |
|
id uint64 |
|
mu sync.Mutex // Protect concurrent operations |
|
} |
|
|
|
type hub struct { |
|
clients sync.Map |
|
broadcast chan []byte |
|
register chan *Client |
|
unregister chan *Client |
|
|
|
// Message pool for reusing byte slices |
|
messagePool sync.Pool |
|
} |
|
|
|
var Hub = &hub{ |
|
broadcast: make(chan []byte, 1024), |
|
register: make(chan *Client, 256), |
|
unregister: make(chan *Client, 256), |
|
} |
|
|
|
var upgrader = websocket.Upgrader{ |
|
ReadBufferSize: 4096, |
|
WriteBufferSize: 4096, |
|
CheckOrigin: func(r *http.Request) bool { |
|
return true |
|
}, |
|
} |
|
|
|
var clientIDCounter uint64 |
|
|
|
const ( |
|
writeWait = 10 * time.Second |
|
pongWait = 60 * time.Second |
|
pingPeriod = 54 * time.Second |
|
maxMessageSize = 512 |
|
) |
|
|
|
func init() { |
|
Hub.messagePool.New = func() interface{} { |
|
return make([]byte, 0, 1024) |
|
} |
|
} |
|
|
|
func (h *hub) Run() { |
|
for { |
|
select { |
|
case client := <-h.register: |
|
h.clients.Store(client.id, client) |
|
fmt.Printf("Client connected. ID: %d\n", client.id) |
|
|
|
case client := <-h.unregister: |
|
if _, loaded := h.clients.LoadAndDelete(client.id); loaded { |
|
close(client.send) |
|
fmt.Printf("Client disconnected. ID: %d\n", client.id) |
|
} |
|
|
|
case message := <-h.broadcast: |
|
// Single-threaded broadcasting to avoid race conditions |
|
h.clients.Range(func(key, value interface{}) bool { |
|
client := value.(*Client) |
|
|
|
select { |
|
case client.send <- message: |
|
// Message sent successfully |
|
default: |
|
// Client is slow, remove it |
|
h.clients.Delete(client.id) |
|
close(client.send) |
|
fmt.Printf("Slow client removed. ID: %d\n", client.id) |
|
} |
|
return true |
|
}) |
|
|
|
// DON'T return message to pool - let GC handle it |
|
// h.messagePool.Put(message) |
|
} |
|
} |
|
} |
|
|
|
func (c *Client) readPump() { |
|
defer func() { |
|
c.hub.unregister <- c |
|
c.conn.Close() |
|
c.cancel() |
|
}() |
|
|
|
c.conn.SetReadLimit(maxMessageSize) |
|
c.conn.SetReadDeadline(time.Now().Add(pongWait)) |
|
c.conn.SetPongHandler(func(string) error { |
|
c.conn.SetReadDeadline(time.Now().Add(pongWait)) |
|
return nil |
|
}) |
|
|
|
for { |
|
select { |
|
case <-c.ctx.Done(): |
|
return |
|
default: |
|
_, messageBytes, err := c.conn.ReadMessage() |
|
if err != nil { |
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { |
|
fmt.Printf("WebSocket error for client %d: %v\n", c.id, err) |
|
} |
|
return |
|
} |
|
|
|
// Create a new message buffer instead of reusing from pool |
|
message := make([]byte, len(messageBytes)) |
|
copy(message, messageBytes) |
|
|
|
// Non-blocking broadcast |
|
select { |
|
case c.hub.broadcast <- message: |
|
// Message will be sent to all clients |
|
default: |
|
// Broadcast buffer full, drop message |
|
fmt.Printf("Broadcast buffer full, dropping message\n") |
|
} |
|
} |
|
} |
|
} |
|
|
|
func (c *Client) writePump() { |
|
ticker := time.NewTicker(pingPeriod) |
|
defer func() { |
|
ticker.Stop() |
|
c.conn.Close() |
|
}() |
|
|
|
for { |
|
select { |
|
case <-c.ctx.Done(): |
|
return |
|
|
|
case message, ok := <-c.send: |
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
|
if !ok { |
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{}) |
|
return |
|
} |
|
|
|
// Write the message |
|
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { |
|
fmt.Printf("Write error for client %d: %v\n", c.id, err) |
|
return |
|
} |
|
|
|
case <-ticker.C: |
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { |
|
return |
|
} |
|
} |
|
} |
|
} |
|
|
|
func (h *Handler) Home(w http.ResponseWriter, r *http.Request) { |
|
session, err := gothic.Store.Get(r, "user-session") |
|
if err != nil { |
|
http.Error(w, "Error retrieving session for welcome page", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
username, ok := session.Values["user_name"].(string) |
|
var pagedata PageData |
|
if ok { |
|
pagedata.Username = username |
|
} else { |
|
pagedata.Username = "" |
|
} |
|
|
|
err = h.Template.ExecuteTemplate(w, "index.html", &pagedata) |
|
if err != nil { |
|
http.Error(w, "Template rendering error", http.StatusInternalServerError) |
|
} |
|
} |
|
|
|
func (h *Handler) WsHandler(w http.ResponseWriter, r *http.Request) { |
|
conn, err := upgrader.Upgrade(w, r, nil) |
|
if err != nil { |
|
fmt.Println("upgrade error:", err) |
|
return |
|
} |
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
client := &Client{ |
|
conn: conn, |
|
send: make(chan []byte, 256), |
|
hub: Hub, |
|
ctx: ctx, |
|
cancel: cancel, |
|
id: atomic.AddUint64(&clientIDCounter, 1), |
|
} |
|
|
|
Hub.register <- client |
|
|
|
// Start pumps in separate goroutines |
|
go client.writePump() |
|
go client.readPump() |
|
|
|
// Wait for completion |
|
<-ctx.Done() |
|
}
|
|
|