From 719e0c16831b77e35367fd64561aa992b220f167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 19 Apr 2024 00:12:45 +0900 Subject: [PATCH] feat: new implementations --- README.md | 29 +------ cmd/main.go | 38 ++------- dns/dns.go | 170 +++++++++++++++++++++++++++++++++++++++ dns/dns_test.go | 96 ++++++++++++++++++++++ go.mod | 8 ++ go.sum | 16 ++++ handshake.go | 190 ++++++++++++++++++++++---------------------- http/http.go | 97 ++++++++++++++++++++++ http/http_test.go | 28 +++++++ http2/http2.go | 96 ++++++++++++++++++++++ http2/http2_test.go | 28 +++++++ ip/ipv6.go | 28 +++++++ terasu.go | 10 ++- terasu_test.go | 4 +- tls.go | 6 +- 15 files changed, 682 insertions(+), 162 deletions(-) create mode 100644 dns/dns.go create mode 100644 dns/dns_test.go create mode 100644 go.sum create mode 100644 http/http.go create mode 100644 http/http_test.go create mode 100644 http2/http2.go create mode 100644 http2/http2_test.go create mode 100644 ip/ipv6.go diff --git a/README.md b/README.md index 0267e83..4b09921 100644 --- a/README.md +++ b/README.md @@ -13,32 +13,5 @@ ## Usage ```go -cli := http.Client{ - Transport: &http.Transport{ - DialTLS: func(network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - addrs, err := net.DefaultResolver.LookupHost(ctx, host) - if err != nil { - return nil, err - } - conn, err := net.Dial(network, net.JoinHostPort(addrs[0], port)) - if err != nil { - return nil, err - } - tlsConn := tls.Client(conn, &tls.Config{ - ServerName: host, - }) - err = terasu.Use(tlsConn).Handshake() - if err != nil { - _ = tlsConn.Close() - return nil, err - } - return tlsConn, nil - }, - }, -} -resp, err := cli.Get(url) +terasu.Use(tlsConn).Handshake() ``` diff --git a/cmd/main.go b/cmd/main.go index 293f47b..0e9180a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,47 +1,25 @@ package main import ( - "crypto/tls" - "flag" "fmt" "io" - "net" "net/http" + "os" "strings" - "github.com/fumiama/terasu" + "github.com/fumiama/terasu/http2" ) func main() { - u := flag.String("url", "https://huggingface.co/", "the url to get") - ipport := flag.String("dest", "18.65.159.2:443", "host:port") - flag.Parse() - if !strings.HasPrefix(*u, "https://") { + if len(os.Args) != 2 { + fmt.Println("Usage:", os.Args[0], "url") + return + } + if !strings.HasPrefix(os.Args[1], "https://") { fmt.Println("ERROR: invalid url") return } - host := (*u)[8:] - host, _, _ = strings.Cut(host, "/") - cli := http.Client{ - Transport: &http.Transport{ - DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.Dial("tcp", *ipport) - if err != nil { - return nil, err - } - tlsConn := tls.Client(conn, &tls.Config{ - ServerName: host, - }) - err = terasu.Use(tlsConn).Handshake() - if err != nil { - _ = tlsConn.Close() - return nil, err - } - return tlsConn, nil - }, - }, - } - resp, err := cli.Get(*u) + resp, err := http2.Get(os.Args[1]) if err != nil { fmt.Println("ERROR:", err) return diff --git a/dns/dns.go b/dns/dns.go new file mode 100644 index 0000000..228b7de --- /dev/null +++ b/dns/dns.go @@ -0,0 +1,170 @@ +package dns + +import ( + "context" + "crypto/tls" + "errors" + "net" + "sync" + "time" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/ip" +) + +var ( + ErrNoDNSAvailable = errors.New("no dns available") +) + +var DefaultDialer = net.Dialer{ + Timeout: time.Second * 8, +} + +type dnsstat struct { + A string + E bool +} + +type DNSList struct { + sync.RWMutex + m map[string][]*dnsstat +} + +// hasrecord no lock, use under lock +func hasrecord(lst []*dnsstat, a string) bool { + for _, addr := range lst { + if addr.A == a { + return true + } + } + return false +} + +func (ds *DNSList) Add(m map[string][]string) { + ds.Lock() + defer ds.Unlock() + addList := map[string][]*dnsstat{} + for host, addrs := range m { + for _, addr := range addrs { + if !hasrecord(ds.m[host], addr) && !hasrecord(addList[host], addr) { + addList[host] = append(addList[host], &dnsstat{addr, true}) + } + } + } + for host, addrs := range addList { + ds.m[host] = append(ds.m[host], addrs...) + } +} + +func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFragmentLen uint8) (tlsConn *tls.Conn, err error) { + err = ErrNoDNSAvailable + + if dialer == nil { + dialer = &DefaultDialer + } + + ds.RLock() + defer ds.RUnlock() + + if dialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, dialer.Timeout) + defer cancel() + } + + if !dialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, dialer.Deadline) + defer cancel() + } + + var conn net.Conn + for host, addrs := range ds.m { + for _, addr := range addrs { + if !addr.E { + continue + } + conn, err = dialer.DialContext(ctx, "tcp", addr.A) + if err != nil { + addr.E = false // no need to acquire write lock + continue + } + tlsConn = tls.Client(conn, &tls.Config{ServerName: host}) + err = terasu.Use(tlsConn).HandshakeContext(ctx, firstFragmentLen) + if err == nil { + return + } + _ = tlsConn.Close() + addr.E = false // no need to acquire write lock + } + } + return +} + +var IPv6Servers = DNSList{ + m: map[string][]*dnsstat{ + "dot.sb": { + {"[2a09::]:853", true}, + {"[2a11::]:853", true}, + }, + "dns.google": { + {"[2001:4860:4860::8888]:853", true}, + {"[2001:4860:4860::8844]:853", true}, + }, + "cloudflare-dns.com": { + {"[2606:4700:4700::1111]:853", true}, + {"[2606:4700:4700::1001]:853", true}, + }, + "dns.opendns.com": { + {"[2620:119:35::35]:853", true}, + {"[2620:119:53::53]:853", true}, + }, + "dns10.quad9.net": { + {"[2620:fe::10]:853", true}, + {"[2620:fe::fe:10]:853", true}, + }, + }, +} + +var IPv4Servers = DNSList{ + m: map[string][]*dnsstat{ + "dot.360.cn": { + {"101.198.192.33:853", true}, + {"112.65.69.15:853", true}, + {"101.226.4.6:853", true}, + {"218.30.118.6:853", true}, + {"123.125.81.6:853", true}, + {"140.207.198.6:853", true}, + }, + "dot.sb": { + {"185.222.222.222:853", true}, + {"45.11.45.11:853", true}, + }, + "dns.google": { + {"8.8.8.8:853", true}, + {"8.8.4.4:853", true}, + }, + "cloudflare-dns.com": { + {"1.1.1.1:853", true}, + {"1.0.0.1:853", true}, + }, + "dns.opendns.com": { + {"208.67.222.222:853", true}, + {"208.67.220.220:853", true}, + }, + "dns10.quad9.net": { + {"9.9.9.10:853", true}, + {"149.112.112.10:853", true}, + }, + }, +} + +var DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + if ip.IsIPv6Available.Get() { + return IPv6Servers.DialContext(ctx, nil, terasu.DefaultFirstFragmentLen) + } + return IPv4Servers.DialContext(ctx, nil, terasu.DefaultFirstFragmentLen) + }, +} diff --git a/dns/dns_test.go b/dns/dns_test.go new file mode 100644 index 0000000..8be8a7c --- /dev/null +++ b/dns/dns_test.go @@ -0,0 +1,96 @@ +package dns + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "testing" + "time" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/ip" +) + +func TestResolver(t *testing.T) { + t.Log("IsIPv6Available:", ip.IsIPv6Available.Get()) + addrs, err := DefaultResolver.LookupHost(context.TODO(), "huggingface.co") + if err != nil { + t.Fatal(err) + } + t.Log(addrs) + if len(addrs) == 0 { + t.Fail() + } +} + +func TestDNS(t *testing.T) { + if ip.IsIPv6Available.Get() { + IPv6Servers.test() + } + IPv4Servers.test() + for i := 0; i < 100; i++ { + addrs, err := DefaultResolver.LookupHost(context.TODO(), "huggingface.co") + if err != nil { + t.Fatal(err) + } + t.Log(addrs) + if len(addrs) == 0 { + t.Fail() + } + time.Sleep(time.Millisecond * 50) + } +} + +func TestBadDNS(t *testing.T) { + dotv6serversbak := IPv6Servers.m + dotv4serversbak := IPv4Servers.m + defer func() { + IPv6Servers.m = dotv6serversbak + IPv4Servers.m = dotv4serversbak + }() + if ip.IsIPv6Available.Get() { + IPv6Servers = DNSList{ + m: map[string][]*dnsstat{}, + } + IPv6Servers.Add(map[string][]string{"test.bad.host": {"169.254.122.111"}}) + } else { + IPv4Servers = DNSList{ + m: map[string][]*dnsstat{}, + } + IPv4Servers.Add(map[string][]string{"test.bad.host": {"169.254.122.111:853"}}) + } + for i := 0; i < 10; i++ { + addrs, err := DefaultResolver.LookupHost(context.TODO(), "api.mangacopy.com") + t.Log(err) + if err == nil && len(addrs) > 0 { + t.Fatal("unexpected") + } + time.Sleep(time.Millisecond * 50) + } +} + +func (ds *DNSList) test() { + ds.RLock() + defer ds.RUnlock() + for host, addrs := range ds.m { + for _, addr := range addrs { + if !addr.E { + continue + } + fmt.Println("dial:", host, addr.A) + conn, err := net.Dial("tcp", addr.A) + if err != nil { + continue + } + tlsConn := tls.Client(conn, &tls.Config{ServerName: host}) + err = terasu.Use(tlsConn).Handshake(4) + _ = tlsConn.Close() + if err == nil { + fmt.Println("succ:", host, addr.A) + continue + } + fmt.Println("fail:", host, addr.A) + } + } +} diff --git a/go.mod b/go.mod index 02abb19..0fe894f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/fumiama/terasu go 1.22.1 + +require ( + github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 + github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 + golang.org/x/net v0.24.0 +) + +require golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..1a0c15d --- /dev/null +++ b/go.sum @@ -0,0 +1,16 @@ +github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 h1:g4pTnDJUW4VbJ9NvoRfUvdjDrHz/6QhfN/LoIIpICbo= +github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= +github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU= +github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handshake.go b/handshake.go index e9f840f..f0c3f39 100644 --- a/handshake.go +++ b/handshake.go @@ -293,7 +293,7 @@ func (hs *clientHandshakeState) handshake() error { // writeHandshakeRecord writes a handshake message to the connection and updates // the record layer state. If transcript is non-nil the marshalled message is // written to it. -func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { +func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) { c.out.Lock() defer c.out.Unlock() @@ -305,118 +305,120 @@ func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcr transcript.Write(data) } - return c.writeRecordLocked(recordTypeHandshake, data) + return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data) } -func (cout *Conn) clientHandshake(ctx context.Context) (err error) { - c := (*_trsconn)(unsafe.Pointer(cout)) +func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error { + return func(ctx context.Context) (err error) { + c := (*_trsconn)(unsafe.Pointer(cout)) - if c.config == nil { - c.config = defaultConfig() - } + if c.config == nil { + c.config = defaultConfig() + } - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false - hello, ecdheKey, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - } - }() - } - - if _, err := c.writeHandshakeRecord(hello, nil); err != nil { - return err - } - - if hello.earlyData { - suite := cipherSuiteTLS13ByID(session.cipherSuite) - transcript := suite.hash.New() - if err := transcriptMsg(hello, transcript); err != nil { + hello, ecdheKey, err := c.makeClientHello() + if err != nil { return err } - earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) - quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) - } + c.serverName = hello.serverName - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } + session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } + if session != nil { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + if cacheKey := c.clientSessionCacheKey(); cacheKey != "" { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + } + }() + } - var serverHello *serverHelloMsg - if !isTypeEqual(msg, "*tls.serverHelloMsg") { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( - unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), - )) + if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil { + return err + } - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } + if hello.earlyData { + suite := cipherSuiteTLS13ByID(session.cipherSuite) + transcript := suite.hash.New() + if err := transcriptMsg(hello, transcript); err != nil { + return err + } + earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) + quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) + } - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := maxSupportedVersion(c.config, roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) + if err != nil { + return err + } - if c.vers == tls.VersionTLS13 { - hs := &clientHandshakeStateTLS13{ + var serverHello *serverHelloMsg + if !isTypeEqual(msg, "*tls.serverHelloMsg") { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)( + unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))), + )) + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := maxSupportedVersion(c.config, roleClient) + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == tls.VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: cout, + ctx: ctx, + serverHello: serverHello, + hello: hello, + ecdheKey: ecdheKey, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ c: cout, ctx: ctx, serverHello: serverHello, hello: hello, - ecdheKey: ecdheKey, session: session, - earlySecret: earlySecret, - binderKey: binderKey, } - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } + if err := hs.handshake(); err != nil { + return err + } - hs := &clientHandshakeState{ - c: cout, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, + return nil } - - if err := hs.handshake(); err != nil { - return err - } - - return nil } diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..14ddc1c --- /dev/null +++ b/http/http.go @@ -0,0 +1,97 @@ +package http + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/url" + "time" + + "github.com/FloatTech/ttl" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/dns" +) + +var ( + ErrEmptyHostAddress = errors.New("empty host addr") +) + +var DefaultDialer = net.Dialer{ + Timeout: time.Minute, +} + +var lookupTable = ttl.NewCache[string, []string](time.Hour) + +var DefaultClient = http.Client{ + Transport: &http.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if DefaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, DefaultDialer.Timeout) + defer cancel() + } + + if !DefaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, DefaultDialer.Deadline) + defer cancel() + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + addrs := lookupTable.Get(host) + if len(addrs) == 0 { + addrs, err = dns.DefaultResolver.LookupHost(ctx, host) + if err != nil { + addrs, err = net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + } + lookupTable.Set(host, addrs) + } + if len(addr) == 0 { + return nil, ErrEmptyHostAddress + } + var tlsConn *tls.Conn + for _, a := range addrs { + conn, err := DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, &tls.Config{ + ServerName: host, + }) + err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + } + return tlsConn, err + }, + }, +} + +func Get(url string) (resp *http.Response, err error) { + return DefaultClient.Get(url) +} + +func Head(url string) (resp *http.Response, err error) { + return DefaultClient.Head(url) +} + +func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { + return DefaultClient.Post(url, contentType, body) +} + +func PostForm(url string, data url.Values) (resp *http.Response, err error) { + return DefaultClient.PostForm(url, data) +} diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..9dc2068 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,28 @@ +package http + +import ( + "io" + "testing" +) + +func TestClientGet(t *testing.T) { + resp, err := Get("https://huggingface.co/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + t.Log("[T] response code", resp.StatusCode) + for k, vs := range resp.Header { + for _, v := range vs { + t.Log("[T] response header", k+":", v) + } + } + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fail() + } + t.Log(string(data)) +} diff --git a/http2/http2.go b/http2/http2.go new file mode 100644 index 0000000..a4122fb --- /dev/null +++ b/http2/http2.go @@ -0,0 +1,96 @@ +package http2 + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/url" + "time" + + "github.com/FloatTech/ttl" + "golang.org/x/net/http2" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/dns" +) + +var ( + ErrEmptyHostAddress = errors.New("empty host addr") +) + +var DefaultDialer = net.Dialer{ + Timeout: time.Minute, +} + +var lookupTable = ttl.NewCache[string, []string](time.Hour) + +var DefaultClient = http.Client{ + Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + if DefaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, DefaultDialer.Timeout) + defer cancel() + } + + if !DefaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, DefaultDialer.Deadline) + defer cancel() + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + addrs := lookupTable.Get(host) + if len(addrs) == 0 { + addrs, err = dns.DefaultResolver.LookupHost(ctx, host) + if err != nil { + addrs, err = net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + } + lookupTable.Set(host, addrs) + } + if len(addr) == 0 { + return nil, ErrEmptyHostAddress + } + var tlsConn *tls.Conn + for _, a := range addrs { + conn, err := DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, cfg) + err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + } + return tlsConn, err + }, + }, +} + +func Get(url string) (resp *http.Response, err error) { + return DefaultClient.Get(url) +} + +func Head(url string) (resp *http.Response, err error) { + return DefaultClient.Head(url) +} + +func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { + return DefaultClient.Post(url, contentType, body) +} + +func PostForm(url string, data url.Values) (resp *http.Response, err error) { + return DefaultClient.PostForm(url, data) +} diff --git a/http2/http2_test.go b/http2/http2_test.go new file mode 100644 index 0000000..cadf0a9 --- /dev/null +++ b/http2/http2_test.go @@ -0,0 +1,28 @@ +package http2 + +import ( + "io" + "testing" +) + +func TestClientGet(t *testing.T) { + resp, err := Get("https://huggingface.co/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + t.Log("[T] response code", resp.StatusCode) + for k, vs := range resp.Header { + for _, v := range vs { + t.Log("[T] response header", k+":", v) + } + } + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fail() + } + t.Log(string(data)) +} diff --git a/ip/ipv6.go b/ip/ipv6.go new file mode 100644 index 0000000..880a271 --- /dev/null +++ b/ip/ipv6.go @@ -0,0 +1,28 @@ +package ip + +import ( + "context" + "net/http" + "time" + + "github.com/RomiChan/syncx" +) + +var IsIPv6Available = syncx.Lazy[bool]{Init: func() bool { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", "http://v6.ipv6-test.com/json/widgetdata.php?callback=?", nil) + if err != nil { + return false + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false + } + _ = resp.Body.Close() + return true +}} + +func init() { + go IsIPv6Available.Get() +} diff --git a/terasu.go b/terasu.go index 24ceb4a..54a03a7 100644 --- a/terasu.go +++ b/terasu.go @@ -6,25 +6,27 @@ import ( "unsafe" ) +var DefaultFirstFragmentLen uint8 = 4 + // Use terasu in this TLS conn func Use(conn *tls.Conn) *Conn { return (*Conn)(conn) } // Handshake do terasu handshake in this TLS conn -func (conn *Conn) Handshake() error { +func (conn *Conn) Handshake(firstFragmentLen uint8) error { expose := (*_trsconn)(unsafe.Pointer(conn)) fnbak := expose.handshakeFn - expose.handshakeFn = conn.clientHandshake + expose.handshakeFn = conn.clientHandshake(firstFragmentLen) defer func() { expose.handshakeFn = fnbak }() return (*tls.Conn)(conn).Handshake() } // Handshake do terasu handshake with ctx in this TLS conn -func (conn *Conn) HandshakeContext(ctx context.Context) error { +func (conn *Conn) HandshakeContext(ctx context.Context, firstFragmentLen uint8) error { expose := (*_trsconn)(unsafe.Pointer(conn)) fnbak := expose.handshakeFn - expose.handshakeFn = conn.clientHandshake + expose.handshakeFn = conn.clientHandshake(firstFragmentLen) defer func() { expose.handshakeFn = fnbak }() return (*tls.Conn)(conn).HandshakeContext(ctx) } diff --git a/terasu_test.go b/terasu_test.go index bc24a1b..b282910 100644 --- a/terasu_test.go +++ b/terasu_test.go @@ -21,7 +21,7 @@ func TestHTTPDialTLS13(t *testing.T) { ServerName: "huggingface.co", InsecureSkipVerify: true, }) - err = Use(tlsConn).Handshake() + err = Use(tlsConn).Handshake(4) if err != nil { _ = tlsConn.Close() return nil, err @@ -59,7 +59,7 @@ func TestHTTPDialTLS12(t *testing.T) { InsecureSkipVerify: true, MaxVersion: tls.VersionTLS12, }) - err = Use(tlsConn).Handshake() + err = Use(tlsConn).Handshake(4) if err != nil { _ = tlsConn.Close() return nil, err diff --git a/tls.go b/tls.go index 35ae50d..15bd9b1 100644 --- a/tls.go +++ b/tls.go @@ -13,8 +13,6 @@ import ( _ "unsafe" ) -const firstFragmentLen = 4 - type recordType uint8 const ( @@ -193,7 +191,7 @@ func (c *_trsconn) sendAlertLocked(err alert) error { // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. -func (c *_trsconn) writeRecordLocked(typ recordType, data []byte) (int, error) { +func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { if c.quic != nil { return tlsWriteRecordLocked(c, typ, data) } @@ -219,7 +217,7 @@ func (c *_trsconn) writeRecordLocked(typ recordType, data []byte) (int, error) { m = maxPayload } } else { - m = firstFragmentLen + m = int(firstFragmentLen) } _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)