Files
qgo-server/vendor/github.com/quic-go/webtransport-go/session.go
Smile Rex 6ace91a21a
All checks were successful
Create and publish a Docker image 🚀 / build-and-push-image (push) Successful in 1m49s
add vendor data
2026-03-10 01:11:41 +03:00

468 lines
12 KiB
Go

package webtransport
import (
"context"
"encoding/binary"
"io"
"math/rand/v2"
"net"
"sync"
"time"
"unicode/utf8"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
)
// sessionID is the WebTransport Session ID
type sessionID uint64
const closeSessionCapsuleType http3.CapsuleType = 0x2843
const maxCloseCapsuleErrorMsgLen = 1024
type acceptQueue[T any] struct {
mx sync.Mutex
// The channel is used to notify consumers (via Chan) about new incoming items.
// Needs to be buffered to preserve the notification if an item is enqueued
// between a call to Next and to Chan.
c chan struct{}
// Contains all the streams waiting to be accepted.
// There's no explicit limit to the length of the queue, but it is implicitly
// limited by the stream flow control provided by QUIC.
queue []T
}
func newAcceptQueue[T any]() *acceptQueue[T] {
return &acceptQueue[T]{c: make(chan struct{}, 1)}
}
func (q *acceptQueue[T]) Add(str T) {
q.mx.Lock()
q.queue = append(q.queue, str)
q.mx.Unlock()
select {
case q.c <- struct{}{}:
default:
}
}
func (q *acceptQueue[T]) Next() T {
q.mx.Lock()
defer q.mx.Unlock()
if len(q.queue) == 0 {
return *new(T)
}
str := q.queue[0]
q.queue = q.queue[1:]
return str
}
func (q *acceptQueue[T]) Chan() <-chan struct{} { return q.c }
type http3Stream interface {
io.ReadWriteCloser
ReceiveDatagram(context.Context) ([]byte, error)
SendDatagram([]byte) error
CancelRead(quic.StreamErrorCode)
CancelWrite(quic.StreamErrorCode)
SetWriteDeadline(time.Time) error
}
var (
_ http3Stream = &http3.Stream{}
_ http3Stream = &http3.RequestStream{}
)
// SessionState contains the state of a WebTransport session
type SessionState struct {
// ConnectionState contains the QUIC connection state, including TLS handshake information
ConnectionState quic.ConnectionState
// ApplicationProtocol contains the application protocol negotiated for the session
ApplicationProtocol string
}
type Session struct {
sessionID sessionID
conn *quic.Conn
str http3Stream
applicationProtocol string
streamHdr []byte
uniStreamHdr []byte
ctx context.Context
closeMx sync.Mutex
closeErr error // not nil once the session is closed
// streamCtxs holds all the context.CancelFuncs of calls to Open{Uni}StreamSync calls currently active.
// When the session is closed, this allows us to cancel all these contexts and make those calls return.
streamCtxs map[int]context.CancelFunc
bidiAcceptQueue acceptQueue[*Stream]
uniAcceptQueue acceptQueue[*ReceiveStream]
streams streamsMap
}
func newSession(ctx context.Context, sessionID sessionID, conn *quic.Conn, str http3Stream, applicationProtocol string) *Session {
ctx, ctxCancel := context.WithCancel(ctx)
c := &Session{
sessionID: sessionID,
conn: conn,
str: str,
applicationProtocol: applicationProtocol,
ctx: ctx,
streamCtxs: make(map[int]context.CancelFunc),
bidiAcceptQueue: *newAcceptQueue[*Stream](),
uniAcceptQueue: *newAcceptQueue[*ReceiveStream](),
streams: *newStreamsMap(),
}
// precompute the headers for unidirectional streams
c.uniStreamHdr = make([]byte, 0, 2+quicvarint.Len(uint64(c.sessionID)))
c.uniStreamHdr = quicvarint.Append(c.uniStreamHdr, webTransportUniStreamType)
c.uniStreamHdr = quicvarint.Append(c.uniStreamHdr, uint64(c.sessionID))
// precompute the headers for bidirectional streams
c.streamHdr = make([]byte, 0, 2+quicvarint.Len(uint64(c.sessionID)))
c.streamHdr = quicvarint.Append(c.streamHdr, webTransportFrameType)
c.streamHdr = quicvarint.Append(c.streamHdr, uint64(c.sessionID))
go func() {
defer ctxCancel()
c.handleConn()
}()
return c
}
func (s *Session) handleConn() {
err := s.parseNextCapsule()
s.closeWithError(err)
}
// parseNextCapsule parses the next Capsule sent on the request stream.
// It returns a SessionError, if the capsule received is a WT_CLOSE_SESSION Capsule.
func (s *Session) parseNextCapsule() error {
for {
typ, r, err := http3.ParseCapsule(quicvarint.NewReader(s.str))
if err != nil {
return err
}
switch typ {
case closeSessionCapsuleType:
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
appErrCode := binary.BigEndian.Uint32(b[:])
// the length of the error message is limited to 1024 bytes
appErrMsg, err := io.ReadAll(io.LimitReader(r, maxCloseCapsuleErrorMsgLen))
if err != nil {
return err
}
return &SessionError{
Remote: true,
ErrorCode: SessionErrorCode(appErrCode),
Message: string(appErrMsg),
}
default:
// unknown capsule, skip it
if _, err := io.ReadAll(r); err != nil {
return err
}
}
}
}
func (s *Session) addStream(qstr *quic.Stream, addStreamHeader bool) *Stream {
var hdr []byte
if addStreamHeader {
hdr = s.streamHdr
}
str := newStream(qstr, hdr, func() { s.streams.RemoveStream(qstr.StreamID()) })
s.streams.AddStream(qstr.StreamID(), str.closeWithSession)
return str
}
func (s *Session) addReceiveStream(qstr *quic.ReceiveStream) *ReceiveStream {
str := newReceiveStream(qstr, func() { s.streams.RemoveStream(qstr.StreamID()) })
s.streams.AddStream(qstr.StreamID(), str.closeWithSession)
return str
}
func (s *Session) addSendStream(qstr *quic.SendStream) *SendStream {
str := newSendStream(qstr, s.uniStreamHdr, func() { s.streams.RemoveStream(qstr.StreamID()) })
s.streams.AddStream(qstr.StreamID(), str.closeWithSession)
return str
}
// addIncomingStream adds a bidirectional stream that the remote peer opened
func (s *Session) addIncomingStream(qstr *quic.Stream) {
s.closeMx.Lock()
closeErr := s.closeErr
if closeErr != nil {
s.closeMx.Unlock()
qstr.CancelRead(WTSessionGoneErrorCode)
qstr.CancelWrite(WTSessionGoneErrorCode)
return
}
str := s.addStream(qstr, false)
s.closeMx.Unlock()
s.bidiAcceptQueue.Add(str)
}
// addIncomingUniStream adds a unidirectional stream that the remote peer opened
func (s *Session) addIncomingUniStream(qstr *quic.ReceiveStream) {
s.closeMx.Lock()
closeErr := s.closeErr
if closeErr != nil {
s.closeMx.Unlock()
qstr.CancelRead(WTSessionGoneErrorCode)
return
}
str := s.addReceiveStream(qstr)
s.closeMx.Unlock()
s.uniAcceptQueue.Add(str)
}
// Context returns a context that is closed when the session is closed.
func (s *Session) Context() context.Context {
return s.ctx
}
func (s *Session) AcceptStream(ctx context.Context) (*Stream, error) {
s.closeMx.Lock()
closeErr := s.closeErr
s.closeMx.Unlock()
if closeErr != nil {
return nil, closeErr
}
for {
// If there's a stream in the accept queue, return it immediately.
if str := s.bidiAcceptQueue.Next(); str != nil {
return str, nil
}
// No stream in the accept queue. Wait until we accept one.
select {
case <-s.ctx.Done():
return nil, s.closeErr
case <-ctx.Done():
return nil, ctx.Err()
case <-s.bidiAcceptQueue.Chan():
}
}
}
func (s *Session) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) {
s.closeMx.Lock()
closeErr := s.closeErr
s.closeMx.Unlock()
if closeErr != nil {
return nil, s.closeErr
}
for {
// If there's a stream in the accept queue, return it immediately.
if str := s.uniAcceptQueue.Next(); str != nil {
return str, nil
}
// No stream in the accept queue. Wait until we accept one.
select {
case <-s.ctx.Done():
return nil, s.closeErr
case <-ctx.Done():
return nil, ctx.Err()
case <-s.uniAcceptQueue.Chan():
}
}
}
func (s *Session) OpenStream() (*Stream, error) {
s.closeMx.Lock()
defer s.closeMx.Unlock()
if s.closeErr != nil {
return nil, s.closeErr
}
qstr, err := s.conn.OpenStream()
if err != nil {
return nil, err
}
return s.addStream(qstr, true), nil
}
func (s *Session) addStreamCtxCancel(cancel context.CancelFunc) (id int) {
rand:
id = rand.Int()
if _, ok := s.streamCtxs[id]; ok {
goto rand
}
s.streamCtxs[id] = cancel
return id
}
func (s *Session) OpenStreamSync(ctx context.Context) (*Stream, error) {
s.closeMx.Lock()
if s.closeErr != nil {
s.closeMx.Unlock()
return nil, s.closeErr
}
ctx, cancel := context.WithCancel(ctx)
id := s.addStreamCtxCancel(cancel)
s.closeMx.Unlock()
// open a new bidirectional stream without holding the mutex: this call might block
qstr, err := s.conn.OpenStreamSync(ctx)
s.closeMx.Lock()
defer s.closeMx.Unlock()
delete(s.streamCtxs, id)
// the session might have been closed concurrently with OpenStreamSync returning
if qstr != nil && s.closeErr != nil {
qstr.CancelRead(WTSessionGoneErrorCode)
qstr.CancelWrite(WTSessionGoneErrorCode)
return nil, s.closeErr
}
if err != nil {
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, err
}
return s.addStream(qstr, true), nil
}
func (s *Session) OpenUniStream() (*SendStream, error) {
s.closeMx.Lock()
defer s.closeMx.Unlock()
if s.closeErr != nil {
return nil, s.closeErr
}
qstr, err := s.conn.OpenUniStream()
if err != nil {
return nil, err
}
return s.addSendStream(qstr), nil
}
func (s *Session) OpenUniStreamSync(ctx context.Context) (str *SendStream, err error) {
s.closeMx.Lock()
if s.closeErr != nil {
s.closeMx.Unlock()
return nil, s.closeErr
}
ctx, cancel := context.WithCancel(ctx)
id := s.addStreamCtxCancel(cancel)
s.closeMx.Unlock()
// open a new unidirectional stream without holding the mutex: this call might block
qstr, err := s.conn.OpenUniStreamSync(ctx)
s.closeMx.Lock()
defer s.closeMx.Unlock()
delete(s.streamCtxs, id)
// the session might have been closed concurrently with OpenStreamSync returning
if qstr != nil && s.closeErr != nil {
qstr.CancelWrite(WTSessionGoneErrorCode)
return nil, s.closeErr
}
if err != nil {
if s.closeErr != nil {
return nil, s.closeErr
}
return nil, err
}
return s.addSendStream(qstr), nil
}
func (s *Session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *Session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *Session) CloseWithError(code SessionErrorCode, msg string) error {
first, err := s.closeWithError(&SessionError{ErrorCode: code, Message: msg})
if err != nil || !first {
return err
}
// truncate the message if it's too long
if len(msg) > maxCloseCapsuleErrorMsgLen {
msg = truncateUTF8(msg, maxCloseCapsuleErrorMsgLen)
}
b := make([]byte, 4, 4+len(msg))
binary.BigEndian.PutUint32(b, uint32(code))
b = append(b, []byte(msg)...)
// Optimistically send the WT_CLOSE_SESSION Capsule:
// If we're flow-control limited, we don't want to wait for the receiver to issue new flow control credits.
// There's no idiomatic way to do a non-blocking write in Go, so we set a short deadline.
s.str.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
if err := http3.WriteCapsule(quicvarint.NewWriter(s.str), closeSessionCapsuleType, b); err != nil {
s.str.CancelWrite(WTSessionGoneErrorCode)
}
s.str.CancelRead(WTSessionGoneErrorCode)
err = s.str.Close()
<-s.ctx.Done()
return err
}
func (s *Session) SendDatagram(b []byte) error {
return s.str.SendDatagram(b)
}
func (s *Session) ReceiveDatagram(ctx context.Context) ([]byte, error) {
return s.str.ReceiveDatagram(ctx)
}
func (s *Session) closeWithError(closeErr error) (bool /* first call to close session */, error) {
s.closeMx.Lock()
defer s.closeMx.Unlock()
// Duplicate call, or the remote already closed this session.
if s.closeErr != nil {
return false, nil
}
s.closeErr = closeErr
for _, cancel := range s.streamCtxs {
cancel()
}
s.streams.CloseSession(closeErr)
return true, nil
}
// SessionState returns the current state of the session
func (s *Session) SessionState() SessionState {
return SessionState{
ConnectionState: s.conn.ConnectionState(),
ApplicationProtocol: s.applicationProtocol,
}
}
// truncateUTF8 cuts a string to max n bytes without breaking UTF-8 characters.
func truncateUTF8(s string, n int) string {
if len(s) <= n {
return s
}
for n > 0 && !utf8.RuneStart(s[n]) {
n--
}
return s[:n]
}