1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-05 01:00:23 +08:00

optimize: api

This commit is contained in:
源文雨
2024-04-16 15:10:47 +09:00
parent dc4fb1ae72
commit 62d3c9f6be
5 changed files with 50 additions and 15 deletions

View File

@@ -29,10 +29,15 @@ func main() {
if err != nil {
return nil, err
}
return terasu.Use(tls.Client(conn, &tls.Config{
ServerName: host,
InsecureSkipVerify: true,
})), nil
tlsConn := tls.Client(conn, &tls.Config{
ServerName: host,
})
err = terasu.Use(tlsConn).Handshake()
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return tlsConn, nil
},
},
}

View File

@@ -228,7 +228,7 @@ const (
)
type clientHandshakeStateTLS13 struct {
c *trsconn
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
@@ -272,7 +272,7 @@ type finishedHash struct {
}
type clientHandshakeState struct {
c *trsconn
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
@@ -308,7 +308,7 @@ func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcr
return c.writeRecordLocked(recordTypeHandshake, data)
}
func (cout *trsconn) clientHandshake(ctx context.Context) (err error) {
func (cout *Conn) clientHandshake(ctx context.Context) (err error) {
c := (*_trsconn)(unsafe.Pointer(cout))
if c.config == nil {

View File

@@ -1,12 +1,30 @@
package terasu
import (
"context"
"crypto/tls"
"unsafe"
)
// Use terasu in this TLS conn
func Use(conn *tls.Conn) *tls.Conn {
(*_trsconn)(unsafe.Pointer(conn)).handshakeFn = (*trsconn)(conn).clientHandshake
return conn
func Use(conn *tls.Conn) *Conn {
return (*Conn)(conn)
}
// Handshake do terasu handshake in this TLS conn
func (conn *Conn) Handshake() error {
expose := (*_trsconn)(unsafe.Pointer(conn))
fnbak := expose.handshakeFn
expose.handshakeFn = conn.clientHandshake
defer func() { expose.handshakeFn = fnbak }()
return (*tls.Conn)(conn).Handshake()
}
// Handshake do terasu handshake with ctx in this TLS conn
func (conn *Conn) HandshakeContext(ctx context.Context) error {
expose := (*_trsconn)(unsafe.Pointer(conn))
fnbak := expose.handshakeFn
expose.handshakeFn = conn.clientHandshake
defer func() { expose.handshakeFn = fnbak }()
return (*tls.Conn)(conn).HandshakeContext(ctx)
}

View File

@@ -17,10 +17,16 @@ func TestHTTPDialTLS13(t *testing.T) {
return nil, err
}
t.Log("net.Dial succeeded")
return Use(tls.Client(conn, &tls.Config{
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "huggingface.co",
InsecureSkipVerify: true,
})), nil
})
err = Use(tlsConn).Handshake()
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return tlsConn, nil
},
},
}
@@ -48,11 +54,17 @@ func TestHTTPDialTLS12(t *testing.T) {
return nil, err
}
t.Log("net.Dial succeeded")
return Use(tls.Client(conn, &tls.Config{
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "huggingface.co",
InsecureSkipVerify: true,
MaxVersion: tls.VersionTLS12,
})), nil
})
err = Use(tlsConn).Handshake()
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return tlsConn, nil
},
},
}

2
tls.go
View File

@@ -57,7 +57,7 @@ type halfConn struct {
trafficSecret []byte // current TLS 1.3 traffic secret
}
type trsconn tls.Conn
type Conn tls.Conn
// A _trsconn represents a secured connection.
// It implements the net._trsconn interface.