From 62d3c9f6be8000376bb715b862e962365753c595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:10:47 +0900 Subject: [PATCH] optimize: api --- cmd/main.go | 13 +++++++++---- handshake.go | 6 +++--- terasu.go | 24 +++++++++++++++++++++--- terasu_test.go | 20 ++++++++++++++++---- tls.go | 2 +- 5 files changed, 50 insertions(+), 15 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index f952848..293f47b 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,10 +29,15 @@ func main() { if err != nil { return nil, err } - return terasu.Use(tls.Client(conn, &tls.Config{ - ServerName: host, - InsecureSkipVerify: true, - })), nil + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: host, + }) + err = terasu.Use(tlsConn).Handshake() + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return tlsConn, nil }, }, } diff --git a/handshake.go b/handshake.go index 0e008c4..e9f840f 100644 --- a/handshake.go +++ b/handshake.go @@ -228,7 +228,7 @@ const ( ) type clientHandshakeStateTLS13 struct { - c *trsconn + c *Conn ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg @@ -272,7 +272,7 @@ type finishedHash struct { } type clientHandshakeState struct { - c *trsconn + c *Conn ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg @@ -308,7 +308,7 @@ func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcr return c.writeRecordLocked(recordTypeHandshake, data) } -func (cout *trsconn) clientHandshake(ctx context.Context) (err error) { +func (cout *Conn) clientHandshake(ctx context.Context) (err error) { c := (*_trsconn)(unsafe.Pointer(cout)) if c.config == nil { diff --git a/terasu.go b/terasu.go index c504411..24ceb4a 100644 --- a/terasu.go +++ b/terasu.go @@ -1,12 +1,30 @@ package terasu import ( + "context" "crypto/tls" "unsafe" ) // Use terasu in this TLS conn -func Use(conn *tls.Conn) *tls.Conn { - (*_trsconn)(unsafe.Pointer(conn)).handshakeFn = (*trsconn)(conn).clientHandshake - return conn +func Use(conn *tls.Conn) *Conn { + return (*Conn)(conn) +} + +// Handshake do terasu handshake in this TLS conn +func (conn *Conn) Handshake() error { + expose := (*_trsconn)(unsafe.Pointer(conn)) + fnbak := expose.handshakeFn + expose.handshakeFn = conn.clientHandshake + defer func() { expose.handshakeFn = fnbak }() + return (*tls.Conn)(conn).Handshake() +} + +// Handshake do terasu handshake with ctx in this TLS conn +func (conn *Conn) HandshakeContext(ctx context.Context) error { + expose := (*_trsconn)(unsafe.Pointer(conn)) + fnbak := expose.handshakeFn + expose.handshakeFn = conn.clientHandshake + defer func() { expose.handshakeFn = fnbak }() + return (*tls.Conn)(conn).HandshakeContext(ctx) } diff --git a/terasu_test.go b/terasu_test.go index 7300091..bc24a1b 100644 --- a/terasu_test.go +++ b/terasu_test.go @@ -17,10 +17,16 @@ func TestHTTPDialTLS13(t *testing.T) { return nil, err } t.Log("net.Dial succeeded") - return Use(tls.Client(conn, &tls.Config{ + tlsConn := tls.Client(conn, &tls.Config{ ServerName: "huggingface.co", InsecureSkipVerify: true, - })), nil + }) + err = Use(tlsConn).Handshake() + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return tlsConn, nil }, }, } @@ -48,11 +54,17 @@ func TestHTTPDialTLS12(t *testing.T) { return nil, err } t.Log("net.Dial succeeded") - return Use(tls.Client(conn, &tls.Config{ + tlsConn := tls.Client(conn, &tls.Config{ ServerName: "huggingface.co", InsecureSkipVerify: true, MaxVersion: tls.VersionTLS12, - })), nil + }) + err = Use(tlsConn).Handshake() + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return tlsConn, nil }, }, } diff --git a/tls.go b/tls.go index e5092e1..35ae50d 100644 --- a/tls.go +++ b/tls.go @@ -57,7 +57,7 @@ type halfConn struct { trafficSecret []byte // current TLS 1.3 traffic secret } -type trsconn tls.Conn +type Conn tls.Conn // A _trsconn represents a secured connection. // It implements the net._trsconn interface.