+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..1f6876f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,17 @@
+
+
+# TeRaSu (TRS)
+
+よの光遍く空へ照しつつ
+
+土棲むものは孰れか見ゆや
+
+
+
+
+
+## Usage
+
+```go
+tlsConn = terasu.Use(tlsConn)
+```
diff --git a/cmd/main.go b/cmd/main.go
new file mode 100644
index 0000000..f952848
--- /dev/null
+++ b/cmd/main.go
@@ -0,0 +1,55 @@
+package main
+
+import (
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+
+ "github.com/fumiama/terasu"
+)
+
+func main() {
+ u := flag.String("url", "https://huggingface.co/", "the url to get")
+ ipport := flag.String("dest", "18.65.159.2:443", "host:port")
+ flag.Parse()
+ if !strings.HasPrefix(*u, "https://") {
+ fmt.Println("ERROR: invalid url")
+ return
+ }
+ host := (*u)[8:]
+ host, _, _ = strings.Cut(host, "/")
+ cli := http.Client{
+ Transport: &http.Transport{
+ DialTLS: func(network, addr string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", *ipport)
+ if err != nil {
+ return nil, err
+ }
+ return terasu.Use(tls.Client(conn, &tls.Config{
+ ServerName: host,
+ InsecureSkipVerify: true,
+ })), nil
+ },
+ },
+ }
+ resp, err := cli.Get(*u)
+ if err != nil {
+ fmt.Println("ERROR:", err)
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ fmt.Println("ERROR:", "status code:", resp.StatusCode)
+ return
+ }
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ fmt.Println("ERROR:", err)
+ return
+ }
+ fmt.Print(string(data))
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..02abb19
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module github.com/fumiama/terasu
+
+go 1.22.1
diff --git a/handshake.go b/handshake.go
new file mode 100644
index 0000000..ce7f9e7
--- /dev/null
+++ b/handshake.go
@@ -0,0 +1,420 @@
+package terasu
+
+import (
+ "context"
+ "crypto"
+ "crypto/ecdh"
+ "crypto/tls"
+ "errors"
+ "hash"
+ "unsafe"
+)
+
+//go:linkname defaultConfig crypto/tls.defaultConfig
+func defaultConfig() *tls.Config
+
+type clientHelloMsg struct {
+ raw []byte
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuites []uint16
+ compressionMethods []uint8
+ serverName string
+ ocspStapling bool
+ supportedCurves []tls.CurveID
+ supportedPoints []uint8
+ ticketSupported bool
+ sessionTicket []uint8
+ supportedSignatureAlgorithms []tls.SignatureScheme
+ supportedSignatureAlgorithmsCert []tls.SignatureScheme
+ secureRenegotiationSupported bool
+ secureRenegotiation []byte
+ extendedMasterSecret bool
+ alpnProtocols []string
+ scts bool
+ supportedVersions []uint16
+ cookie []byte
+ keyShares []byte
+ earlyData bool
+}
+
+//go:linkname marshal crypto/tls.(*clientHelloMsg).marshal
+func marshal(m *clientHelloMsg) ([]byte, error)
+
+func (m *clientHelloMsg) marshal() ([]byte, error) {
+ return marshal(m)
+}
+
+//go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal
+func unmarshal(m *clientHelloMsg, data []byte) bool
+
+func (m *clientHelloMsg) unmarshal(data []byte) bool {
+ return unmarshal(m, data)
+}
+
+//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello
+func makeClientHello(c *trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error)
+
+func (c *trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
+ return makeClientHello(c)
+}
+
+// A sessionState is a resumable session.
+type sessionState struct {
+ // Encoded as a SessionState (in the language of RFC 8446, Section 3).
+ //
+ // enum { server(1), client(2) } SessionStateType;
+ //
+ // opaque Certificate<1..2^24-1>;
+ //
+ // Certificate CertificateChain<0..2^24-1>;
+ //
+ // opaque Extra<0..2^24-1>;
+ //
+ // struct {
+ // uint16 version;
+ // SessionStateType type;
+ // uint16 cipher_suite;
+ // uint64 created_at;
+ // opaque secret<1..2^8-1>;
+ // Extra extra<0..2^24-1>;
+ // uint8 ext_master_secret = { 0, 1 };
+ // uint8 early_data = { 0, 1 };
+ // CertificateEntry certificate_list<0..2^24-1>;
+ // CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */
+ // select (SessionState.early_data) {
+ // case 0: Empty;
+ // case 1: opaque alpn<1..2^8-1>;
+ // };
+ // select (SessionState.type) {
+ // case server: Empty;
+ // case client: struct {
+ // select (SessionState.version) {
+ // case VersionTLS10..VersionTLS12: Empty;
+ // case VersionTLS13: struct {
+ // uint64 use_by;
+ // uint32 age_add;
+ // };
+ // };
+ // };
+ // };
+ // } SessionState;
+ //
+
+ // Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes]
+ // and parsed by [ParseSessionState].
+ //
+ // This allows [Config.UnwrapSession]/[Config.WrapSession] and
+ // [ClientSessionCache] implementations to store and retrieve additional
+ // data alongside this session.
+ //
+ // To allow different layers in a protocol stack to share this field,
+ // applications must only append to it, not replace it, and must use entries
+ // that can be recognized even if out of order (for example, by starting
+ // with an id and version prefix).
+ Extra [][]byte
+
+ // EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC
+ // connection. The application may set this to false if it is true to
+ // decline to offer 0-RTT even if supported.
+ EarlyData bool
+
+ version uint16
+ isClient bool
+ cipherSuite uint16
+}
+
+//go:linkname loadSession crypto/tls.(*Conn).loadSession
+func loadSession(c *trsconn, hello *clientHelloMsg) (
+ session *sessionState, earlySecret, binderKey []byte, err error,
+)
+
+func (c *trsconn) loadSession(hello *clientHelloMsg) (
+ session *sessionState, earlySecret, binderKey []byte, err error,
+) {
+ return loadSession(c, hello)
+}
+
+//go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey
+func clientSessionCacheKey(c *trsconn) string
+
+func (c *trsconn) clientSessionCacheKey() string {
+ return clientSessionCacheKey(c)
+}
+
+// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
+// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
+type cipherSuiteTLS13 struct {
+ id uint16
+ keyLen int
+ aead func(key, fixedNonce []byte) any
+ hash crypto.Hash
+}
+
+//go:linkname deriveSecret crypto/tls.(*cipherSuiteTLS13).deriveSecret
+func deriveSecret(c *cipherSuiteTLS13, secret []byte, label string, transcript hash.Hash) []byte
+
+func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
+ return deriveSecret(c, secret, label, transcript)
+}
+
+//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
+func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
+
+type handshakeMessage interface {
+ marshal() ([]byte, error)
+ unmarshal([]byte) bool
+}
+
+type transcriptHash interface {
+ Write([]byte) (int, error)
+}
+
+//go:linkname transcriptMsg crypto/tls.transcriptMsg
+func transcriptMsg(msg handshakeMessage, h transcriptHash) error
+
+const clientEarlyTrafficLabel = "c e traffic"
+
+//go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret
+func quicSetWriteSecret(c *trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte)
+
+//go:linkname readHandshake crypto/tls.(*Conn).readHandshake
+func readHandshake(c *trsconn, transcript transcriptHash) (any, error)
+
+func (c *trsconn) readHandshake(transcript transcriptHash) (any, error) {
+ return readHandshake(c, transcript)
+}
+
+type serverHelloMsg struct {
+ raw []byte
+ vers uint16
+ random []byte
+}
+
+//go:linkname sendAlert crypto/tls.(*Conn).sendAlert
+func sendAlert(c *trsconn, err alert) error
+
+func (c *trsconn) sendAlert(err alert) error {
+ return sendAlert(c, err)
+}
+
+//go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError
+func unexpectedMessageError(wanted, got any) error
+
+const (
+ alertUnexpectedMessage alert = 10
+ alertIllegalParameter alert = 47
+)
+
+//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion
+func pickTLSVersion(c *trsconn, serverHello *serverHelloMsg) error
+
+func (c *trsconn) pickTLSVersion(serverHello *serverHelloMsg) error {
+ return pickTLSVersion(c, serverHello)
+}
+
+//go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion
+func maxSupportedVersion(c *tls.Config, isClient bool) uint16
+
+const roleClient = true
+
+const (
+ // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server
+ // random as a downgrade protection if the server would be capable of
+ // negotiating a higher version. See RFC 8446, Section 4.1.3.
+ downgradeCanaryTLS12 = "DOWNGRD\x01"
+ downgradeCanaryTLS11 = "DOWNGRD\x00"
+)
+
+type clientHandshakeStateTLS13 struct {
+ c *trsconn
+ ctx context.Context
+ serverHello *serverHelloMsg
+ hello *clientHelloMsg
+ ecdheKey *ecdh.PrivateKey
+
+ session *sessionState
+ earlySecret []byte
+ binderKey []byte
+
+ certReq *uintptr
+ usingPSK bool
+ sentDummyCCS bool
+ suite *cipherSuiteTLS13
+ transcript hash.Hash
+ masterSecret []byte
+ trafficSecret []byte // client_application_traffic_secret_0
+}
+
+//go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake
+func handshake13(hs *clientHandshakeStateTLS13) error
+
+func (hs *clientHandshakeStateTLS13) handshake() error {
+ return handshake13(hs)
+}
+
+// A finishedHash calculates the hash of a set of handshake messages suitable
+// for including in a Finished message.
+type finishedHash struct {
+ client hash.Hash
+ server hash.Hash
+
+ // Prior to TLS 1.2, an additional MD5 hash is required.
+ clientMD5 hash.Hash
+ serverMD5 hash.Hash
+
+ // In TLS 1.2, a full buffer is sadly required.
+ buffer []byte
+
+ version uint16
+ prf func(result, secret, label, seed []byte)
+}
+
+type clientHandshakeState struct {
+ c *trsconn
+ ctx context.Context
+ serverHello *serverHelloMsg
+ hello *clientHelloMsg
+ suite *uintptr
+ finishedHash finishedHash
+ masterSecret []byte
+ session *sessionState // the session being resumed
+ ticket []byte // a fresh ticket received during this handshake
+}
+
+//go:linkname handshake crypto/tls.(*clientHandshakeState).handshake
+func handshake(hs *clientHandshakeState) error
+
+func (hs *clientHandshakeState) handshake() error {
+ return handshake(hs)
+}
+
+// writeHandshakeRecord writes a handshake message to the connection and updates
+// the record layer state. If transcript is non-nil the marshalled message is
+// written to it.
+func (c *trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
+ c.out.Lock()
+ defer c.out.Unlock()
+
+ data, err := msg.marshal()
+ if err != nil {
+ return 0, err
+ }
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
+ return writeRecordLocked(c, recordTypeHandshake, data)
+}
+
+func (c *trsconn) clientHandshake(ctx context.Context) (err error) {
+ if c.config == nil {
+ c.config = defaultConfig()
+ }
+
+ // This may be a renegotiation handshake, in which case some fields
+ // need to be reset.
+ c.didResume = false
+
+ hello, ecdheKey, err := c.makeClientHello()
+ if err != nil {
+ return err
+ }
+ c.serverName = hello.serverName
+
+ session, earlySecret, binderKey, err := c.loadSession(hello)
+ if err != nil {
+ return err
+ }
+ if session != nil {
+ defer func() {
+ // If we got a handshake failure when resuming a session, throw away
+ // the session ticket. See RFC 5077, Section 3.2.
+ //
+ // RFC 8446 makes no mention of dropping tickets on failure, but it
+ // does require servers to abort on invalid binders, so we need to
+ // delete tickets to recover from a corrupted PSK.
+ if err != nil {
+ if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
+ c.config.ClientSessionCache.Put(cacheKey, nil)
+ }
+ }
+ }()
+ }
+
+ if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
+ return err
+ }
+
+ if hello.earlyData {
+ suite := cipherSuiteTLS13ByID(session.cipherSuite)
+ transcript := suite.hash.New()
+ if err := transcriptMsg(hello, transcript); err != nil {
+ return err
+ }
+ earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript)
+ quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
+ }
+
+ // serverHelloMsg is not included in the transcript
+ msg, err := c.readHandshake(nil)
+ if err != nil {
+ return err
+ }
+
+ var serverHello *serverHelloMsg
+ if !isTypeEqual(msg, "*tls.serverHelloMsg") {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(serverHello, msg)
+ }
+ serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)(
+ unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))),
+ ))
+
+ if err := c.pickTLSVersion(serverHello); err != nil {
+ return err
+ }
+
+ // If we are negotiating a protocol version that's lower than what we
+ // support, check for the server downgrade canaries.
+ // See RFC 8446, Section 4.1.3.
+ maxVers := maxSupportedVersion(c.config, roleClient)
+ tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12
+ tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11
+ if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) ||
+ maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade {
+ c.sendAlert(alertIllegalParameter)
+ return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox")
+ }
+
+ if c.vers == tls.VersionTLS13 {
+ hs := &clientHandshakeStateTLS13{
+ c: c,
+ ctx: ctx,
+ serverHello: serverHello,
+ hello: hello,
+ ecdheKey: ecdheKey,
+ session: session,
+ earlySecret: earlySecret,
+ binderKey: binderKey,
+ }
+
+ // In TLS 1.3, session tickets are delivered after the handshake.
+ return hs.handshake()
+ }
+
+ hs := &clientHandshakeState{
+ c: c,
+ ctx: ctx,
+ serverHello: serverHello,
+ hello: hello,
+ session: session,
+ }
+
+ if err := hs.handshake(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/terasu.go b/terasu.go
new file mode 100644
index 0000000..fdc3342
--- /dev/null
+++ b/terasu.go
@@ -0,0 +1,13 @@
+package terasu
+
+import (
+ "crypto/tls"
+ "unsafe"
+)
+
+// Use terasu in this TLS conn
+func Use(conn *tls.Conn) *tls.Conn {
+ trsConn := (*trsconn)(unsafe.Pointer(conn))
+ trsConn.handshakeFn = trsConn.clientHandshake
+ return conn
+}
diff --git a/terasu_test.go b/terasu_test.go
new file mode 100644
index 0000000..dfaea3e
--- /dev/null
+++ b/terasu_test.go
@@ -0,0 +1,40 @@
+package terasu
+
+import (
+ "crypto/tls"
+ "io"
+ "net"
+ "net/http"
+ "testing"
+)
+
+func TestHTTPDialTLS(t *testing.T) {
+ cli := http.Client{
+ Transport: &http.Transport{
+ DialTLS: func(network, addr string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", "18.65.159.2:443")
+ if err != nil {
+ return nil, err
+ }
+ t.Log("net.Dial succeeded")
+ return Use(tls.Client(conn, &tls.Config{
+ ServerName: "huggingface.co",
+ InsecureSkipVerify: true,
+ })), nil
+ },
+ },
+ }
+ resp, err := cli.Get("https://huggingface.co/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Fatal("status code:", resp.StatusCode)
+ }
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(string(data))
+}
diff --git a/tls.go b/tls.go
new file mode 100644
index 0000000..306b477
--- /dev/null
+++ b/tls.go
@@ -0,0 +1,269 @@
+package terasu
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "hash"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+ _ "unsafe"
+)
+
+const firstFragmentLen = 4
+
+type recordType uint8
+
+const (
+ recordTypeChangeCipherSpec recordType = 20
+ recordTypeAlert recordType = 21
+ recordTypeHandshake recordType = 22
+ recordTypeApplicationData recordType = 23
+)
+
+const (
+ recordHeaderLen = 5 // record header length
+)
+
+type alert uint8
+
+//go:linkname alertError tls.(tls.alert).Error
+func alertError(e alert) string
+
+func (e alert) Error() string {
+ return alertError(e)
+}
+
+// A halfConn represents one direction of the record layer
+// connection, either sending or receiving.
+type halfConn struct {
+ sync.Mutex
+
+ err error // first permanent error
+ version uint16 // protocol version
+ cipher any // cipher algorithm
+ mac hash.Hash
+ seq [8]byte // 64-bit sequence number
+
+ scratchBuf [13]byte // to avoid allocs; interface method args escape
+
+ nextCipher any // next encryption state
+ nextMac hash.Hash // next MAC algorithm
+
+ level tls.QUICEncryptionLevel // current QUIC encryption level
+ trafficSecret []byte // current TLS 1.3 traffic secret
+}
+
+// A trsconn represents a secured connection.
+// It implements the net.trsconn interface.
+type trsconn struct {
+ // constant
+ conn net.Conn
+ isClient bool
+ handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
+ quic *uintptr // nil for non-QUIC connections
+
+ // isHandshakeComplete is true if the connection is currently transferring
+ // application data (i.e. is not currently processing a handshake).
+ // isHandshakeComplete is true implies handshakeErr == nil.
+ isHandshakeComplete atomic.Bool
+ // constant after handshake; protected by handshakeMutex
+ handshakeMutex sync.Mutex
+ handshakeErr error // error resulting from handshake
+ vers uint16 // TLS version
+ haveVers bool // version has been negotiated
+ config *tls.Config // configuration passed to constructor
+ // handshakes counts the number of handshakes performed on the
+ // connection so far. If renegotiation is disabled then this is either
+ // zero or one.
+ handshakes int
+ extMasterSecret bool
+ didResume bool // whether this connection was a session resumption
+ cipherSuite uint16
+ ocspResponse []byte // stapled OCSP response
+ scts [][]byte // signed certificate timestamps from server
+ peerCertificates []*x509.Certificate
+ // activeCertHandles contains the cache handles to certificates in
+ // peerCertificates that are used to track active references.
+ activeCertHandles []*uintptr
+ // verifiedChains contains the certificate chains that we built, as
+ // opposed to the ones presented by the server.
+ verifiedChains [][]*x509.Certificate
+ // serverName contains the server name indicated by the client, if any.
+ serverName string
+ // secureRenegotiation is true if the server echoed the secure
+ // renegotiation extension. (This is meaningless as a server because
+ // renegotiation is not supported in that case.)
+ secureRenegotiation bool
+ // ekm is a closure for exporting keying material.
+ ekm func(label string, context []byte, length int) ([]byte, error)
+ // resumptionSecret is the resumption_master_secret for handling
+ // or sending NewSessionTicket messages.
+ resumptionSecret []byte
+
+ // ticketKeys is the set of active session ticket keys for this
+ // connection. The first one is used to encrypt new tickets and
+ // all are tried to decrypt tickets.
+ ticketKeys []byte
+
+ // clientFinishedIsFirst is true if the client sent the first Finished
+ // message during the most recent handshake. This is recorded because
+ // the first transmitted Finished message is the tls-unique
+ // channel-binding value.
+ clientFinishedIsFirst bool
+
+ // closeNotifyErr is any error from sending the alertCloseNotify record.
+ closeNotifyErr error
+ // closeNotifySent is true if the Conn attempted to send an
+ // alertCloseNotify record.
+ closeNotifySent bool
+
+ // clientFinished and serverFinished contain the Finished message sent
+ // by the client or server in the most recent handshake. This is
+ // retained to support the renegotiation extension and tls-unique
+ // channel-binding.
+ clientFinished [12]byte
+ serverFinished [12]byte
+
+ // clientProtocol is the negotiated ALPN protocol.
+ clientProtocol string
+
+ // input/output
+ in, out halfConn
+}
+
+//go:linkname outBufPool crypto/tls.outBufPool
+var outBufPool sync.Pool
+
+//go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked
+func tlsWriteRecordLocked(c *trsconn, typ recordType, data []byte) (int, error)
+
+//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite
+func maxPayloadSizeForWrite(c *trsconn, typ recordType) int
+
+func (c *trsconn) maxPayloadSizeForWrite(typ recordType) int {
+ return maxPayloadSizeForWrite(c, typ)
+}
+
+//go:linkname sliceForAppend crypto/tls.sliceForAppend
+func sliceForAppend(in []byte, n int) (head, tail []byte)
+
+//go:linkname encrypt crypto/tls.(*halfConn).encrypt
+func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error)
+
+func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
+ return encrypt(hc, record, payload, rand)
+}
+
+//go:linkname rand crypto/tls.(*Config).rand
+func rand(c *tls.Config) io.Reader
+
+//go:linkname write crypto/tls.(*Conn).write
+func write(c *trsconn, data []byte) (int, error)
+
+func (c *trsconn) write(data []byte) (int, error) {
+ return write(c, data)
+}
+
+//go:linkname flush crypto/tls.(*Conn).flush
+func flush(c *trsconn) (int, error)
+
+func (c *trsconn) flush() (int, error) {
+ return flush(c)
+}
+
+//go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec
+func changeCipherSpec(hc *halfConn) error
+
+func (hc *halfConn) changeCipherSpec() error {
+ return changeCipherSpec(hc)
+}
+
+//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked
+func sendAlertLocked(c *trsconn, err alert) error
+
+func (c *trsconn) sendAlertLocked(err alert) error {
+ return sendAlertLocked(c, err)
+}
+
+// writeRecordLocked writes a TLS record with the given type and payload to the
+// connection and updates the record layer state.
+func writeRecordLocked(c *trsconn, typ recordType, data []byte) (int, error) {
+ if c.quic != nil {
+ return tlsWriteRecordLocked(c, typ, data)
+ }
+
+ outBufPtr := outBufPool.Get().(*[]byte)
+ outBuf := *outBufPtr
+ defer func() {
+ // You might be tempted to simplify this by just passing &outBuf to Put,
+ // but that would make the local copy of the outBuf slice header escape
+ // to the heap, causing an allocation. Instead, we keep around the
+ // pointer to the slice header returned by Get, which is already on the
+ // heap, and overwrite and return that.
+ *outBufPtr = outBuf
+ outBufPool.Put(outBufPtr)
+ }()
+
+ var n int
+ isFirstLoop := true
+ for len(data) > 0 {
+ m := len(data)
+ if !isFirstLoop {
+ if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
+ m = maxPayload
+ }
+ } else {
+ m = firstFragmentLen
+ }
+
+ _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
+ outBuf[0] = byte(typ)
+ vers := c.vers
+ if vers == 0 {
+ // Some TLS servers fail if the record version is
+ // greater than TLS 1.0 for the initial ClientHello.
+ vers = tls.VersionTLS10
+ } else if vers == tls.VersionTLS13 {
+ // TLS 1.3 froze the record layer version to 1.2.
+ // See RFC 8446, Section 5.1.
+ vers = tls.VersionTLS12
+ }
+ outBuf[1] = byte(vers >> 8)
+ outBuf[2] = byte(vers)
+ outBuf[3] = byte(m >> 8)
+ outBuf[4] = byte(m)
+
+ var err error
+ outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config))
+ if err != nil {
+ return n, err
+ }
+ if _, err := c.write(outBuf); err != nil {
+ return n, err
+ }
+ n += m
+ data = data[m:]
+ if isFirstLoop {
+ isFirstLoop = false
+ if _, err := c.flush(); err != nil {
+ return n, err
+ }
+ }
+ }
+
+ if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 {
+ if err := c.out.changeCipherSpec(); err != nil {
+ return n, c.sendAlertLocked(alert(
+ *(*uintptr)(
+ unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))),
+ ),
+ ))
+ }
+ }
+
+ return n, nil
+}
diff --git a/utils.go b/utils.go
new file mode 100644
index 0000000..1d1442d
--- /dev/null
+++ b/utils.go
@@ -0,0 +1,9 @@
+package terasu
+
+import (
+ "reflect"
+)
+
+func isTypeEqual(obj any, name string) bool {
+ return reflect.ValueOf(obj).Type().String() == name
+}