351 lines
6.5 KiB
Go
351 lines
6.5 KiB
Go
package subscriptions
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/llehouerou/go-graphql-client"
|
|
"github.com/pkg/errors"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
const WSURL = "wss://ws.sorare.com/cable"
|
|
|
|
type wsClient[T any] struct {
|
|
cancel context.CancelFunc
|
|
c *websocket.Conn
|
|
|
|
Done chan struct{}
|
|
|
|
rawMessageChan chan []byte
|
|
writeChan chan []byte
|
|
|
|
lastPing time.Time
|
|
pingMutex *sync.Mutex
|
|
|
|
Data chan T
|
|
subscriptionName string
|
|
subscriptionParams string
|
|
|
|
channelId string
|
|
|
|
debug bool
|
|
}
|
|
|
|
func newWsClient[T any](
|
|
ctx context.Context,
|
|
subscriptionName string,
|
|
subscriptionParams string,
|
|
debug bool,
|
|
) (*wsClient[T], error) {
|
|
localctx, cancel := context.WithCancel(ctx)
|
|
channelId := uuid.New().String()[0:7]
|
|
w := &wsClient[T]{
|
|
cancel: cancel,
|
|
rawMessageChan: make(chan []byte),
|
|
writeChan: make(chan []byte),
|
|
lastPing: time.Now(),
|
|
pingMutex: &sync.Mutex{},
|
|
Data: make(chan T),
|
|
subscriptionName: subscriptionName,
|
|
subscriptionParams: subscriptionParams,
|
|
Done: make(chan struct{}),
|
|
debug: debug,
|
|
|
|
channelId: channelId,
|
|
}
|
|
err := w.connect()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "connecting websocket")
|
|
}
|
|
go w.processMessages()
|
|
go w.readMessages(localctx)
|
|
go w.writeMessages(localctx)
|
|
|
|
err = w.subscribe()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "subscribing")
|
|
}
|
|
return w, nil
|
|
}
|
|
|
|
func (w *wsClient[T]) connect() error {
|
|
dialer := &websocket.Dialer{
|
|
HandshakeTimeout: 30 * time.Second,
|
|
EnableCompression: true,
|
|
}
|
|
connection, _, err := dialer.Dial(WSURL, nil)
|
|
if err != nil {
|
|
return errors.Wrap(err, "dialing websocket")
|
|
}
|
|
|
|
w.c = connection
|
|
return nil
|
|
}
|
|
|
|
type WsMessage struct {
|
|
Type string
|
|
Message json.RawMessage
|
|
}
|
|
|
|
type WsError struct {
|
|
Message string
|
|
Locations []struct {
|
|
Line int
|
|
Column int
|
|
}
|
|
Path []string
|
|
Extensions struct {
|
|
Code string
|
|
TypeName string `json:"typeName"`
|
|
ArgumentName string `json:"argumentName"`
|
|
}
|
|
}
|
|
|
|
func (w *wsClient[T]) processMessage(message json.RawMessage) error {
|
|
var m struct {
|
|
More bool
|
|
Result struct {
|
|
Data map[string]json.RawMessage
|
|
Errors []WsError
|
|
}
|
|
}
|
|
err := json.Unmarshal(message, &m)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unmarshalling message")
|
|
}
|
|
|
|
if len(m.Result.Errors) > 0 {
|
|
return errors.New("graphql errors")
|
|
}
|
|
|
|
if len(m.Result.Data) == 0 {
|
|
return nil
|
|
}
|
|
data, ok := m.Result.Data[w.subscriptionName]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if string(data) == "null" {
|
|
return nil
|
|
}
|
|
|
|
var unmarshalledData T
|
|
err = graphql.UnmarshalGraphQL(data, &unmarshalledData)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unmarshalling graphql data")
|
|
}
|
|
|
|
w.Data <- unmarshalledData
|
|
|
|
if !m.More {
|
|
w.Stop()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (w *wsClient[T]) Stop() {
|
|
w.cancel()
|
|
_ = w.c.Close()
|
|
}
|
|
|
|
func (w *wsClient[T]) processMessages() {
|
|
defer close(w.Data)
|
|
for message := range w.rawMessageChan {
|
|
if w.debug {
|
|
fmt.Println("<-- " + string(message))
|
|
}
|
|
var m WsMessage
|
|
err := json.Unmarshal(message, &m)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("unmarshalling message")
|
|
continue
|
|
}
|
|
|
|
switch m.Type {
|
|
case "welcome":
|
|
case "ping":
|
|
err = w.processPing(m.Message)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("unmarshalling message")
|
|
continue
|
|
}
|
|
case "":
|
|
err = w.processMessage(m.Message)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("unmarshalling message")
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type PingMessage time.Time
|
|
|
|
func (p *PingMessage) UnmarshalJSON(bytes []byte) error {
|
|
var raw int64
|
|
err := json.Unmarshal(bytes, &raw)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unmarshalling ping message")
|
|
}
|
|
*p = PingMessage(time.Unix(raw, 0))
|
|
return nil
|
|
}
|
|
|
|
func (p *PingMessage) Time() time.Time {
|
|
return time.Time(*p)
|
|
}
|
|
|
|
func (w *wsClient[T]) processPing(message json.RawMessage) error {
|
|
var m PingMessage
|
|
err := json.Unmarshal(message, &m)
|
|
if err != nil {
|
|
return errors.Wrap(err, "unmarshalling ping message")
|
|
}
|
|
w.pingMutex.Lock()
|
|
w.lastPing = m.Time()
|
|
w.pingMutex.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (w *wsClient[T]) writeMessages(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case d := <-w.writeChan:
|
|
if w.debug {
|
|
fmt.Println("--> " + string(d))
|
|
}
|
|
err := w.c.WriteMessage(websocket.TextMessage, d)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *wsClient[T]) readMessages(ctx context.Context) {
|
|
defer close(w.rawMessageChan)
|
|
defer close(w.Done)
|
|
w.c.SetCloseHandler(func(code int, text string) error {
|
|
return nil
|
|
})
|
|
|
|
w.c.SetPongHandler(func(pong string) error {
|
|
return nil
|
|
})
|
|
|
|
w.c.SetPingHandler(func(ping string) error {
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
t, msg, err := w.c.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
switch t {
|
|
case websocket.TextMessage:
|
|
w.rawMessageChan <- msg
|
|
case websocket.BinaryMessage:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
type identifier struct {
|
|
Channel string `json:"channel"`
|
|
ChannelId string `json:"channelId"`
|
|
}
|
|
|
|
type message struct {
|
|
Command string `json:"command"`
|
|
Identifier string `json:"identifier"`
|
|
Data string `json:"data,omitempty"`
|
|
}
|
|
|
|
type queryData struct {
|
|
Action string `json:"action"`
|
|
Query string `json:"query"`
|
|
Variables []string `json:"variables,omitempty"`
|
|
OperationName string `json:"operationName,omitempty"`
|
|
}
|
|
|
|
func (w *wsClient[T]) sendMessage(message any) error {
|
|
data, err := json.Marshal(message)
|
|
if err != nil {
|
|
return errors.Wrap(err, "marshalling message")
|
|
}
|
|
w.writeChan <- data
|
|
return nil
|
|
}
|
|
|
|
func (w *wsClient[T]) subscribe() error {
|
|
go func() {
|
|
identifier, err := json.Marshal(identifier{
|
|
Channel: "GraphqlChannel",
|
|
ChannelId: w.channelId,
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = w.sendMessage(message{
|
|
Command: "subscribe",
|
|
Identifier: string(identifier),
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
time.Sleep(5 * time.Second)
|
|
|
|
var query struct {
|
|
SubscriptionName T
|
|
}
|
|
queryMarshalled, err := graphql.ConstructSubscription(query, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
queryMarshalled = strings.Replace(
|
|
queryMarshalled,
|
|
"subscriptionName",
|
|
w.subscriptionName+w.subscriptionParams,
|
|
1,
|
|
)
|
|
|
|
querystr, err := json.Marshal(queryData{
|
|
Action: "execute",
|
|
Query: queryMarshalled,
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = w.sendMessage(message{
|
|
Command: "message",
|
|
Identifier: string(identifier),
|
|
Data: string(querystr),
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
}()
|
|
return nil
|
|
|
|
}
|