diff --git a/.github/workflows/go-vet.yml b/.github/workflows/go-vet.yml new file mode 100644 index 0000000..ed30d12 --- /dev/null +++ b/.github/workflows/go-vet.yml @@ -0,0 +1,43 @@ +name: golang-ci + +on: [push, pull_request] + +jobs: + + build: + name: CI + runs-on: ubuntu-latest + steps: + - name: Set up Go 1.x + uses: actions/setup-go@master + with: + go-version: "1.20" + + - name: Check out code into the Go module directory + uses: actions/checkout@master + + - name: Get dependencies + run: go mod tidy + + - name: Build + run: go build -v ./... + + - name: Test + run: go test $(go list ./...) + + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - name: Set up Go 1.x + uses: actions/setup-go@master + with: + go-version: "1.20" + + - name: Check out code into the Go module directory + uses: actions/checkout@master + + - name: golangci-lint + uses: golangci/golangci-lint-action@master + with: + version: latest diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..290397a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,68 @@ +linters-settings: + errcheck: + ignore: fmt:.*,io/ioutil:^Read.* + ignoretests: true + + goimports: + local-prefixes: github.com/fumiama/terasu + +linters: + # please, do not use `enable-all`: it's deprecated and will be removed soon. + # inverted configuration with `enable-all` and `disable` is not scalable during updates of golangci-lint + disable-all: true + fast: false + enable: + - bodyclose + #- depguard + - dogsled + - errcheck + #- exportloopref + - exhaustive + #- funlen + #- goconst + - gocritic + #- gocyclo + - gofmt + - goimports + - goprintffuncname + #- gosec + - gosimple + - govet + - ineffassign + #- misspell + - nolintlint + - rowserrcheck + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + - prealloc + - predeclared + - asciicheck + - revive + - forbidigo + - makezero + + +run: + # default concurrency is a available CPU number. + # concurrency: 4 # explicitly omit this value to fully utilize available resources. + deadline: 5m + issues-exit-code: 1 + tests: false + go: '1.20' + +# output configuration options +output: + formats: + - format: "colored-line-number" + print-issued-lines: true + print-linter-name: true + uniq-by-line: true + +issues: + # Fix found issues (if it's supported by the linter) + fix: true diff --git a/README.md b/README.md index 6f33cfc..5f25408 100644 --- a/README.md +++ b/README.md @@ -15,5 +15,7 @@ ## Usage ```go -terasu.Use(tlsConn).Handshake() +tls.Client(terasu.NewConn(conn), &tls.Config{ + ServerName: host, +}).Handshake() ``` diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..39668c1 --- /dev/null +++ b/builder.go @@ -0,0 +1,28 @@ +package terasu + +import ( + "io" + "net" +) + +type builder net.Buffers + +func newbuilder() builder { + return builder{} +} + +// move is write without copy +func (bd *builder) move(b []byte) { + *bd = append(*bd, b) +} + +func (bd *builder) send(conn *net.TCPConn, rs ...io.Reader) (int64, error) { + if len(rs) == 0 { + return conn.ReadFrom((*net.Buffers)(bd)) + } + return conn.ReadFrom(io.MultiReader(append([]io.Reader{(*net.Buffers)(bd)}, rs...)...)) +} + +func (bd *builder) reset() { + *bd = (*bd)[:0] +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..8ce9265 --- /dev/null +++ b/conn.go @@ -0,0 +1,184 @@ +package terasu + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "sync" + "time" +) + +// DefaultFirstFragmentLen ... +var DefaultFirstFragmentLen = 4 + +// Conn remote: real server; local: relay +type Conn struct { + mu sync.Mutex + conn *net.TCPConn + isold bool +} + +// NewConn wraps *net.TCPConn (net.Conn must be *net.TCPConn) +func NewConn(conn net.Conn) *Conn { + return &Conn{conn: conn.(*net.TCPConn)} +} + +// Write is send +func (conn *Conn) Write(b []byte) (int, error) { + if conn.isold || DefaultFirstFragmentLen == 0 { + return conn.conn.Write(b) + } + conn.mu.Lock() + defer conn.mu.Unlock() + n, err := conn.ReadFrom(bytes.NewReader(b)) + return int(n), err +} + +// ReadFrom when client want to send to server, detect and split. +func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) { + // ContentType [0:1] + // Version [1:3] + // Length [3:5] + // Payload[Length] -> + // HandshakeType [5:6] + // HandshakeBodyLength [6:9] + var header [1 + 2 + 2 + 1 + 3]byte + x := DefaultFirstFragmentLen + // preserved as tmp vars + plen := uint16(0) + var b []byte + bd := newbuilder() + + defer func() { + conn.isold = true + }() + + // ContentType [0:1] Version [1:2] 0x03 + _, err = io.ReadFull(r, header[:2]) + if err != nil { + return + } + // recordTypeHandshake = 0x16 + // Version [1:2] = 0x03 + if binary.BigEndian.Uint16(header[:2]) != 0x1603 { + bd.move(header[:2]) + goto PIPE + } + // Version [2:3] (0x01 1.0) (0x02 1.1) (0x03 1.2) (0x04 1.3) + _, err = io.ReadFull(r, header[2:3]) + if err != nil { + return + } + // skip unsupported version + if header[2] < 1 || header[2] > 4 { + bd.move(header[:3]) + goto PIPE + } + // Length [3:5] HandshakeType [5:6] + _, err = io.ReadFull(r, header[3:6]) + if err != nil { + return + } + // skip unsupported handshake type + if header[5] != 0x01 { // client hello + bd.move(header[:6]) + goto PIPE + } + // HandshakeBodyLength [6:9] + _, err = io.ReadFull(r, header[6:9]) + if err != nil { + return + } + plen = binary.BigEndian.Uint16(header[3:5]) + if binary.BigEndian.Uint32(header[5:9])&0x00ffffff+ // body + // handshake type, body length + 1+3 != + // payload length + uint32(plen) { + bd.move(header[:9]) + goto PIPE + } + + // split + if x <= 4 { // first is in header range + // first + binary.BigEndian.PutUint16(header[3:5], uint16(x)) + bd.move(header[:5+x]) + n, err = bd.send(conn.conn) + bd.reset() + if err != nil { + return + } + copy(header[5:5+x], header[9-x:9]) + // second + binary.BigEndian.PutUint16(header[3:5], plen-uint16(x)) + bd.move(header[:9-x]) + goto PIPE + } + // first is out of header range + // first + binary.BigEndian.PutUint16(header[3:5], uint16(x)) + bd.move(header[:9]) + b = make([]byte, x-4) + _, err = io.ReadFull(r, b) + if err != nil { + return + } + bd.move(b) + n, err = bd.send(conn.conn) + bd.reset() + if err != nil { + return + } + // second + binary.BigEndian.PutUint16(header[3:5], plen-uint16(x)) + bd.move(header[:5]) +PIPE: + if err != nil { + return + } + cnt, err := bd.send(conn.conn, r) + n += cnt + return +} + +// Read is recv +func (conn *Conn) Read(b []byte) (int, error) { + return conn.conn.Read(b) +} + +// WriteTo remote response and releay it to local client without any change. +func (conn *Conn) WriteTo(w io.Writer) (int64, error) { + return conn.conn.WriteTo(w) +} + +// Close closes the connection. +func (conn *Conn) Close() error { + return conn.conn.Close() +} + +// LocalAddr returns the local network address. +func (conn *Conn) LocalAddr() net.Addr { + return conn.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (conn *Conn) RemoteAddr() net.Addr { + return conn.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +func (conn *Conn) SetDeadline(t time.Time) error { + return conn.conn.SetDeadline(t) +} + +// SetReadDeadline sets the deadline for future Read calls. +func (conn *Conn) SetReadDeadline(t time.Time) error { + return conn.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (conn *Conn) SetWriteDeadline(t time.Time) error { + return conn.conn.SetWriteDeadline(t) +} diff --git a/terasu_test.go b/conn_test.go similarity index 79% rename from terasu_test.go rename to conn_test.go index 4c072db..32f7453 100644 --- a/terasu_test.go +++ b/conn_test.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/http" + "net/netip" "testing" ) @@ -12,17 +13,19 @@ func TestHTTPDialTLS13(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.Dial("tcp", "18.65.159.2:443") + conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( + netip.MustParseAddrPort("52.222.136.117:443"), + )) if err != nil { return nil, err } t.Log("net.Dial succeeded") - tlsConn := tls.Client(conn, &tls.Config{ + tlsConn := tls.Client(NewConn(conn), &tls.Config{ ServerName: "huggingface.co", MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, }) - err = Use(tlsConn).Handshake(4) + err = tlsConn.Handshake() if err != nil { _ = tlsConn.Close() return nil, err @@ -50,18 +53,20 @@ func TestHTTPDialTLS12(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.Dial("tcp", "18.65.159.2:443") + conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( + netip.MustParseAddrPort("52.222.136.117:443"), + )) if err != nil { return nil, err } t.Log("net.Dial succeeded") - tlsConn := tls.Client(conn, &tls.Config{ + tlsConn := tls.Client(NewConn(conn), &tls.Config{ ServerName: "huggingface.co", InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS12, }) - err = Use(tlsConn).Handshake(4) + err = tlsConn.Handshake() if err != nil { _ = tlsConn.Close() return nil, err diff --git a/dns/dns.go b/dns/dns.go index b98e6c0..173aa3f 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -192,7 +192,7 @@ func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []stri return nil, ErrNoDNSAvailable } -func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFragmentLen uint8) (tlsConn *tls.Conn, err error) { +func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) { err = ErrNoDNSAvailable if dialer == nil { @@ -230,7 +230,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra } logrus.Debugln("[terasu.dns] <- dial tcp", host, addr, "succeeded") logrus.Debugln("[terasu.dns] -> hs tls", host, addr) - tlsConn = tls.Client(conn, &tls.Config{ + tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{ ServerName: host, MinVersion: tls.VersionTLS12, NextProtos: []string{"dns"}, @@ -245,13 +245,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) defer cancel() } - if firstFragmentLen > 0 { - logrus.Debugln("[terasu.dns] -- hs tls", host, addr, "use first frag len", firstFragmentLen) - err = terasu.Use(tlsConn).HandshakeContext(ctx, firstFragmentLen) - } else { - logrus.Debugln("[terasu.dns] -- hs tls", host, addr, "normally") - err = tlsConn.HandshakeContext(ctx) - } + err = tlsConn.HandshakeContext(ctx) if err == nil { logrus.Debugln("[terasu.dns] <- hs tls", host, addr, "succeeded") // this is a successful server, keep it @@ -348,8 +342,8 @@ var DefaultResolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, nw, _ string) (net.Conn, error) { if ip.IsIPv6Available { - return IPv6Servers.DialContext(ctx, nil, terasu.DefaultFirstFragmentLen) + return IPv6Servers.DialContext(ctx, nil) } - return IPv4Servers.DialContext(ctx, nil, terasu.DefaultFirstFragmentLen) + return IPv4Servers.DialContext(ctx, nil) }, } diff --git a/dns/dns_test.go b/dns/dns_test.go index ecefc44..8debc5b 100644 --- a/dns/dns_test.go +++ b/dns/dns_test.go @@ -114,11 +114,11 @@ func (ds *DNSList) test() { if err != nil { continue } - tlsConn := tls.Client(conn, &tls.Config{ + tlsConn := tls.Client(terasu.NewConn(conn), &tls.Config{ ServerName: host, MinVersion: tls.VersionTLS12, }) - err = terasu.Use(tlsConn).Handshake(4) + err = tlsConn.Handshake() _ = tlsConn.Close() if err == nil { fmt.Println("succ:", host, addr.addr) diff --git a/dns/doh.go b/dns/doh.go index b8e3ccc..a2f1509 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -88,12 +88,8 @@ var trsHTTP2ClientWithSystemDNS = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, cfg) - if terasu.DefaultFirstFragmentLen > 0 { - err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) - } else { - err = tlsConn.HandshakeContext(ctx) - } + tlsConn = tls.Client(terasu.NewConn(conn), cfg) + err = tlsConn.HandshakeContext(ctx) if err == nil { break } @@ -103,7 +99,7 @@ var trsHTTP2ClientWithSystemDNS = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, cfg) + tlsConn = tls.Client(terasu.NewConn(conn), cfg) err = tlsConn.HandshakeContext(ctx) if err == nil { break diff --git a/handshake_1.20.go b/handshake_1.20.go deleted file mode 100644 index fb0661d..0000000 --- a/handshake_1.20.go +++ /dev/null @@ -1,328 +0,0 @@ -//go:build !go1.21 - -package terasu - -import ( - "context" - "crypto/ecdh" - "crypto/tls" - "crypto/x509" - "errors" - "hash" - "time" - "unsafe" -) - -//go:linkname defaultConfig crypto/tls.defaultConfig -func defaultConfig() *tls.Config - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string -} - -//go:linkname marshal crypto/tls.(*clientHelloMsg).marshal -func marshal(m *clientHelloMsg) ([]byte, error) - -func (m *clientHelloMsg) marshal() ([]byte, error) { - return marshal(m) -} - -//go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal -func unmarshal(m *clientHelloMsg, data []byte) bool - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - return unmarshal(m, data) -} - -//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello -func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) - -func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { - return makeClientHello(c) -} - -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type sessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - -//go:linkname loadSession crypto/tls.(*Conn).loadSession -func loadSession(c *_trsconn, hello *clientHelloMsg) (cacheKey string, - session *sessionState, earlySecret, binderKey []byte, err error, -) - -func (c *_trsconn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *sessionState, earlySecret, binderKey []byte, err error, -) { - return loadSession(c, hello) -} - -type handshakeMessage interface { - marshal() ([]byte, error) - unmarshal([]byte) bool -} - -type transcriptHash interface { - Write([]byte) (int, error) -} - -//go:linkname transcriptMsg crypto/tls.transcriptMsg -func transcriptMsg(msg handshakeMessage, h transcriptHash) error - -//go:linkname readHandshake crypto/tls.(*Conn).readHandshake -func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) - -func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { - return readHandshake(c, transcript) -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte -} - -//go:linkname sendAlert crypto/tls.(*Conn).sendAlert -func sendAlert(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlert(err alert) error { - return sendAlert(c, err) -} - -//go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError -func unexpectedMessageError(wanted, got any) error - -const ( - alertUnexpectedMessage alert = 10 - alertIllegalParameter alert = 47 -) - -//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion -func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error - -func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { - return pickTLSVersion(c, serverHello) -} - -//go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion -func maxSupportedVersion(c *tls.Config, isClient bool) uint16 - -const roleClient = true - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - - session *sessionState - earlySecret []byte - binderKey []byte - - certReq unsafe.Pointer - usingPSK bool - sentDummyCCS bool - suite unsafe.Pointer - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 -} - -//go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake -func handshake13(hs *clientHandshakeStateTLS13) error - -func (hs *clientHandshakeStateTLS13) handshake() error { - return handshake13(hs) -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite unsafe.Pointer - finishedHash finishedHash - masterSecret []byte - session *sessionState // the session being resumed - ticket []byte // a fresh ticket received during this handshake -} - -//go:linkname handshake crypto/tls.(*clientHandshakeState).handshake -func handshake(hs *clientHandshakeState) error - -func (hs *clientHandshakeState) handshake() error { - return handshake(hs) -} - -// writeHandshakeRecord writes a handshake message to the connection and updates -// the record layer state. If transcript is non-nil the marshalled message is -// written to it. -func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { - c.out.Lock() - defer c.out.Unlock() - - data, err := msg.marshal() - if err != nil { - return 0, err - } - if transcript != nil { - transcript.Write(data) - } - - return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) -} - -func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { - return func(ctx context.Context) (err error) { - c := (*_trsconn)(unsafe.Pointer(cout)) - - if c.config == nil { - c.config = defaultConfig() - } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheKey, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if cacheKey != "" && session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - }() - } - - if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { - return err - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - var serverHello *serverHelloMsg - if !isTypeEqual(msg, "*tls.serverHelloMsg") { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( - unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), - )) - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := maxSupportedVersion(c.config, roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == tls.VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - ecdheKey: ecdheKey, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - if cacheKey != "" && hs.session != nil && session != hs.session { - c.config.ClientSessionCache.Put(cacheKey, (*tls.ClientSessionState)(unsafe.Pointer(hs.session))) - } - - return nil - } -} diff --git a/handshake_1.21.go b/handshake_1.21.go deleted file mode 100644 index dbb6830..0000000 --- a/handshake_1.21.go +++ /dev/null @@ -1,426 +0,0 @@ -//go:build go1.21 && !go1.24 - -package terasu - -import ( - "context" - "crypto" - "crypto/ecdh" - "crypto/tls" - "errors" - "hash" - "unsafe" -) - -//go:linkname defaultConfig crypto/tls.defaultConfig -func defaultConfig() *tls.Config - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []tls.CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []tls.SignatureScheme - supportedSignatureAlgorithmsCert []tls.SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - extendedMasterSecret bool - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []byte - earlyData bool -} - -//go:linkname marshal crypto/tls.(*clientHelloMsg).marshal -func marshal(m *clientHelloMsg) ([]byte, error) - -func (m *clientHelloMsg) marshal() ([]byte, error) { - return marshal(m) -} - -//go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal -func unmarshal(m *clientHelloMsg, data []byte) bool - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - return unmarshal(m, data) -} - -//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello -func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) - -func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { - return makeClientHello(c) -} - -// A sessionState is a resumable session. -type sessionState struct { - // Encoded as a SessionState (in the language of RFC 8446, Section 3). - // - // enum { server(1), client(2) } SessionStateType; - // - // opaque Certificate<1..2^24-1>; - // - // Certificate CertificateChain<0..2^24-1>; - // - // opaque Extra<0..2^24-1>; - // - // struct { - // uint16 version; - // SessionStateType type; - // uint16 cipher_suite; - // uint64 created_at; - // opaque secret<1..2^8-1>; - // Extra extra<0..2^24-1>; - // uint8 ext_master_secret = { 0, 1 }; - // uint8 early_data = { 0, 1 }; - // CertificateEntry certificate_list<0..2^24-1>; - // CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */ - // select (SessionState.early_data) { - // case 0: Empty; - // case 1: opaque alpn<1..2^8-1>; - // }; - // select (SessionState.type) { - // case server: Empty; - // case client: struct { - // select (SessionState.version) { - // case VersionTLS10..VersionTLS12: Empty; - // case VersionTLS13: struct { - // uint64 use_by; - // uint32 age_add; - // }; - // }; - // }; - // }; - // } SessionState; - // - - // Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes] - // and parsed by [ParseSessionState]. - // - // This allows [Config.UnwrapSession]/[Config.WrapSession] and - // [ClientSessionCache] implementations to store and retrieve additional - // data alongside this session. - // - // To allow different layers in a protocol stack to share this field, - // applications must only append to it, not replace it, and must use entries - // that can be recognized even if out of order (for example, by starting - // with an id and version prefix). - Extra [][]byte - - // EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC - // connection. The application may set this to false if it is true to - // decline to offer 0-RTT even if supported. - EarlyData bool - - version uint16 - isClient bool - cipherSuite uint16 -} - -//go:linkname loadSession crypto/tls.(*Conn).loadSession -func loadSession(c *_trsconn, hello *clientHelloMsg) ( - session *sessionState, earlySecret, binderKey []byte, err error, -) - -func (c *_trsconn) loadSession(hello *clientHelloMsg) ( - session *sessionState, earlySecret, binderKey []byte, err error, -) { - return loadSession(c, hello) -} - -//go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey -func clientSessionCacheKey(c *_trsconn) string - -func (c *_trsconn) clientSessionCacheKey() string { - return clientSessionCacheKey(c) -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) any - hash crypto.Hash -} - -//go:linkname deriveSecret crypto/tls.(*cipherSuiteTLS13).deriveSecret -func deriveSecret(c *cipherSuiteTLS13, secret []byte, label string, transcript hash.Hash) []byte - -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - return deriveSecret(c, secret, label, transcript) -} - -//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -type handshakeMessage interface { - marshal() ([]byte, error) - unmarshal([]byte) bool -} - -type transcriptHash interface { - Write([]byte) (int, error) -} - -//go:linkname transcriptMsg crypto/tls.transcriptMsg -func transcriptMsg(msg handshakeMessage, h transcriptHash) error - -const clientEarlyTrafficLabel = "c e traffic" - -//go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret -func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) - -//go:linkname readHandshake crypto/tls.(*Conn).readHandshake -func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) - -func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { - return readHandshake(c, transcript) -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte -} - -//go:linkname sendAlert crypto/tls.(*Conn).sendAlert -func sendAlert(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlert(err alert) error { - return sendAlert(c, err) -} - -//go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError -func unexpectedMessageError(wanted, got any) error - -const ( - alertUnexpectedMessage alert = 10 - alertIllegalParameter alert = 47 -) - -//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion -func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error - -func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { - return pickTLSVersion(c, serverHello) -} - -//go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion -func maxSupportedVersion(c *tls.Config, isClient bool) uint16 - -const roleClient = true - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - - session *sessionState - earlySecret []byte - binderKey []byte - - certReq unsafe.Pointer - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 -} - -//go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake -func handshake13(hs *clientHandshakeStateTLS13) error - -func (hs *clientHandshakeStateTLS13) handshake() error { - return handshake13(hs) -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite unsafe.Pointer - finishedHash finishedHash - masterSecret []byte - session *sessionState // the session being resumed - ticket []byte // a fresh ticket received during this handshake -} - -//go:linkname handshake crypto/tls.(*clientHandshakeState).handshake -func handshake(hs *clientHandshakeState) error - -func (hs *clientHandshakeState) handshake() error { - return handshake(hs) -} - -// writeHandshakeRecord writes a handshake message to the connection and updates -// the record layer state. If transcript is non-nil the marshalled message is -// written to it. -func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { - c.out.Lock() - defer c.out.Unlock() - - data, err := msg.marshal() - if err != nil { - return 0, err - } - if transcript != nil { - transcript.Write(data) - } - - return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) -} - -func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { - return func(ctx context.Context) (err error) { - c := (*_trsconn)(unsafe.Pointer(cout)) - - if c.config == nil { - c.config = defaultConfig() - } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheKey, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - } - }() - } - - if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { - return err - } - - if hello.earlyData { - suite := cipherSuiteTLS13ByID(session.cipherSuite) - transcript := suite.hash.New() - if err := transcriptMsg(hello, transcript); err != nil { - return err - } - earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) - quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - var serverHello *serverHelloMsg - if !isTypeEqual(msg, "*tls.serverHelloMsg") { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( - unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), - )) - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := maxSupportedVersion(c.config, roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == tls.VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - ecdheKey: ecdheKey, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - return nil - } -} diff --git a/handshake_1.24.go b/handshake_1.24.go deleted file mode 100644 index 8fadb07..0000000 --- a/handshake_1.24.go +++ /dev/null @@ -1,609 +0,0 @@ -//go:build go1.24 - -package terasu - -import ( - "context" - "crypto" - "crypto/cipher" - "crypto/ecdh" - "crypto/mlkem" - "crypto/tls" - "crypto/x509" - "errors" - "hash" - "io" - "unsafe" -) - -//go:linkname defaultConfig crypto/tls.defaultConfig -func defaultConfig() *tls.Config - -// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved -// session. See RFC 8446, Section 4.2.11. -type pskIdentity struct { - label []byte - obfuscatedTicketAge uint32 -} - -type clientHelloMsg struct { - original []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []tls.CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []tls.SignatureScheme - supportedSignatureAlgorithmsCert []tls.SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - extendedMasterSecret bool - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []byte - earlyData bool - pskModes []uint8 - pskIdentities []pskIdentity - pskBinders [][]byte - quicTransportParameters []byte - encryptedClientHello []byte - // extensions are only populated on the server-side of a handshake - extensions []uint16 -} - -//go:linkname marshal crypto/tls.(*clientHelloMsg).marshal -func marshal(m *clientHelloMsg) ([]byte, error) - -func (m *clientHelloMsg) marshal() ([]byte, error) { - return marshal(m) -} - -//go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal -func unmarshal(m *clientHelloMsg, data []byte) bool - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - return unmarshal(m, data) -} - -//go:linkname clone crypto/tls.(*clientHelloMsg).clone -func clone(m *clientHelloMsg) *clientHelloMsg - -func (m *clientHelloMsg) clone() *clientHelloMsg { - return clone(m) -} - -type keySharePrivateKeys struct { - curveID tls.CurveID - ecdhe *ecdh.PrivateKey - mlkem *mlkem.DecapsulationKey768 -} - -type echCipher struct { - KDFID uint16 - AEADID uint16 -} - -type echExtension struct { - Type uint16 - Data []byte -} - -type echConfig struct { - raw []byte - - Version uint16 - Length uint16 - - ConfigID uint8 - KemID uint16 - PublicKey []byte - SymmetricCipherSuite []echCipher - - MaxNameLength uint8 - PublicName []byte - Extensions []echExtension -} - -type uint128 struct { - hi, lo uint64 -} - -type hpkecontext struct { - aead cipher.AEAD - - sharedSecret []byte - - suiteID []byte - - key []byte - baseNonce []byte - exporterSecret []byte - - seqNum uint128 -} - -type hpkeSender struct { - *hpkecontext -} - -type echClientContext struct { - config *echConfig - hpkeContext *hpkeSender - encapsulatedKey []byte - innerHello *clientHelloMsg - innerTranscript hash.Hash - kdfID uint16 - aeadID uint16 - echRejected bool - retryConfigs []byte -} - -//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello -func makeClientHello(c *_trsconn) (*clientHelloMsg, *keySharePrivateKeys, *echClientContext, error) - -func (c *_trsconn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echClientContext, error) { - return makeClientHello(c) -} - -// activeCert is a handle to a certificate held in the cache. Once there are -// no alive activeCerts for a given certificate, the certificate is removed -// from the cache by a finalizer. -type activeCert struct { - cert *x509.Certificate -} - -// A sessionState is a resumable session. -type sessionState struct { - // Encoded as a SessionState (in the language of RFC 8446, Section 3). - // - // enum { server(1), client(2) } SessionStateType; - // - // opaque Certificate<1..2^24-1>; - // - // Certificate CertificateChain<0..2^24-1>; - // - // opaque Extra<0..2^24-1>; - // - // struct { - // uint16 version; - // SessionStateType type; - // uint16 cipher_suite; - // uint64 created_at; - // opaque secret<1..2^8-1>; - // Extra extra<0..2^24-1>; - // uint8 ext_master_secret = { 0, 1 }; - // uint8 early_data = { 0, 1 }; - // CertificateEntry certificate_list<0..2^24-1>; - // CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */ - // select (SessionState.early_data) { - // case 0: Empty; - // case 1: opaque alpn<1..2^8-1>; - // }; - // select (SessionState.type) { - // case server: Empty; - // case client: struct { - // select (SessionState.version) { - // case VersionTLS10..VersionTLS12: Empty; - // case VersionTLS13: struct { - // uint64 use_by; - // uint32 age_add; - // }; - // }; - // }; - // }; - // } SessionState; - // - - // Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes] - // and parsed by [ParseSessionState]. - // - // This allows [Config.UnwrapSession]/[Config.WrapSession] and - // [ClientSessionCache] implementations to store and retrieve additional - // data alongside this session. - // - // To allow different layers in a protocol stack to share this field, - // applications must only append to it, not replace it, and must use entries - // that can be recognized even if out of order (for example, by starting - // with an id and version prefix). - Extra [][]byte - - // EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC - // connection. The application may set this to false if it is true to - // decline to offer 0-RTT even if supported. - EarlyData bool - - version uint16 - isClient bool - cipherSuite uint16 - // createdAt is the generation time of the secret on the sever (which for - // TLS 1.0–1.2 might be earlier than the current session) and the time at - // which the ticket was received on the client. - createdAt uint64 // seconds since UNIX epoch - secret []byte // master secret for TLS 1.2, or the PSK for TLS 1.3 - extMasterSecret bool - peerCertificates []*x509.Certificate - activeCertHandles []*activeCert - ocspResponse []byte - scts [][]byte - verifiedChains [][]*x509.Certificate - alpnProtocol string // only set if EarlyData is true - - // Client-side TLS 1.3-only fields. - useBy uint64 // seconds since UNIX epoch - ageAdd uint32 - ticket []byte -} - -type earlySecret struct { - secret []byte - hash func() any -} - -//go:linkname clientEarlyTrafficSecret crypto/internal/fips140/tls13.(*EarlySecret).ClientEarlyTrafficSecret -func clientEarlyTrafficSecret(s *earlySecret, transcript any) []byte - -//go:linkname loadSession crypto/tls.(*Conn).loadSession -func loadSession(c *_trsconn, hello *clientHelloMsg) ( - session *sessionState, earlySecret *earlySecret, binderKey []byte, err error, -) - -func (c *_trsconn) loadSession(hello *clientHelloMsg) ( - session *sessionState, earlySecret *earlySecret, binderKey []byte, err error, -) { - return loadSession(c, hello) -} - -//go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey -func clientSessionCacheKey(c *_trsconn) string - -func (c *_trsconn) clientSessionCacheKey() string { - return clientSessionCacheKey(c) -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) any - hash crypto.Hash -} - -//go:linkname deriveSecret crypto/tls.(*cipherSuiteTLS13).deriveSecret -func deriveSecret(c *cipherSuiteTLS13, secret []byte, label string, transcript hash.Hash) []byte - -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - return deriveSecret(c, secret, label, transcript) -} - -//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -type handshakeMessage interface { - marshal() ([]byte, error) - unmarshal([]byte) bool -} - -type transcriptHash interface { - Write([]byte) (int, error) -} - -//go:linkname transcriptMsg crypto/tls.transcriptMsg -func transcriptMsg(msg handshakeMessage, h transcriptHash) error - -const clientEarlyTrafficLabel = "c e traffic" - -//go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret -func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) - -//go:linkname readHandshake crypto/tls.(*Conn).readHandshake -func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) - -func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { - return readHandshake(c, transcript) -} - -// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. -type keyShare struct { - group tls.CurveID - data []byte -} - -type serverHelloMsg struct { - original []byte - vers uint16 - random []byte - sessionId []byte - cipherSuite uint16 - compressionMethod uint8 - ocspStapling bool - ticketSupported bool - secureRenegotiationSupported bool - secureRenegotiation []byte - extendedMasterSecret bool - alpnProtocol string - scts [][]byte - supportedVersion uint16 - serverShare keyShare - selectedIdentityPresent bool - selectedIdentity uint16 - supportedPoints []uint8 - encryptedClientHello []byte - serverNameAck bool - - // HelloRetryRequest extensions - cookie []byte - selectedGroup tls.CurveID -} - -//go:linkname sendAlert crypto/tls.(*Conn).sendAlert -func sendAlert(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlert(err alert) error { - return sendAlert(c, err) -} - -//go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError -func unexpectedMessageError(wanted, got any) error - -const ( - alertUnexpectedMessage alert = 10 - alertIllegalParameter alert = 47 -) - -//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion -func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error - -func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { - return pickTLSVersion(c, serverHello) -} - -//go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion -func maxSupportedVersion(c *tls.Config, isClient bool) uint16 - -const roleClient = true - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - keyShareKeys *keySharePrivateKeys - - session *sessionState - earlySecret *earlySecret - binderKey []byte - - certReq unsafe.Pointer - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret unsafe.Pointer - trafficSecret []byte // client_application_traffic_secret_0 - - echContext *echClientContext -} - -//go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake -func handshake13(hs *clientHandshakeStateTLS13) error - -func (hs *clientHandshakeStateTLS13) handshake() error { - return handshake13(hs) -} - -type prfFunc func(secret []byte, label string, seed []byte, keyLen int) []byte - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf prfFunc -} - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite unsafe.Pointer - finishedHash finishedHash - masterSecret []byte - session *sessionState // the session being resumed - ticket []byte // a fresh ticket received during this handshake -} - -//go:linkname handshake crypto/tls.(*clientHandshakeState).handshake -func handshake(hs *clientHandshakeState) error - -func (hs *clientHandshakeState) handshake() error { - return handshake(hs) -} - -//go:linkname computeAndUpdateOuterECHExtension crypto/tls.computeAndUpdateOuterECHExtension -func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echClientContext, useKey bool) error - -// writeHandshakeRecord writes a handshake message to the connection and updates -// the record layer state. If transcript is non-nil the marshalled message is -// written to it. -func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { - c.out.Lock() - defer c.out.Unlock() - - data, err := msg.marshal() - if err != nil { - return 0, err - } - if transcript != nil { - transcript.Write(data) - } - - return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) -} - -func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { - return func(ctx context.Context) (err error) { - c := (*_trsconn)(unsafe.Pointer(cout)) - - if c.config == nil { - c.config = defaultConfig() - } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, keyShareKeys, ech, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - } - }() - } - - if ech != nil { - // Split hello into inner and outer - ech.innerHello = hello.clone() - - // Overwrite the server name in the outer hello with the public facing - // name. - hello.serverName = string(ech.config.PublicName) - // Generate a new random for the outer hello. - hello.random = make([]byte, 32) - _, err = io.ReadFull(tlsConfigRand(c.config), hello.random) - if err != nil { - return errors.New("tls: short read from Rand: " + err.Error()) - } - - // NOTE: we don't do PSK GREASE, in line with boringssl, it's meant to - // work around _possibly_ broken middleboxes, but there is little-to-no - // evidence that this is actually a problem. - - if err := computeAndUpdateOuterECHExtension(hello, ech.innerHello, ech, true); err != nil { - return err - } - } - - c.serverName = hello.serverName - - if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { - return err - } - - if hello.earlyData { - suite := cipherSuiteTLS13ByID(session.cipherSuite) - transcript := suite.hash.New() - if err := transcriptMsg(hello, transcript); err != nil { - return err - } - earlyTrafficSecret := clientEarlyTrafficSecret(earlySecret, transcript) - quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - var serverHello *serverHelloMsg - if !isTypeEqual(msg, "*tls.serverHelloMsg") { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( - unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), - )) - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := maxSupportedVersion(c.config, roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == tls.VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - keyShareKeys: keyShareKeys, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - echContext: ech, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - return nil - } -} diff --git a/http/http.go b/http/http.go index 123507e..946c70c 100644 --- a/http/http.go +++ b/http/http.go @@ -58,7 +58,7 @@ var DefaultClient = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, &tls.Config{ + tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{ ServerName: host, MinVersion: tls.VersionTLS12, }) @@ -72,11 +72,7 @@ var DefaultClient = http.Client{ ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) defer cancel() } - if terasu.DefaultFirstFragmentLen > 0 { - err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) - } else { - err = tlsConn.HandshakeContext(ctx) - } + err = tlsConn.HandshakeContext(ctx) if err == nil { break } @@ -86,7 +82,7 @@ var DefaultClient = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, &tls.Config{ + tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{ ServerName: host, MinVersion: tls.VersionTLS12, }) diff --git a/http2/http2.go b/http2/http2.go index ec3a155..e31e24f 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -58,7 +58,7 @@ var DefaultClient = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, cfg) + tlsConn = tls.Client(terasu.NewConn(conn), cfg) // re-init ctx due to deadline settings in tcp dial if defaultDialer.Timeout != 0 { var cancel context.CancelFunc @@ -69,11 +69,7 @@ var DefaultClient = http.Client{ ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) defer cancel() } - if terasu.DefaultFirstFragmentLen > 0 { - err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) - } else { - err = tlsConn.HandshakeContext(ctx) - } + err = tlsConn.HandshakeContext(ctx) if err == nil { break } @@ -83,7 +79,7 @@ var DefaultClient = http.Client{ if err != nil { continue } - tlsConn = tls.Client(conn, cfg) + tlsConn = tls.Client(terasu.NewConn(conn), cfg) err = tlsConn.HandshakeContext(ctx) if err == nil { break diff --git a/ip/ipv6.go b/ip/ipv6.go index 9be0036..0ba5fef 100644 --- a/ip/ipv6.go +++ b/ip/ipv6.go @@ -1,3 +1,3 @@ package ip -var IsIPv6Available = true +var IsIPv6Available = false diff --git a/terasu.go b/terasu.go deleted file mode 100644 index c0d4cb0..0000000 --- a/terasu.go +++ /dev/null @@ -1,32 +0,0 @@ -package terasu - -import ( - "context" - "crypto/tls" - "unsafe" -) - -var DefaultFirstFragmentLen uint8 = 3 - -// Use terasu in this TLS conn -func Use(conn *tls.Conn) *Conn { - return (*Conn)(conn) -} - -// Handshake do terasu handshake in this TLS conn -func (conn *Conn) Handshake(firstFragmentLen uint8) error { - expose := (*_trsconn)(unsafe.Pointer(conn)) - fnbak := expose.handshakeFn - expose.handshakeFn = conn.clientHandshake(firstFragmentLen) - defer func() { expose.handshakeFn = fnbak }() - return (*tls.Conn)(conn).Handshake() -} - -// Handshake do terasu handshake with ctx in this TLS conn -func (conn *Conn) HandshakeContext(ctx context.Context, firstFragmentLen uint8) error { - expose := (*_trsconn)(unsafe.Pointer(conn)) - fnbak := expose.handshakeFn - expose.handshakeFn = conn.clientHandshake(firstFragmentLen) - defer func() { expose.handshakeFn = fnbak }() - return (*tls.Conn)(conn).HandshakeContext(ctx) -} diff --git a/tls_1.20.go b/tls_1.20.go deleted file mode 100644 index 3cc4f52..0000000 --- a/tls_1.20.go +++ /dev/null @@ -1,261 +0,0 @@ -//go:build !go1.21 - -package terasu - -import ( - "context" - "crypto/tls" - "crypto/x509" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "unsafe" - _ "unsafe" -) - -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -const ( - recordHeaderLen = 5 // record header length -) - -type alert uint8 - -//go:linkname alertError tls.(tls.alert).Error -func alertError(e alert) string - -func (e alert) Error() string { - return alertError(e) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher any // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher any // next encryption state - nextMac hash.Hash // next MAC algorithm - - trafficSecret []byte // current TLS 1.3 traffic secret -} - -type Conn tls.Conn - -// A _trsconn represents a secured connection. -// It implements the net._trsconn interface. -type _trsconn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - - // isHandshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // isHandshakeComplete is true implies handshakeErr == nil. - isHandshakeComplete atomic.Bool - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *tls.Config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []unsafe.Pointer - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // or sending NewSessionTicket messages. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []byte - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn -} - -//go:linkname outBufPool crypto/tls.outBufPool -var outBufPool sync.Pool - -//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite -func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int - -func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { - return maxPayloadSizeForWrite(c, typ) -} - -//go:linkname sliceForAppend crypto/tls.sliceForAppend -func sliceForAppend(in []byte, n int) (head, tail []byte) - -//go:linkname encrypt crypto/tls.(*halfConn).encrypt -func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) - -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - return encrypt(hc, record, payload, rand) -} - -//go:linkname rand crypto/tls.(*Config).rand -func rand(c *tls.Config) io.Reader - -//go:linkname write crypto/tls.(*Conn).write -func write(c *_trsconn, data []byte) (int, error) - -func (c *_trsconn) write(data []byte) (int, error) { - return write(c, data) -} - -//go:linkname flush crypto/tls.(*Conn).flush -func flush(c *_trsconn) (int, error) - -func (c *_trsconn) flush() (int, error) { - return flush(c) -} - -//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec -func changeCipherSpec(hc *halfConn) error - -func (hc *halfConn) changeCipherSpec() error { - return changeCipherSpec(hc) -} - -//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked -func sendAlertLocked(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlertLocked(err alert) error { - return sendAlertLocked(c, err) -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - isFirstLoop := true - for len(data) > 0 { - m := len(data) - if !isFirstLoop { - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - } else { - m = int(firstFragmentLen) - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = tls.VersionTLS10 - } else if vers == tls.VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = tls.VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - if isFirstLoop { - isFirstLoop = false - if _, err := c.flush(); err != nil { - return n, err - } - } - } - - if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(alert( - *(*uintptr)( - unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), - ), - )) - } - } - - return n, nil -} diff --git a/tls_1.21.go b/tls_1.21.go deleted file mode 100644 index 2cb6b1f..0000000 --- a/tls_1.21.go +++ /dev/null @@ -1,271 +0,0 @@ -//go:build go1.21 && !go1.23 - -package terasu - -import ( - "context" - "crypto/tls" - "crypto/x509" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "unsafe" - _ "unsafe" -) - -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -const ( - recordHeaderLen = 5 // record header length -) - -type alert uint8 - -//go:linkname alertError tls.(tls.alert).Error -func alertError(e alert) string - -func (e alert) Error() string { - return alertError(e) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher any // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher any // next encryption state - nextMac hash.Hash // next MAC algorithm - - level tls.QUICEncryptionLevel // current QUIC encryption level - trafficSecret []byte // current TLS 1.3 traffic secret -} - -type Conn tls.Conn - -// A _trsconn represents a secured connection. -// It implements the net._trsconn interface. -type _trsconn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - quic unsafe.Pointer // nil for non-QUIC connections - - // isHandshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // isHandshakeComplete is true implies handshakeErr == nil. - isHandshakeComplete atomic.Bool - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *tls.Config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - extMasterSecret bool - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []unsafe.Pointer - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // or sending NewSessionTicket messages. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []byte - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn -} - -//go:linkname outBufPool crypto/tls.outBufPool -var outBufPool sync.Pool - -//go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked -func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error) - -//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite -func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int - -func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { - return maxPayloadSizeForWrite(c, typ) -} - -//go:linkname sliceForAppend crypto/tls.sliceForAppend -func sliceForAppend(in []byte, n int) (head, tail []byte) - -//go:linkname encrypt crypto/tls.(*halfConn).encrypt -func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) - -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - return encrypt(hc, record, payload, rand) -} - -//go:linkname rand crypto/tls.(*Config).rand -func rand(c *tls.Config) io.Reader - -//go:linkname write crypto/tls.(*Conn).write -func write(c *_trsconn, data []byte) (int, error) - -func (c *_trsconn) write(data []byte) (int, error) { - return write(c, data) -} - -//go:linkname flush crypto/tls.(*Conn).flush -func flush(c *_trsconn) (int, error) - -func (c *_trsconn) flush() (int, error) { - return flush(c) -} - -//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec -func changeCipherSpec(hc *halfConn) error - -func (hc *halfConn) changeCipherSpec() error { - return changeCipherSpec(hc) -} - -//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked -func sendAlertLocked(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlertLocked(err alert) error { - return sendAlertLocked(c, err) -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { - if c.quic != nil { - return tlsWriteRecordLocked(c, typ, data) - } - - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - isFirstLoop := true - for len(data) > 0 { - m := len(data) - if !isFirstLoop { - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - } else { - m = int(firstFragmentLen) - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = tls.VersionTLS10 - } else if vers == tls.VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = tls.VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - if isFirstLoop { - isFirstLoop = false - if _, err := c.flush(); err != nil { - return n, err - } - } - } - - if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(alert( - *(*uintptr)( - unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), - ), - )) - } - } - - return n, nil -} diff --git a/tls_1.23.go b/tls_1.23.go deleted file mode 100644 index 3c968cd..0000000 --- a/tls_1.23.go +++ /dev/null @@ -1,274 +0,0 @@ -//go:build go1.23 && !go1.24 - -package terasu - -import ( - "context" - "crypto/tls" - "crypto/x509" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "unsafe" - _ "unsafe" -) - -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -const ( - recordHeaderLen = 5 // record header length -) - -type alert uint8 - -//go:linkname alertError tls.(tls.alert).Error -func alertError(e alert) string - -func (e alert) Error() string { - return alertError(e) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher any // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher any // next encryption state - nextMac hash.Hash // next MAC algorithm - - level tls.QUICEncryptionLevel // current QUIC encryption level - trafficSecret []byte // current TLS 1.3 traffic secret -} - -type Conn tls.Conn - -// A _trsconn represents a secured connection. -// It implements the net._trsconn interface. -type _trsconn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - quic unsafe.Pointer // nil for non-QUIC connections - - // isHandshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // isHandshakeComplete is true implies handshakeErr == nil. - isHandshakeComplete atomic.Bool - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *tls.Config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - extMasterSecret bool - didResume bool // whether this connection was a session resumption - didHRR bool // whether a HelloRetryRequest was sent/received - cipherSuite uint16 - curveID tls.CurveID - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []unsafe.Pointer - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // or sending NewSessionTicket messages. - resumptionSecret []byte - echAccepted bool - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []byte - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn -} - -//go:linkname outBufPool crypto/tls.outBufPool -var outBufPool sync.Pool - -//go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked -func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error) - -//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite -func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int - -func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { - return maxPayloadSizeForWrite(c, typ) -} - -//go:linkname sliceForAppend crypto/tls.sliceForAppend -func sliceForAppend(in []byte, n int) (head, tail []byte) - -//go:linkname encrypt crypto/tls.(*halfConn).encrypt -func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) - -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - return encrypt(hc, record, payload, rand) -} - -//go:linkname rand crypto/tls.(*Config).rand -func rand(c *tls.Config) io.Reader - -//go:linkname write crypto/tls.(*Conn).write -func write(c *_trsconn, data []byte) (int, error) - -func (c *_trsconn) write(data []byte) (int, error) { - return write(c, data) -} - -//go:linkname flush crypto/tls.(*Conn).flush -func flush(c *_trsconn) (int, error) - -func (c *_trsconn) flush() (int, error) { - return flush(c) -} - -//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec -func changeCipherSpec(hc *halfConn) error - -func (hc *halfConn) changeCipherSpec() error { - return changeCipherSpec(hc) -} - -//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked -func sendAlertLocked(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlertLocked(err alert) error { - return sendAlertLocked(c, err) -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { - if c.quic != nil { - return tlsWriteRecordLocked(c, typ, data) - } - - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - isFirstLoop := true - for len(data) > 0 { - m := len(data) - if !isFirstLoop { - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - } else { - m = int(firstFragmentLen) - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = tls.VersionTLS10 - } else if vers == tls.VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = tls.VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - if isFirstLoop { - isFirstLoop = false - if _, err := c.flush(); err != nil { - return n, err - } - } - } - - if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(alert( - *(*uintptr)( - unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), - ), - )) - } - } - - return n, nil -} diff --git a/tls_1.24.go b/tls_1.24.go deleted file mode 100644 index 6274e0d..0000000 --- a/tls_1.24.go +++ /dev/null @@ -1,277 +0,0 @@ -//go:build go1.24 - -package terasu - -import ( - "context" - "crypto/tls" - "crypto/x509" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "unsafe" - _ "unsafe" -) - -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -const ( - recordHeaderLen = 5 // record header length -) - -type alert uint8 - -//go:linkname tlsConfigRand crypto/tls.(*Config).rand -func tlsConfigRand(c *tls.Config) io.Reader - -//go:linkname alertError tls.(tls.alert).Error -func alertError(e alert) string - -func (e alert) Error() string { - return alertError(e) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher any // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher any // next encryption state - nextMac hash.Hash // next MAC algorithm - - level tls.QUICEncryptionLevel // current QUIC encryption level - trafficSecret []byte // current TLS 1.3 traffic secret -} - -type Conn tls.Conn - -// A _trsconn represents a secured connection. -// It implements the net._trsconn interface. -type _trsconn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - quic unsafe.Pointer // nil for non-QUIC connections - - // isHandshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // isHandshakeComplete is true implies handshakeErr == nil. - isHandshakeComplete atomic.Bool - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *tls.Config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - extMasterSecret bool - didResume bool // whether this connection was a session resumption - didHRR bool // whether a HelloRetryRequest was sent/received - cipherSuite uint16 - curveID tls.CurveID - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []unsafe.Pointer - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // or sending NewSessionTicket messages. - resumptionSecret []byte - echAccepted bool - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []byte - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn -} - -//go:linkname outBufPool crypto/tls.outBufPool -var outBufPool sync.Pool - -//go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked -func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error) - -//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite -func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int - -func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { - return maxPayloadSizeForWrite(c, typ) -} - -//go:linkname sliceForAppend crypto/tls.sliceForAppend -func sliceForAppend(in []byte, n int) (head, tail []byte) - -//go:linkname encrypt crypto/tls.(*halfConn).encrypt -func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) - -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - return encrypt(hc, record, payload, rand) -} - -//go:linkname rand crypto/tls.(*Config).rand -func rand(c *tls.Config) io.Reader - -//go:linkname write crypto/tls.(*Conn).write -func write(c *_trsconn, data []byte) (int, error) - -func (c *_trsconn) write(data []byte) (int, error) { - return write(c, data) -} - -//go:linkname flush crypto/tls.(*Conn).flush -func flush(c *_trsconn) (int, error) - -func (c *_trsconn) flush() (int, error) { - return flush(c) -} - -//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec -func changeCipherSpec(hc *halfConn) error - -func (hc *halfConn) changeCipherSpec() error { - return changeCipherSpec(hc) -} - -//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked -func sendAlertLocked(c *_trsconn, err alert) error - -func (c *_trsconn) sendAlertLocked(err alert) error { - return sendAlertLocked(c, err) -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { - if c.quic != nil { - return tlsWriteRecordLocked(c, typ, data) - } - - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - isFirstLoop := true - for len(data) > 0 { - m := len(data) - if !isFirstLoop { - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - } else { - m = int(firstFragmentLen) - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = tls.VersionTLS10 - } else if vers == tls.VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = tls.VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - if isFirstLoop { - isFirstLoop = false - if _, err := c.flush(); err != nil { - return n, err - } - } - } - - if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(alert( - *(*uintptr)( - unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), - ), - )) - } - } - - return n, nil -} diff --git a/utils.go b/utils.go deleted file mode 100644 index 1d1442d..0000000 --- a/utils.go +++ /dev/null @@ -1,9 +0,0 @@ -package terasu - -import ( - "reflect" -) - -func isTypeEqual(obj any, name string) bool { - return reflect.ValueOf(obj).Type().String() == name -}