1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-05 01:00:23 +08:00

fix: error on different frag lens

This commit is contained in:
源文雨
2025-10-23 23:33:06 +08:00
parent 3f56d5341b
commit bda0c8de97
9 changed files with 206 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
// Package main ... // Package main provides the main entry point for terasu.
// It demonstrates basic Go usage of this library.
package main package main
import ( import (

31
conn.go
View File

@@ -1,8 +1,9 @@
package terasu package terasu
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"encoding/hex"
"fmt"
"io" "io"
"net" "net"
"sync" "sync"
@@ -14,14 +15,19 @@ var DefaultFirstFragmentLen = 4
// Conn remote: real server; local: relay // Conn remote: real server; local: relay
type Conn struct { type Conn struct {
mu sync.Mutex relay relay
init *sync.Once
conn *net.TCPConn conn *net.TCPConn
isold bool isold bool
} }
// NewConn wraps *net.TCPConn (net.Conn must be *net.TCPConn) // NewConn wraps *net.TCPConn (net.Conn must be *net.TCPConn)
func NewConn(conn net.Conn) *Conn { func NewConn(conn net.Conn) *Conn {
return &Conn{conn: conn.(*net.TCPConn)} return &Conn{
relay: newrelay(),
init: &sync.Once{},
conn: conn.(*net.TCPConn),
}
} }
// Write is send // Write is send
@@ -29,14 +35,21 @@ func (conn *Conn) Write(b []byte) (int, error) {
if conn.isold || DefaultFirstFragmentLen == 0 { if conn.isold || DefaultFirstFragmentLen == 0 {
return conn.conn.Write(b) return conn.conn.Write(b)
} }
conn.mu.Lock() go conn.init.Do(func() {
defer conn.mu.Unlock() _, err := io.Copy(conn, &conn.relay)
n, err := conn.ReadFrom(bytes.NewReader(b)) if err != nil {
return int(n), err _ = conn.relay.Close()
}
})
return conn.relay.Write(b)
} }
// ReadFrom when client want to send to server, detect and split. // ReadFrom when client want to send to server, detect and split.
func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) { func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
if conn.isold || DefaultFirstFragmentLen == 0 {
return conn.conn.ReadFrom(r)
}
// ContentType [0:1] // ContentType [0:1]
// Version [1:3] // Version [1:3]
// Length [3:5] // Length [3:5]
@@ -102,6 +115,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
// split // split
if x <= 4 { // first is in header range if x <= 4 { // first is in header range
fmt.Println(hex.EncodeToString(header[:]))
// first // first
binary.BigEndian.PutUint16(header[3:5], uint16(x)) binary.BigEndian.PutUint16(header[3:5], uint16(x))
bd.move(header[:5+x]) bd.move(header[:5+x])
@@ -110,7 +124,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
if err != nil { if err != nil {
return return
} }
copy(header[5:5+x], header[9-x:9]) copy(header[5:9-x], header[5+x:9])
// second // second
binary.BigEndian.PutUint16(header[3:5], plen-uint16(x)) binary.BigEndian.PutUint16(header[3:5], plen-uint16(x))
bd.move(header[:9-x]) bd.move(header[:9-x])
@@ -138,6 +152,7 @@ PIPE:
if err != nil { if err != nil {
return return
} }
_ = conn.relay.Close()
cnt, err := bd.send(conn.conn, r) cnt, err := bd.send(conn.conn, r)
n += cnt n += cnt
return return

View File

@@ -9,6 +9,50 @@ import (
"testing" "testing"
) )
func TestHTTPDialDifferentFragLen(t *testing.T) {
cli := http.Client{
Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(
netip.MustParseAddrPort("52.222.136.117:443"),
))
if err != nil {
return nil, err
}
t.Log("net.Dial succeeded")
tlsConn := tls.Client(NewConn(conn), &tls.Config{
ServerName: "huggingface.co",
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
})
err = tlsConn.Handshake()
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return tlsConn, nil
},
},
}
for i := 0; i < 10; i++ {
// will fail when i=0 in CN
DefaultFirstFragmentLen = i
resp, err := cli.Get("https://huggingface.co/")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("status code:", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
}
}
func TestHTTPDialTLS13(t *testing.T) { func TestHTTPDialTLS13(t *testing.T) {
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -27,6 +28,7 @@ var dnsDialer = net.Dialer{
Timeout: time.Second * 4, Timeout: time.Second * 4,
} }
// SetTimeout ...
func SetTimeout(t time.Duration) { func SetTimeout(t time.Duration) {
dnsDialer.Timeout = t dnsDialer.Timeout = t
} }
@@ -37,6 +39,7 @@ type dnsstat struct {
keep bool keep bool
} }
// String ...
func (ds *dnsstat) String() string { func (ds *dnsstat) String() string {
sb := strings.Builder{} sb := strings.Builder{}
sb.WriteString("[addr: ") sb.WriteString("[addr: ")
@@ -78,6 +81,7 @@ func (ds *dnsstat) disable(reEnable time.Duration) {
}) })
} }
// DNSList is a bundle of DNSs
type DNSList struct { type DNSList struct {
sync.RWMutex sync.RWMutex
hostseq []string hostseq []string
@@ -85,6 +89,7 @@ type DNSList struct {
b map[string][]string b map[string][]string
} }
// DNSConfig is the user config
type DNSConfig struct { type DNSConfig struct {
Servers map[string][]string `yaml:"Servers"` // Servers map[dot.com]ip:ports Servers map[string][]string `yaml:"Servers"` // Servers map[dot.com]ip:ports
Fallbacks map[string][]string `yaml:"Fallbacks"` // Fallbacks map[domain]ips Fallbacks map[string][]string `yaml:"Fallbacks"` // Fallbacks map[domain]ips
@@ -102,14 +107,10 @@ func hasrecord(lst []*dnsstat, a string) bool {
// hasrecord no lock, use under lock // hasrecord no lock, use under lock
func hasfallback(lst []string, a string) bool { func hasfallback(lst []string, a string) bool {
for _, addr := range lst { return slices.Contains(lst, a)
if addr == a {
return true
}
}
return false
} }
// Add ...
func (ds *DNSList) Add(c *DNSConfig) { func (ds *DNSList) Add(c *DNSConfig) {
ds.Lock() ds.Lock()
defer ds.Unlock() defer ds.Unlock()
@@ -193,6 +194,7 @@ func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []stri
return nil, ErrNoDNSAvailable return nil, ErrNoDNSAvailable
} }
// DialContext ...
func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) { func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) {
err = ErrNoDNSAvailable err = ErrNoDNSAvailable
@@ -267,6 +269,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn
return return
} }
// IPv6Servers should only be used when IPv6 is available
var IPv6Servers = DNSList{ var IPv6Servers = DNSList{
hostseq: []string{ hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
@@ -303,6 +306,7 @@ var IPv6Servers = DNSList{
b: map[string][]string{}, b: map[string][]string{},
} }
// IPv4Servers is the default server set
var IPv4Servers = DNSList{ var IPv4Servers = DNSList{
hostseq: []string{ hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
@@ -339,6 +343,7 @@ var IPv4Servers = DNSList{
b: map[string][]string{}, b: map[string][]string{},
} }
// DefaultResolver ...
var DefaultResolver = &net.Resolver{ var DefaultResolver = &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, nw, _ string) (net.Conn, error) { Dial: func(ctx context.Context, nw, _ string) (net.Conn, error) {

View File

@@ -18,6 +18,7 @@ import (
) )
var ( var (
// ErrEmptyHostAddress ...
ErrEmptyHostAddress = errors.New("empty host addr") ErrEmptyHostAddress = errors.New("empty host addr")
) )
@@ -29,25 +30,43 @@ const (
recordTypeAAAA recordType = 28 recordTypeAAAA recordType = 28
) )
// dohjsonresponse represents the JSON response structure for DNS over HTTPS (DoH) queries.
// It contains DNS query results and metadata about the response.
type dohjsonresponse struct { type dohjsonresponse struct {
Status uint32 // Status indicates the DNS query status code (0 = NOERROR, etc.)
TC bool Status uint32
RD bool // TC indicates whether the response was truncated (true if truncated)
RA bool TC bool
AD bool // RD indicates whether recursion was requested in the query
CD bool RD bool
// RA indicates whether the server supports recursion
RA bool
// AD indicates whether the response was authenticated (DNSSEC)
AD bool
// CD indicates whether the client requested that DNSSEC validation be disabled
CD bool
// Question contains the DNS query question section with name and type
Question []struct { Question []struct {
Name string `json:"name"` // Name is the domain name being queried
Name string `json:"name"`
// Type is the DNS record type being requested (A, AAAA, etc.)
Type recordType `json:"type"` Type recordType `json:"type"`
} }
// Answer contains the DNS response answer section with resource records
Answer []struct { Answer []struct {
Name string `json:"name"` // Name is the domain name for this resource record
Name string `json:"name"`
// Type is the DNS record type (A, AAAA, etc.)
Type recordType `json:"type"` Type recordType `json:"type"`
TTL uint16 // TTL is the time-to-live value for this resource record in seconds
TTL uint16
// Data is the textual representation of the resource record data
Data string `json:"data"` Data string `json:"data"`
} }
// EdnsClientSubnet is the EDNS client subnet information for geolocation
EdnsClientSubnet string `json:"edns_client_subnet"` EdnsClientSubnet string `json:"edns_client_subnet"`
Comment string // Comment is an optional comment field for additional information
Comment string
} }
func (jr *dohjsonresponse) hosts() []string { func (jr *dohjsonresponse) hosts() []string {

View File

@@ -1,4 +1,4 @@
// Package http is the same as the standard http lib // Package http is a wrapper around the standard http library with enhanced DNS resolution and TLS handling capabilities.
package http package http
import ( import (
@@ -16,18 +16,23 @@ import (
) )
var ( var (
ErrNoTLSConnection = errors.New("no tls connection") // ErrNoTLSConnection is returned when a TLS connection cannot be established.
ErrNoTLSConnection = errors.New("no tls connection")
// ErrEmptyHostAddress is returned when the host address is empty.
ErrEmptyHostAddress = errors.New("empty host addr") ErrEmptyHostAddress = errors.New("empty host addr")
) )
// defaultDialer is the default dialer used for connecting to hosts.
var defaultDialer = net.Dialer{ var defaultDialer = net.Dialer{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
} }
// SetDefaultClientTimeout sets the default timeout for the client's dialer.
func SetDefaultClientTimeout(t time.Duration) { func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t defaultDialer.Timeout = t
} }
// DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling.
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@@ -40,12 +45,13 @@ var DefaultClient = http.Client{
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(addr) == 0 { if len(addrs) == 0 {
return nil, ErrEmptyHostAddress return nil, ErrEmptyHostAddress
} }
var conn net.Conn var conn net.Conn
var tlsConn *tls.Conn var tlsConn *tls.Conn
for _, a := range addrs { for _, a := range addrs {
// Apply timeout if set, otherwise use deadline
if defaultDialer.Timeout != 0 { if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
@@ -63,7 +69,7 @@ var DefaultClient = http.Client{
ServerName: host, ServerName: host,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}) })
// re-init ctx due to deadline settings in tcp dial // Re-initialize context due to potential deadline changes from TCP dial
if defaultDialer.Timeout != 0 { if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout) ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
@@ -104,18 +110,22 @@ var DefaultClient = http.Client{
}, },
} }
// Get performs an HTTP GET request using the default client.
func Get(url string) (resp *http.Response, err error) { func Get(url string) (resp *http.Response, err error) {
return DefaultClient.Get(url) return DefaultClient.Get(url)
} }
// Head performs an HTTP HEAD request using the default client.
func Head(url string) (resp *http.Response, err error) { func Head(url string) (resp *http.Response, err error) {
return DefaultClient.Head(url) return DefaultClient.Head(url)
} }
// Post performs an HTTP POST request using the default client.
func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) {
return DefaultClient.Post(url, contentType, body) return DefaultClient.Post(url, contentType, body)
} }
// PostForm performs an HTTP POST request with form data using the default client.
func PostForm(url string, data url.Values) (resp *http.Response, err error) { func PostForm(url string, data url.Values) (resp *http.Response, err error) {
return DefaultClient.PostForm(url, data) return DefaultClient.PostForm(url, data)
} }

View File

@@ -17,18 +17,20 @@ import (
"github.com/fumiama/terasu/dns" "github.com/fumiama/terasu/dns"
) )
var ( // ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses
ErrEmptyHostAddress = errors.New("empty host addr") var ErrEmptyHostAddress = errors.New("empty host addr")
)
// defaultDialer is the default dialer used for establishing TCP connections
var defaultDialer = net.Dialer{ var defaultDialer = net.Dialer{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
} }
// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections
func SetDefaultClientTimeout(t time.Duration) { func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t defaultDialer.Timeout = t
} }
// DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution
var DefaultClient = http.Client{ var DefaultClient = http.Client{
Transport: &http2.Transport{ Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -40,7 +42,7 @@ var DefaultClient = http.Client{
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(addr) == 0 { if len(addrs) == 0 {
return nil, ErrEmptyHostAddress return nil, ErrEmptyHostAddress
} }
var conn net.Conn var conn net.Conn
@@ -93,18 +95,22 @@ var DefaultClient = http.Client{
}, },
} }
// Get sends an HTTP GET request to the specified URL using the default HTTP2 client
func Get(url string) (resp *http.Response, err error) { func Get(url string) (resp *http.Response, err error) {
return DefaultClient.Get(url) return DefaultClient.Get(url)
} }
// Head sends an HTTP HEAD request to the specified URL using the default HTTP2 client
func Head(url string) (resp *http.Response, err error) { func Head(url string) (resp *http.Response, err error) {
return DefaultClient.Head(url) return DefaultClient.Head(url)
} }
// Post sends an HTTP POST request to the specified URL with the given content type and body using the default HTTP2 client
func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) { func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) {
return DefaultClient.Post(url, contentType, body) return DefaultClient.Post(url, contentType, body)
} }
// PostForm sends an HTTP POST request with form data to the specified URL using the default HTTP2 client
func PostForm(url string, data url.Values) (resp *http.Response, err error) { func PostForm(url string, data url.Values) (resp *http.Response, err error) {
return DefaultClient.PostForm(url, data) return DefaultClient.PostForm(url, data)
} }

View File

@@ -1,4 +1,5 @@
// Package ip contains IP-related configs // Package ip contains IP-related configs
package ip package ip
// IsIPv6Available ...
var IsIPv6Available = false var IsIPv6Available = false

72
relay.go Normal file
View File

@@ -0,0 +1,72 @@
package terasu
import (
"io"
"sync"
)
type relay struct {
mu sync.Mutex
buf chan []byte
rem []byte
}
func newrelay() relay {
return relay{buf: make(chan []byte, 64)}
}
// Read ...
func (r *relay) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
switch {
case len(p) == 0:
return
case len(p) <= len(r.rem):
n = copy(p, r.rem)
r.rem = r.rem[n:]
if len(r.rem) == 0 {
r.rem = nil
}
return
case len(r.rem) > 0:
n = copy(p, r.rem)
r.rem = nil
fallthrough
default:
for n < len(p) {
buf := <-r.buf
if len(buf) == 0 {
err = io.EOF
return
}
switch {
case len(buf) >= len(p)-n:
cnt := copy(p[n:], buf)
n += cnt
r.rem = buf[cnt:]
if len(r.rem) == 0 {
r.rem = nil
}
return
default:
n += copy(p[n:], buf)
}
}
}
panic("unexpected")
}
// Write ...
func (r *relay) Write(p []byte) (n int, err error) {
buf := make([]byte, len(p))
n = copy(buf, p)
r.buf <- p
return
}
// Close ...
func (r *relay) Close() error {
close(r.buf)
return nil
}