diff --git a/dns/cache.go b/dns/cache.go new file mode 100644 index 0000000..2288f7d --- /dev/null +++ b/dns/cache.go @@ -0,0 +1,31 @@ +package dns + +import ( + "context" + "time" + + "github.com/FloatTech/ttl" + "github.com/fumiama/terasu/ip" +) + +var lookupTable = ttl.NewCache[string, []string](time.Hour) + +// LookupHost use default resolver with its fallback +func LookupHost(ctx context.Context, host string) (addrs []string, err error) { + addrs = lookupTable.Get(host) + if len(addrs) == 0 { + addrs, err = DefaultResolver.LookupHost(ctx, host) + if err != nil { + if ip.IsIPv6Available.Get() { + addrs, err = IPv6Servers.lookupHostDoH(ctx, host) + } else { + addrs, err = IPv4Servers.lookupHostDoH(ctx, host) + } + if err != nil { + return nil, err + } + } + lookupTable.Set(host, addrs) + } + return +} diff --git a/dns/dns.go b/dns/dns.go index 26f5e13..ad3b192 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -14,15 +14,16 @@ import ( ) var ( + // ErrNoDNSAvailable is reported when all servers failed to response ErrNoDNSAvailable = errors.New("no dns available") ) -var defaultDialer = net.Dialer{ +var dnsDialer = net.Dialer{ Timeout: time.Second * 4, } func SetTimeout(t time.Duration) { - defaultDialer.Timeout = t + dnsDialer.Timeout = t } type dnsstat struct { @@ -88,7 +89,7 @@ func (ds *DNSList) Add(c *DNSConfig) { } } -func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]string, error) { +func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) ([]string, error) { ds.RLock() defer ds.RUnlock() // try to use DoH first @@ -97,7 +98,7 @@ func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]strin if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH continue } - jr, err := lookupdoh(addr.a, host) + jr, err := lookupdoh(ctx, addr.a, host) if err == nil { hosts := jr.hosts() if len(hosts) > 0 { @@ -110,37 +111,34 @@ func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]strin if addrs, ok := ds.b[host]; ok { return addrs, nil } - return net.DefaultResolver.LookupHost(ctx, host) + return nil, ErrNoDNSAvailable } func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFragmentLen uint8) (tlsConn *tls.Conn, err error) { err = ErrNoDNSAvailable if dialer == nil { - dialer = &defaultDialer + dialer = &dnsDialer } ds.RLock() defer ds.RUnlock() - if dialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, dialer.Timeout) - defer cancel() - } - - if !dialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, dialer.Deadline) - defer cancel() - } - var conn net.Conn for host, addrs := range ds.m { for _, addr := range addrs { if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH continue } + if dialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout) + defer cancel() + } else if !dialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) + defer cancel() + } conn, err = dialer.DialContext(ctx, "tcp", addr.a) if err != nil { addr.e = false // no need to acquire write lock diff --git a/dns/dns_test.go b/dns/dns_test.go index 186136a..62b8035 100644 --- a/dns/dns_test.go +++ b/dns/dns_test.go @@ -28,7 +28,7 @@ func TestResolverFallback(t *testing.T) { t.Log("IsIPv6Available:", ip.IsIPv6Available.Get()) if ip.IsIPv6Available.Get() { - addrs, err := IPv6Servers.LookupHostFallback(context.TODO(), "huggingface.co") + addrs, err := IPv6Servers.lookupHostDoH(context.TODO(), "huggingface.co") if err != nil { t.Fatal(err) } @@ -37,7 +37,7 @@ func TestResolverFallback(t *testing.T) { t.Fail() } } - addrs, err := IPv4Servers.LookupHostFallback(context.TODO(), "huggingface.co") + addrs, err := IPv4Servers.lookupHostDoH(context.TODO(), "huggingface.co") if err != nil { t.Fatal(err) } diff --git a/dns/doh.go b/dns/doh.go index dd4d4f2..76b118b 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -10,9 +10,7 @@ import ( "net/url" "strconv" "strings" - "time" - "github.com/FloatTech/ttl" "golang.org/x/net/http2" "github.com/fumiama/terasu" @@ -65,23 +63,9 @@ func (jr *dohjsonresponse) hosts() []string { return hosts } -var lookupTable = ttl.NewCache[string, []string](time.Hour) - var trsHTTP2ClientWithSystemDNS = http.Client{ Transport: &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) - defer cancel() - } - - if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) - defer cancel() - } - host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -100,7 +84,7 @@ var trsHTTP2ClientWithSystemDNS = http.Client{ var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue } @@ -111,7 +95,7 @@ var trsHTTP2ClientWithSystemDNS = http.Client{ } _ = tlsConn.Close() tlsConn = nil - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue } @@ -128,18 +112,18 @@ var trsHTTP2ClientWithSystemDNS = http.Client{ }, } -func lookupdoh(server, u string) (jr dohjsonresponse, err error) { - jr, err = lookupdohwithtype(server, u, preferreddohtype()) +func lookupdoh(ctx context.Context, server, u string) (jr dohjsonresponse, err error) { + jr, err = lookupdohwithtype(ctx, server, u, preferreddohtype()) if err == nil { return } if ip.IsIPv6Available.Get() { - jr, err = lookupdohwithtype(server, u, recordTypeA) + jr, err = lookupdohwithtype(ctx, server, u, recordTypeA) } return } -func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, err error) { +func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (jr dohjsonresponse, err error) { sb := strings.Builder{} sb.WriteString(server) sb.WriteString("?name=") @@ -148,7 +132,7 @@ func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, er sb.WriteString("&type=") sb.WriteString(strconv.Itoa(int(typ))) } - req, err := http.NewRequest("GET", sb.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", sb.String(), nil) if err != nil { return } diff --git a/go.mod b/go.mod index 6ca9e72..f0ee650 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/fumiama/terasu go 1.20 require ( - github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 + github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 golang.org/x/net v0.24.0 ) diff --git a/go.sum b/go.sum index 0788821..79dcf50 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 h1:g4pTnDJUW4VbJ9NvoRfUvdjDrHz/6QhfN/LoIIpICbo= -github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= +github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d h1:mUQ/c3wXKsUGa4Sg9DBy01APXKB68PmobhxOyaJI7lY= +github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/http/http.go b/http/http.go index f1f4452..7e55fd0 100644 --- a/http/http.go +++ b/http/http.go @@ -10,11 +10,8 @@ import ( "net/url" "time" - "github.com/FloatTech/ttl" - "github.com/fumiama/terasu" "github.com/fumiama/terasu/dns" - "github.com/fumiama/terasu/ip" ) var ( @@ -23,49 +20,24 @@ var ( ) var defaultDialer = net.Dialer{ - Timeout: time.Minute, + Timeout: 10 * time.Second, } func SetDefaultClientTimeout(t time.Duration) { defaultDialer.Timeout = t } -var lookupTable = ttl.NewCache[string, []string](time.Hour) - var DefaultClient = http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) - defer cancel() - } - - if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) - defer cancel() - } - host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } - addrs := lookupTable.Get(host) - if len(addrs) == 0 { - addrs, err = dns.DefaultResolver.LookupHost(ctx, host) - if err != nil { - if ip.IsIPv6Available.Get() { - addrs, err = dns.IPv6Servers.LookupHostFallback(ctx, host) - } else { - addrs, err = dns.IPv4Servers.LookupHostFallback(ctx, host) - } - if err != nil { - return nil, err - } - } - lookupTable.Set(host, addrs) + addrs, err := dns.LookupHost(ctx, host) + if err != nil { + return nil, err } if len(addr) == 0 { return nil, ErrEmptyHostAddress @@ -73,6 +45,15 @@ var DefaultClient = http.Client{ var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { + if defaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) + defer cancel() + } else if !defaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) + defer cancel() + } conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue diff --git a/http2/http2.go b/http2/http2.go index e5d9621..c6e738b 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -10,12 +10,10 @@ import ( "net/url" "time" - "github.com/FloatTech/ttl" "golang.org/x/net/http2" "github.com/fumiama/terasu" "github.com/fumiama/terasu/dns" - "github.com/fumiama/terasu/ip" ) var ( @@ -23,48 +21,23 @@ var ( ) var defaultDialer = net.Dialer{ - Timeout: time.Minute, + Timeout: 10 * time.Second, } func SetDefaultClientTimeout(t time.Duration) { defaultDialer.Timeout = t } -var lookupTable = ttl.NewCache[string, []string](time.Hour) - var DefaultClient = http.Client{ Transport: &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) - defer cancel() - } - - if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) - defer cancel() - } - host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } - addrs := lookupTable.Get(host) - if len(addrs) == 0 { - addrs, err = dns.DefaultResolver.LookupHost(ctx, host) - if err != nil { - if ip.IsIPv6Available.Get() { - addrs, err = dns.IPv6Servers.LookupHostFallback(ctx, host) - } else { - addrs, err = dns.IPv4Servers.LookupHostFallback(ctx, host) - } - if err != nil { - return nil, err - } - } - lookupTable.Set(host, addrs) + addrs, err := dns.LookupHost(ctx, host) + if err != nil { + return nil, err } if len(addr) == 0 { return nil, ErrEmptyHostAddress @@ -72,6 +45,15 @@ var DefaultClient = http.Client{ var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { + if defaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) + defer cancel() + } else if !defaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) + defer cancel() + } conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue diff --git a/terasu.go b/terasu.go index 54a03a7..c0d4cb0 100644 --- a/terasu.go +++ b/terasu.go @@ -6,7 +6,7 @@ import ( "unsafe" ) -var DefaultFirstFragmentLen uint8 = 4 +var DefaultFirstFragmentLen uint8 = 3 // Use terasu in this TLS conn func Use(conn *tls.Conn) *Conn {