diff --git a/handshake.go b/handshake.go index ce7f9e7..0e008c4 100644 --- a/handshake.go +++ b/handshake.go @@ -54,9 +54,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } //go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello -func makeClientHello(c *trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) +func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) -func (c *trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { +func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { return makeClientHello(c) } @@ -126,20 +126,20 @@ type sessionState struct { } //go:linkname loadSession crypto/tls.(*Conn).loadSession -func loadSession(c *trsconn, hello *clientHelloMsg) ( +func loadSession(c *_trsconn, hello *clientHelloMsg) ( session *sessionState, earlySecret, binderKey []byte, err error, ) -func (c *trsconn) loadSession(hello *clientHelloMsg) ( +func (c *_trsconn) loadSession(hello *clientHelloMsg) ( session *sessionState, earlySecret, binderKey []byte, err error, ) { return loadSession(c, hello) } //go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey -func clientSessionCacheKey(c *trsconn) string +func clientSessionCacheKey(c *_trsconn) string -func (c *trsconn) clientSessionCacheKey() string { +func (c *_trsconn) clientSessionCacheKey() string { return clientSessionCacheKey(c) } @@ -177,12 +177,12 @@ func transcriptMsg(msg handshakeMessage, h transcriptHash) error const clientEarlyTrafficLabel = "c e traffic" //go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret -func quicSetWriteSecret(c *trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) +func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) //go:linkname readHandshake crypto/tls.(*Conn).readHandshake -func readHandshake(c *trsconn, transcript transcriptHash) (any, error) +func readHandshake(c *_trsconn, transcript transcriptHash) (any, error) -func (c *trsconn) readHandshake(transcript transcriptHash) (any, error) { +func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) { return readHandshake(c, transcript) } @@ -193,9 +193,9 @@ type serverHelloMsg struct { } //go:linkname sendAlert crypto/tls.(*Conn).sendAlert -func sendAlert(c *trsconn, err alert) error +func sendAlert(c *_trsconn, err alert) error -func (c *trsconn) sendAlert(err alert) error { +func (c *_trsconn) sendAlert(err alert) error { return sendAlert(c, err) } @@ -208,9 +208,9 @@ const ( ) //go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion -func pickTLSVersion(c *trsconn, serverHello *serverHelloMsg) error +func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error -func (c *trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { +func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { return pickTLSVersion(c, serverHello) } @@ -293,7 +293,7 @@ func (hs *clientHandshakeState) handshake() error { // writeHandshakeRecord writes a handshake message to the connection and updates // the record layer state. If transcript is non-nil the marshalled message is // written to it. -func (c *trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { +func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { c.out.Lock() defer c.out.Unlock() @@ -305,10 +305,12 @@ func (c *trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcri transcript.Write(data) } - return writeRecordLocked(c, recordTypeHandshake, data) + return c.writeRecordLocked(recordTypeHandshake, data) } -func (c *trsconn) clientHandshake(ctx context.Context) (err error) { +func (cout *trsconn) clientHandshake(ctx context.Context) (err error) { + c := (*_trsconn)(unsafe.Pointer(cout)) + if c.config == nil { c.config = defaultConfig() } @@ -390,7 +392,7 @@ func (c *trsconn) clientHandshake(ctx context.Context) (err error) { if c.vers == tls.VersionTLS13 { hs := &clientHandshakeStateTLS13{ - c: c, + c: cout, ctx: ctx, serverHello: serverHello, hello: hello, @@ -405,7 +407,7 @@ func (c *trsconn) clientHandshake(ctx context.Context) (err error) { } hs := &clientHandshakeState{ - c: c, + c: cout, ctx: ctx, serverHello: serverHello, hello: hello, diff --git a/terasu.go b/terasu.go index 2b8cbf4..c504411 100644 --- a/terasu.go +++ b/terasu.go @@ -7,7 +7,6 @@ import ( // Use terasu in this TLS conn func Use(conn *tls.Conn) *tls.Conn { - trsConn := (*trsconn)(unsafe.Pointer(conn)) - trsConn.handshakeFn = trsConn.clientHandshake - return (*tls.Conn)(unsafe.Pointer(trsConn)) + (*_trsconn)(unsafe.Pointer(conn)).handshakeFn = (*trsconn)(conn).clientHandshake + return conn } diff --git a/terasu_test.go b/terasu_test.go index dfaea3e..7300091 100644 --- a/terasu_test.go +++ b/terasu_test.go @@ -8,7 +8,7 @@ import ( "testing" ) -func TestHTTPDialTLS(t *testing.T) { +func TestHTTPDialTLS13(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { @@ -38,3 +38,35 @@ func TestHTTPDialTLS(t *testing.T) { } t.Log(string(data)) } + +func TestHTTPDialTLS12(t *testing.T) { + cli := http.Client{ + Transport: &http.Transport{ + DialTLS: func(network, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", "18.65.159.2:443") + if err != nil { + return nil, err + } + t.Log("net.Dial succeeded") + return Use(tls.Client(conn, &tls.Config{ + ServerName: "huggingface.co", + InsecureSkipVerify: true, + MaxVersion: tls.VersionTLS12, + })), nil + }, + }, + } + resp, err := cli.Get("https://huggingface.co/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatal("status code:", resp.StatusCode) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(data)) +} diff --git a/tls.go b/tls.go index 306b477..e5092e1 100644 --- a/tls.go +++ b/tls.go @@ -57,9 +57,11 @@ type halfConn struct { trafficSecret []byte // current TLS 1.3 traffic secret } -// A trsconn represents a secured connection. -// It implements the net.trsconn interface. -type trsconn struct { +type trsconn tls.Conn + +// A _trsconn represents a secured connection. +// It implements the net._trsconn interface. +type _trsconn struct { // constant conn net.Conn isClient bool @@ -139,12 +141,12 @@ type trsconn struct { var outBufPool sync.Pool //go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked -func tlsWriteRecordLocked(c *trsconn, typ recordType, data []byte) (int, error) +func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error) //go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite -func maxPayloadSizeForWrite(c *trsconn, typ recordType) int +func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int -func (c *trsconn) maxPayloadSizeForWrite(typ recordType) int { +func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { return maxPayloadSizeForWrite(c, typ) } @@ -162,16 +164,16 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err func rand(c *tls.Config) io.Reader //go:linkname write crypto/tls.(*Conn).write -func write(c *trsconn, data []byte) (int, error) +func write(c *_trsconn, data []byte) (int, error) -func (c *trsconn) write(data []byte) (int, error) { +func (c *_trsconn) write(data []byte) (int, error) { return write(c, data) } //go:linkname flush crypto/tls.(*Conn).flush -func flush(c *trsconn) (int, error) +func flush(c *_trsconn) (int, error) -func (c *trsconn) flush() (int, error) { +func (c *_trsconn) flush() (int, error) { return flush(c) } @@ -183,15 +185,15 @@ func (hc *halfConn) changeCipherSpec() error { } //go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked -func sendAlertLocked(c *trsconn, err alert) error +func sendAlertLocked(c *_trsconn, err alert) error -func (c *trsconn) sendAlertLocked(err alert) error { +func (c *_trsconn) sendAlertLocked(err alert) error { return sendAlertLocked(c, err) } // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. -func writeRecordLocked(c *trsconn, typ recordType, data []byte) (int, error) { +func (c *_trsconn) writeRecordLocked(typ recordType, data []byte) (int, error) { if c.quic != nil { return tlsWriteRecordLocked(c, typ, data) }