379 lines
9.3 KiB
Go
379 lines
9.3 KiB
Go
package notification
|
||
|
||
import (
|
||
"WiiCITMS/process/hr"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
// EventType 定义事件类型
|
||
type EventType string
|
||
|
||
const (
|
||
PositionChangeEvent EventType = "position_change" // 职位变动事件
|
||
ApprovalNoticeEvent EventType = "approval_notice" // 审批通知事件
|
||
MessageNotificationEvent EventType = "message_notification" // 消息提醒
|
||
SystemNotificationEvent EventType = "system_notification" // 系统通知
|
||
ScheduleReminderEvent EventType = "schedule_reminder" // 日程提醒
|
||
ScheduleConflictEvent EventType = "schedule_conflict" // 日程冲突提醒
|
||
)
|
||
|
||
// Event 表示SSE事件
|
||
type Event struct {
|
||
ID string `json:"id"` // 事件ID
|
||
Type EventType `json:"type"` // 事件类型
|
||
Data interface{} `json:"data"` // 事件数据
|
||
Timestamp int64 `json:"timestamp"` // 事件时间戳
|
||
}
|
||
|
||
// Client 代表一个连接的客户端
|
||
type Client struct {
|
||
UserGuid string // 用户ID
|
||
StaffGuid string // 员工ID
|
||
Channel chan Event // 事件通道
|
||
Connected bool // 连接状态
|
||
LastPing time.Time // 最后一次心跳时间
|
||
}
|
||
|
||
// SSEServer 是SSE服务器的实现
|
||
type SSEServer struct {
|
||
clients map[string]*Client // 客户端映射表 (UserGuid -> Client)
|
||
staffMap map[string]string // 员工映射表 (StaffGuid -> UserGuid)
|
||
register chan *Client // 注册通道
|
||
unregister chan string // 注销通道
|
||
broadcast chan Event // 广播通道
|
||
targetedSend chan TargetedEvent // 定向发送通道
|
||
mutex sync.RWMutex // 读写锁
|
||
}
|
||
|
||
// TargetedEvent 表示针对特定用户的事件
|
||
type TargetedEvent struct {
|
||
TargetStaffGuids []string // 目标员工ID列表
|
||
Event Event // 事件内容
|
||
}
|
||
|
||
// NewSSEServer 创建一个新的SSE服务器
|
||
func NewSSEServer() *SSEServer {
|
||
return &SSEServer{
|
||
clients: make(map[string]*Client),
|
||
staffMap: make(map[string]string),
|
||
register: make(chan *Client),
|
||
unregister: make(chan string),
|
||
broadcast: make(chan Event),
|
||
targetedSend: make(chan TargetedEvent),
|
||
mutex: sync.RWMutex{},
|
||
}
|
||
}
|
||
|
||
// Start 启动SSE服务器
|
||
func (s *SSEServer) Start() {
|
||
go s.run()
|
||
}
|
||
|
||
// 服务器主循环
|
||
func (s *SSEServer) run() {
|
||
// 定期清理断开连接的客户端
|
||
ticker := time.NewTicker(30 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case client := <-s.register:
|
||
s.registerClient(client)
|
||
case userGuid := <-s.unregister:
|
||
s.unregisterClient(userGuid)
|
||
case event := <-s.broadcast:
|
||
s.broadcastEvent(event)
|
||
case targetedEvent := <-s.targetedSend:
|
||
s.sendToTargets(targetedEvent)
|
||
case <-ticker.C:
|
||
s.cleanupClients()
|
||
}
|
||
}
|
||
}
|
||
|
||
// 注册客户端
|
||
func (s *SSEServer) registerClient(client *Client) {
|
||
s.mutex.Lock()
|
||
defer s.mutex.Unlock()
|
||
|
||
// 如果已存在相同UserGuid的客户端,先关闭旧连接
|
||
if oldClient, exists := s.clients[client.UserGuid]; exists {
|
||
close(oldClient.Channel)
|
||
}
|
||
|
||
s.clients[client.UserGuid] = client
|
||
|
||
// 更新员工映射
|
||
if client.StaffGuid != "" {
|
||
s.staffMap[client.StaffGuid] = client.UserGuid
|
||
}
|
||
|
||
fmt.Printf("Client registered: UserGuid=%s, StaffGuid=%s\n", client.UserGuid, client.StaffGuid)
|
||
}
|
||
|
||
// 注销客户端
|
||
func (s *SSEServer) unregisterClient(userGuid string) {
|
||
s.mutex.Lock()
|
||
defer s.mutex.Unlock()
|
||
|
||
if client, exists := s.clients[userGuid]; exists {
|
||
// 移除staffMap中的映射
|
||
for staffGuid, uGuid := range s.staffMap {
|
||
if uGuid == userGuid {
|
||
delete(s.staffMap, staffGuid)
|
||
}
|
||
}
|
||
|
||
// 关闭通道并移除客户端
|
||
close(client.Channel)
|
||
delete(s.clients, userGuid)
|
||
fmt.Printf("Client unregistered: UserGuid=%s\n", userGuid)
|
||
}
|
||
}
|
||
|
||
// 广播事件给所有客户端
|
||
func (s *SSEServer) broadcastEvent(event Event) {
|
||
s.mutex.RLock()
|
||
defer s.mutex.RUnlock()
|
||
|
||
for _, client := range s.clients {
|
||
if client.Connected {
|
||
select {
|
||
case client.Channel <- event:
|
||
// 事件已发送
|
||
default:
|
||
// 通道已满或已关闭,无法发送
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 发送事件给目标员工
|
||
func (s *SSEServer) sendToTargets(targetedEvent TargetedEvent) {
|
||
s.mutex.RLock()
|
||
defer s.mutex.RUnlock()
|
||
|
||
for _, staffGuid := range targetedEvent.TargetStaffGuids {
|
||
if userGuid, exists := s.staffMap[staffGuid]; exists {
|
||
if client, ok := s.clients[userGuid]; ok && client.Connected {
|
||
select {
|
||
case client.Channel <- targetedEvent.Event:
|
||
// 事件已发送
|
||
default:
|
||
// 通道已满或已关闭,无法发送
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 清理断开连接的客户端
|
||
func (s *SSEServer) cleanupClients() {
|
||
s.mutex.Lock()
|
||
defer s.mutex.Unlock()
|
||
|
||
now := time.Now()
|
||
timeout := 3 * time.Minute
|
||
|
||
for userGuid, client := range s.clients {
|
||
if !client.Connected || now.Sub(client.LastPing) > timeout {
|
||
// 关闭通道并移除客户端
|
||
close(client.Channel)
|
||
delete(s.clients, userGuid)
|
||
|
||
// 移除staffMap中的映射
|
||
for staffGuid, uGuid := range s.staffMap {
|
||
if uGuid == userGuid {
|
||
delete(s.staffMap, staffGuid)
|
||
}
|
||
}
|
||
|
||
fmt.Printf("Client timed out: UserGuid=%s\n", userGuid)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 根据StaffGuid获取Client
|
||
func (s *SSEServer) GetClientByStaffGuid(staffGuid string) *Client {
|
||
s.mutex.RLock()
|
||
defer s.mutex.RUnlock()
|
||
|
||
if userGuid, exists := s.staffMap[staffGuid]; exists {
|
||
return s.clients[userGuid]
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetClients 获取所有客户端
|
||
func (s *SSEServer) GetClients() map[string]*Client {
|
||
s.mutex.RLock()
|
||
defer s.mutex.RUnlock()
|
||
|
||
// 创建一个副本以避免并发问题
|
||
result := make(map[string]*Client)
|
||
for k, v := range s.clients {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// BroadcastEvent 广播事件给所有客户端
|
||
func (s *SSEServer) BroadcastEvent(eventType EventType, data interface{}) {
|
||
event := Event{
|
||
ID: uuid.New().String(),
|
||
Type: eventType,
|
||
Data: data,
|
||
Timestamp: time.Now().Unix(),
|
||
}
|
||
s.broadcast <- event
|
||
}
|
||
|
||
// SendEventToStaff 发送事件给特定员工
|
||
func (s *SSEServer) SendEventToStaff(staffGuids []string, eventType EventType, data interface{}) {
|
||
event := Event{
|
||
ID: uuid.New().String(),
|
||
Type: eventType,
|
||
Data: data,
|
||
Timestamp: time.Now().Unix(),
|
||
}
|
||
|
||
targetedEvent := TargetedEvent{
|
||
TargetStaffGuids: staffGuids,
|
||
Event: event,
|
||
}
|
||
|
||
s.targetedSend <- targetedEvent
|
||
}
|
||
|
||
// HandleSSE 处理SSE连接请求
|
||
func (s *SSEServer) HandleSSE(w http.ResponseWriter, r *http.Request) {
|
||
// 从请求中获取UserGuid
|
||
userGuid := r.URL.Query().Get("userGuid")
|
||
if userGuid == "" {
|
||
http.Error(w, "Missing userGuid parameter", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 查询对应的StaffGuid
|
||
staffDetail, proc := hr.GetStaffByUserGuid(userGuid)
|
||
var staffGuid string
|
||
if !proc.IsError() && staffDetail != nil {
|
||
staffGuid = staffDetail.StaffGuid
|
||
}
|
||
|
||
// 设置SSE所需的响应头
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
||
// 强制刷新响应
|
||
if flusher, ok := w.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
} else {
|
||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 创建新的客户端
|
||
client := &Client{
|
||
UserGuid: userGuid,
|
||
StaffGuid: staffGuid,
|
||
Channel: make(chan Event, 100), // 缓冲区大小为100
|
||
Connected: true,
|
||
LastPing: time.Now(),
|
||
}
|
||
|
||
// 注册客户端
|
||
s.register <- client
|
||
|
||
// 当连接断开时,注销客户端
|
||
notify := r.Context().Done()
|
||
go func() {
|
||
<-notify
|
||
s.unregister <- userGuid
|
||
}()
|
||
|
||
var welcomeData map[string]string
|
||
welcomeData = map[string]string{"message": "connect success"}
|
||
welcomeDataJson, _ := json.Marshal(welcomeData)
|
||
// 发送欢迎事件
|
||
welcomeEvent := Event{
|
||
ID: uuid.New().String(),
|
||
Type: SystemNotificationEvent,
|
||
Data: string(welcomeDataJson),
|
||
Timestamp: time.Now().Unix(),
|
||
}
|
||
|
||
fmt.Fprintf(w, "id: %s\n", welcomeEvent.ID)
|
||
fmt.Fprintf(w, "event: %s\n", welcomeEvent.Type)
|
||
fmt.Fprintf(w, "data: %v\n", welcomeEvent.Data)
|
||
fmt.Fprintf(w, "timestamp: %d\n\n", welcomeEvent.Timestamp)
|
||
|
||
if flusher, ok := w.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
|
||
// 持续发送事件
|
||
for {
|
||
select {
|
||
case <-notify:
|
||
// 连接已关闭
|
||
return
|
||
case event, ok := <-client.Channel:
|
||
if !ok {
|
||
// 通道已关闭
|
||
return
|
||
}
|
||
|
||
// 更新最后活动时间
|
||
client.LastPing = time.Now()
|
||
|
||
// 发送事件
|
||
fmt.Fprintf(w, "id: %s\n", event.ID)
|
||
fmt.Fprintf(w, "event: %s\n", event.Type)
|
||
fmt.Fprintf(w, "data: %v\n", event.Data)
|
||
fmt.Fprintf(w, "timestamp: %d\n\n", event.Timestamp)
|
||
|
||
if flusher, ok := w.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
case <-time.After(30 * time.Second):
|
||
// 发送心跳消息以保持连接
|
||
fmt.Fprintf(w, ": heartbeat\n\n")
|
||
|
||
if flusher, ok := w.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
|
||
// 更新最后活动时间
|
||
client.LastPing = time.Now()
|
||
}
|
||
}
|
||
}
|
||
|
||
// Global SSE server instance
|
||
var globalSSEServer *SSEServer
|
||
|
||
// GetSSEServer 获取全局SSE服务器实例
|
||
func GetSSEServer() *SSEServer {
|
||
if globalSSEServer == nil {
|
||
globalSSEServer = NewSSEServer()
|
||
globalSSEServer.Start()
|
||
}
|
||
return globalSSEServer
|
||
}
|
||
|
||
// SetupSSERoutes 设置SSE路由
|
||
func SetupSSERoutes(router http.Handler) {
|
||
sseServer := GetSSEServer()
|
||
|
||
// 使用默认的http处理器
|
||
http.HandleFunc("/api/v1/events", sseServer.HandleSSE)
|
||
}
|