1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-27 16:20:25 +08:00

fix(dns): fast-failed on RST

This commit is contained in:
源文雨
2025-10-03 14:47:40 +08:00
parent 1d9b679c36
commit 9974bdca12
9 changed files with 86 additions and 110 deletions

31
dns/cache.go Normal file
View File

@@ -0,0 +1,31 @@
package dns
import (
"context"
"time"
"github.com/FloatTech/ttl"
"github.com/fumiama/terasu/ip"
)
var lookupTable = ttl.NewCache[string, []string](time.Hour)
// LookupHost use default resolver with its fallback
func LookupHost(ctx context.Context, host string) (addrs []string, err error) {
addrs = lookupTable.Get(host)
if len(addrs) == 0 {
addrs, err = DefaultResolver.LookupHost(ctx, host)
if err != nil {
if ip.IsIPv6Available.Get() {
addrs, err = IPv6Servers.lookupHostDoH(ctx, host)
} else {
addrs, err = IPv4Servers.lookupHostDoH(ctx, host)
}
if err != nil {
return nil, err
}
}
lookupTable.Set(host, addrs)
}
return
}

View File

@@ -14,15 +14,16 @@ import (
) )
var ( var (
// ErrNoDNSAvailable is reported when all servers failed to response
ErrNoDNSAvailable = errors.New("no dns available") ErrNoDNSAvailable = errors.New("no dns available")
) )
var defaultDialer = net.Dialer{ var dnsDialer = net.Dialer{
Timeout: time.Second * 4, Timeout: time.Second * 4,
} }
func SetTimeout(t time.Duration) { func SetTimeout(t time.Duration) {
defaultDialer.Timeout = t dnsDialer.Timeout = t
} }
type dnsstat struct { type dnsstat struct {
@@ -88,7 +89,7 @@ func (ds *DNSList) Add(c *DNSConfig) {
} }
} }
func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]string, error) { func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) ([]string, error) {
ds.RLock() ds.RLock()
defer ds.RUnlock() defer ds.RUnlock()
// try to use DoH first // try to use DoH first
@@ -97,7 +98,7 @@ func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]strin
if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH
continue continue
} }
jr, err := lookupdoh(addr.a, host) jr, err := lookupdoh(ctx, addr.a, host)
if err == nil { if err == nil {
hosts := jr.hosts() hosts := jr.hosts()
if len(hosts) > 0 { if len(hosts) > 0 {
@@ -110,37 +111,34 @@ func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]strin
if addrs, ok := ds.b[host]; ok { if addrs, ok := ds.b[host]; ok {
return addrs, nil return addrs, nil
} }
return net.DefaultResolver.LookupHost(ctx, host) return nil, ErrNoDNSAvailable
} }
func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFragmentLen uint8) (tlsConn *tls.Conn, err error) { func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFragmentLen uint8) (tlsConn *tls.Conn, err error) {
err = ErrNoDNSAvailable err = ErrNoDNSAvailable
if dialer == nil { if dialer == nil {
dialer = &defaultDialer dialer = &dnsDialer
} }
ds.RLock() ds.RLock()
defer ds.RUnlock() defer ds.RUnlock()
if dialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, dialer.Timeout)
defer cancel()
}
if !dialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, dialer.Deadline)
defer cancel()
}
var conn net.Conn var conn net.Conn
for host, addrs := range ds.m { for host, addrs := range ds.m {
for _, addr := range addrs { for _, addr := range addrs {
if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH
continue continue
} }
if dialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout)
defer cancel()
} else if !dialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline)
defer cancel()
}
conn, err = dialer.DialContext(ctx, "tcp", addr.a) conn, err = dialer.DialContext(ctx, "tcp", addr.a)
if err != nil { if err != nil {
addr.e = false // no need to acquire write lock addr.e = false // no need to acquire write lock

View File

@@ -28,7 +28,7 @@ func TestResolverFallback(t *testing.T) {
t.Log("IsIPv6Available:", ip.IsIPv6Available.Get()) t.Log("IsIPv6Available:", ip.IsIPv6Available.Get())
if ip.IsIPv6Available.Get() { if ip.IsIPv6Available.Get() {
addrs, err := IPv6Servers.LookupHostFallback(context.TODO(), "huggingface.co") addrs, err := IPv6Servers.lookupHostDoH(context.TODO(), "huggingface.co")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -37,7 +37,7 @@ func TestResolverFallback(t *testing.T) {
t.Fail() t.Fail()
} }
} }
addrs, err := IPv4Servers.LookupHostFallback(context.TODO(), "huggingface.co") addrs, err := IPv4Servers.lookupHostDoH(context.TODO(), "huggingface.co")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -10,9 +10,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/FloatTech/ttl"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/fumiama/terasu" "github.com/fumiama/terasu"
@@ -65,23 +63,9 @@ func (jr *dohjsonresponse) hosts() []string {
return hosts return hosts
} }
var lookupTable = ttl.NewCache[string, []string](time.Hour)
var trsHTTP2ClientWithSystemDNS = http.Client{ var trsHTTP2ClientWithSystemDNS = http.Client{
Transport: &http2.Transport{ Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout)
defer cancel()
}
if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline)
defer cancel()
}
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -100,7 +84,7 @@ var trsHTTP2ClientWithSystemDNS = http.Client{
var conn net.Conn var conn net.Conn
var tlsConn *tls.Conn var tlsConn *tls.Conn
for _, a := range addrs { for _, a := range addrs {
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil { if err != nil {
continue continue
} }
@@ -111,7 +95,7 @@ var trsHTTP2ClientWithSystemDNS = http.Client{
} }
_ = tlsConn.Close() _ = tlsConn.Close()
tlsConn = nil tlsConn = nil
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil { if err != nil {
continue continue
} }
@@ -128,18 +112,18 @@ var trsHTTP2ClientWithSystemDNS = http.Client{
}, },
} }
func lookupdoh(server, u string) (jr dohjsonresponse, err error) { func lookupdoh(ctx context.Context, server, u string) (jr dohjsonresponse, err error) {
jr, err = lookupdohwithtype(server, u, preferreddohtype()) jr, err = lookupdohwithtype(ctx, server, u, preferreddohtype())
if err == nil { if err == nil {
return return
} }
if ip.IsIPv6Available.Get() { if ip.IsIPv6Available.Get() {
jr, err = lookupdohwithtype(server, u, recordTypeA) jr, err = lookupdohwithtype(ctx, server, u, recordTypeA)
} }
return return
} }
func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, err error) { func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (jr dohjsonresponse, err error) {
sb := strings.Builder{} sb := strings.Builder{}
sb.WriteString(server) sb.WriteString(server)
sb.WriteString("?name=") sb.WriteString("?name=")
@@ -148,7 +132,7 @@ func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, er
sb.WriteString("&type=") sb.WriteString("&type=")
sb.WriteString(strconv.Itoa(int(typ))) sb.WriteString(strconv.Itoa(int(typ)))
} }
req, err := http.NewRequest("GET", sb.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", sb.String(), nil)
if err != nil { if err != nil {
return return
} }

2
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/fumiama/terasu
go 1.20 go 1.20
require ( require (
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7
golang.org/x/net v0.24.0 golang.org/x/net v0.24.0
) )

4
go.sum
View File

@@ -1,5 +1,5 @@
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 h1:g4pTnDJUW4VbJ9NvoRfUvdjDrHz/6QhfN/LoIIpICbo= github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d h1:mUQ/c3wXKsUGa4Sg9DBy01APXKB68PmobhxOyaJI7lY=
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs=
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU=
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

View File

@@ -10,11 +10,8 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/FloatTech/ttl"
"github.com/fumiama/terasu" "github.com/fumiama/terasu"
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
"github.com/fumiama/terasu/ip"
) )
var ( var (
@@ -23,49 +20,24 @@ var (
) )
var defaultDialer = net.Dialer{ var defaultDialer = net.Dialer{
Timeout: time.Minute, Timeout: 10 * time.Second,
} }
func SetDefaultClientTimeout(t time.Duration) { func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t defaultDialer.Timeout = t
} }
var lookupTable = ttl.NewCache[string, []string](time.Hour)
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout)
defer cancel()
}
if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline)
defer cancel()
}
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
addrs := lookupTable.Get(host) addrs, err := dns.LookupHost(ctx, host)
if len(addrs) == 0 { if err != nil {
addrs, err = dns.DefaultResolver.LookupHost(ctx, host) return nil, err
if err != nil {
if ip.IsIPv6Available.Get() {
addrs, err = dns.IPv6Servers.LookupHostFallback(ctx, host)
} else {
addrs, err = dns.IPv4Servers.LookupHostFallback(ctx, host)
}
if err != nil {
return nil, err
}
}
lookupTable.Set(host, addrs)
} }
if len(addr) == 0 { if len(addr) == 0 {
return nil, ErrEmptyHostAddress return nil, ErrEmptyHostAddress
@@ -73,6 +45,15 @@ var DefaultClient = http.Client{
var conn net.Conn var conn net.Conn
var tlsConn *tls.Conn var tlsConn *tls.Conn
for _, a := range addrs { for _, a := range addrs {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil { if err != nil {
continue continue

View File

@@ -10,12 +10,10 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/FloatTech/ttl"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/fumiama/terasu" "github.com/fumiama/terasu"
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
"github.com/fumiama/terasu/ip"
) )
var ( var (
@@ -23,48 +21,23 @@ var (
) )
var defaultDialer = net.Dialer{ var defaultDialer = net.Dialer{
Timeout: time.Minute, Timeout: 10 * time.Second,
} }
func SetDefaultClientTimeout(t time.Duration) { func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t defaultDialer.Timeout = t
} }
var lookupTable = ttl.NewCache[string, []string](time.Hour)
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http2.Transport{ Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout)
defer cancel()
}
if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline)
defer cancel()
}
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
addrs := lookupTable.Get(host) addrs, err := dns.LookupHost(ctx, host)
if len(addrs) == 0 { if err != nil {
addrs, err = dns.DefaultResolver.LookupHost(ctx, host) return nil, err
if err != nil {
if ip.IsIPv6Available.Get() {
addrs, err = dns.IPv6Servers.LookupHostFallback(ctx, host)
} else {
addrs, err = dns.IPv4Servers.LookupHostFallback(ctx, host)
}
if err != nil {
return nil, err
}
}
lookupTable.Set(host, addrs)
} }
if len(addr) == 0 { if len(addr) == 0 {
return nil, ErrEmptyHostAddress return nil, ErrEmptyHostAddress
@@ -72,6 +45,15 @@ var DefaultClient = http.Client{
var conn net.Conn var conn net.Conn
var tlsConn *tls.Conn var tlsConn *tls.Conn
for _, a := range addrs { for _, a := range addrs {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil { if err != nil {
continue continue

View File

@@ -6,7 +6,7 @@ import (
"unsafe" "unsafe"
) )
var DefaultFirstFragmentLen uint8 = 4 var DefaultFirstFragmentLen uint8 = 3
// Use terasu in this TLS conn // Use terasu in this TLS conn
func Use(conn *tls.Conn) *Conn { func Use(conn *tls.Conn) *Conn {