1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-12 06:00:37 +08:00

feat: add plugin

This commit is contained in:
源文雨
2026-02-16 15:20:45 +08:00
parent f6d5336492
commit 1d573cf2be
14 changed files with 234 additions and 295 deletions

View File

@@ -19,3 +19,14 @@ tls.Client(terasu.NewConn(conn), &tls.Config{
ServerName: host, ServerName: host,
}).Handshake() }).Handshake()
``` ```
## Custom Plugin (Linux Only)
Custom plugin code is located in the `ext/custom` directory. You can write and build your own plugin.
1. Write your plugin code in the `ext/custom` directory
2. Build the plugin:
```bash
GOOS=linux go build -o terasu.plugin.so -buildmode=plugin -ldflags="-s -w" -trimpath ./ext/custom
```

View File

@@ -8,6 +8,7 @@ import (
"os" "os"
"strings" "strings"
_ "github.com/fumiama/terasu/ext"
"github.com/fumiama/terasu/http2" "github.com/fumiama/terasu/http2"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )

View File

@@ -5,17 +5,16 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"testing" "testing"
"github.com/fumiama/terasu/dialer"
) )
func TestHTTPDialDifferentFragLen(t *testing.T) { func TestHTTPDialDifferentFragLen(t *testing.T) {
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) { DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443")
netip.MustParseAddrPort("52.222.136.117:443"),
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -57,9 +56,7 @@ func TestHTTPDialTLS13(t *testing.T) {
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) { DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443")
netip.MustParseAddrPort("52.222.136.117:443"),
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -97,9 +94,7 @@ func TestHTTPDialTLS12(t *testing.T) {
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) { DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort( conn, err := dialer.DefaultDialer.Dial("tcp", "3.164.110.114:443")
netip.MustParseAddrPort("52.222.136.117:443"),
))
if err != nil { if err != nil {
return nil, err return nil, err
} }

22
dialer/dialer.go Normal file
View File

@@ -0,0 +1,22 @@
package dialer
import (
"net"
"syscall"
"time"
)
// DefaultDialer is the default dialer used for establishing TCP connections
var DefaultDialer = net.Dialer{
Timeout: 4 * time.Second,
}
// SetDefaultTimeout sets the default timeout for all HTTP2 client connections
func SetDefaultTimeout(t time.Duration) {
DefaultDialer.Timeout = t
}
// SetDefaultControl sets control of the default dailer
func SetDefaultControl(c func(network string, address string, c syscall.RawConn) error) {
DefaultDialer.Control = c
}

View File

@@ -13,6 +13,8 @@ import (
"time" "time"
"github.com/fumiama/terasu" "github.com/fumiama/terasu"
"github.com/fumiama/terasu/dialer"
"github.com/fumiama/terasu/doh"
"github.com/fumiama/terasu/ip" "github.com/fumiama/terasu/ip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@@ -24,15 +26,6 @@ var (
ErrSuccess = errors.New("success") ErrSuccess = errors.New("success")
) )
var dnsDialer = net.Dialer{
Timeout: time.Second * 4,
}
// SetTimeout ...
func SetTimeout(t time.Duration) {
dnsDialer.Timeout = t
}
type dnsstat struct { type dnsstat struct {
addr string addr string
en bool en bool
@@ -173,9 +166,9 @@ func (ds *List) lookupHostDoH(ctx context.Context, host string) (hosts []string,
if !addr.enabled() || !addr.ishttps() { // disabled or is not DoH if !addr.enabled() || !addr.ishttps() { // disabled or is not DoH
continue continue
} }
jr, err := lookupdoh(ctx, addr.addr, host) jr, err := doh.LookupDoH(ctx, addr.addr, host)
if err == nil { if err == nil {
hosts = jr.hosts() hosts = jr.Hosts()
if len(hosts) > 0 { if len(hosts) > 0 {
// this is a successful server, keep it // this is a successful server, keep it
addr.keepit() addr.keepit()
@@ -203,11 +196,11 @@ func (ds *List) lookupHostDoH(ctx context.Context, host string) (hosts []string,
} }
// DialContext ... // DialContext ...
func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) { func (ds *List) DialContext(ctx context.Context, d *net.Dialer) (tlsConn *tls.Conn, err error) {
err = ErrNoDNSAvailable err = ErrNoDNSAvailable
if dialer == nil { if d == nil {
dialer = &dnsDialer d = &dialer.DefaultDialer
} }
ds.RLock() ds.RLock()
@@ -220,16 +213,16 @@ func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *t
continue continue
} }
logrus.Debugln("[terasu.dns] -> dial", host, addr) logrus.Debugln("[terasu.dns] -> dial", host, addr)
if dialer.Timeout != 0 { if d.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), d.Timeout)
defer cancel() defer cancel()
} else if !dialer.Deadline.IsZero() { } else if !d.Deadline.IsZero() {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) ctx, cancel = context.WithDeadline(context.Background(), d.Deadline)
defer cancel() defer cancel()
} }
conn, err = dialer.DialContext(ctx, "tcp", addr.addr) conn, err = d.DialContext(ctx, "tcp", addr.addr)
if err != nil { if err != nil {
logrus.Debugln("[terasu.dns] -- dial tcp", host, addr, "err:", err) logrus.Debugln("[terasu.dns] -- dial tcp", host, addr, "err:", err)
if !errors.Is(err, context.Canceled) && if !errors.Is(err, context.Canceled) &&
@@ -247,13 +240,13 @@ func (ds *List) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *t
NextProtos: []string{"dns"}, NextProtos: []string{"dns"},
}) })
// re-init ctx due to deadline settings in tcp dial // re-init ctx due to deadline settings in tcp dial
if dialer.Timeout != 0 { if d.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), d.Timeout)
defer cancel() defer cancel()
} else if !dialer.Deadline.IsZero() { } else if !d.Deadline.IsZero() {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), dialer.Deadline) ctx, cancel = context.WithDeadline(context.Background(), d.Deadline)
defer cancel() defer cancel()
} }
err = tlsConn.HandshakeContext(ctx) err = tlsConn.HandshakeContext(ctx)

