diff --git a/README.md b/README.md index 5f25408..2c4d579 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,14 @@ tls.Client(terasu.NewConn(conn), &tls.Config{ ServerName: host, }).Handshake() ``` + +## Custom Plugin (Linux Only) + +Custom plugin code is located in the `ext/custom` directory. You can write and build your own plugin. + +1. Write your plugin code in the `ext/custom` directory +2. Build the plugin: + +```bash +GOOS=linux go build -o terasu.plugin.so -buildmode=plugin -ldflags="-s -w" -trimpath ./ext/custom +``` diff --git a/cmd/main.go b/cmd/main.go index fe0ebb8..9d46e1e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,6 +8,7 @@ import ( "os" "strings" + _ "github.com/fumiama/terasu/ext" "github.com/fumiama/terasu/http2" "github.com/sirupsen/logrus" ) diff --git a/conn_test.go b/conn_test.go index 3108eed..1613685 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,17 +5,16 @@ import ( "io" "net" "net/http" - "net/netip" "testing" + + "github.com/fumiama/terasu/dialer" ) func TestHTTPDialDifferentFragLen(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( - netip.MustParseAddrPort("52.222.136.117:443"), - )) + conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443") if err != nil { return nil, err } @@ -57,9 +56,7 @@ func TestHTTPDialTLS13(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( - netip.MustParseAddrPort("52.222.136.117:443"), - )) + conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443") if err != nil { return nil, err } @@ -97,9 +94,7 @@ func TestHTTPDialTLS12(t *testing.T) { cli := http.Client{ Transport: &http.Transport{ DialTLS: func(network, addr string) (net.Conn, error) { - conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( - netip.MustParseAddrPort("52.222.136.117:443"), - )) + conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443") if err != nil { return nil, err } diff --git a/dialer/dialer.go b/dialer/dialer.go new file mode 100644 index 0000000..7744e4b --- /dev/null +++ b/dialer/dialer.go @@ -0,0 +1,22 @@ +package dialer + +import ( + "net" + "syscall" + "time" +) + +// DefaultDialer is the default dialer used for establishing TCP connections +var DefaultDialer = net.Dialer{ + Timeout: 4 * time.Second, +} + +// SetDefaultTimeout sets the default timeout for all HTTP2 client connections +func SetDefaultTimeout(t time.Duration) { + DefaultDialer.Timeout = t +} + +// SetDefaultControl sets control of the default dailer +func SetDefaultControl(c func(network string, address string, c syscall.RawConn) error) { + DefaultDialer.Control = c +} diff --git a/dns/dns.go b/dns/dns.go index dcc5f5c..9725da9 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -13,6 +13,8 @@ import ( "time" "github.com/fumiama/terasu" + "github.com/fumiama/terasu/dialer" + "github.com/fumiama/terasu/doh" "github.com/fumiama/terasu/ip" "github.com/sirupsen/logrus" ) @@ -24,15 +26,6 @@ var ( ErrSuccess = errors.New("success") ) -var dnsDialer = net.Dialer{ - Timeout: time.Second * 4, -} - -// SetTimeout ... -func SetTimeout(t time.Duration) { - dnsDialer.Timeout = t -} - type dnsstat struct { addr string en bool @@ -173,9 +166,9 @@ func (ds *List) lookupHostDoH(ctx context.Context, host string) (hosts []string, if !addr.enabled() || !addr.ishttps() { // disabled or is not DoH continue } - jr, err := lookupdoh(ctx, addr.addr, host) + jr, err := doh.LookupDoH(ctx, addr.addr, host) if err == nil { - hosts = jr.hosts() + hosts = jr.Hosts() if len(hosts) > 0 { // this is a successful server, keep it addr.keepit() @@ -203,11 +196,11 @@ func (ds *List) lookupHostDoH(ctx context.Context, host string) (hosts []string, } // DialContext ... -func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) { +func (ds *List) DialContext(ctx context.Context, d *net.Dialer) (tlsConn *tls.Conn, err error) { err = ErrNoDNSAvailable - if dialer == nil { - dialer = &dnsDialer + if d == nil { + d = &dialer.DefaultDialer } ds.RLock() @@ -220,16 +213,16 @@ func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *t continue } logrus.Debugln("[terasu.dns] -> dial", host, addr) - if dialer.Timeout != 0 { + if d.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout) + ctx, cancel = context.WithTimeout(context.Background(), d.Timeout) defer cancel() - } else if !dialer.Deadline.IsZero() { + } else if !d.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) + ctx, cancel = context.WithDeadline(context.Background(), d.Deadline) defer cancel() } - conn, err = dialer.DialContext(ctx, "tcp", addr.addr) + conn, err = d.DialContext(ctx, "tcp", addr.addr) if err != nil { logrus.Debugln("[terasu.dns] -- dial tcp", host, addr, "err:", err) if !errors.Is(err, context.Canceled) && @@ -247,13 +240,13 @@ func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *t NextProtos: []string{"dns"}, }) // re-init ctx due to deadline settings in tcp dial - if dialer.Timeout != 0 { + if d.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout) + ctx, cancel = context.WithTimeout(context.Background(), d.Timeout) defer cancel() - } else if !dialer.Deadline.IsZero() { + } else if !d.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) + ctx, cancel = context.WithDeadline(context.Background(), d.Deadline) defer cancel() } err = tlsConn.HandshakeContext(ctx) diff --git a/dns/dns_test.go b/dns/dns_test.go index 1a43935..48e80d6 100644 --- a/dns/dns_test.go +++ b/dns/dns_test.go @@ -4,11 +4,11 @@ import ( "context" "crypto/tls" "fmt" - "net" "testing" "time" "github.com/fumiama/terasu" + "github.com/fumiama/terasu/dialer" "github.com/fumiama/terasu/ip" ) @@ -110,7 +110,7 @@ func (ds *List) test() { continue } fmt.Println("dial:", host, addr.addr) - conn, err := net.Dial("tcp", addr.addr) + conn, err := dialer.DefaultDialer.Dial("tcp", addr.addr) if err != nil { continue } diff --git a/dns/tls.go b/dns/tls.go new file mode 100644 index 0000000..d7f1244 --- /dev/null +++ b/dns/tls.go @@ -0,0 +1,19 @@ +package dns + +import ( + "context" + "crypto/tls" + "net" + + mtls "github.com/fumiama/terasu/tls" +) + +// DialTLSContext fills http.Transport method with terasu and DNS +func DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + return DialTLSContextWithConfig(ctx, network, addr, nil) +} + +// DialTLSContextWithConfig fills http2.Transport method with terasu and DNS +func DialTLSContextWithConfig(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return mtls.DialTLSContextCL(ctx, network, addr, cfg, nil) +} diff --git a/dns/doh.go b/doh/doh.go similarity index 53% rename from dns/doh.go rename to doh/doh.go index 86325f4..fbed697 100644 --- a/dns/doh.go +++ b/doh/doh.go @@ -1,11 +1,9 @@ -package dns +package doh import ( "context" - "crypto/tls" "encoding/json" "errors" - "net" "net/http" "net/url" "strconv" @@ -13,26 +11,22 @@ import ( "golang.org/x/net/http2" - "github.com/fumiama/terasu" "github.com/fumiama/terasu/ip" + "github.com/fumiama/terasu/tls" ) -var ( - // ErrEmptyHostAddress ... - ErrEmptyHostAddress = errors.New("empty host addr") -) - -type recordType uint16 +// RecordType ... +type RecordType uint16 const ( - recordTypeNone recordType = 0 - recordTypeA recordType = 1 - recordTypeAAAA recordType = 28 + RecordTypeNone RecordType = 0 // RecordTypeNone ... + RecordTypeA RecordType = 1 // RecordTypeA IPv4 + RecordTypeAAAA RecordType = 28 // RecordTypeAAAA IPv6 ) -// dohjsonresponse represents the JSON response structure for DNS over HTTPS (DoH) queries. +// Response represents the JSON response structure for DNS over HTTPS (DoH) queries. // It contains DNS query results and metadata about the response. -type dohjsonresponse struct { +type Response struct { // Status indicates the DNS query status code (0 = NOERROR, etc.) Status uint32 // TC indicates whether the response was truncated (true if truncated) @@ -50,14 +44,14 @@ type dohjsonresponse struct { // Name is the domain name being queried Name string `json:"name"` // Type is the DNS record type being requested (A, AAAA, etc.) - Type recordType `json:"type"` + Type RecordType `json:"type"` } // Answer contains the DNS response answer section with resource records Answer []struct { // Name is the domain name for this resource record Name string `json:"name"` // Type is the DNS record type (A, AAAA, etc.) - Type recordType `json:"type"` + Type RecordType `json:"type"` // TTL is the time-to-live value for this resource record in seconds TTL uint16 // Data is the textual representation of the resource record data @@ -69,13 +63,13 @@ type dohjsonresponse struct { Comment string } -func (jr *dohjsonresponse) hosts() []string { +func (jr *Response) Hosts() []string { if len(jr.Answer) == 0 { return nil } hosts := make([]string, 0, len(jr.Answer)) for _, ans := range jr.Answer { - if ans.Type == recordTypeA || ans.Type == recordTypeAAAA { + if ans.Type == RecordTypeA || ans.Type == RecordTypeAAAA { hosts = append(hosts, ans.Data) } } @@ -84,70 +78,29 @@ func (jr *dohjsonresponse) hosts() []string { var trsHTTP2ClientWithSystemDNS = http.Client{ Transport: &http2.Transport{ - DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - addrs := lookupTable.Get(host) - if len(addrs) == 0 { - addrs, err = net.DefaultResolver.LookupHost(ctx, host) - if err != nil { - return nil, err - } - lookupTable.Set(host, addrs) - } - if len(addr) == 0 { - return nil, ErrEmptyHostAddress - } - var conn net.Conn - var tlsConn *tls.Conn - for _, a := range addrs { - conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), cfg) - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), cfg) - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - } - return tlsConn, err - }, + DialTLSContext: tls.DialTLSContextWithConfigAndSystemResolver, }, } -func lookupdoh(ctx context.Context, server, u string) (jr dohjsonresponse, err error) { - jr, err = lookupdohwithtype(ctx, server, u, preferreddohtype()) +// LookupDoH lookup uname's ip from server +func LookupDoH(ctx context.Context, server, name string) (jr Response, err error) { + jr, err = LookupDoHWithType(ctx, server, name, prefertyp()) if err == nil { return } if ip.IsIPv6Available { - jr, err = lookupdohwithtype(ctx, server, u, recordTypeA) + jr, err = LookupDoHWithType(ctx, server, name, RecordTypeA) } return } -func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (jr dohjsonresponse, err error) { +// LookupDoHWithType ... +func LookupDoHWithType(ctx context.Context, server, name string, typ RecordType) (jr Response, err error) { sb := strings.Builder{} sb.WriteString(server) sb.WriteString("?name=") - sb.WriteString(url.QueryEscape(u)) - if typ != recordTypeNone { + sb.WriteString(url.QueryEscape(name)) + if typ != RecordTypeNone { sb.WriteString("&type=") sb.WriteString(strconv.Itoa(int(typ))) } @@ -171,9 +124,9 @@ func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (j return } -func preferreddohtype() recordType { +func prefertyp() RecordType { if ip.IsIPv6Available { - return recordTypeAAAA + return RecordTypeAAAA } - return recordTypeA + return RecordTypeA } diff --git a/ext/custom/.gitignore b/ext/custom/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/ext/custom/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/ext/init.go b/ext/init.go new file mode 100644 index 0000000..09ebf10 --- /dev/null +++ b/ext/init.go @@ -0,0 +1,24 @@ +package ext + +import ( + "os" + "plugin" + + "github.com/sirupsen/logrus" +) + +const ( + TRSPluginFile = "./terasu.plugin.so" +) + +func init() { + if _, err := os.Stat(TRSPluginFile); err != nil { + return + } + _, err := plugin.Open(TRSPluginFile) + if err != nil { + logrus.Warnln("[terasu.plugin] load", TRSPluginFile, "err:", err) + logrus.Warnln("[terasu.plugin] hint: ensure the main binary and plugin are built with identical flags (e.g. both use -trimpath -ldflags=\"-s -w\"), and avoid using 'go run'") + return + } +} diff --git a/http/http.go b/http/http.go index 433bdba..b7dce31 100644 --- a/http/http.go +++ b/http/http.go @@ -2,106 +2,19 @@ package http import ( - "context" - "crypto/tls" - "errors" "io" - "net" "net/http" "net/url" "time" - "github.com/fumiama/terasu" "github.com/fumiama/terasu/dns" ) -var ( - // ErrNoTLSConnection is returned when a TLS connection cannot be established. - ErrNoTLSConnection = errors.New("no tls connection") - // ErrEmptyHostAddress is returned when the host address is empty. - ErrEmptyHostAddress = errors.New("empty host addr") -) - -// defaultDialer is the default dialer used for connecting to hosts. -var defaultDialer = net.Dialer{ - Timeout: 10 * time.Second, -} - -// SetDefaultClientTimeout sets the default timeout for the client's dialer. -func SetDefaultClientTimeout(t time.Duration) { - defaultDialer.Timeout = t -} - // DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling. var DefaultClient = http.Client{ Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - addrs, err := dns.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if len(addrs) == 0 { - return nil, ErrEmptyHostAddress - } - var conn net.Conn - var tlsConn *tls.Conn - for _, a := range addrs { - // Apply timeout if set, otherwise use deadline - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) - defer cancel() - } else if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) - defer cancel() - } - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{ - ServerName: host, - MinVersion: tls.VersionTLS12, - }) - // Re-initialize context due to potential deadline changes from TCP dial - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) - defer cancel() - } else if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) - defer cancel() - } - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{ - ServerName: host, - MinVersion: tls.VersionTLS12, - }) - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - } - return tlsConn, err - }, + Proxy: http.ProxyFromEnvironment, + DialTLSContext: dns.DialTLSContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, diff --git a/http2/http.go b/http2/http.go index 5550a31..fa80969 100644 --- a/http2/http.go +++ b/http2/http.go @@ -2,96 +2,18 @@ package http2 import ( - "context" - "crypto/tls" - "errors" "io" - "net" "net/http" "net/url" - "time" - "golang.org/x/net/http2" - - "github.com/fumiama/terasu" "github.com/fumiama/terasu/dns" + "golang.org/x/net/http2" ) -// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses -var ErrEmptyHostAddress = errors.New("empty host addr") - -// defaultDialer is the default dialer used for establishing TCP connections -var defaultDialer = net.Dialer{ - Timeout: 10 * time.Second, -} - -// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections -func SetDefaultClientTimeout(t time.Duration) { - defaultDialer.Timeout = t -} - // DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution var DefaultClient = http.Client{ Transport: &http2.Transport{ - DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - addrs, err := dns.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if len(addrs) == 0 { - return nil, ErrEmptyHostAddress - } - var conn net.Conn - var tlsConn *tls.Conn - for _, a := range addrs { - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) - defer cancel() - } else if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) - defer cancel() - } - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), cfg) - // re-init ctx due to deadline settings in tcp dial - if defaultDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) - defer cancel() - } else if !defaultDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) - defer cancel() - } - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) - if err != nil { - continue - } - tlsConn = tls.Client(terasu.NewConn(conn), cfg) - err = tlsConn.HandshakeContext(ctx) - if err == nil { - break - } - _ = tlsConn.Close() - tlsConn = nil - } - return tlsConn, err - }, + DialTLSContext: dns.DialTLSContextWithConfig, }, } diff --git a/http3/http.go b/http3/http.go index 014d652..503dd87 100644 --- a/http3/http.go +++ b/http3/http.go @@ -13,9 +13,9 @@ import ( "net/http" "net/netip" "net/url" - "time" base14 "github.com/fumiama/go-base16384" + "github.com/fumiama/terasu/dialer" "github.com/fumiama/terasu/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -24,16 +24,6 @@ import ( // ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses var ErrEmptyHostAddress = errors.New("empty host addr") -// defaultDialer is the default dialer used for establishing TCP connections -var defaultDialer = net.Dialer{ - Timeout: 10 * time.Second, -} - -// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections -func SetDefaultClientTimeout(t time.Duration) { - defaultDialer.Timeout = t -} - // DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution var DefaultClient = http.Client{ Transport: &http3.Transport{ @@ -52,13 +42,13 @@ var DefaultClient = http.Client{ var conn net.Conn var qConn quic.EarlyConnection for _, a := range addrs { - if defaultDialer.Timeout != 0 { + if dialer.DefaultDialer.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) + ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout) defer cancel() - } else if !defaultDialer.Deadline.IsZero() { + } else if !dialer.DefaultDialer.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) + ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline) defer cancel() } conn, err = net.ListenUDP("udp", nil) @@ -76,13 +66,13 @@ var DefaultClient = http.Client{ _ = e.Close() _, _ = ucon.WriteToUDP(w.Bytes(), raddr) // re-init ctx due to deadline settings in tcp dial - if defaultDialer.Timeout != 0 { + if dialer.DefaultDialer.Timeout != 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) + ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout) defer cancel() - } else if !defaultDialer.Deadline.IsZero() { + } else if !dialer.DefaultDialer.Deadline.IsZero() { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) + ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline) defer cancel() } qConn, err = quic.DialEarly(ctx, ucon, raddr, tlsCfg, cfg) diff --git a/tls/dial.go b/tls/dial.go new file mode 100644 index 0000000..24b0c4c --- /dev/null +++ b/tls/dial.go @@ -0,0 +1,94 @@ +package tls + +import ( + "context" + "crypto/tls" + "errors" + "net" + + "github.com/fumiama/terasu" + "github.com/fumiama/terasu/dialer" +) + +// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses +var ErrEmptyHostAddress = errors.New("empty host addr") + +// DialTLSContextWithConfigAndSystemResolver fills http2.Transport method with terasu and system DNS +func DialTLSContextWithConfigAndSystemResolver(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return DialTLSContextCL(ctx, network, addr, cfg, nil) +} + +// DialTLSContextCL fills http2.Transport method with terasu +func DialTLSContextCL( + ctx context.Context, network, addr string, + cfg *tls.Config, lookup func(ctx context.Context, host string, + ) (addrs []string, err error)) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + var addrs []string + if lookup != nil { + addrs, err = lookup(ctx, host) + } else { + addrs, err = net.DefaultResolver.LookupHost(ctx, host) + } + if err != nil { + return nil, err + } + if len(addrs) == 0 { + return nil, ErrEmptyHostAddress + } + if cfg == nil { + cfg = &tls.Config{ + ServerName: host, + MinVersion: tls.VersionTLS12, + } + } + var conn net.Conn + var tlsConn *tls.Conn + for _, a := range addrs { + if dialer.DefaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout) + defer cancel() + } else if !dialer.DefaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline) + defer cancel() + } + conn, err = dialer.DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(terasu.NewConn(conn), cfg) + // re-init ctx due to deadline settings in tcp dial + if dialer.DefaultDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout) + defer cancel() + } else if !dialer.DefaultDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline) + defer cancel() + } + err = tlsConn.HandshakeContext(ctx) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + conn, err = dialer.DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port)) + if err != nil { + continue + } + tlsConn = tls.Client(terasu.NewConn(conn), cfg) + err = tlsConn.HandshakeContext(ctx) + if err == nil { + break + } + _ = tlsConn.Close() + tlsConn = nil + } + return tlsConn, err +}