1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-05 01:00:23 +08:00

fix: tls conn nil pointer

This commit is contained in:
源文雨
2024-04-16 14:33:54 +09:00
parent e65650a52c
commit dc4fb1ae72
4 changed files with 70 additions and 35 deletions

View File

@@ -54,9 +54,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
//go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello //go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello
func makeClientHello(c *trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error) func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error)
func (c *trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
return makeClientHello(c) return makeClientHello(c)
} }
@@ -126,20 +126,20 @@ type sessionState struct {
} }
//go:linkname loadSession crypto/tls.(*Conn).loadSession //go:linkname loadSession crypto/tls.(*Conn).loadSession
func loadSession(c *trsconn, hello *clientHelloMsg) ( func loadSession(c *_trsconn, hello *clientHelloMsg) (
session *sessionState, earlySecret, binderKey []byte, err error, session *sessionState, earlySecret, binderKey []byte, err error,
) )
func (c *trsconn) loadSession(hello *clientHelloMsg) ( func (c *_trsconn) loadSession(hello *clientHelloMsg) (
session *sessionState, earlySecret, binderKey []byte, err error, session *sessionState, earlySecret, binderKey []byte, err error,
) { ) {
return loadSession(c, hello) return loadSession(c, hello)
} }
//go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey //go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey
func clientSessionCacheKey(c *trsconn) string func clientSessionCacheKey(c *_trsconn) string
func (c *trsconn) clientSessionCacheKey() string { func (c *_trsconn) clientSessionCacheKey() string {
return clientSessionCacheKey(c) return clientSessionCacheKey(c)
} }
@@ -177,12 +177,12 @@ func transcriptMsg(msg handshakeMessage, h transcriptHash) error
const clientEarlyTrafficLabel = "c e traffic" const clientEarlyTrafficLabel = "c e traffic"
//go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret //go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret
func quicSetWriteSecret(c *trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte) func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte)
//go:linkname readHandshake crypto/tls.(*Conn).readHandshake //go:linkname readHandshake crypto/tls.(*Conn).readHandshake
func readHandshake(c *trsconn, transcript transcriptHash) (any, error) func readHandshake(c *_trsconn, transcript transcriptHash) (any, error)
func (c *trsconn) readHandshake(transcript transcriptHash) (any, error) { func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) {
return readHandshake(c, transcript) return readHandshake(c, transcript)
} }
@@ -193,9 +193,9 @@ type serverHelloMsg struct {
} }
//go:linkname sendAlert crypto/tls.(*Conn).sendAlert //go:linkname sendAlert crypto/tls.(*Conn).sendAlert
func sendAlert(c *trsconn, err alert) error func sendAlert(c *_trsconn, err alert) error
func (c *trsconn) sendAlert(err alert) error { func (c *_trsconn) sendAlert(err alert) error {
return sendAlert(c, err) return sendAlert(c, err)
} }
@@ -208,9 +208,9 @@ const (
) )
//go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion //go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion
func pickTLSVersion(c *trsconn, serverHello *serverHelloMsg) error func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error
func (c *trsconn) pickTLSVersion(serverHello *serverHelloMsg) error { func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error {
return pickTLSVersion(c, serverHello) return pickTLSVersion(c, serverHello)
} }
@@ -293,7 +293,7 @@ func (hs *clientHandshakeState) handshake() error {
// writeHandshakeRecord writes a handshake message to the connection and updates // writeHandshakeRecord writes a handshake message to the connection and updates
// the record layer state. If transcript is non-nil the marshalled message is // the record layer state. If transcript is non-nil the marshalled message is
// written to it. // written to it.
func (c *trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock() c.out.Lock()
defer c.out.Unlock() defer c.out.Unlock()
@@ -305,10 +305,12 @@ func (c *trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcri
transcript.Write(data) transcript.Write(data)
} }
return writeRecordLocked(c, recordTypeHandshake, data) return c.writeRecordLocked(recordTypeHandshake, data)
} }
func (c *trsconn) clientHandshake(ctx context.Context) (err error) { func (cout *trsconn) clientHandshake(ctx context.Context) (err error) {
c := (*_trsconn)(unsafe.Pointer(cout))
if c.config == nil { if c.config == nil {
c.config = defaultConfig() c.config = defaultConfig()
} }
@@ -390,7 +392,7 @@ func (c *trsconn) clientHandshake(ctx context.Context) (err error) {
if c.vers == tls.VersionTLS13 { if c.vers == tls.VersionTLS13 {
hs := &clientHandshakeStateTLS13{ hs := &clientHandshakeStateTLS13{
c: c, c: cout,
ctx: ctx, ctx: ctx,
serverHello: serverHello, serverHello: serverHello,
hello: hello, hello: hello,
@@ -405,7 +407,7 @@ func (c *trsconn) clientHandshake(ctx context.Context) (err error) {
} }
hs := &clientHandshakeState{ hs := &clientHandshakeState{
c: c, c: cout,
ctx: ctx, ctx: ctx,
serverHello: serverHello, serverHello: serverHello,
hello: hello, hello: hello,

View File

@@ -7,7 +7,6 @@ import (
// Use terasu in this TLS conn // Use terasu in this TLS conn
func Use(conn *tls.Conn) *tls.Conn { func Use(conn *tls.Conn) *tls.Conn {
trsConn := (*trsconn)(unsafe.Pointer(conn)) (*_trsconn)(unsafe.Pointer(conn)).handshakeFn = (*trsconn)(conn).clientHandshake
trsConn.handshakeFn = trsConn.clientHandshake return conn
return (*tls.Conn)(unsafe.Pointer(trsConn))
} }

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
) )
func TestHTTPDialTLS(t *testing.T) { func TestHTTPDialTLS13(t *testing.T) {
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) { DialTLS: func(network, addr string) (net.Conn, error) {
@@ -38,3 +38,35 @@ func TestHTTPDialTLS(t *testing.T) {
} }
t.Log(string(data)) t.Log(string(data))
} }
func TestHTTPDialTLS12(t *testing.T) {
cli := http.Client{
Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", "18.65.159.2:443")
if err != nil {
return nil, err
}
t.Log("net.Dial succeeded")
return Use(tls.Client(conn, &tls.Config{
ServerName: "huggingface.co",
InsecureSkipVerify: true,
MaxVersion: tls.VersionTLS12,
})), nil
},
},
}
resp, err := cli.Get("https://huggingface.co/")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("status code:", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
}

28
tls.go
View File

@@ -57,9 +57,11 @@ type halfConn struct {
trafficSecret []byte // current TLS 1.3 traffic secret trafficSecret []byte // current TLS 1.3 traffic secret
} }
// A trsconn represents a secured connection. type trsconn tls.Conn
// It implements the net.trsconn interface.
type trsconn struct { // A _trsconn represents a secured connection.
// It implements the net._trsconn interface.
type _trsconn struct {
// constant // constant
conn net.Conn conn net.Conn
isClient bool isClient bool
@@ -139,12 +141,12 @@ type trsconn struct {
var outBufPool sync.Pool var outBufPool sync.Pool
//go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked //go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked
func tlsWriteRecordLocked(c *trsconn, typ recordType, data []byte) (int, error) func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error)
//go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite //go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite
func maxPayloadSizeForWrite(c *trsconn, typ recordType) int func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int
func (c *trsconn) maxPayloadSizeForWrite(typ recordType) int { func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int {
return maxPayloadSizeForWrite(c, typ) return maxPayloadSizeForWrite(c, typ)
} }
@@ -162,16 +164,16 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
func rand(c *tls.Config) io.Reader func rand(c *tls.Config) io.Reader
//go:linkname write crypto/tls.(*Conn).write //go:linkname write crypto/tls.(*Conn).write
func write(c *trsconn, data []byte) (int, error) func write(c *_trsconn, data []byte) (int, error)
func (c *trsconn) write(data []byte) (int, error) { func (c *_trsconn) write(data []byte) (int, error) {
return write(c, data) return write(c, data)
} }
//go:linkname flush crypto/tls.(*Conn).flush //go:linkname flush crypto/tls.(*Conn).flush
func flush(c *trsconn) (int, error) func flush(c *_trsconn) (int, error)
func (c *trsconn) flush() (int, error) { func (c *_trsconn) flush() (int, error) {
return flush(c) return flush(c)
} }
@@ -183,15 +185,15 @@ func (hc *halfConn) changeCipherSpec() error {
} }
//go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked //go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked
func sendAlertLocked(c *trsconn, err alert) error func sendAlertLocked(c *_trsconn, err alert) error
func (c *trsconn) sendAlertLocked(err alert) error { func (c *_trsconn) sendAlertLocked(err alert) error {
return sendAlertLocked(c, err) return sendAlertLocked(c, err)
} }
// writeRecordLocked writes a TLS record with the given type and payload to the // writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state. // connection and updates the record layer state.
func writeRecordLocked(c *trsconn, typ recordType, data []byte) (int, error) { func (c *_trsconn) writeRecordLocked(typ recordType, data []byte) (int, error) {
if c.quic != nil { if c.quic != nil {
return tlsWriteRecordLocked(c, typ, data) return tlsWriteRecordLocked(c, typ, data)
} }