View File

@@ -4,11 +4,11 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net"
"testing" "testing"
"time" "time"
"github.com/fumiama/terasu" "github.com/fumiama/terasu"
"github.com/fumiama/terasu/dialer"
"github.com/fumiama/terasu/ip" "github.com/fumiama/terasu/ip"
) )
@@ -110,7 +110,7 @@ func (ds *List) test() {
continue continue
} }
fmt.Println("dial:", host, addr.addr) fmt.Println("dial:", host, addr.addr)
conn, err := net.Dial("tcp", addr.addr) conn, err := dialer.DefaultDialer.Dial("tcp", addr.addr)
if err != nil { if err != nil {
continue continue
} }

19
dns/tls.go Normal file
View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
"crypto/tls"
"net"
mtls "github.com/fumiama/terasu/tls"
)
// DialTLSContext fills http.Transport method with terasu and DNS
func DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
return DialTLSContextWithConfig(ctx, network, addr, nil)
}
// DialTLSContextWithConfig fills http2.Transport method with terasu and DNS
func DialTLSContextWithConfig(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
return mtls.DialTLSContextCL(ctx, network, addr, cfg, nil)
}

View File

@@ -1,11 +1,9 @@
package dns package doh
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@@ -13,26 +11,22 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/fumiama/terasu"
"github.com/fumiama/terasu/ip" "github.com/fumiama/terasu/ip"
"github.com/fumiama/terasu/tls"
) )
var ( // RecordType ...
// ErrEmptyHostAddress ... type RecordType uint16
ErrEmptyHostAddress = errors.New("empty host addr")
)
type recordType uint16
const ( const (
recordTypeNone recordType = 0 RecordTypeNone RecordType = 0 // RecordTypeNone ...
recordTypeA recordType = 1 RecordTypeA RecordType = 1 // RecordTypeA IPv4
recordTypeAAAA recordType = 28 RecordTypeAAAA RecordType = 28 // RecordTypeAAAA IPv6
) )
// dohjsonresponse represents the JSON response structure for DNS over HTTPS (DoH) queries. // Response represents the JSON response structure for DNS over HTTPS (DoH) queries.
// It contains DNS query results and metadata about the response. // It contains DNS query results and metadata about the response.
type dohjsonresponse struct { type Response struct {
// Status indicates the DNS query status code (0 = NOERROR, etc.) // Status indicates the DNS query status code (0 = NOERROR, etc.)
Status uint32 Status uint32
// TC indicates whether the response was truncated (true if truncated) // TC indicates whether the response was truncated (true if truncated)
@@ -50,14 +44,14 @@ type dohjsonresponse struct {
// Name is the domain name being queried // Name is the domain name being queried
Name string `json:"name"` Name string `json:"name"`
// Type is the DNS record type being requested (A, AAAA, etc.) // Type is the DNS record type being requested (A, AAAA, etc.)
Type recordType `json:"type"` Type RecordType `json:"type"`
} }
// Answer contains the DNS response answer section with resource records // Answer contains the DNS response answer section with resource records
Answer []struct { Answer []struct {
// Name is the domain name for this resource record // Name is the domain name for this resource record
Name string `json:"name"` Name string `json:"name"`
// Type is the DNS record type (A, AAAA, etc.) // Type is the DNS record type (A, AAAA, etc.)
Type recordType `json:"type"` Type RecordType `json:"type"`
// TTL is the time-to-live value for this resource record in seconds // TTL is the time-to-live value for this resource record in seconds
TTL uint16 TTL uint16
// Data is the textual representation of the resource record data // Data is the textual representation of the resource record data
@@ -69,13 +63,13 @@ type dohjsonresponse struct {
Comment string Comment string
} }
func (jr *dohjsonresponse) hosts() []string { func (jr *Response) Hosts() []string {
if len(jr.Answer) == 0 { if len(jr.Answer) == 0 {
return nil return nil
} }
hosts := make([]string, 0, len(jr.Answer)) hosts := make([]string, 0, len(jr.Answer))
for _, ans := range jr.Answer { for _, ans := range jr.Answer {
if ans.Type == recordTypeA || ans.Type == recordTypeAAAA { if ans.Type == RecordTypeA || ans.Type == RecordTypeAAAA {
hosts = append(hosts, ans.Data) hosts = append(hosts, ans.Data)
} }
} }
@@ -84,70 +78,29 @@ func (jr *dohjsonresponse) hosts() []string {
var trsHTTP2ClientWithSystemDNS = http.Client{ var trsHTTP2ClientWithSystemDNS = http.Client{
Transport: &http2.Transport{ Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { DialTLSContext: tls.DialTLSContextWithConfigAndSystemResolver,
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
addrs := lookupTable.Get(host)
if len(addrs) == 0 {
addrs, err = net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}
lookupTable.Set(host, addrs)
}
if len(addr) == 0 {
return nil, ErrEmptyHostAddress
}
var conn net.Conn
var tlsConn *tls.Conn
for _, a := range addrs {
conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
conn, err = dnsDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
}
return tlsConn, err
},
}, },
} }
func lookupdoh(ctx context.Context, server, u string) (jr dohjsonresponse, err error) { // LookupDoH lookup uname's ip from server
jr, err = lookupdohwithtype(ctx, server, u, preferreddohtype()) func LookupDoH(ctx context.Context, server, name string) (jr Response, err error) {
jr, err = LookupDoHWithType(ctx, server, name, prefertyp())
if err == nil { if err == nil {
return return
} }
if ip.IsIPv6Available { if ip.IsIPv6Available {
jr, err = lookupdohwithtype(ctx, server, u, recordTypeA) jr, err = LookupDoHWithType(ctx, server, name, RecordTypeA)
} }
return return
} }
func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (jr dohjsonresponse, err error) { // LookupDoHWithType ...
func LookupDoHWithType(ctx context.Context, server, name string, typ RecordType) (jr Response, err error) {
sb := strings.Builder{} sb := strings.Builder{}
sb.WriteString(server) sb.WriteString(server)
sb.WriteString("?name=") sb.WriteString("?name=")
sb.WriteString(url.QueryEscape(u)) sb.WriteString(url.QueryEscape(name))
if typ != recordTypeNone { if typ != RecordTypeNone {
sb.WriteString("&type=") sb.WriteString("&type=")
sb.WriteString(strconv.Itoa(int(typ))) sb.WriteString(strconv.Itoa(int(typ)))
} }
@@ -171,9 +124,9 @@ func lookupdohwithtype(ctx context.Context, server, u string, typ recordType) (j
return return
} }
func preferreddohtype() recordType { func prefertyp() RecordType {
if ip.IsIPv6Available { if ip.IsIPv6Available {
return recordTypeAAAA return RecordTypeAAAA
} }
return recordTypeA return RecordTypeA
} }

2
ext/custom/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

24
ext/init.go Normal file
View File

@@ -0,0 +1,24 @@
package ext
import (
"os"
"plugin"
"github.com/sirupsen/logrus"
)
const (
TRSPluginFile = "./terasu.plugin.so"
)
func init() {
if _, err := os.Stat(TRSPluginFile); err != nil {
return
}
_, err := plugin.Open(TRSPluginFile)
if err != nil {
logrus.Warnln("[terasu.plugin] load", TRSPluginFile, "err:", err)
logrus.Warnln("[terasu.plugin] hint: ensure the main binary and plugin are built with identical flags (e.g. both use -trimpath -ldflags=\"-s -w\"), and avoid using 'go run'")
return
}
}

View File

@@ -2,106 +2,19 @@
package http package http
import ( import (
"context"
"crypto/tls"
"errors"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/fumiama/terasu"
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
) )
var (
// ErrNoTLSConnection is returned when a TLS connection cannot be established.
ErrNoTLSConnection = errors.New("no tls connection")
// ErrEmptyHostAddress is returned when the host address is empty.
ErrEmptyHostAddress = errors.New("empty host addr")
)
// defaultDialer is the default dialer used for connecting to hosts.
var defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
}
// SetDefaultClientTimeout sets the default timeout for the client's dialer.
func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t
}
// DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling. // DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling.
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialTLSContext: dns.DialTLSContext,
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
addrs, err := dns.LookupHost(ctx, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, ErrEmptyHostAddress
}
var conn net.Conn
var tlsConn *tls.Conn
for _, a := range addrs {
// Apply timeout if set, otherwise use deadline
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{
ServerName: host,
MinVersion: tls.VersionTLS12,
})
// Re-initialize context due to potential deadline changes from TCP dial
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), &tls.Config{
ServerName: host,
MinVersion: tls.VersionTLS12,
})
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
}
return tlsConn, err
},
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
MaxIdleConns: 100, MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,

