package subscriptions import ( "context" "encoding/json" "fmt" "strings" "sync" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/llehouerou/go-graphql-client" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) const WSURL = "wss://ws.sorare.com/cable" type wsClient[T any] struct { cancel context.CancelFunc c *websocket.Conn Done chan struct{} rawMessageChan chan []byte writeChan chan []byte lastPing time.Time pingMutex *sync.Mutex Data chan T subscriptionName string subscriptionParams string channelId string debug bool } func newWsClient[T any]( ctx context.Context, subscriptionName string, subscriptionParams string, debug bool, ) (*wsClient[T], error) { localctx, cancel := context.WithCancel(ctx) channelId := uuid.New().String()[0:7] w := &wsClient[T]{ cancel: cancel, rawMessageChan: make(chan []byte), writeChan: make(chan []byte), lastPing: time.Now(), pingMutex: &sync.Mutex{}, Data: make(chan T), subscriptionName: subscriptionName, subscriptionParams: subscriptionParams, Done: make(chan struct{}), debug: debug, channelId: channelId, } err := w.connect() if err != nil { return nil, errors.Wrap(err, "connecting websocket") } go w.processMessages() go w.readMessages(localctx) go w.writeMessages(localctx) err = w.subscribe() if err != nil { return nil, errors.Wrap(err, "subscribing") } return w, nil } func (w *wsClient[T]) connect() error { dialer := &websocket.Dialer{ HandshakeTimeout: 30 * time.Second, EnableCompression: true, } connection, _, err := dialer.Dial(WSURL, nil) if err != nil { return errors.Wrap(err, "dialing websocket") } w.c = connection return nil } type WsMessage struct { Type string Message json.RawMessage } type WsError struct { Message string Locations []struct { Line int Column int } Path []string Extensions struct { Code string TypeName string `json:"typeName"` ArgumentName string `json:"argumentName"` } } func (w *wsClient[T]) processMessage(message json.RawMessage) error { var m struct { More bool Result struct { Data map[string]json.RawMessage Errors []WsError } } err := json.Unmarshal(message, &m) if err != nil { return errors.Wrap(err, "unmarshalling message") } if len(m.Result.Errors) > 0 { return errors.New("graphql errors") } if len(m.Result.Data) == 0 { return nil } data, ok := m.Result.Data[w.subscriptionName] if !ok { return nil } if string(data) == "null" { return nil } var unmarshalledData T err = graphql.UnmarshalGraphQL(data, &unmarshalledData) if err != nil { return errors.Wrap(err, "unmarshalling graphql data") } w.Data <- unmarshalledData if !m.More { w.Stop() } return nil } func (w *wsClient[T]) Stop() { w.cancel() _ = w.c.Close() } func (w *wsClient[T]) processMessages() { defer close(w.Data) for message := range w.rawMessageChan { if w.debug { fmt.Println("<-- " + string(message)) } var m WsMessage err := json.Unmarshal(message, &m) if err != nil { log.Error().Err(err).Msg("unmarshalling message") continue } switch m.Type { case "welcome": case "ping": err = w.processPing(m.Message) if err != nil { log.Error().Err(err).Msg("unmarshalling message") continue } case "": err = w.processMessage(m.Message) if err != nil { log.Error().Err(err).Msg("unmarshalling message") continue } } } } type PingMessage time.Time func (p *PingMessage) UnmarshalJSON(bytes []byte) error { var raw int64 err := json.Unmarshal(bytes, &raw) if err != nil { return errors.Wrap(err, "unmarshalling ping message") } *p = PingMessage(time.Unix(raw, 0)) return nil } func (p *PingMessage) Time() time.Time { return time.Time(*p) } func (w *wsClient[T]) processPing(message json.RawMessage) error { var m PingMessage err := json.Unmarshal(message, &m) if err != nil { return errors.Wrap(err, "unmarshalling ping message") } w.pingMutex.Lock() w.lastPing = m.Time() w.pingMutex.Unlock() return nil } func (w *wsClient[T]) writeMessages(ctx context.Context) { for { select { case <-ctx.Done(): return case d := <-w.writeChan: if w.debug { fmt.Println("--> " + string(d)) } err := w.c.WriteMessage(websocket.TextMessage, d) if err != nil { continue } } } } func (w *wsClient[T]) readMessages(ctx context.Context) { defer close(w.rawMessageChan) defer close(w.Done) w.c.SetCloseHandler(func(code int, text string) error { return nil }) w.c.SetPongHandler(func(pong string) error { return nil }) w.c.SetPingHandler(func(ping string) error { return nil }) for { select { case <-ctx.Done(): return default: t, msg, err := w.c.ReadMessage() if err != nil { return } switch t { case websocket.TextMessage: w.rawMessageChan <- msg case websocket.BinaryMessage: default: } } } } type identifier struct { Channel string `json:"channel"` ChannelId string `json:"channelId"` } type message struct { Command string `json:"command"` Identifier string `json:"identifier"` Data string `json:"data,omitempty"` } type queryData struct { Action string `json:"action"` Query string `json:"query"` Variables []string `json:"variables,omitempty"` OperationName string `json:"operationName,omitempty"` } func (w *wsClient[T]) sendMessage(message any) error { data, err := json.Marshal(message) if err != nil { return errors.Wrap(err, "marshalling message") } w.writeChan <- data return nil } func (w *wsClient[T]) subscribe() error { go func() { identifier, err := json.Marshal(identifier{ Channel: "GraphqlChannel", ChannelId: w.channelId, }) if err != nil { return } err = w.sendMessage(message{ Command: "subscribe", Identifier: string(identifier), }) if err != nil { return } time.Sleep(5 * time.Second) var query struct { SubscriptionName T } queryMarshalled, err := graphql.ConstructSubscription(query, nil) if err != nil { return } queryMarshalled = strings.Replace( queryMarshalled, "subscriptionName", w.subscriptionName+w.subscriptionParams, 1, ) querystr, err := json.Marshal(queryData{ Action: "execute", Query: queryMarshalled, }) if err != nil { return } err = w.sendMessage(message{ Command: "message", Identifier: string(identifier), Data: string(querystr), }) if err != nil { return } }() return nil }