From bda0c8de97e15d258d3a76e9ef5de74b23d8e490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 23 Oct 2025 23:33:06 +0800 Subject: [PATCH] fix: error on different frag lens --- cmd/main.go | 3 ++- conn.go | 31 ++++++++++++++++------ conn_test.go | 44 ++++++++++++++++++++++++++++++ dns/dns.go | 17 +++++++----- dns/doh.go | 39 ++++++++++++++++++++------- http/http.go | 18 ++++++++++--- http2/http2.go | 14 +++++++--- ip/ipv6.go | 1 + relay.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 206 insertions(+), 33 deletions(-) create mode 100644 relay.go diff --git a/cmd/main.go b/cmd/main.go index 9b7b0fd..b2bc086 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,4 +1,5 @@ -// Package main ... +// Package main provides the main entry point for terasu. +// It demonstrates basic Go usage of this library. package main import ( diff --git a/conn.go b/conn.go index 8ce9265..a08cfa0 100644 --- a/conn.go +++ b/conn.go @@ -1,8 +1,9 @@ package terasu import ( - "bytes" "encoding/binary" + "encoding/hex" + "fmt" "io" "net" "sync" @@ -14,14 +15,19 @@ var DefaultFirstFragmentLen = 4 // Conn remote: real server; local: relay type Conn struct { - mu sync.Mutex + relay relay + init *sync.Once conn *net.TCPConn isold bool } // NewConn wraps *net.TCPConn (net.Conn must be *net.TCPConn) func NewConn(conn net.Conn) *Conn { - return &Conn{conn: conn.(*net.TCPConn)} + return &Conn{ + relay: newrelay(), + init: &sync.Once{}, + conn: conn.(*net.TCPConn), + } } // Write is send @@ -29,14 +35,21 @@ func (conn *Conn) Write(b []byte) (int, error) { if conn.isold || DefaultFirstFragmentLen == 0 { return conn.conn.Write(b) } - conn.mu.Lock() - defer conn.mu.Unlock() - n, err := conn.ReadFrom(bytes.NewReader(b)) - return int(n), err + go conn.init.Do(func() { + _, err := io.Copy(conn, &conn.relay) + if err != nil { + _ = conn.relay.Close() + } + }) + return conn.relay.Write(b) } // ReadFrom when client want to send to server, detect and split. func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) { + if conn.isold || DefaultFirstFragmentLen == 0 { + return conn.conn.ReadFrom(r) + } + // ContentType [0:1] // Version [1:3] // Length [3:5] @@ -102,6 +115,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) { // split if x <= 4 { // first is in header range + fmt.Println(hex.EncodeToString(header[:])) // first binary.BigEndian.PutUint16(header[3:5], uint16(x)) bd.move(header[:5+x]) @@ -110,7 +124,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) { if err != nil { return } - copy(header[5:5+x], header[9-x:9]) + copy(header[5:9-x], header[5+x:9]) // second binary.BigEndian.PutUint16(header[3:5], plen-uint16(x)) bd.move(header[:9-x]) @@ -138,6 +152,7 @@ PIPE: if err != nil { return } + _ = conn.relay.Close() cnt, err := bd.send(conn.conn, r) n += cnt return diff --git a/conn_test.go b/conn_test.go index 32f7453..3108eed 100644 --- a/conn_test.go +++ b/conn_test.go @@ -9,6 +9,50 @@ import ( "testing" ) +func TestHTTPDialDifferentFragLen(t *testing.T) { + cli := http.Client{ + Transport: &http.Transport{ + DialTLS: func(network, addr string) (net.Conn, error) { + conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( + netip.MustParseAddrPort("52.222.136.117:443"), + )) + if err != nil { + return nil, err + } + t.Log("net.Dial succeeded") + tlsConn := tls.Client(NewConn(conn), &tls.Config{ + ServerName: "huggingface.co", + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + }) + err = tlsConn.Handshake() + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return tlsConn, nil + }, + }, + } + for i := 0; i < 10; i++ { + // will fail when i=0 in CN + DefaultFirstFragmentLen = i + resp, err := cli.Get("https://huggingface.co/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatal("status code:", resp.StatusCode) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(data)) + } +} + func TestHTTPDialTLS13(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ diff --git a/dns/dns.go b/dns/dns.go index e97cfca..c284f36 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "net" + "slices" "strings" "sync" "syscall" @@ -27,6 +28,7 @@ var dnsDialer = net.Dialer{ Timeout: time.Second * 4, } +// SetTimeout ... func SetTimeout(t time.Duration) { dnsDialer.Timeout = t } @@ -37,6 +39,7 @@ type dnsstat struct { keep bool } +// String ... func (ds *dnsstat) String() string { sb := strings.Builder{} sb.WriteString("[addr: ") @@ -78,6 +81,7 @@ func (ds *dnsstat) disable(reEnable time.Duration) { }) } +// DNSList is a bundle of DNSs type DNSList struct { sync.RWMutex hostseq []string @@ -85,6 +89,7 @@ type DNSList struct { b map[string][]string } +// DNSConfig is the user config type DNSConfig struct { Servers map[string][]string `yaml:"Servers"` // Servers map[dot.com]ip:ports Fallbacks map[string][]string `yaml:"Fallbacks"` // Fallbacks map[domain]ips @@ -102,14 +107,10 @@ func hasrecord(lst []*dnsstat, a string) bool { // hasrecord no lock, use under lock func hasfallback(lst []string, a string) bool { - for _, addr := range lst { - if addr == a { - return true - } - } - return false + return slices.Contains(lst, a) } +// Add ... func (ds *DNSList) Add(c *DNSConfig) { ds.Lock() defer ds.Unlock() @@ -193,6 +194,7 @@ func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []stri return nil, ErrNoDNSAvailable } +// DialContext ... func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) { err = ErrNoDNSAvailable @@ -267,6 +269,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn return } +// IPv6Servers should only be used when IPv6 is available var IPv6Servers = DNSList{ hostseq: []string{ "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", @@ -303,6 +306,7 @@ var IPv6Servers = DNSList{ b: map[string][]string{}, } +// IPv4Servers is the default server set var IPv4Servers = DNSList{ hostseq: []string{ "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", @@ -339,6 +343,7 @@ var IPv4Servers = DNSList{ b: map[string][]string{}, } +// DefaultResolver ... var DefaultResolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, nw, _ string) (net.Conn, error) { diff --git a/dns/doh.go b/dns/doh.go index a2f1509..86325f4 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -18,6 +18,7 @@ import ( ) var ( + // ErrEmptyHostAddress ... ErrEmptyHostAddress = errors.New("empty host addr") ) @@ -29,25 +30,43 @@ const ( recordTypeAAAA recordType = 28 ) +// dohjsonresponse represents the JSON response structure for DNS over HTTPS (DoH) queries. +// It contains DNS query results and metadata about the response. type dohjsonresponse struct { - Status uint32 - TC bool - RD bool - RA bool - AD bool - CD bool + // Status indicates the DNS query status code (0 = NOERROR, etc.) + Status uint32 + // TC indicates whether the response was truncated (true if truncated) + TC bool + // RD indicates whether recursion was requested in the query + RD bool + // RA indicates whether the server supports recursion + RA bool + // AD indicates whether the response was authenticated (DNSSEC) + AD bool + // CD indicates whether the client requested that DNSSEC validation be disabled + CD bool + // Question contains the DNS query question section with name and type Question []struct { - Name string `json:"name"` + // Name is the domain name being queried + Name string `json:"name"` + // Type is the DNS record type being requested (A, AAAA, etc.) Type recordType `json:"type"` } + // Answer contains the DNS response answer section with resource records Answer []struct { - Name string `json:"name"` + // Name is the domain name for this resource record + Name string `json:"name"` + // Type is the DNS record type (A, AAAA, etc.) Type recordType `json:"type"` - TTL uint16 + // TTL is the time-to-live value for this resource record in seconds + TTL uint16 + // Data is the textual representation of the resource record data Data string `json:"data"` } + // EdnsClientSubnet is the EDNS client subnet information for geolocation EdnsClientSubnet string `json:"edns_client_subnet"` - Comment string + // Comment is an optional comment field for additional information + Comment string } func (jr *dohjsonresponse) hosts() []string { diff --git a/http/http.go b/http/http.go index b4aa946..433bdba 100644 --- a/http/http.go +++ b/http/http.go @@ -1,4 +1,4 @@ -// Package http is the same as the standard http lib +// Package http is a wrapper around the standard http library with enhanced DNS resolution and TLS handling capabilities. package http import ( @@ -16,18 +16,23 @@ import ( ) var ( - ErrNoTLSConnection = errors.New("no tls connection") + // ErrNoTLSConnection is returned when a TLS connection cannot be established. + ErrNoTLSConnection = errors.New("no tls connection") + // ErrEmptyHostAddress is returned when the host address is empty. ErrEmptyHostAddress = errors.New("empty host addr") ) +// defaultDialer is the default dialer used for connecting to hosts. var defaultDialer = net.Dialer{ Timeout: 10 * time.Second, } +// SetDefaultClientTimeout sets the default timeout for the client's dialer. func SetDefaultClientTimeout(t time.Duration) { defaultDialer.Timeout = t } +// DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling. var DefaultClient = http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -40,12 +45,13 @@ var DefaultClient = http.Client{ if err != nil { return nil, err } - if len(addr) == 0 { + if len(addrs) == 0 { return nil, ErrEmptyHostAddress } var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { + // Apply timeout if set, otherwise use deadline if defaultDialer.Timeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) @@ -63,7 +69,7 @@ var DefaultClient = http.Client{ ServerName: host, MinVersion: tls.VersionTLS12, }) - // re-init ctx due to deadline settings in tcp dial + // Re-initialize context due to potential deadline changes from TCP dial if defaultDialer.Timeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) @@ -104,18 +110,22 @@ var DefaultClient = http.Client{ }, } +// Get performs an HTTP GET request using the default client. func Get(url string) (resp *http.Response, err error) { return DefaultClient.Get(url) } +// Head performs an HTTP HEAD request using the default client. func Head(url string) (resp *http.Response, err error) { return DefaultClient.Head(url) } +// Post performs an HTTP POST request using the default client. func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { return DefaultClient.Post(url, contentType, body) } +// PostForm performs an HTTP POST request with form data using the default client. func PostForm(url string, data url.Values) (resp *http.Response, err error) { return DefaultClient.PostForm(url, data) } diff --git a/http2/http2.go b/http2/http2.go index f5630c8..5550a31 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -17,18 +17,20 @@ import ( "github.com/fumiama/terasu/dns" ) -var ( - ErrEmptyHostAddress = errors.New("empty host addr") -) +// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses +var ErrEmptyHostAddress = errors.New("empty host addr") +// defaultDialer is the default dialer used for establishing TCP connections var defaultDialer = net.Dialer{ Timeout: 10 * time.Second, } +// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections func SetDefaultClientTimeout(t time.Duration) { defaultDialer.Timeout = t } +// DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution var DefaultClient = http.Client{ Transport: &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { @@ -40,7 +42,7 @@ var DefaultClient = http.Client{ if err != nil { return nil, err } - if len(addr) == 0 { + if len(addrs) == 0 { return nil, ErrEmptyHostAddress } var conn net.Conn @@ -93,18 +95,22 @@ var DefaultClient = http.Client{ }, } +// Get sends an HTTP GET request to the specified URL using the default HTTP2 client func Get(url string) (resp *http.Response, err error) { return DefaultClient.Get(url) } +// Head sends an HTTP HEAD request to the specified URL using the default HTTP2 client func Head(url string) (resp *http.Response, err error) { return DefaultClient.Head(url) } +// Post sends an HTTP POST request to the specified URL with the given content type and body using the default HTTP2 client func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { return DefaultClient.Post(url, contentType, body) } +// PostForm sends an HTTP POST request with form data to the specified URL using the default HTTP2 client func PostForm(url string, data url.Values) (resp *http.Response, err error) { return DefaultClient.PostForm(url, data) } diff --git a/ip/ipv6.go b/ip/ipv6.go index 17c339b..eb7e9c0 100644 --- a/ip/ipv6.go +++ b/ip/ipv6.go @@ -1,4 +1,5 @@ // Package ip contains IP-related configs package ip +// IsIPv6Available ... var IsIPv6Available = false diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..9730d6b --- /dev/null +++ b/relay.go @@ -0,0 +1,72 @@ +package terasu + +import ( + "io" + "sync" +) + +type relay struct { + mu sync.Mutex + buf chan []byte + rem []byte +} + +func newrelay() relay { + return relay{buf: make(chan []byte, 64)} +} + +// Read ... +func (r *relay) Read(p []byte) (n int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + switch { + case len(p) == 0: + return + case len(p) <= len(r.rem): + n = copy(p, r.rem) + r.rem = r.rem[n:] + if len(r.rem) == 0 { + r.rem = nil + } + return + case len(r.rem) > 0: + n = copy(p, r.rem) + r.rem = nil + fallthrough + default: + for n < len(p) { + buf := <-r.buf + if len(buf) == 0 { + err = io.EOF + return + } + switch { + case len(buf) >= len(p)-n: + cnt := copy(p[n:], buf) + n += cnt + r.rem = buf[cnt:] + if len(r.rem) == 0 { + r.rem = nil + } + return + default: + n += copy(p[n:], buf) + } + } + } + panic("unexpected") +} + +// Write ... +func (r *relay) Write(p []byte) (n int, err error) { + buf := make([]byte, len(p)) + n = copy(buf, p) + r.buf <- p + return +} + +// Close ... +func (r *relay) Close() error { + close(r.buf) + return nil +}