View File

@@ -2,96 +2,18 @@
package http2 package http2
import ( import (
"context"
"crypto/tls"
"errors"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"time"
"golang.org/x/net/http2"
"github.com/fumiama/terasu"
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
"golang.org/x/net/http2"
) )
// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses
var ErrEmptyHostAddress = errors.New("empty host addr")
// defaultDialer is the default dialer used for establishing TCP connections
var defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
}
// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections
func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t
}
// DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution // DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http2.Transport{ Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { DialTLSContext: dns.DialTLSContextWithConfig,
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
addrs, err := dns.LookupHost(ctx, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, ErrEmptyHostAddress
}
var conn net.Conn
var tlsConn *tls.Conn
for _, a := range addrs {
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
// re-init ctx due to deadline settings in tcp dial
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
defer cancel()
} else if !defaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline)
defer cancel()
}
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
}
return tlsConn, err
},
}, },
} }

View File

@@ -13,9 +13,9 @@ import (
"net/http" "net/http"
"net/netip" "net/netip"
"net/url" "net/url"
"time"
base14 "github.com/fumiama/go-base16384" base14 "github.com/fumiama/go-base16384"
"github.com/fumiama/terasu/dialer"
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3"
@@ -24,16 +24,6 @@ import (
// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses // ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses
var ErrEmptyHostAddress = errors.New("empty host addr") var ErrEmptyHostAddress = errors.New("empty host addr")
// defaultDialer is the default dialer used for establishing TCP connections
var defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
}
// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections
func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t
}
// DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution // DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http3.Transport{ Transport: &http3.Transport{
@@ -52,13 +42,13 @@ var DefaultClient = http.Client{
var conn net.Conn var conn net.Conn
var qConn quic.EarlyConnection var qConn quic.EarlyConnection
for _, a := range addrs { for _, a := range addrs {
if defaultDialer.Timeout != 0 { if dialer.DefaultDialer.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout)
defer cancel() defer cancel()
} else if !defaultDialer.Deadline.IsZero() { } else if !dialer.DefaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline)
defer cancel() defer cancel()
} }
conn, err = net.ListenUDP("udp", nil) conn, err = net.ListenUDP("udp", nil)
@@ -76,13 +66,13 @@ var DefaultClient = http.Client{
_ = e.Close() _ = e.Close()
_, _ = ucon.WriteToUDP(w.Bytes(), raddr) _, _ = ucon.WriteToUDP(w.Bytes(), raddr)
// re-init ctx due to deadline settings in tcp dial // re-init ctx due to deadline settings in tcp dial
if defaultDialer.Timeout != 0 { if dialer.DefaultDialer.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout)
defer cancel() defer cancel()
} else if !defaultDialer.Deadline.IsZero() { } else if !dialer.DefaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), defaultDialer.Deadline) ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline)
defer cancel() defer cancel()
} }
qConn, err = quic.DialEarly(ctx, ucon, raddr, tlsCfg, cfg) qConn, err = quic.DialEarly(ctx, ucon, raddr, tlsCfg, cfg)

