From c887e26289a861c9c4a1ac9bc4c01cb91e9917d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 2 May 2024 18:19:19 +0900 Subject: [PATCH] feat: add DoH fallback supports --- dns/dns.go | 33 ++++++++- dns/dns_test.go | 23 +++++++ dns/doh.go | 180 ++++++++++++++++++++++++++++++++++++++++++++++++ http/http.go | 29 ++++++-- http2/http2.go | 27 ++++++-- 5 files changed, 279 insertions(+), 13 deletions(-) create mode 100644 dns/doh.go diff --git a/dns/dns.go b/dns/dns.go index f7ced5d..c7ee0c0 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "net" + "strings" "sync" "time" @@ -86,6 +87,22 @@ func (ds *DNSList) Add(c *DNSConfig) { func (ds *DNSList) LookupHostFallback(ctx context.Context, host string) ([]string, error) { ds.RLock() defer ds.RUnlock() + // try to use DoH first + for _, addrs := range ds.m { + for _, addr := range addrs { + if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH + continue + } + jr, err := lookupdoh(addr.a, host) + if err == nil { + hosts := jr.hosts() + if len(hosts) > 0 { + return hosts, nil + } + } + addr.e = false // no need to acquire write lock + } + } if addrs, ok := ds.b[host]; ok { return addrs, nil } @@ -117,7 +134,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra var conn net.Conn for host, addrs := range ds.m { for _, addr := range addrs { - if !addr.e { + if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH continue } conn, err = dialer.DialContext(ctx, "tcp", addr.a) @@ -142,14 +159,21 @@ var IPv6Servers = DNSList{ "dot.sb": { {"[2a09::]:853", true}, {"[2a11::]:853", true}, + {"https://doh.sb/dns-query", true}, }, "dns.google": { {"[2001:4860:4860::8888]:853", true}, {"[2001:4860:4860::8844]:853", true}, + {"https://dns.google/resolve", true}, + {"https://[2001:4860:4860::8888]/resolve", true}, + {"https://[2001:4860:4860::8844]/resolve", true}, }, "cloudflare-dns.com": { {"[2606:4700:4700::1111]:853", true}, {"[2606:4700:4700::1001]:853", true}, + {"https://cloudflare-dns.com/dns-query", true}, + {"https://[2606:4700:4700::1111]/dns-query", true}, + {"https://[2606:4700:4700::1001]/dns-query", true}, }, "dns.opendns.com": { {"[2620:119:35::35]:853", true}, @@ -168,14 +192,21 @@ var IPv4Servers = DNSList{ "dot.sb": { {"185.222.222.222:853", true}, {"45.11.45.11:853", true}, + {"https://doh.sb/dns-query", true}, }, "dns.google": { {"8.8.8.8:853", true}, {"8.8.4.4:853", true}, + {"https://dns.google/resolve", true}, + {"https://8.8.8.8/resolve", true}, + {"https://8.8.4.4/resolve", true}, }, "cloudflare-dns.com": { {"1.1.1.1:853", true}, {"1.0.0.1:853", true}, + {"https://cloudflare-dns.com/dns-query", true}, + {"https://1.1.1.1/dns-query", true}, + {"https://1.0.0.1/dns-query", true}, }, "dns.opendns.com": { {"208.67.222.222:853", true}, diff --git a/dns/dns_test.go b/dns/dns_test.go index 11ba18c..186136a 100644 --- a/dns/dns_test.go +++ b/dns/dns_test.go @@ -24,6 +24,29 @@ func TestResolver(t *testing.T) { } } +func TestResolverFallback(t *testing.T) { + t.Log("IsIPv6Available:", ip.IsIPv6Available.Get()) + + if ip.IsIPv6Available.Get() { + addrs, err := IPv6Servers.LookupHostFallback(context.TODO(), "huggingface.co") + if err != nil { + t.Fatal(err) + } + t.Log(addrs) + if len(addrs) == 0 { + t.Fail() + } + } + addrs, err := IPv4Servers.LookupHostFallback(context.TODO(), "huggingface.co") + if err != nil { + t.Fatal(err) + } + t.Log(addrs) + if len(addrs) == 0 { + t.Fail() + } +} + func TestDNS(t *testing.T) { if ip.IsIPv6Available.Get() { IPv6Servers.test() diff --git a/dns/doh.go b/dns/doh.go new file mode 100644 index 0000000..698a9b4 --- /dev/null +++ b/dns/doh.go @@ -0,0 +1,180 @@ +package dns + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/FloatTech/ttl" + "golang.org/x/net/http2" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/ip" +) + +var ( + ErrEmptyHostAddress = errors.New("empty host addr") +) + +type recordType uint16 + +const ( + recordTypeNone recordType = 0 + recordTypeA recordType = 1 + recordTypeAAAA recordType = 28 +) + +type dohjsonresponse struct { + Status uint32 + TC bool + RD bool + RA bool + AD bool + CD bool + Question []struct { + Name string `json:"name"` + Type recordType `json:"type"` + } + Answer []struct { + Name string `json:"name"` + Type recordType `json:"type"` + TTL uint16 + Data string `json:"data"` + } + EdnsClientSubnet string `json:"edns_client_subnet"` + Comment string +} + +func (jr *dohjsonresponse) hosts() []string { + if len(jr.Answer) == 0 { + return nil + } + hosts := make([]string, 0, len(jr.Answer)) + for _, ans := range jr.Answer { + if ans.Type == recordTypeA || ans.Type == recordTypeAAAA { + hosts = append(hosts, ans.Data) + } + } + return hosts +} + +var defaultDialer = net.Dialer{ + Timeout: time.Second * 4, +} + +var lookupTable = ttl.NewCache[string, []string](time.Hour) + +var trsHTTP2ClientWithSystemDNS = http.Client{ + Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + if defaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) + defer cancel() + } + + if !defaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) + defer cancel() + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + addrs := lookupTable.Get(host) + if len(addrs) == 0 { + addrs, err = net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + lookupTable.Set(host, addrs) + } + if len(addr) == 0 { + return nil, ErrEmptyHostAddress + } + var conn net.Conn + var tlsConn *tls.Conn + for _, a := range addrs { + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, cfg) + err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, cfg) + err = tlsConn.HandshakeContext(ctx) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + } + return tlsConn, err + }, + }, +} + +func lookupdoh(server, u string) (jr dohjsonresponse, err error) { + jr, err = lookupdohwithtype(server, u, preferreddohtype()) + if err == nil { + return + } + if ip.IsIPv6Available.Get() { + jr, err = lookupdohwithtype(server, u, recordTypeA) + } + return +} + +func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, err error) { + sb := strings.Builder{} + sb.WriteString(server) + sb.WriteString("?name=") + sb.WriteString(url.QueryEscape(u)) + if typ != recordTypeNone { + sb.WriteString("&type=") + sb.WriteString(strconv.Itoa(int(typ))) + } + req, err := http.NewRequest("GET", sb.String(), nil) + if err != nil { + return + } + req.Header.Add("accept", "application/dns-json") + resp, err := trsHTTP2ClientWithSystemDNS.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + err = json.NewDecoder(resp.Body).Decode(&jr) + if err != nil { + return + } + if jr.Status != 0 { + err = errors.New("comment: " + jr.Comment) + } + return +} + +func preferreddohtype() recordType { + if ip.IsIPv6Available.Get() { + return recordTypeAAAA + } + return recordTypeA +} diff --git a/http/http.go b/http/http.go index 0e28f38..58ca5b6 100644 --- a/http/http.go +++ b/http/http.go @@ -22,24 +22,28 @@ var ( ErrEmptyHostAddress = errors.New("empty host addr") ) -var DefaultDialer = net.Dialer{ +var defaultDialer = net.Dialer{ Timeout: time.Minute, } +func SetDefaultClientTimeout(t time.Duration) { + defaultDialer.Timeout = t +} + var lookupTable = ttl.NewCache[string, []string](time.Hour) var DefaultClient = http.Client{ Transport: &http.Transport{ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if DefaultDialer.Timeout != 0 { + if defaultDialer.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, DefaultDialer.Timeout) + ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) defer cancel() } - if !DefaultDialer.Deadline.IsZero() { + if !defaultDialer.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, DefaultDialer.Deadline) + ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) defer cancel() } @@ -68,7 +72,7 @@ var DefaultClient = http.Client{ var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { - conn, err = DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue } @@ -81,6 +85,19 @@ var DefaultClient = http.Client{ } _ = tlsConn.Close() tlsConn = nil + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, &tls.Config{ + ServerName: host, + }) + err = tlsConn.HandshakeContext(ctx) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil } return tlsConn, err }, diff --git a/http2/http2.go b/http2/http2.go index 25d4f98..e5d9621 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -22,24 +22,28 @@ var ( ErrEmptyHostAddress = errors.New("empty host addr") ) -var DefaultDialer = net.Dialer{ +var defaultDialer = net.Dialer{ Timeout: time.Minute, } +func SetDefaultClientTimeout(t time.Duration) { + defaultDialer.Timeout = t +} + var lookupTable = ttl.NewCache[string, []string](time.Hour) var DefaultClient = http.Client{ Transport: &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - if DefaultDialer.Timeout != 0 { + if defaultDialer.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, DefaultDialer.Timeout) + ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout) defer cancel() } - if !DefaultDialer.Deadline.IsZero() { + if !defaultDialer.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, DefaultDialer.Deadline) + ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline) defer cancel() } @@ -68,7 +72,7 @@ var DefaultClient = http.Client{ var conn net.Conn var tlsConn *tls.Conn for _, a := range addrs { - conn, err = DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) if err != nil { continue } @@ -79,6 +83,17 @@ var DefaultClient = http.Client{ } _ = tlsConn.Close() tlsConn = nil + conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(conn, cfg) + err = tlsConn.HandshakeContext(ctx) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil } return tlsConn, err },