sorare/subscriptions/wsclient.go

351 lines
6.5 KiB
Go

package subscriptions
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/llehouerou/go-graphql-client"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
const WSURL = "wss://ws.sorare.com/cable"
type wsClient[T any] struct {
cancel context.CancelFunc
c *websocket.Conn
Done chan struct{}
rawMessageChan chan []byte
writeChan chan []byte
lastPing time.Time
pingMutex *sync.Mutex
Data chan T
subscriptionName string
subscriptionParams string
channelId string
debug bool
}
func newWsClient[T any](
ctx context.Context,
subscriptionName string,
subscriptionParams string,
debug bool,
) (*wsClient[T], error) {
localctx, cancel := context.WithCancel(ctx)
channelId := uuid.New().String()[0:7]
w := &wsClient[T]{
cancel: cancel,
rawMessageChan: make(chan []byte),
writeChan: make(chan []byte),
lastPing: time.Now(),
pingMutex: &sync.Mutex{},
Data: make(chan T),
subscriptionName: subscriptionName,
subscriptionParams: subscriptionParams,
Done: make(chan struct{}),
debug: debug,
channelId: channelId,
}
err := w.connect()
if err != nil {
return nil, errors.Wrap(err, "connecting websocket")
}
go w.processMessages()
go w.readMessages(localctx)
go w.writeMessages(localctx)
err = w.subscribe()
if err != nil {
return nil, errors.Wrap(err, "subscribing")
}
return w, nil
}
func (w *wsClient[T]) connect() error {
dialer := &websocket.Dialer{
HandshakeTimeout: 30 * time.Second,
EnableCompression: true,
}
connection, _, err := dialer.Dial(WSURL, nil)
if err != nil {
return errors.Wrap(err, "dialing websocket")
}
w.c = connection
return nil
}
type WsMessage struct {
Type string
Message json.RawMessage
}
type WsError struct {
Message string
Locations []struct {
Line int
Column int
}
Path []string
Extensions struct {
Code string
TypeName string `json:"typeName"`
ArgumentName string `json:"argumentName"`
}
}
func (w *wsClient[T]) processMessage(message json.RawMessage) error {
var m struct {
More bool
Result struct {
Data map[string]json.RawMessage
Errors []WsError
}
}
err := json.Unmarshal(message, &m)
if err != nil {
return errors.Wrap(err, "unmarshalling message")
}
if len(m.Result.Errors) > 0 {
return errors.New("graphql errors")
}
if len(m.Result.Data) == 0 {
return nil
}
data, ok := m.Result.Data[w.subscriptionName]
if !ok {
return nil
}
if string(data) == "null" {
return nil
}
var unmarshalledData T
err = graphql.UnmarshalGraphQL(data, &unmarshalledData)
if err != nil {
return errors.Wrap(err, "unmarshalling graphql data")
}
w.Data <- unmarshalledData
if !m.More {
w.Stop()
}
return nil
}
func (w *wsClient[T]) Stop() {
w.cancel()
_ = w.c.Close()
}
func (w *wsClient[T]) processMessages() {
defer close(w.Data)
for message := range w.rawMessageChan {
if w.debug {
fmt.Println("<-- " + string(message))
}
var m WsMessage
err := json.Unmarshal(message, &m)
if err != nil {
log.Error().Err(err).Msg("unmarshalling message")
continue
}
switch m.Type {
case "welcome":
case "ping":
err = w.processPing(m.Message)
if err != nil {
log.Error().Err(err).Msg("unmarshalling message")
continue
}
case "":
err = w.processMessage(m.Message)
if err != nil {
log.Error().Err(err).Msg("unmarshalling message")
continue
}
}
}
}
type PingMessage time.Time
func (p *PingMessage) UnmarshalJSON(bytes []byte) error {
var raw int64
err := json.Unmarshal(bytes, &raw)
if err != nil {
return errors.Wrap(err, "unmarshalling ping message")
}
*p = PingMessage(time.Unix(raw, 0))
return nil
}
func (p *PingMessage) Time() time.Time {
return time.Time(*p)
}
func (w *wsClient[T]) processPing(message json.RawMessage) error {
var m PingMessage
err := json.Unmarshal(message, &m)
if err != nil {
return errors.Wrap(err, "unmarshalling ping message")
}
w.pingMutex.Lock()
w.lastPing = m.Time()
w.pingMutex.Unlock()
return nil
}
func (w *wsClient[T]) writeMessages(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case d := <-w.writeChan:
if w.debug {
fmt.Println("--> " + string(d))
}
err := w.c.WriteMessage(websocket.TextMessage, d)
if err != nil {
continue
}
}
}
}
func (w *wsClient[T]) readMessages(ctx context.Context) {
defer close(w.rawMessageChan)
defer close(w.Done)
w.c.SetCloseHandler(func(code int, text string) error {
return nil
})
w.c.SetPongHandler(func(pong string) error {
return nil
})
w.c.SetPingHandler(func(ping string) error {
return nil
})
for {
select {
case <-ctx.Done():
return
default:
t, msg, err := w.c.ReadMessage()
if err != nil {
return
}
switch t {
case websocket.TextMessage:
w.rawMessageChan <- msg
case websocket.BinaryMessage:
default:
}
}
}
}
type identifier struct {
Channel string `json:"channel"`
ChannelId string `json:"channelId"`
}
type message struct {
Command string `json:"command"`
Identifier string `json:"identifier"`
Data string `json:"data,omitempty"`
}
type queryData struct {
Action string `json:"action"`
Query string `json:"query"`
Variables []string `json:"variables,omitempty"`
OperationName string `json:"operationName,omitempty"`
}
func (w *wsClient[T]) sendMessage(message any) error {
data, err := json.Marshal(message)
if err != nil {
return errors.Wrap(err, "marshalling message")
}
w.writeChan <- data
return nil
}
func (w *wsClient[T]) subscribe() error {
go func() {
identifier, err := json.Marshal(identifier{
Channel: "GraphqlChannel",
ChannelId: w.channelId,
})
if err != nil {
return
}
err = w.sendMessage(message{
Command: "subscribe",
Identifier: string(identifier),
})
if err != nil {
return
}
time.Sleep(5 * time.Second)
var query struct {
SubscriptionName T
}
queryMarshalled, err := graphql.ConstructSubscription(query, nil)
if err != nil {
return
}
queryMarshalled = strings.Replace(
queryMarshalled,
"subscriptionName",
w.subscriptionName+w.subscriptionParams,
1,
)
querystr, err := json.Marshal(queryData{
Action: "execute",
Query: queryMarshalled,
})
if err != nil {
return
}
err = w.sendMessage(message{
Command: "message",
Identifier: string(identifier),
Data: string(querystr),
})
if err != nil {
return
}
}()
return nil
}