94
tls/dial.go Normal file
View File

@@ -0,0 +1,94 @@
package tls
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/fumiama/terasu"
"github.com/fumiama/terasu/dialer"
)
// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses
var ErrEmptyHostAddress = errors.New("empty host addr")
// DialTLSContextWithConfigAndSystemResolver fills http2.Transport method with terasu and system DNS
func DialTLSContextWithConfigAndSystemResolver(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
return DialTLSContextCL(ctx, network, addr, cfg, nil)
}
// DialTLSContextCL fills http2.Transport method with terasu
func DialTLSContextCL(
ctx context.Context, network, addr string,
cfg *tls.Config, lookup func(ctx context.Context, host string,
) (addrs []string, err error)) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
var addrs []string
if lookup != nil {
addrs, err = lookup(ctx, host)
} else {
addrs, err = net.DefaultResolver.LookupHost(ctx, host)
}
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, ErrEmptyHostAddress
}
if cfg == nil {
cfg = &tls.Config{
ServerName: host,
MinVersion: tls.VersionTLS12,
}
}
var conn net.Conn
var tlsConn *tls.Conn
for _, a := range addrs {
if dialer.DefaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout)
defer cancel()
} else if !dialer.DefaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline)
defer cancel()
}
conn, err = dialer.DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
// re-init ctx due to deadline settings in tcp dial
if dialer.DefaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), dialer.DefaultDialer.Timeout)
defer cancel()
} else if !dialer.DefaultDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), dialer.DefaultDialer.Deadline)
defer cancel()
}
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
conn, err = dialer.DefaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
if err != nil {
continue
}
tlsConn = tls.Client(terasu.NewConn(conn), cfg)
err = tlsConn.HandshakeContext(ctx)
if err == nil {
break
}
_ = tlsConn.Close()
tlsConn = nil
}
return tlsConn, err
}