1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-05 01:00:23 +08:00

fix: error on different frag lens

This commit is contained in:
源文雨
2025-10-23 23:33:06 +08:00
parent 3f56d5341b
commit bda0c8de97
9 changed files with 206 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
// Package main ...
// Package main provides the main entry point for terasu.
// It demonstrates basic Go usage of this library.
package main
import (

31
conn.go
View File

@@ -1,8 +1,9 @@
package terasu
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"net"
"sync"
@@ -14,14 +15,19 @@ var DefaultFirstFragmentLen = 4
// Conn remote: real server; local: relay
type Conn struct {
mu sync.Mutex
relay relay
init *sync.Once
conn *net.TCPConn
isold bool
}
// NewConn wraps *net.TCPConn (net.Conn must be *net.TCPConn)
func NewConn(conn net.Conn) *Conn {
return &Conn{conn: conn.(*net.TCPConn)}
return &Conn{
relay: newrelay(),
init: &sync.Once{},
conn: conn.(*net.TCPConn),
}
}
// Write is send
@@ -29,14 +35,21 @@ func (conn *Conn) Write(b []byte) (int, error) {
if conn.isold || DefaultFirstFragmentLen == 0 {
return conn.conn.Write(b)
}
conn.mu.Lock()
defer conn.mu.Unlock()
n, err := conn.ReadFrom(bytes.NewReader(b))
return int(n), err
go conn.init.Do(func() {
_, err := io.Copy(conn, &conn.relay)
if err != nil {
_ = conn.relay.Close()
}
})
return conn.relay.Write(b)
}
// ReadFrom when client want to send to server, detect and split.
func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
if conn.isold || DefaultFirstFragmentLen == 0 {
return conn.conn.ReadFrom(r)
}
// ContentType [0:1]
// Version [1:3]
// Length [3:5]
@@ -102,6 +115,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
// split
if x <= 4 { // first is in header range
fmt.Println(hex.EncodeToString(header[:]))
// first
binary.BigEndian.PutUint16(header[3:5], uint16(x))
bd.move(header[:5+x])
@@ -110,7 +124,7 @@ func (conn *Conn) ReadFrom(r io.Reader) (n int64, err error) {
if err != nil {
return
}
copy(header[5:5+x], header[9-x:9])
copy(header[5:9-x], header[5+x:9])
// second
binary.BigEndian.PutUint16(header[3:5], plen-uint16(x))
bd.move(header[:9-x])
@@ -138,6 +152,7 @@ PIPE:
if err != nil {
return
}
_ = conn.relay.Close()
cnt, err := bd.send(conn.conn, r)
n += cnt
return

View File

@@ -9,6 +9,50 @@ import (
"testing"
)
func TestHTTPDialDifferentFragLen(t *testing.T) {
cli := http.Client{
Transport: &http.Transport{
DialTLS: func(network, addr string) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(
netip.MustParseAddrPort("52.222.136.117:443"),
))
if err != nil {
return nil, err
}
t.Log("net.Dial succeeded")
tlsConn := tls.Client(NewConn(conn), &tls.Config{
ServerName: "huggingface.co",
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
})
err = tlsConn.Handshake()
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return tlsConn, nil
},
},
}
for i := 0; i < 10; i++ {
// will fail when i=0 in CN
DefaultFirstFragmentLen = i
resp, err := cli.Get("https://huggingface.co/")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("status code:", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
}
}
func TestHTTPDialTLS13(t *testing.T) {
cli := http.Client{
Transport: &http.Transport{

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"errors"
"net"
"slices"
"strings"
"sync"
"syscall"
@@ -27,6 +28,7 @@ var dnsDialer = net.Dialer{
Timeout: time.Second * 4,
}
// SetTimeout ...
func SetTimeout(t time.Duration) {
dnsDialer.Timeout = t
}
@@ -37,6 +39,7 @@ type dnsstat struct {
keep bool
}
// String ...
func (ds *dnsstat) String() string {
sb := strings.Builder{}
sb.WriteString("[addr: ")
@@ -78,6 +81,7 @@ func (ds *dnsstat) disable(reEnable time.Duration) {
})
}
// DNSList is a bundle of DNSs
type DNSList struct {
sync.RWMutex
hostseq []string
@@ -85,6 +89,7 @@ type DNSList struct {
b map[string][]string
}
// DNSConfig is the user config
type DNSConfig struct {
Servers map[string][]string `yaml:"Servers"` // Servers map[dot.com]ip:ports
Fallbacks map[string][]string `yaml:"Fallbacks"` // Fallbacks map[domain]ips
@@ -102,14 +107,10 @@ func hasrecord(lst []*dnsstat, a string) bool {
// hasrecord no lock, use under lock
func hasfallback(lst []string, a string) bool {
for _, addr := range lst {
if addr == a {
return true
}
}
return false
return slices.Contains(lst, a)
}
// Add ...
func (ds *DNSList) Add(c *DNSConfig) {
ds.Lock()
defer ds.Unlock()
@@ -193,6 +194,7 @@ func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []stri
return nil, ErrNoDNSAvailable
}
// DialContext ...
func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn *tls.Conn, err error) {
err = ErrNoDNSAvailable
@@ -267,6 +269,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer) (tlsConn
return
}
// IPv6Servers should only be used when IPv6 is available
var IPv6Servers = DNSList{
hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
@@ -303,6 +306,7 @@ var IPv6Servers = DNSList{
b: map[string][]string{},
}
// IPv4Servers is the default server set
var IPv4Servers = DNSList{
hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
@@ -339,6 +343,7 @@ var IPv4Servers = DNSList{
b: map[string][]string{},
}
// DefaultResolver ...
var DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, nw, _ string) (net.Conn, error) {

View File

@@ -18,6 +18,7 @@ import (
)
var (
// ErrEmptyHostAddress ...
ErrEmptyHostAddress = errors.New("empty host addr")
)
@@ -29,25 +30,43 @@ const (
recordTypeAAAA recordType = 28
)
// dohjsonresponse represents the JSON response structure for DNS over HTTPS (DoH) queries.
// It contains DNS query results and metadata about the response.
type dohjsonresponse struct {
Status uint32
TC bool
RD bool
RA bool
AD bool
CD bool
// Status indicates the DNS query status code (0 = NOERROR, etc.)
Status uint32
// TC indicates whether the response was truncated (true if truncated)
TC bool
// RD indicates whether recursion was requested in the query
RD bool
// RA indicates whether the server supports recursion
RA bool
// AD indicates whether the response was authenticated (DNSSEC)
AD bool
// CD indicates whether the client requested that DNSSEC validation be disabled
CD bool
// Question contains the DNS query question section with name and type
Question []struct {
Name string `json:"name"`
// Name is the domain name being queried
Name string `json:"name"`
// Type is the DNS record type being requested (A, AAAA, etc.)
Type recordType `json:"type"`
}
// Answer contains the DNS response answer section with resource records
Answer []struct {
Name string `json:"name"`
// Name is the domain name for this resource record
Name string `json:"name"`
// Type is the DNS record type (A, AAAA, etc.)
Type recordType `json:"type"`
TTL uint16
// TTL is the time-to-live value for this resource record in seconds
TTL uint16
// Data is the textual representation of the resource record data
Data string `json:"data"`
}
// EdnsClientSubnet is the EDNS client subnet information for geolocation
EdnsClientSubnet string `json:"edns_client_subnet"`
Comment string
// Comment is an optional comment field for additional information
Comment string
}
func (jr *dohjsonresponse) hosts() []string {

View File

@@ -1,4 +1,4 @@
// Package http is the same as the standard http lib
// Package http is a wrapper around the standard http library with enhanced DNS resolution and TLS handling capabilities.
package http
import (
@@ -16,18 +16,23 @@ import (
)
var (
ErrNoTLSConnection = errors.New("no tls connection")
// ErrNoTLSConnection is returned when a TLS connection cannot be established.
ErrNoTLSConnection = errors.New("no tls connection")
// ErrEmptyHostAddress is returned when the host address is empty.
ErrEmptyHostAddress = errors.New("empty host addr")
)
// defaultDialer is the default dialer used for connecting to hosts.
var defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
}
// SetDefaultClientTimeout sets the default timeout for the client's dialer.
func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t
}
// DefaultClient is the default HTTP client with custom transport settings, including DNS resolution and TLS handling.
var DefaultClient = http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -40,12 +45,13 @@ var DefaultClient = http.Client{
if err != nil {
return nil, err
}
if len(addr) == 0 {
if len(addrs) == 0 {
return nil, ErrEmptyHostAddress
}
var conn net.Conn
var tlsConn *tls.Conn
for _, a := range addrs {
// Apply timeout if set, otherwise use deadline
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
@@ -63,7 +69,7 @@ var DefaultClient = http.Client{
ServerName: host,
MinVersion: tls.VersionTLS12,
})
// re-init ctx due to deadline settings in tcp dial
// Re-initialize context due to potential deadline changes from TCP dial
if defaultDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), defaultDialer.Timeout)
@@ -104,18 +110,22 @@ var DefaultClient = http.Client{
},
}
// Get performs an HTTP GET request using the default client.
func Get(url string) (resp *http.Response, err error) {
return DefaultClient.Get(url)
}
// Head performs an HTTP HEAD request using the default client.
func Head(url string) (resp *http.Response, err error) {
return DefaultClient.Head(url)
}
// Post performs an HTTP POST request using the default client.
func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) {
return DefaultClient.Post(url, contentType, body)
}
// PostForm performs an HTTP POST request with form data using the default client.
func PostForm(url string, data url.Values) (resp *http.Response, err error) {
return DefaultClient.PostForm(url, data)
}

View File

@@ -17,18 +17,20 @@ import (
"github.com/fumiama/terasu/dns"
)
var (
ErrEmptyHostAddress = errors.New("empty host addr")
)
// ErrEmptyHostAddress is returned when DNS lookup for a host returns no addresses
var ErrEmptyHostAddress = errors.New("empty host addr")
// defaultDialer is the default dialer used for establishing TCP connections
var defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
}
// SetDefaultClientTimeout sets the default timeout for all HTTP2 client connections
func SetDefaultClientTimeout(t time.Duration) {
defaultDialer.Timeout = t
}
// DefaultClient is the default HTTP2 client that supports HTTP/2 and DNS resolution
var DefaultClient = http.Client{
Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -40,7 +42,7 @@ var DefaultClient = http.Client{
if err != nil {
return nil, err
}
if len(addr) == 0 {
if len(addrs) == 0 {
return nil, ErrEmptyHostAddress
}
var conn net.Conn
@@ -93,18 +95,22 @@ var DefaultClient = http.Client{
},
}
// Get sends an HTTP GET request to the specified URL using the default HTTP2 client
func Get(url string) (resp *http.Response, err error) {
return DefaultClient.Get(url)
}
// Head sends an HTTP HEAD request to the specified URL using the default HTTP2 client
func Head(url string) (resp *http.Response, err error) {
return DefaultClient.Head(url)
}
// Post sends an HTTP POST request to the specified URL with the given content type and body using the default HTTP2 client
func Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) {
return DefaultClient.Post(url, contentType, body)
}
// PostForm sends an HTTP POST request with form data to the specified URL using the default HTTP2 client
func PostForm(url string, data url.Values) (resp *http.Response, err error) {
return DefaultClient.PostForm(url, data)
}

View File

@@ -1,4 +1,5 @@
// Package ip contains IP-related configs
package ip
// IsIPv6Available ...
var IsIPv6Available = false

72
relay.go Normal file
View File

@@ -0,0 +1,72 @@
package terasu
import (
"io"
"sync"
)
type relay struct {
mu sync.Mutex
buf chan []byte
rem []byte
}
func newrelay() relay {
return relay{buf: make(chan []byte, 64)}
}
// Read ...
func (r *relay) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
switch {
case len(p) == 0:
return
case len(p) <= len(r.rem):
n = copy(p, r.rem)
r.rem = r.rem[n:]
if len(r.rem) == 0 {
r.rem = nil
}
return
case len(r.rem) > 0:
n = copy(p, r.rem)
r.rem = nil
fallthrough
default:
for n < len(p) {
buf := <-r.buf
if len(buf) == 0 {
err = io.EOF
return
}
switch {
case len(buf) >= len(p)-n:
cnt := copy(p[n:], buf)
n += cnt
r.rem = buf[cnt:]
if len(r.rem) == 0 {
r.rem = nil
}
return
default:
n += copy(p[n:], buf)
}
}
}
panic("unexpected")
}
// Write ...
func (r *relay) Write(p []byte) (n int, err error) {
buf := make([]byte, len(p))
n = copy(buf, p)
r.buf <- p
return
}
// Close ...
func (r *relay) Close() error {
close(r.buf)
return nil
}