Files
qgo-server/vendor/github.com/quic-go/webtransport-go/server.go
Smile Rex 6ace91a21a
All checks were successful
Create and publish a Docker image 🚀 / build-and-push-image (push) Successful in 1m49s
add vendor data
2026-03-10 01:11:41 +03:00

446 lines
10 KiB
Go

package webtransport
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"slices"
"sync"
"time"
"unicode/utf8"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
"github.com/dunglas/httpsfv"
)
const (
wtAvailableProtocolsHeader = "WT-Available-Protocols"
wtProtocolHeader = "WT-Protocol"
)
const (
webTransportFrameType = 0x41
webTransportUniStreamType = 0x54
)
type quicConnKeyType struct{}
var quicConnKey = quicConnKeyType{}
func ConfigureHTTP3Server(s *http3.Server) {
if s.AdditionalSettings == nil {
s.AdditionalSettings = make(map[uint64]uint64, 1)
}
s.AdditionalSettings[settingsEnableWebtransport] = 1
s.EnableDatagrams = true
origConnContext := s.ConnContext
s.ConnContext = func(ctx context.Context, conn *quic.Conn) context.Context {
if origConnContext != nil {
ctx = origConnContext(ctx, conn)
}
ctx = context.WithValue(ctx, quicConnKey, conn)
return ctx
}
}
type Server struct {
H3 *http3.Server
// ApplicationProtocols is a list of application protocols that can be negotiated,
// see section 3.3 of https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14 for details.
ApplicationProtocols []string
// ReorderingTimeout is the maximum time an incoming WebTransport stream that cannot be associated
// with a session is buffered. It is also the maximum time a WebTransport connection request is
// blocked waiting for the client's SETTINGS are received.
// This can happen if the CONNECT request (that creates a new session) is reordered, and arrives
// after the first WebTransport stream(s) for that session.
// Defaults to 5 seconds.
ReorderingTimeout time.Duration
// CheckOrigin is used to validate the request origin, thereby preventing cross-site request forgery.
// CheckOrigin returns true if the request Origin header is acceptable.
// If unset, a safe default is used: If the Origin header is set, it is checked that it
// matches the request's Host header.
CheckOrigin func(r *http.Request) bool
ctx context.Context // is closed when Close is called
ctxCancel context.CancelFunc
refCount sync.WaitGroup
initOnce sync.Once
initErr error
connsMx sync.Mutex
conns map[*quic.Conn]*sessionManager
}
func (s *Server) initialize() error {
s.initOnce.Do(func() {
s.initErr = s.init()
})
return s.initErr
}
func (s *Server) timeout() time.Duration {
timeout := s.ReorderingTimeout
if timeout == 0 {
return 5 * time.Second
}
return timeout
}
func (s *Server) init() error {
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.conns = make(map[*quic.Conn]*sessionManager)
if s.CheckOrigin == nil {
s.CheckOrigin = checkSameOrigin
}
return nil
}
func (s *Server) Serve(conn net.PacketConn) error {
if err := s.initialize(); err != nil {
return err
}
var quicConf *quic.Config
if s.H3.QUICConfig != nil {
quicConf = s.H3.QUICConfig.Clone()
} else {
quicConf = &quic.Config{}
}
quicConf.EnableDatagrams = true
quicConf.EnableStreamResetPartialDelivery = true
ln, err := quic.ListenEarly(conn, s.H3.TLSConfig, quicConf)
if err != nil {
return err
}
defer ln.Close()
for {
qconn, err := ln.Accept(s.ctx)
if err != nil {
return err
}
s.refCount.Add(1)
go func() {
defer s.refCount.Done()
if err := s.ServeQUICConn(qconn); err != nil {
log.Printf("http3: error serving QUIC connection: %v", err)
}
}()
}
}
// ServeQUICConn serves a single QUIC connection.
func (s *Server) ServeQUICConn(conn *quic.Conn) error {
connState := conn.ConnectionState()
if !connState.SupportsDatagrams.Local {
return errors.New("webtransport: QUIC DATAGRAM support required, enable it via QUICConfig.EnableDatagrams")
}
if !connState.SupportsStreamResetPartialDelivery.Local {
return errors.New("webtransport: QUIC Stream Resets with Partial Delivery required, enable it via QUICConfig.EnableStreamResetPartialDelivery")
}
if err := s.initialize(); err != nil {
return err
}
s.connsMx.Lock()
sessMgr, ok := s.conns[conn]
if !ok {
sessMgr = newSessionManager(s.timeout())
s.conns[conn] = sessMgr
}
s.connsMx.Unlock()
// Clean up when connection closes
context.AfterFunc(conn.Context(), func() {
s.connsMx.Lock()
delete(s.conns, conn)
s.connsMx.Unlock()
sessMgr.Close()
})
http3Conn, err := s.H3.NewRawServerConn(conn)
if err != nil {
return err
}
// slose the connection when the server context is cancelled.
go func() {
select {
case <-s.ctx.Done():
conn.CloseWithError(0, "")
case <-conn.Context().Done():
// connection already closed
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for {
str, err := conn.AcceptStream(s.ctx)
if err != nil {
return
}
wg.Add(1)
go func() {
defer wg.Done()
typ, err := quicvarint.Peek(str)
if err != nil {
return
}
if typ != webTransportFrameType {
http3Conn.HandleRequestStream(str)
return
}
// read the frame type (already peeked)
if _, err := quicvarint.Read(quicvarint.NewReader(str)); err != nil {
return
}
// read the session ID
id, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
str.CancelRead(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
str.CancelWrite(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
return
}
sessMgr.AddStream(str, sessionID(id))
}()
}
}()
go func() {
defer wg.Done()
for {
str, err := conn.AcceptUniStream(s.ctx)
if err != nil {
return
}
wg.Add(1)
go func() {
defer wg.Done()
typ, err := quicvarint.Peek(str)
if err != nil {
return
}
if typ != webTransportUniStreamType {
http3Conn.HandleUnidirectionalStream(str)
return
}
// read the stream type (already peeked) before passing to AddUniStream
r := quicvarint.NewReader(str)
if _, err := quicvarint.Read(r); err != nil {
return
}
// read the session ID
id, err := quicvarint.Read(r)
if err != nil {
str.CancelRead(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
return
}
sessMgr.AddUniStream(str, sessionID(id))
}()
}
}()
wg.Wait()
return nil
}
func (s *Server) ListenAndServe() error {
addr := s.H3.Addr
if addr == "" {
addr = ":https"
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
return s.Serve(conn)
}
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
if s.H3.TLSConfig == nil {
s.H3.TLSConfig = &tls.Config{}
}
s.H3.TLSConfig.Certificates = []tls.Certificate{cert}
return s.ListenAndServe()
}
func (s *Server) Close() error {
// Make sure that ctxCancel is defined.
// This is expected to be uncommon.
// It only happens if the server is closed without Serve / ListenAndServe having been called.
s.initOnce.Do(func() {})
if s.ctxCancel != nil {
s.ctxCancel()
}
s.connsMx.Lock()
if s.conns != nil {
for _, mgr := range s.conns {
mgr.Close()
}
s.conns = nil
}
s.connsMx.Unlock()
err := s.H3.Close()
s.refCount.Wait()
return err
}
func (s *Server) Upgrade(w http.ResponseWriter, r *http.Request) (*Session, error) {
if err := s.initialize(); err != nil {
return nil, err
}
if r.Method != http.MethodConnect {
return nil, fmt.Errorf("expected CONNECT request, got %s", r.Method)
}
if r.Proto != protocolHeader {
return nil, fmt.Errorf("unexpected protocol: %s", r.Proto)
}
if !s.CheckOrigin(r) {
return nil, errors.New("webtransport: request origin not allowed")
}
id := r.Context().Value(quicConnKey)
if id == nil {
return nil, errors.New("webtransport: missing QUIC connection")
}
conn := id.(*quic.Conn)
selectedProtocol := s.selectProtocol(r.Header[http.CanonicalHeaderKey(wtAvailableProtocolsHeader)])
// Wait for SETTINGS
settingser := w.(http3.Settingser)
timer := time.NewTimer(s.timeout())
defer timer.Stop()
select {
case <-settingser.ReceivedSettings():
case <-timer.C:
return nil, errors.New("webtransport: didn't receive the client's SETTINGS on time")
}
settings := settingser.Settings()
if !settings.EnableDatagrams {
return nil, errors.New("webtransport: missing datagram support")
}
if selectedProtocol != "" {
v, err := httpsfv.Marshal(httpsfv.NewItem(selectedProtocol))
if err != nil {
return nil, fmt.Errorf("failed to marshal selected protocol: %w", err)
}
w.Header().Add(wtProtocolHeader, v)
}
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()
str := w.(http3.HTTPStreamer).HTTPStream()
sessID := sessionID(str.StreamID())
// The session manager should already exist because ServeQUICConn creates it
// before any HTTP requests can be processed on this connection.
s.connsMx.Lock()
defer s.connsMx.Unlock()
sessMgr, ok := s.conns[conn]
if !ok {
return nil, errors.New("webtransport: connection session manager not found")
}
sess := newSession(context.WithoutCancel(r.Context()), sessID, conn, str, selectedProtocol)
sessMgr.AddSession(sessID, sess)
return sess, nil
}
func (s *Server) selectProtocol(theirs []string) string {
list, err := httpsfv.UnmarshalList(theirs)
if err != nil {
return ""
}
offered := make([]string, 0, len(list))
for _, item := range list {
i, ok := item.(httpsfv.Item)
if !ok {
return ""
}
protocol, ok := i.Value.(string)
if !ok {
return ""
}
offered = append(offered, protocol)
}
var selectedProtocol string
for _, p := range offered {
if slices.Contains(s.ApplicationProtocols, p) {
selectedProtocol = p
break
}
}
return selectedProtocol
}
// copied from https://github.com/gorilla/websocket
func checkSameOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
// copied from https://github.com/gorilla/websocket
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}