From 871703d64492c3da39a09ad228fcc3347d8132b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 2 May 2024 00:36:08 +0900 Subject: [PATCH] feat: adapt to go1.20 --- handshake.go | 2 + handshake_1.20.go | 328 ++++++++++++++++++++++++++++++++++++++++++++++ tls.go | 2 + tls_1.20.go | 261 ++++++++++++++++++++++++++++++++++++ 4 files changed, 593 insertions(+) create mode 100644 handshake_1.20.go create mode 100644 tls_1.20.go diff --git a/handshake.go b/handshake.go index f0c3f39..6f6b5e7 100644 --- a/handshake.go +++ b/handshake.go @@ -1,3 +1,5 @@ +//go:build go1.21 + package terasu import ( diff --git a/handshake_1.20.go b/handshake_1.20.go new file mode 100644 index 0000000..ec0e61d --- /dev/null +++ b/handshake_1.20.go @@ -0,0 +1,328 @@ +//go:build !go1.21 + +package terasu + +import ( + "context" + "crypto/ecdh" + "crypto/tls" + "crypto/x509" + "errors" + "hash" + "time" + "unsafe" +) + +//go:linkname defaultConfig crypto/tls.defaultConfig +func defaultConfig() *tls.Config + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string +} + +//go:linkname marshal crypto/tls.(*clientHelloMsg).marshal +func marshal(m *clientHelloMsg) ([]byte, error) + +func (m *clientHelloMsg) marshal() ([]byte, error) { + return marshal(m) +} + +//go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal +func unmarshal(m *clientHelloMsg, data []byte) bool + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + return unmarshal(m, data) +} + +//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello +func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) + +func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { + return makeClientHello(c) +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type sessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +//go:linkname loadSession crypto/tls.(*Conn).loadSession +func loadSession(c *_trsconn, hello *clientHelloMsg) (cacheKey string, + session *sessionState, earlySecret, binderKey []byte, err error, +) + +func (c *_trsconn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *sessionState, earlySecret, binderKey []byte, err error, +) { + return loadSession(c, hello) +} + +type handshakeMessage interface { + marshal() ([]byte, error) + unmarshal([]byte) bool +} + +type transcriptHash interface { + Write([]byte) (int, error) +} + +//go:linkname transcriptMsg crypto/tls.transcriptMsg +func transcriptMsg(msg handshakeMessage, h transcriptHash) error + +//go:linkname readHandshake crypto/tls.(*Conn).readHandshake +func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) + +func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { + return readHandshake(c, transcript) +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte +} + +//go:linkname sendAlert crypto/tls.(*Conn).sendAlert +func sendAlert(c *_trsconn, err alert) error + +func (c *_trsconn) sendAlert(err alert) error { + return sendAlert(c, err) +} + +//go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError +func unexpectedMessageError(wanted, got any) error + +const ( + alertUnexpectedMessage alert = 10 + alertIllegalParameter alert = 47 +) + +//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion +func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error + +func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { + return pickTLSVersion(c, serverHello) +} + +//go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion +func maxSupportedVersion(c *tls.Config, isClient bool) uint16 + +const roleClient = true + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheKey *ecdh.PrivateKey + + session *sessionState + earlySecret []byte + binderKey []byte + + certReq *uintptr + usingPSK bool + sentDummyCCS bool + suite *uintptr + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +//go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake +func handshake13(hs *clientHandshakeStateTLS13) error + +func (hs *clientHandshakeStateTLS13) handshake() error { + return handshake13(hs) +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +type clientHandshakeState struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *uintptr + finishedHash finishedHash + masterSecret []byte + session *sessionState // the session being resumed + ticket []byte // a fresh ticket received during this handshake +} + +//go:linkname handshake crypto/tls.(*clientHandshakeState).handshake +func handshake(hs *clientHandshakeState) error + +func (hs *clientHandshakeState) handshake() error { + return handshake(hs) +} + +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { + c.out.Lock() + defer c.out.Unlock() + + data, err := msg.marshal() + if err != nil { + return 0, err + } + if transcript != nil { + transcript.Write(data) + } + + return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) +} + +func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { + return func(ctx context.Context) (err error) { + c := (*_trsconn)(unsafe.Pointer(cout)) + + if c.config == nil { + c.config = defaultConfig() + } + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheKey, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } + if cacheKey != "" && session != nil { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + + if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { + return err + } + + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) + if err != nil { + return err + } + + var serverHello *serverHelloMsg + if !isTypeEqual(msg, "*tls.serverHelloMsg") { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( + unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), + )) + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := maxSupportedVersion(c.config, roleClient) + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == tls.VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: cout, + ctx: ctx, + serverHello: serverHello, + hello: hello, + ecdheKey: ecdheKey, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: cout, + ctx: ctx, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, (*tls.ClientSessionState)(unsafe.Pointer(hs.session))) + } + + return nil + } +} diff --git a/tls.go b/tls.go index 15bd9b1..eebce06 100644 --- a/tls.go +++ b/tls.go @@ -1,3 +1,5 @@ +//go:build go1.21 + package terasu import ( diff --git a/tls_1.20.go b/tls_1.20.go new file mode 100644 index 0000000..3f92d8b --- /dev/null +++ b/tls_1.20.go @@ -0,0 +1,261 @@ +//go:build !go1.21 + +package terasu + +import ( + "context" + "crypto/tls" + "crypto/x509" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "unsafe" + _ "unsafe" +) + +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +const ( + recordHeaderLen = 5 // record header length +) + +type alert uint8 + +//go:linkname alertError tls.(tls.alert).Error +func alertError(e alert) string + +func (e alert) Error() string { + return alertError(e) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher any // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher any // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret +} + +type Conn tls.Conn + +// A _trsconn represents a secured connection. +// It implements the net._trsconn interface. +type _trsconn struct { + // constant + conn net.Conn + isClient bool + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + + // isHandshakeComplete is true if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // isHandshakeComplete is true implies handshakeErr == nil. + isHandshakeComplete atomic.Bool + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *tls.Config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // activeCertHandles contains the cache handles to certificates in + // peerCertificates that are used to track active references. + activeCertHandles []*uintptr + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // resumptionSecret is the resumption_master_secret for handling + // or sending NewSessionTicket messages. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []byte + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // input/output + in, out halfConn +} + +//go:linkname outBufPool crypto/tls.outBufPool +var outBufPool sync.Pool + +//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite +func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int + +func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { + return maxPayloadSizeForWrite(c, typ) +} + +//go:linkname sliceForAppend crypto/tls.sliceForAppend +func sliceForAppend(in []byte, n int) (head, tail []byte) + +//go:linkname encrypt crypto/tls.(*halfConn).encrypt +func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) + +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + return encrypt(hc, record, payload, rand) +} + +//go:linkname rand crypto/tls.(*Config).rand +func rand(c *tls.Config) io.Reader + +//go:linkname write crypto/tls.(*Conn).write +func write(c *_trsconn, data []byte) (int, error) + +func (c *_trsconn) write(data []byte) (int, error) { + return write(c, data) +} + +//go:linkname flush crypto/tls.(*Conn).flush +func flush(c *_trsconn) (int, error) + +func (c *_trsconn) flush() (int, error) { + return flush(c) +} + +//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec +func changeCipherSpec(hc *halfConn) error + +func (hc *halfConn) changeCipherSpec() error { + return changeCipherSpec(hc) +} + +//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked +func sendAlertLocked(c *_trsconn, err alert) error + +func (c *_trsconn) sendAlertLocked(err alert) error { + return sendAlertLocked(c, err) +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int + isFirstLoop := true + for len(data) > 0 { + m := len(data) + if !isFirstLoop { + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + } else { + m = int(firstFragmentLen) + } + + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = tls.VersionTLS10 + } else if vers == tls.VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = tls.VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + if isFirstLoop { + isFirstLoop = false + if _, err := c.flush(); err != nil { + return n, err + } + } + } + + if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(alert( + *(*uintptr)( + unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), + ), + )) + } + } + + return n, nil +}