1
0
mirror of https://github.com/fumiama/terasu-cloudflared.git synced 2026-06-12 14:10:31 +08:00

TUN-8861: Add session limiter to UDP session manager

## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This first
commit introduces the session limiter and adds it to the UDP
handling path. For now the limiter is set to run only in
unlimited mode.
This commit is contained in:
João "Pisco" Fernandes
2025-01-20 02:52:32 -08:00
parent 8918b6729e
commit bf4954e96a
66 changed files with 3409 additions and 1184 deletions

View File

@@ -904,6 +904,10 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
return errors.New("ssh: signature did not verify")
}
func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey {
return &k.PublicKey
}
type skEd25519PublicKey struct {
// application is a URL-like string, typically "ssh:" for SSH.
// see openssh/PROTOCOL.u2f for details.
@@ -1000,6 +1004,10 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
return nil
}
func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey {
return k.PublicKey
}
// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
// *ecdsa.PrivateKey or any other crypto.Signer and returns a
// corresponding Signer instance. ECDSA keys must use P-256, P-384 or

View File

@@ -462,6 +462,24 @@ func (p *PartialSuccessError) Error() string {
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
// BannerError is an error that can be returned by authentication handlers in
// ServerConfig to send a banner message to the client.
type BannerError struct {
Err error
Message string
}
func (b *BannerError) Unwrap() error {
return b.Err
}
func (b *BannerError) Error() string {
if b.Err == nil {
return b.Message
}
return b.Err.Error()
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
@@ -734,6 +752,18 @@ userAuthLoop:
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}
var bannerErr *BannerError
if errors.As(authErr, &bannerErr) {
if bannerErr.Message != "" {
bannerMsg := &userAuthBannerMsg{
Message: bannerErr.Message,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
if authErr == nil {
break userAuthLoop
}

View File

@@ -38,6 +38,7 @@ type File struct {
Module *Module
Go *Go
Toolchain *Toolchain
Godebug []*Godebug
Require []*Require
Exclude []*Exclude
Replace []*Replace
@@ -65,6 +66,13 @@ type Toolchain struct {
Syntax *Line
}
// A Godebug is a single godebug key=value statement.
type Godebug struct {
Key string
Value string
Syntax *Line
}
// An Exclude is a single exclude statement.
type Exclude struct {
Mod module.Version
@@ -289,7 +297,7 @@ func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (parse
})
}
continue
case "module", "require", "exclude", "replace", "retract":
case "module", "godebug", "require", "exclude", "replace", "retract":
for _, l := range x.Line {
f.add(&errs, x, l, x.Token[0], l.Token, fix, strict)
}
@@ -308,7 +316,9 @@ var laxGoVersionRE = lazyregexp.New(`^v?(([1-9][0-9]*)\.(0|[1-9][0-9]*))([^0-9].
// Toolchains must be named beginning with `go1`,
// like "go1.20.3" or "go1.20.3-gccgo". As a special case, "default" is also permitted.
// TODO(samthanawalla): Replace regex with https://pkg.go.dev/go/version#IsValid in 1.23+
// Note that this regexp is a much looser condition than go/version.IsValid,
// for forward compatibility.
// (This code has to be work to identify new toolchains even if we tweak the syntax in the future.)
var ToolchainRE = lazyregexp.New(`^default$|^go1($|\.)`)
func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, args []string, fix VersionFixer, strict bool) {
@@ -384,7 +394,7 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a
if len(args) != 1 {
errorf("toolchain directive expects exactly one argument")
return
} else if strict && !ToolchainRE.MatchString(args[0]) {
} else if !ToolchainRE.MatchString(args[0]) {
errorf("invalid toolchain version '%s': must match format go1.23.0 or default", args[0])
return
}
@@ -412,6 +422,22 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a
}
f.Module.Mod = module.Version{Path: s}
case "godebug":
if len(args) != 1 || strings.ContainsAny(args[0], "\"`',") {
errorf("usage: godebug key=value")
return
}
key, value, ok := strings.Cut(args[0], "=")
if !ok {
errorf("usage: godebug key=value")
return
}
f.Godebug = append(f.Godebug, &Godebug{
Key: key,
Value: value,
Syntax: line,
})
case "require", "exclude":
if len(args) != 2 {
errorf("usage: %s module/path v1.2.3", verb)
@@ -654,6 +680,22 @@ func (f *WorkFile) add(errs *ErrorList, line *Line, verb string, args []string,
f.Toolchain = &Toolchain{Syntax: line}
f.Toolchain.Name = args[0]
case "godebug":
if len(args) != 1 || strings.ContainsAny(args[0], "\"`',") {
errorf("usage: godebug key=value")
return
}
key, value, ok := strings.Cut(args[0], "=")
if !ok {
errorf("usage: godebug key=value")
return
}
f.Godebug = append(f.Godebug, &Godebug{
Key: key,
Value: value,
Syntax: line,
})
case "use":
if len(args) != 1 {
errorf("usage: %s local/dir", verb)
@@ -929,6 +971,15 @@ func (f *File) Format() ([]byte, error) {
// Cleanup cleans out all the cleared entries.
func (f *File) Cleanup() {
w := 0
for _, g := range f.Godebug {
if g.Key != "" {
f.Godebug[w] = g
w++
}
}
f.Godebug = f.Godebug[:w]
w = 0
for _, r := range f.Require {
if r.Mod.Path != "" {
f.Require[w] = r
@@ -1027,6 +1078,45 @@ func (f *File) AddToolchainStmt(name string) error {
return nil
}
// AddGodebug sets the first godebug line for key to value,
// preserving any existing comments for that line and removing all
// other godebug lines for key.
//
// If no line currently exists for key, AddGodebug adds a new line
// at the end of the last godebug block.
func (f *File) AddGodebug(key, value string) error {
need := true
for _, g := range f.Godebug {
if g.Key == key {
if need {
g.Value = value
f.Syntax.updateLine(g.Syntax, "godebug", key+"="+value)
need = false
} else {
g.Syntax.markRemoved()
*g = Godebug{}
}
}
}
if need {
f.addNewGodebug(key, value)
}
return nil
}
// addNewGodebug adds a new godebug key=value line at the end
// of the last godebug block, regardless of any existing godebug lines for key.
func (f *File) addNewGodebug(key, value string) {
line := f.Syntax.addLine(nil, "godebug", key+"="+value)
g := &Godebug{
Key: key,
Value: value,
Syntax: line,
}
f.Godebug = append(f.Godebug, g)
}
// AddRequire sets the first require line for path to version vers,
// preserving any existing comments for that line and removing all
// other lines for path.
@@ -1334,6 +1424,16 @@ func (f *File) SetRequireSeparateIndirect(req []*Require) {
f.SortBlocks()
}
func (f *File) DropGodebug(key string) error {
for _, g := range f.Godebug {
if g.Key == key {
g.Syntax.markRemoved()
*g = Godebug{}
}
}
return nil
}
func (f *File) DropRequire(path string) error {
for _, r := range f.Require {
if r.Mod.Path == path {

View File

@@ -14,6 +14,7 @@ import (
type WorkFile struct {
Go *Go
Toolchain *Toolchain
Godebug []*Godebug
Use []*Use
Replace []*Replace
@@ -68,7 +69,7 @@ func ParseWork(file string, data []byte, fix VersionFixer) (*WorkFile, error) {
Err: fmt.Errorf("unknown block type: %s", strings.Join(x.Token, " ")),
})
continue
case "use", "replace":
case "godebug", "use", "replace":
for _, l := range x.Line {
f.add(&errs, l, x.Token[0], l.Token, fix)
}
@@ -184,6 +185,55 @@ func (f *WorkFile) DropToolchainStmt() {
}
}
// AddGodebug sets the first godebug line for key to value,
// preserving any existing comments for that line and removing all
// other godebug lines for key.
//
// If no line currently exists for key, AddGodebug adds a new line
// at the end of the last godebug block.
func (f *WorkFile) AddGodebug(key, value string) error {
need := true
for _, g := range f.Godebug {
if g.Key == key {
if need {
g.Value = value
f.Syntax.updateLine(g.Syntax, "godebug", key+"="+value)
need = false
} else {
g.Syntax.markRemoved()
*g = Godebug{}
}
}
}
if need {
f.addNewGodebug(key, value)
}
return nil
}
// addNewGodebug adds a new godebug key=value line at the end
// of the last godebug block, regardless of any existing godebug lines for key.
func (f *WorkFile) addNewGodebug(key, value string) {
line := f.Syntax.addLine(nil, "godebug", key+"="+value)
g := &Godebug{
Key: key,
Value: value,
Syntax: line,
}
f.Godebug = append(f.Godebug, g)
}
func (f *WorkFile) DropGodebug(key string) error {
for _, g := range f.Godebug {
if g.Key == key {
g.Syntax.markRemoved()
*g = Godebug{}
}
}
return nil
}
func (f *WorkFile) AddUse(diskPath, modulePath string) error {
need := true
for _, d := range f.Use {

View File

@@ -506,6 +506,7 @@ var badWindowsNames = []string{
"PRN",
"AUX",
"NUL",
"COM0",
"COM1",
"COM2",
"COM3",
@@ -515,6 +516,7 @@ var badWindowsNames = []string{
"COM7",
"COM8",
"COM9",
"LPT0",
"LPT1",
"LPT2",
"LPT3",

View File

@@ -17,6 +17,7 @@ package http2 // import "golang.org/x/net/http2"
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
@@ -26,6 +27,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
@@ -210,12 +212,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error)
}
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{}
@@ -383,3 +379,14 @@ func validPseudoPath(v string) bool {
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type incomparable [0]func()
// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
// It's defined as an interface here to let us keep synctestGroup entirely test-only
// and not a part of non-test builds.
type synctestGroupInterface interface {
Join()
Now() time.Time
NewTimer(d time.Duration) timer
AfterFunc(d time.Duration, f func()) timer
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
}

View File

@@ -154,6 +154,39 @@ type Server struct {
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *serverInternalState
// Synchronization group used for testing.
// Outside of tests, this is nil.
group synctestGroupInterface
}
func (s *Server) markNewGoroutine() {
if s.group != nil {
s.group.Join()
}
}
func (s *Server) now() time.Time {
if s.group != nil {
return s.group.Now()
}
return time.Now()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (s *Server) newTimer(d time.Duration) timer {
if s.group != nil {
return s.group.NewTimer(d)
}
return timeTimer{time.NewTimer(d)}
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (s *Server) afterFunc(d time.Duration, f func()) timer {
if s.group != nil {
return s.group.AfterFunc(d, f)
}
return timeTimer{time.AfterFunc(d, f)}
}
func (s *Server) initialConnRecvWindowSize() int32 {
@@ -400,6 +433,10 @@ func (o *ServeConnOpts) handler() http.Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
s.serveConn(c, opts, nil)
}
func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel()
@@ -426,6 +463,9 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
if newf != nil {
newf(sc)
}
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
@@ -599,8 +639,8 @@ type serverConn struct {
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
goAwayCode ErrCode
shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused
shutdownTimer timer // nil until used
idleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -649,12 +689,12 @@ type stream struct {
flow outflow // limits writing from Handler to client
inflow inflow // what the client is allowed to POST/etc to us
state streamState
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline timer // nil if unused
writeDeadline timer // nil if unused
closeErr error // set before cw is closed
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
@@ -811,8 +851,9 @@ type readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
gate := make(gate)
gateDone := gate.Done
sc.srv.markNewGoroutine()
gate := make(chan struct{})
gateDone := func() { gate <- struct{}{} }
for {
f, err := sc.framer.ReadFrame()
select {
@@ -843,6 +884,7 @@ type frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
sc.srv.markNewGoroutine()
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
@@ -922,13 +964,13 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
loopNum := 0
@@ -1057,10 +1099,10 @@ func (sc *serverConn) readPreface() error {
errc <- nil
}
}()
timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server?
timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
case <-timer.C:
case <-timer.C():
return errPrefaceTimeout
case err := <-errc:
if err == nil {
@@ -1425,7 +1467,7 @@ func (sc *serverConn) goAway(code ErrCode) {
func (sc *serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
}
func (sc *serverConn) resetStream(se StreamError) {
@@ -1639,7 +1681,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@@ -1661,6 +1703,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
}
}
st.closeErr = err
st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id)
}
@@ -2021,7 +2064,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -2119,7 +2162,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
@@ -2343,6 +2386,7 @@ func (sc *serverConn) handlerDone() {
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
sc.srv.markNewGoroutine()
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true
defer func() {
@@ -2639,7 +2683,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
date = time.Now().UTC().Format(http.TimeFormat)
date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
@@ -2761,7 +2805,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
@@ -2777,9 +2821,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(time.Now()))
st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil
@@ -2787,7 +2831,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
@@ -2803,9 +2847,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(time.Now()))
st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil

View File

@@ -1,331 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

20
vendor/golang.org/x/net/http2/timer.go generated vendored Normal file
View File

@@ -0,0 +1,20 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import "time"
// A timer is a time.Timer, as an interface which can be replaced in tests.
type timer = interface {
C() <-chan time.Time
Reset(d time.Duration) bool
Stop() bool
}
// timeTimer adapts a time.Timer to the timer interface.
type timeTimer struct {
*time.Timer
}
func (t timeTimer) C() <-chan time.Time { return t.Timer.C }

View File

@@ -185,7 +185,45 @@ type Transport struct {
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
syncHooks *testSyncHooks
*transportTestHooks
}
// Hook points used for testing.
// Outside of tests, t.transportTestHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
type transportTestHooks struct {
newclientconn func(*ClientConn)
group synctestGroupInterface
}
func (t *Transport) markNewGoroutine() {
if t != nil && t.transportTestHooks != nil {
t.transportTestHooks.group.Join()
}
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (t *Transport) newTimer(d time.Duration) timer {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.NewTimer(d)
}
return timeTimer{time.NewTimer(d)}
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (t *Transport) afterFunc(d time.Duration, f func()) timer {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.AfterFunc(d, f)
}
return timeTimer{time.AfterFunc(d, f)}
}
func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.ContextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
func (t *Transport) maxHeaderListSize() uint32 {
@@ -352,60 +390,6 @@ type ClientConn struct {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
syncHooks *testSyncHooks // can be nil
}
// Hook points used for testing.
// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
// goRun starts a new goroutine.
func (cc *ClientConn) goRun(f func()) {
if cc.syncHooks != nil {
cc.syncHooks.goRun(f)
return
}
go f()
}
// condBroadcast is cc.cond.Broadcast.
func (cc *ClientConn) condBroadcast() {
if cc.syncHooks != nil {
cc.syncHooks.condBroadcast(cc.cond)
}
cc.cond.Broadcast()
}
// condWait is cc.cond.Wait.
func (cc *ClientConn) condWait() {
if cc.syncHooks != nil {
cc.syncHooks.condWait(cc.cond)
}
cc.cond.Wait()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (cc *ClientConn) newTimer(d time.Duration) timer {
if cc.syncHooks != nil {
return cc.syncHooks.newTimer(d)
}
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
@@ -487,7 +471,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.condBroadcast()
cs.cc.cond.Broadcast()
}
}
@@ -497,7 +481,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
cc.condBroadcast()
cc.cond.Broadcast()
}
}
@@ -507,10 +491,11 @@ func (cs *clientStream) closeReqBodyLocked() {
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
cs.cc.goRun(func() {
go func() {
cs.cc.t.markNewGoroutine()
cs.reqBody.Close()
close(reqBodyClosed)
})
}()
}
type stickyErrWriter struct {
@@ -626,21 +611,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
var tm timer
if t.syncHooks != nil {
tm = t.syncHooks.newTimer(d)
t.syncHooks.blockUntil(func() bool {
select {
case <-tm.C():
case <-req.Context().Done():
default:
return false
}
return true
})
} else {
tm = newTimeTimer(d)
}
tm := t.newTimer(d)
select {
case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
@@ -725,8 +696,8 @@ func canRetryError(err error) bool {
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.syncHooks != nil {
return t.newClientConn(nil, singleUse, t.syncHooks)
if t.transportTestHooks != nil {
return t.newClientConn(nil, singleUse)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
@@ -736,7 +707,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil {
return nil, err
}
return t.newClientConn(tconn, singleUse, nil)
return t.newClientConn(tconn, singleUse)
}
func (t *Transport) newTLSConfig(host string) *tls.Config {
@@ -802,10 +773,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives(), nil)
return t.newClientConn(c, t.disableKeepAlives())
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
cc := &ClientConn{
t: t,
tconn: c,
@@ -820,16 +791,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
syncHooks: hooks,
}
if hooks != nil {
hooks.newclientconn(cc)
if t.transportTestHooks != nil {
t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc)
c = cc.tconn
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
}
@@ -893,7 +860,13 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo
return nil, cc.werr
}
cc.goRun(cc.readLoop)
// Start the idle timer after the connection is fully initialized.
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout)
}
go cc.readLoop()
return cc, nil
}
@@ -901,7 +874,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -1144,7 +1117,8 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
cc.goRun(func() {
go func() {
cc.t.markNewGoroutine()
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -1156,9 +1130,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled {
break
}
cc.condWait()
cc.cond.Wait()
}
})
}()
shutdownEnterWaitStateHook()
select {
case <-done:
@@ -1168,7 +1142,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
@@ -1206,7 +1180,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
cc.closeConn()
}
@@ -1321,23 +1295,30 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
cc.goRun(func() {
cs.doRequest(req)
})
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
!cs.isHead {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// http://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
cs.requestedGzip = true
}
go cs.doRequest(req, streamf)
waitDone := func() error {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.donec:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.donec:
return nil
@@ -1398,24 +1379,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
return err
}
if streamf != nil {
streamf(cs)
}
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
@@ -1445,8 +1409,9 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
// doRequest runs for the duration of the request lifetime.
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
func (cs *clientStream) doRequest(req *http.Request) {
err := cs.writeRequest(req)
func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) {
cs.cc.t.markNewGoroutine()
err := cs.writeRequest(req, streamf)
cs.cleanupWriteRequest(err)
}
@@ -1457,7 +1422,7 @@ func (cs *clientStream) doRequest(req *http.Request) {
//
// It returns non-nil if the request ends otherwise.
// If the returned error is StreamError, the error Code may be used in resetting the stream.
func (cs *clientStream) writeRequest(req *http.Request) (err error) {
func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStream)) (err error) {
cc := cs.cc
ctx := cs.ctx
@@ -1471,21 +1436,6 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1510,28 +1460,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
!cs.isHead {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// http://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
cs.requestedGzip = true
if streamf != nil {
streamf(cs)
}
continueTimeout := cc.t.expectContinueTimeout()
@@ -1594,7 +1524,7 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
timer := cc.newTimer(d)
timer := cc.t.newTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
@@ -1603,21 +1533,6 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.peerClosed:
case <-respHeaderTimer:
case <-respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.peerClosed:
return nil
@@ -1766,7 +1681,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil
}
cc.pendingRequests++
cc.condWait()
cc.cond.Wait()
cc.pendingRequests--
select {
case <-cs.abort:
@@ -2028,7 +1943,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take)
return take, nil
}
cc.condWait()
cc.cond.Wait()
}
}
@@ -2311,7 +2226,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
cc.condBroadcast()
cc.cond.Broadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@@ -2333,6 +2248,7 @@ type clientConnReadLoop struct {
// readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *ClientConn) readLoop() {
cc.t.markNewGoroutine()
rl := &clientConnReadLoop{cc: cc}
defer rl.cleanup()
cc.readerErr = rl.run()
@@ -2399,7 +2315,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err)
}
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
}
@@ -2436,7 +2352,7 @@ func (rl *clientConnReadLoop) run() error {
readIdleTimeout := cc.t.ReadIdleTimeout
var t timer
if readIdleTimeout != 0 {
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -3034,7 +2950,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams {
cs.flow.add(delta)
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.initialWindowSize = s.Val
case SettingHeaderTableSize:
@@ -3089,7 +3005,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
return ConnectionError(ErrCodeFlowControl)
}
cc.condBroadcast()
cc.cond.Broadcast()
return nil
}
@@ -3133,7 +3049,8 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
var pingError error
errc := make(chan struct{})
cc.goRun(func() {
go func() {
cc.t.markNewGoroutine()
cc.wmu.Lock()
defer cc.wmu.Unlock()
if pingError = cc.fr.WritePing(false, p); pingError != nil {
@@ -3144,20 +3061,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
close(errc)
return
}
})
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
}()
select {
case <-c:
return nil

View File

@@ -443,8 +443,8 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
}
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) {
for k := n.kids; k != nil; k = k.next {
k.setParent(n.parent)
for n.kids != nil {
n.kids.setParent(n.parent)
}
n.setParent(nil)
delete(ws.nodes, n.id)

View File

@@ -8,7 +8,6 @@ import (
"bytes"
"encoding/binary"
"io"
"io/ioutil"
"math/rand"
"net"
"runtime"
@@ -173,7 +172,7 @@ func testRacyRead(t *testing.T, c1, c2 net.Conn) {
// testRacyWrite tests that it is safe to mutate the input Write buffer
// immediately after cancelation has occurred.
func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
go chunkedCopy(ioutil.Discard, c2)
go chunkedCopy(io.Discard, c2)
var wg sync.WaitGroup
defer wg.Wait()
@@ -200,7 +199,7 @@ func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
// testReadTimeout tests that Read timeouts do not affect Write.
func testReadTimeout(t *testing.T, c1, c2 net.Conn) {
go chunkedCopy(ioutil.Discard, c2)
go chunkedCopy(io.Discard, c2)
c1.SetReadDeadline(aLongTimeAgo)
_, err := c1.Read(make([]byte, 1024))

View File

@@ -8,7 +8,6 @@ package nettest
import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"os/exec"
@@ -226,7 +225,7 @@ func LocalPath() (string, error) {
if runtime.GOOS == "darwin" {
dir = "/tmp"
}
f, err := ioutil.TempFile(dir, "go-nettest")
f, err := os.CreateTemp(dir, "go-nettest")
if err != nil {
return "", err
}

View File

@@ -137,9 +137,7 @@ func (p *PerHost) AddNetwork(net *net.IPNet) {
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
zone = strings.TrimSuffix(zone, ".")
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
@@ -148,8 +146,6 @@ func (p *PerHost) AddZone(zone string) {
// AddHost specifies a host name that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
host = strings.TrimSuffix(host, ".")
p.bypassHosts = append(p.bypassHosts, host)
}

View File

@@ -16,7 +16,6 @@ import (
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
@@ -279,7 +278,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
io.Copy(io.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
@@ -294,7 +293,7 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, er
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
io.Copy(io.Discard, frame)
if frame.PayloadType() == PingFrame {
if _, err := handler.WritePong(b[:n]); err != nil {
return nil, err

View File

@@ -17,7 +17,6 @@ import (
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
@@ -208,7 +207,7 @@ again:
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
io.Copy(io.Discard, trailer)
}
ws.frameReader = nil
goto again
@@ -330,7 +329,7 @@ func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
_, err = io.Copy(io.Discard, ws.frameReader)
if err != nil {
return err
}
@@ -362,7 +361,7 @@ again:
return ErrFrameTooLarge
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
data, err := io.ReadAll(frame)
if err != nil {
return err
}

View File

@@ -263,6 +263,7 @@ struct ltchars {
#include <linux/sched.h>
#include <linux/seccomp.h>
#include <linux/serial.h>
#include <linux/sock_diag.h>
#include <linux/sockios.h>
#include <linux/taskstats.h>
#include <linux/tipc.h>
@@ -549,6 +550,7 @@ ccflags="$@"
$2 !~ "NLA_TYPE_MASK" &&
$2 !~ /^RTC_VL_(ACCURACY|BACKUP|DATA)/ &&
$2 ~ /^(NETLINK|NLM|NLMSG|NLA|IFA|IFAN|RT|RTC|RTCF|RTN|RTPROT|RTNH|ARPHRD|ETH_P|NETNSA)_/ ||
$2 ~ /^SOCK_|SK_DIAG_|SKNLGRP_$/ ||
$2 ~ /^FIORDCHK$/ ||
$2 ~ /^SIOC/ ||
$2 ~ /^TIOC/ ||

View File

@@ -502,6 +502,7 @@ const (
BPF_IMM = 0x0
BPF_IND = 0x40
BPF_JA = 0x0
BPF_JCOND = 0xe0
BPF_JEQ = 0x10
BPF_JGE = 0x30
BPF_JGT = 0x20
@@ -657,6 +658,9 @@ const (
CAN_NPROTO = 0x8
CAN_RAW = 0x1
CAN_RAW_FILTER_MAX = 0x200
CAN_RAW_XL_VCID_RX_FILTER = 0x4
CAN_RAW_XL_VCID_TX_PASS = 0x2
CAN_RAW_XL_VCID_TX_SET = 0x1
CAN_RTR_FLAG = 0x40000000
CAN_SFF_ID_BITS = 0xb
CAN_SFF_MASK = 0x7ff
@@ -1339,6 +1343,7 @@ const (
F_OFD_SETLK = 0x25
F_OFD_SETLKW = 0x26
F_OK = 0x0
F_SEAL_EXEC = 0x20
F_SEAL_FUTURE_WRITE = 0x10
F_SEAL_GROW = 0x4
F_SEAL_SEAL = 0x1
@@ -1627,6 +1632,7 @@ const (
IP_FREEBIND = 0xf
IP_HDRINCL = 0x3
IP_IPSEC_POLICY = 0x10
IP_LOCAL_PORT_RANGE = 0x33
IP_MAXPACKET = 0xffff
IP_MAX_MEMBERSHIPS = 0x14
IP_MF = 0x2000
@@ -1653,6 +1659,7 @@ const (
IP_PMTUDISC_OMIT = 0x5
IP_PMTUDISC_PROBE = 0x3
IP_PMTUDISC_WANT = 0x1
IP_PROTOCOL = 0x34
IP_RECVERR = 0xb
IP_RECVERR_RFC4884 = 0x1a
IP_RECVFRAGSIZE = 0x19
@@ -2169,7 +2176,7 @@ const (
NFT_SECMARK_CTX_MAXLEN = 0x100
NFT_SET_MAXNAMELEN = 0x100
NFT_SOCKET_MAX = 0x3
NFT_TABLE_F_MASK = 0x3
NFT_TABLE_F_MASK = 0x7
NFT_TABLE_MAXNAMELEN = 0x100
NFT_TRACETYPE_MAX = 0x3
NFT_TUNNEL_F_MASK = 0x7
@@ -2403,6 +2410,7 @@ const (
PERF_RECORD_MISC_USER = 0x2
PERF_SAMPLE_BRANCH_PLM_ALL = 0x7
PERF_SAMPLE_WEIGHT_TYPE = 0x1004000
PID_FS_MAGIC = 0x50494446
PIPEFS_MAGIC = 0x50495045
PPPIOCGNPMODE = 0xc008744c
PPPIOCNEWUNIT = 0xc004743e
@@ -2896,8 +2904,9 @@ const (
RWF_APPEND = 0x10
RWF_DSYNC = 0x2
RWF_HIPRI = 0x1
RWF_NOAPPEND = 0x20
RWF_NOWAIT = 0x8
RWF_SUPPORTED = 0x1f
RWF_SUPPORTED = 0x3f
RWF_SYNC = 0x4
RWF_WRITE_LIFE_NOT_SET = 0x0
SCHED_BATCH = 0x3
@@ -2918,7 +2927,9 @@ const (
SCHED_RESET_ON_FORK = 0x40000000
SCHED_RR = 0x2
SCM_CREDENTIALS = 0x2
SCM_PIDFD = 0x4
SCM_RIGHTS = 0x1
SCM_SECURITY = 0x3
SCM_TIMESTAMP = 0x1d
SC_LOG_FLUSH = 0x100000
SECCOMP_ADDFD_FLAG_SEND = 0x2
@@ -3051,6 +3062,8 @@ const (
SIOCSMIIREG = 0x8949
SIOCSRARP = 0x8962
SIOCWANDEV = 0x894a
SK_DIAG_BPF_STORAGE_MAX = 0x3
SK_DIAG_BPF_STORAGE_REQ_MAX = 0x1
SMACK_MAGIC = 0x43415d53
SMART_AUTOSAVE = 0xd2
SMART_AUTO_OFFLINE = 0xdb
@@ -3071,6 +3084,8 @@ const (
SOCKFS_MAGIC = 0x534f434b
SOCK_BUF_LOCK_MASK = 0x3
SOCK_DCCP = 0x6
SOCK_DESTROY = 0x15
SOCK_DIAG_BY_FAMILY = 0x14
SOCK_IOC_TYPE = 0x89
SOCK_PACKET = 0xa
SOCK_RAW = 0x3
@@ -3260,6 +3275,7 @@ const (
TCP_MAX_WINSHIFT = 0xe
TCP_MD5SIG = 0xe
TCP_MD5SIG_EXT = 0x20
TCP_MD5SIG_FLAG_IFINDEX = 0x2
TCP_MD5SIG_FLAG_PREFIX = 0x1
TCP_MD5SIG_MAXKEYLEN = 0x50
TCP_MSS = 0x200

View File

@@ -118,6 +118,7 @@ const (
IXOFF = 0x1000
IXON = 0x400
MAP_32BIT = 0x40
MAP_ABOVE4G = 0x80
MAP_ANON = 0x20
MAP_ANONYMOUS = 0x20
MAP_DENYWRITE = 0x800

View File

@@ -118,6 +118,7 @@ const (
IXOFF = 0x1000
IXON = 0x400
MAP_32BIT = 0x40
MAP_ABOVE4G = 0x80
MAP_ANON = 0x20
MAP_ANONYMOUS = 0x20
MAP_DENYWRITE = 0x800

View File

@@ -87,6 +87,7 @@ const (
FICLONE = 0x40049409
FICLONERANGE = 0x4020940d
FLUSHO = 0x1000
FPMR_MAGIC = 0x46504d52
FPSIMD_MAGIC = 0x46508001
FS_IOC_ENABLE_VERITY = 0x40806685
FS_IOC_GETFLAGS = 0x80086601

View File

@@ -4605,7 +4605,7 @@ const (
NL80211_ATTR_MAC_HINT = 0xc8
NL80211_ATTR_MAC_MASK = 0xd7
NL80211_ATTR_MAX_AP_ASSOC_STA = 0xca
NL80211_ATTR_MAX = 0x149
NL80211_ATTR_MAX = 0x14a
NL80211_ATTR_MAX_CRIT_PROT_DURATION = 0xb4
NL80211_ATTR_MAX_CSA_COUNTERS = 0xce
NL80211_ATTR_MAX_MATCH_SETS = 0x85
@@ -5209,7 +5209,7 @@ const (
NL80211_FREQUENCY_ATTR_GO_CONCURRENT = 0xf
NL80211_FREQUENCY_ATTR_INDOOR_ONLY = 0xe
NL80211_FREQUENCY_ATTR_IR_CONCURRENT = 0xf
NL80211_FREQUENCY_ATTR_MAX = 0x1f
NL80211_FREQUENCY_ATTR_MAX = 0x20
NL80211_FREQUENCY_ATTR_MAX_TX_POWER = 0x6
NL80211_FREQUENCY_ATTR_NO_10MHZ = 0x11
NL80211_FREQUENCY_ATTR_NO_160MHZ = 0xc
@@ -5703,7 +5703,7 @@ const (
NL80211_STA_FLAG_ASSOCIATED = 0x7
NL80211_STA_FLAG_AUTHENTICATED = 0x5
NL80211_STA_FLAG_AUTHORIZED = 0x1
NL80211_STA_FLAG_MAX = 0x7
NL80211_STA_FLAG_MAX = 0x8
NL80211_STA_FLAG_MAX_OLD_API = 0x6
NL80211_STA_FLAG_MFP = 0x4
NL80211_STA_FLAG_SHORT_PREAMBLE = 0x2
@@ -6001,3 +6001,34 @@ type CachestatRange struct {
Off uint64
Len uint64
}
const (
SK_MEMINFO_RMEM_ALLOC = 0x0
SK_MEMINFO_RCVBUF = 0x1
SK_MEMINFO_WMEM_ALLOC = 0x2
SK_MEMINFO_SNDBUF = 0x3
SK_MEMINFO_FWD_ALLOC = 0x4
SK_MEMINFO_WMEM_QUEUED = 0x5
SK_MEMINFO_OPTMEM = 0x6
SK_MEMINFO_BACKLOG = 0x7
SK_MEMINFO_DROPS = 0x8
SK_MEMINFO_VARS = 0x9
SKNLGRP_NONE = 0x0
SKNLGRP_INET_TCP_DESTROY = 0x1
SKNLGRP_INET_UDP_DESTROY = 0x2
SKNLGRP_INET6_TCP_DESTROY = 0x3
SKNLGRP_INET6_UDP_DESTROY = 0x4
SK_DIAG_BPF_STORAGE_REQ_NONE = 0x0
SK_DIAG_BPF_STORAGE_REQ_MAP_FD = 0x1
SK_DIAG_BPF_STORAGE_REP_NONE = 0x0
SK_DIAG_BPF_STORAGE = 0x1
SK_DIAG_BPF_STORAGE_NONE = 0x0
SK_DIAG_BPF_STORAGE_PAD = 0x1
SK_DIAG_BPF_STORAGE_MAP_ID = 0x2
SK_DIAG_BPF_STORAGE_MAP_VALUE = 0x3
)
type SockDiagReq struct {
Family uint8
Protocol uint8
}

View File

@@ -68,6 +68,7 @@ type UserInfo10 struct {
//sys NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo
//sys NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) = netapi32.NetGetJoinInformation
//sys NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree
//sys NetUserEnum(serverName *uint16, level uint32, filter uint32, buf **byte, prefMaxLen uint32, entriesRead *uint32, totalEntries *uint32, resumeHandle *uint32) (neterr error) = netapi32.NetUserEnum
const (
// do not reorder

View File

@@ -401,6 +401,7 @@ var (
procTransmitFile = modmswsock.NewProc("TransmitFile")
procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree")
procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation")
procNetUserEnum = modnetapi32.NewProc("NetUserEnum")
procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo")
procNtCreateFile = modntdll.NewProc("NtCreateFile")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
@@ -3486,6 +3487,14 @@ func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (nete
return
}
func NetUserEnum(serverName *uint16, level uint32, filter uint32, buf **byte, prefMaxLen uint32, entriesRead *uint32, totalEntries *uint32, resumeHandle *uint32) (neterr error) {
r0, _, _ := syscall.Syscall9(procNetUserEnum.Addr(), 8, uintptr(unsafe.Pointer(serverName)), uintptr(level), uintptr(filter), uintptr(unsafe.Pointer(buf)), uintptr(prefMaxLen), uintptr(unsafe.Pointer(entriesRead)), uintptr(unsafe.Pointer(totalEntries)), uintptr(unsafe.Pointer(resumeHandle)), 0)
if r0 != 0 {
neterr = syscall.Errno(r0)
}
return
}
func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) {
r0, _, _ := syscall.Syscall6(procNetUserGetInfo.Addr(), 4, uintptr(unsafe.Pointer(serverName)), uintptr(unsafe.Pointer(userName)), uintptr(level), uintptr(unsafe.Pointer(buf)), 0, 0)
if r0 != 0 {

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/tools/internal/gocommand"
)
// TODO(adonovan): move back into go/packages.
func GetSizesForArgsGolist(ctx context.Context, inv gocommand.Invocation, gocmdRunner *gocommand.Runner) (string, string, error) {
inv.Verb = "list"
inv.Args = []string{"-f", "{{context.GOARCH}} {{context.Compiler}}", "--", "unsafe"}

View File

@@ -198,14 +198,6 @@ Instead, ssadump no longer requests the runtime package,
but seeks it among the dependencies of the user-specified packages,
and emits an error if it is not found.
Overlays: The Overlay field in the Config allows providing alternate contents
for Go source files, by providing a mapping from file path to contents.
go/packages will pull in new imports added in overlay files when go/packages
is run in LoadImports mode or greater.
Overlay support for the go list driver isn't complete yet: if the file doesn't
exist on disk, it will only be recognized in an overlay if it is a non-test file
and the package would be reported even without the overlay.
Questions & Tasks
- Add GOARCH/GOOS?

View File

@@ -34,8 +34,8 @@ type DriverRequest struct {
// Tests specifies whether the patterns should also return test packages.
Tests bool `json:"tests"`
// Overlay maps file paths (relative to the driver's working directory) to the byte contents
// of overlay files.
// Overlay maps file paths (relative to the driver's working directory)
// to the contents of overlay files (see Config.Overlay).
Overlay map[string][]byte `json:"overlay"`
}
@@ -119,7 +119,19 @@ func findExternalDriver(cfg *Config) driver {
stderr := new(bytes.Buffer)
cmd := exec.CommandContext(cfg.Context, tool, words...)
cmd.Dir = cfg.Dir
cmd.Env = cfg.Env
// The cwd gets resolved to the real path. On Darwin, where
// /tmp is a symlink, this breaks anything that expects the
// working directory to keep the original path, including the
// go command when dealing with modules.
//
// os.Getwd stdlib has a special feature where if the
// cwd and the PWD are the same node then it trusts
// the PWD, so by setting it in the env for the child
// process we fix up all the paths returned by the go
// command.
//
// (See similar trick in Invocation.run in ../../internal/gocommand/invoke.go)
cmd.Env = append(slicesClip(cfg.Env), "PWD="+cfg.Dir)
cmd.Stdin = bytes.NewReader(req)
cmd.Stdout = buf
cmd.Stderr = stderr
@@ -138,3 +150,7 @@ func findExternalDriver(cfg *Config) driver {
return &response, nil
}
}
// slicesClip removes unused capacity from the slice, returning s[:len(s):len(s)].
// TODO(adonovan): use go1.21 slices.Clip.
func slicesClip[S ~[]E, E any](s S) S { return s[:len(s):len(s)] }

View File

@@ -841,6 +841,7 @@ func (state *golistState) cfgInvocation() gocommand.Invocation {
Env: cfg.Env,
Logf: cfg.Logf,
WorkingDir: cfg.Dir,
Overlay: cfg.goListOverlayFile,
}
}
@@ -849,26 +850,6 @@ func (state *golistState) invokeGo(verb string, args ...string) (*bytes.Buffer,
cfg := state.cfg
inv := state.cfgInvocation()
// For Go versions 1.16 and above, `go list` accepts overlays directly via
// the -overlay flag. Set it, if it's available.
//
// The check for "list" is not necessarily required, but we should avoid
// getting the go version if possible.
if verb == "list" {
goVersion, err := state.getGoVersion()
if err != nil {
return nil, err
}
if goVersion >= 16 {
filename, cleanup, err := state.writeOverlays()
if err != nil {
return nil, err
}
defer cleanup()
inv.Overlay = filename
}
}
inv.Verb = verb
inv.Args = args
gocmdRunner := cfg.gocmdRunner
@@ -1015,67 +996,6 @@ func (state *golistState) invokeGo(verb string, args ...string) (*bytes.Buffer,
return stdout, nil
}
// OverlayJSON is the format overlay files are expected to be in.
// The Replace map maps from overlaid paths to replacement paths:
// the Go command will forward all reads trying to open
// each overlaid path to its replacement path, or consider the overlaid
// path not to exist if the replacement path is empty.
//
// From golang/go#39958.
type OverlayJSON struct {
Replace map[string]string `json:"replace,omitempty"`
}
// writeOverlays writes out files for go list's -overlay flag, as described
// above.
func (state *golistState) writeOverlays() (filename string, cleanup func(), err error) {
// Do nothing if there are no overlays in the config.
if len(state.cfg.Overlay) == 0 {
return "", func() {}, nil
}
dir, err := os.MkdirTemp("", "gopackages-*")
if err != nil {
return "", nil, err
}
// The caller must clean up this directory, unless this function returns an
// error.
cleanup = func() {
os.RemoveAll(dir)
}
defer func() {
if err != nil {
cleanup()
}
}()
overlays := map[string]string{}
for k, v := range state.cfg.Overlay {
// Create a unique filename for the overlaid files, to avoid
// creating nested directories.
noSeparator := strings.Join(strings.Split(filepath.ToSlash(k), "/"), "")
f, err := os.CreateTemp(dir, fmt.Sprintf("*-%s", noSeparator))
if err != nil {
return "", func() {}, err
}
if _, err := f.Write(v); err != nil {
return "", func() {}, err
}
if err := f.Close(); err != nil {
return "", func() {}, err
}
overlays[k] = f.Name()
}
b, err := json.Marshal(OverlayJSON{Replace: overlays})
if err != nil {
return "", func() {}, err
}
// Write out the overlay file that contains the filepath mappings.
filename = filepath.Join(dir, "overlay.json")
if err := os.WriteFile(filename, b, 0665); err != nil {
return "", func() {}, err
}
return filename, cleanup, nil
}
func containsGoFile(s []string) bool {
for _, f := range s {
if strings.HasSuffix(f, ".go") {

View File

@@ -37,10 +37,20 @@ import (
// A LoadMode controls the amount of detail to return when loading.
// The bits below can be combined to specify which fields should be
// filled in the result packages.
//
// The zero value is a special case, equivalent to combining
// the NeedName, NeedFiles, and NeedCompiledGoFiles bits.
//
// ID and Errors (if present) will always be filled.
// Load may return more information than requested.
// [Load] may return more information than requested.
//
// Unfortunately there are a number of open bugs related to
// interactions among the LoadMode bits:
// - https://github.com/golang/go/issues/48226
// - https://github.com/golang/go/issues/56633
// - https://github.com/golang/go/issues/56677
// - https://github.com/golang/go/issues/58726
// - https://github.com/golang/go/issues/63517
type LoadMode int
const (
@@ -123,7 +133,14 @@ const (
// A Config specifies details about how packages should be loaded.
// The zero value is a valid configuration.
//
// Calls to Load do not modify this struct.
//
// TODO(adonovan): #67702: this is currently false: in fact,
// calls to [Load] do not modify the public fields of this struct, but
// may modify hidden fields, so concurrent calls to [Load] must not
// use the same Config. But perhaps we should reestablish the
// documented invariant.
type Config struct {
// Mode controls the level of information returned for each package.
Mode LoadMode
@@ -199,13 +216,23 @@ type Config struct {
// setting Tests may have no effect.
Tests bool
// Overlay provides a mapping of absolute file paths to file contents.
// If the file with the given path already exists, the parser will use the
// alternative file contents provided by the map.
// Overlay is a mapping from absolute file paths to file contents.
//
// Overlays provide incomplete support for when a given file doesn't
// already exist on disk. See the package doc above for more details.
// For each map entry, [Load] uses the alternative file
// contents provided by the overlay mapping instead of reading
// from the file system. This mechanism can be used to enable
// editor-integrated tools to correctly analyze the contents
// of modified but unsaved buffers, for example.
//
// The overlay mapping is passed to the build system's driver
// (see "The driver protocol") so that it too can report
// consistent package metadata about unsaved files. However,
// drivers may vary in their level of support for overlays.
Overlay map[string][]byte
// goListOverlayFile is the JSON file that encodes the Overlay
// mapping, used by 'go list -overlay=...'
goListOverlayFile string
}
// Load loads and returns the Go packages named by the given patterns.
@@ -213,6 +240,20 @@ type Config struct {
// Config specifies loading options;
// nil behaves the same as an empty Config.
//
// The [Config.Mode] field is a set of bits that determine what kinds
// of information should be computed and returned. Modes that require
// more information tend to be slower. See [LoadMode] for details
// and important caveats. Its zero value is equivalent to
// NeedName | NeedFiles | NeedCompiledGoFiles.
//
// Each call to Load returns a new set of [Package] instances.
// The Packages and their Imports form a directed acyclic graph.
//
// If the [NeedTypes] mode flag was set, each call to Load uses a new
// [types.Importer], so [types.Object] and [types.Type] values from
// different calls to Load must not be mixed as they will have
// inconsistent notions of type identity.
//
// If any of the patterns was invalid as defined by the
// underlying build system, Load returns an error.
// It may return an empty list of packages without an error,
@@ -286,6 +327,17 @@ func defaultDriver(cfg *Config, patterns ...string) (*DriverResponse, bool, erro
// (fall through)
}
// go list fallback
//
// Write overlays once, as there are many calls
// to 'go list' (one per chunk plus others too).
overlay, cleanupOverlay, err := gocommand.WriteOverlays(cfg.Overlay)
if err != nil {
return nil, false, err
}
defer cleanupOverlay()
cfg.goListOverlayFile = overlay
response, err := callDriverOnChunks(goListDriver, cfg, chunks)
if err != nil {
return nil, false, err
@@ -365,6 +417,9 @@ func mergeResponses(responses ...*DriverResponse) *DriverResponse {
}
// A Package describes a loaded Go package.
//
// It also defines part of the JSON schema of [DriverResponse].
// See the package documentation for an overview.
type Package struct {
// ID is a unique identifier for a package,
// in a syntax provided by the underlying build system.
@@ -423,6 +478,13 @@ type Package struct {
// to corresponding loaded Packages.
Imports map[string]*Package
// Module is the module information for the package if it exists.
//
// Note: it may be missing for std and cmd; see Go issue #65816.
Module *Module
// -- The following fields are not part of the driver JSON schema. --
// Types provides type information for the package.
// The NeedTypes LoadMode bit sets this field for packages matching the
// patterns; type information for dependencies may be missing or incomplete,
@@ -431,15 +493,15 @@ type Package struct {
// Each call to [Load] returns a consistent set of type
// symbols, as defined by the comment at [types.Identical].
// Avoid mixing type information from two or more calls to [Load].
Types *types.Package
Types *types.Package `json:"-"`
// Fset provides position information for Types, TypesInfo, and Syntax.
// It is set only when Types is set.
Fset *token.FileSet
Fset *token.FileSet `json:"-"`
// IllTyped indicates whether the package or any dependency contains errors.
// It is set only when Types is set.
IllTyped bool
IllTyped bool `json:"-"`
// Syntax is the package's syntax trees, for the files listed in CompiledGoFiles.
//
@@ -449,26 +511,28 @@ type Package struct {
//
// Syntax is kept in the same order as CompiledGoFiles, with the caveat that nils are
// removed. If parsing returned nil, Syntax may be shorter than CompiledGoFiles.
Syntax []*ast.File
Syntax []*ast.File `json:"-"`
// TypesInfo provides type information about the package's syntax trees.
// It is set only when Syntax is set.
TypesInfo *types.Info
TypesInfo *types.Info `json:"-"`
// TypesSizes provides the effective size function for types in TypesInfo.
TypesSizes types.Sizes
TypesSizes types.Sizes `json:"-"`
// -- internal --
// forTest is the package under test, if any.
forTest string
// depsErrors is the DepsErrors field from the go list response, if any.
depsErrors []*packagesinternal.PackageError
// module is the module information for the package if it exists.
Module *Module
}
// Module provides module information for a package.
//
// It also defines part of the JSON schema of [DriverResponse].
// See the package documentation for an overview.
type Module struct {
Path string // module path
Version string // module version
@@ -601,6 +665,7 @@ func (p *Package) UnmarshalJSON(b []byte) error {
OtherFiles: flat.OtherFiles,
EmbedFiles: flat.EmbedFiles,
EmbedPatterns: flat.EmbedPatterns,
IgnoredFiles: flat.IgnoredFiles,
ExportFile: flat.ExportFile,
}
if len(flat.Imports) > 0 {

View File

@@ -8,12 +8,14 @@ package gocommand
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"path/filepath"
"reflect"
"regexp"
"runtime"
@@ -167,7 +169,9 @@ type Invocation struct {
// TODO(rfindley): remove, in favor of Args.
ModFile string
// If Overlay is set, the go command is invoked with -overlay=Overlay.
// Overlay is the name of the JSON overlay file that describes
// unsaved editor buffers; see [WriteOverlays].
// If set, the go command is invoked with -overlay=Overlay.
// TODO(rfindley): remove, in favor of Args.
Overlay string
@@ -255,12 +259,15 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error {
waitDelay.Set(reflect.ValueOf(30 * time.Second))
}
// On darwin the cwd gets resolved to the real path, which breaks anything that
// expects the working directory to keep the original path, including the
// The cwd gets resolved to the real path. On Darwin, where
// /tmp is a symlink, this breaks anything that expects the
// working directory to keep the original path, including the
// go command when dealing with modules.
// The Go stdlib has a special feature where if the cwd and the PWD are the
// same node then it trusts the PWD, so by setting it in the env for the child
// process we fix up all the paths returned by the go command.
//
// os.Getwd has a special feature where if the cwd and the PWD
// are the same node then it trusts the PWD, so by setting it
// in the env for the child process we fix up all the paths
// returned by the go command.
if !i.CleanEnv {
cmd.Env = os.Environ()
}
@@ -351,6 +358,7 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) (err error) {
}
}
startTime := time.Now()
err = cmd.Start()
if stdoutW != nil {
// The child process has inherited the pipe file,
@@ -377,7 +385,7 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) (err error) {
case err := <-resChan:
return err
case <-timer.C:
HandleHangingGoCommand(cmd.Process)
HandleHangingGoCommand(startTime, cmd)
case <-ctx.Done():
}
} else {
@@ -411,7 +419,7 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) (err error) {
return <-resChan
}
func HandleHangingGoCommand(proc *os.Process) {
func HandleHangingGoCommand(start time.Time, cmd *exec.Cmd) {
switch runtime.GOOS {
case "linux", "darwin", "freebsd", "netbsd":
fmt.Fprintln(os.Stderr, `DETECTED A HANGING GO COMMAND
@@ -444,7 +452,7 @@ See golang/go#54461 for more details.`)
panic(fmt.Sprintf("running %s: %v", listFiles, err))
}
}
panic(fmt.Sprintf("detected hanging go command (pid %d): see golang/go#54461 for more details", proc.Pid))
panic(fmt.Sprintf("detected hanging go command (golang/go#54461); waited %s\n\tcommand:%s\n\tpid:%d", time.Since(start), cmd, cmd.Process.Pid))
}
func cmdDebugStr(cmd *exec.Cmd) string {
@@ -468,3 +476,73 @@ func cmdDebugStr(cmd *exec.Cmd) string {
}
return fmt.Sprintf("GOROOT=%v GOPATH=%v GO111MODULE=%v GOPROXY=%v PWD=%v %v", env["GOROOT"], env["GOPATH"], env["GO111MODULE"], env["GOPROXY"], env["PWD"], strings.Join(args, " "))
}
// WriteOverlays writes each value in the overlay (see the Overlay
// field of go/packages.Config) to a temporary file and returns the name
// of a JSON file describing the mapping that is suitable for the "go
// list -overlay" flag.
//
// On success, the caller must call the cleanup function exactly once
// when the files are no longer needed.
func WriteOverlays(overlay map[string][]byte) (filename string, cleanup func(), err error) {
// Do nothing if there are no overlays in the config.
if len(overlay) == 0 {
return "", func() {}, nil
}
dir, err := os.MkdirTemp("", "gocommand-*")
if err != nil {
return "", nil, err
}
// The caller must clean up this directory,
// unless this function returns an error.
// (The cleanup operand of each return
// statement below is ignored.)
defer func() {
cleanup = func() {
os.RemoveAll(dir)
}
if err != nil {
cleanup()
cleanup = nil
}
}()
// Write each map entry to a temporary file.
overlays := make(map[string]string)
for k, v := range overlay {
// Use a unique basename for each file (001-foo.go),
// to avoid creating nested directories.
base := fmt.Sprintf("%d-%s.go", 1+len(overlays), filepath.Base(k))
filename := filepath.Join(dir, base)
err := os.WriteFile(filename, v, 0666)
if err != nil {
return "", nil, err
}
overlays[k] = filename
}
// Write the JSON overlay file that maps logical file names to temp files.
//
// OverlayJSON is the format overlay files are expected to be in.
// The Replace map maps from overlaid paths to replacement paths:
// the Go command will forward all reads trying to open
// each overlaid path to its replacement path, or consider the overlaid
// path not to exist if the replacement path is empty.
//
// From golang/go#39958.
type OverlayJSON struct {
Replace map[string]string `json:"replace,omitempty"`
}
b, err := json.Marshal(OverlayJSON{Replace: overlays})
if err != nil {
return "", nil, err
}
filename = filepath.Join(dir, "overlay.json")
if err := os.WriteFile(filename, b, 0666); err != nil {
return "", nil, err
}
return filename, nil, nil
}

View File

@@ -104,7 +104,10 @@ type packageInfo struct {
// parseOtherFiles parses all the Go files in srcDir except filename, including
// test files if filename looks like a test.
func parseOtherFiles(fset *token.FileSet, srcDir, filename string) []*ast.File {
//
// It returns an error only if ctx is cancelled. Files with parse errors are
// ignored.
func parseOtherFiles(ctx context.Context, fset *token.FileSet, srcDir, filename string) ([]*ast.File, error) {
// This could use go/packages but it doesn't buy much, and it fails
// with https://golang.org/issue/26296 in LoadFiles mode in some cases.
considerTests := strings.HasSuffix(filename, "_test.go")
@@ -112,11 +115,14 @@ func parseOtherFiles(fset *token.FileSet, srcDir, filename string) []*ast.File {
fileBase := filepath.Base(filename)
packageFileInfos, err := os.ReadDir(srcDir)
if err != nil {
return nil
return nil, ctx.Err()
}
var files []*ast.File
for _, fi := range packageFileInfos {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") {
continue
}
@@ -132,7 +138,7 @@ func parseOtherFiles(fset *token.FileSet, srcDir, filename string) []*ast.File {
files = append(files, f)
}
return files
return files, ctx.Err()
}
// addGlobals puts the names of package vars into the provided map.
@@ -557,12 +563,7 @@ func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) {
// fixImports adds and removes imports from f so that all its references are
// satisfied and there are no unused imports.
//
// This is declared as a variable rather than a function so goimports can
// easily be extended by adding a file with an init function.
var fixImports = fixImportsDefault
func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
func fixImports(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
fixes, err := getFixes(context.Background(), fset, f, filename, env)
if err != nil {
return err
@@ -592,7 +593,10 @@ func getFixes(ctx context.Context, fset *token.FileSet, f *ast.File, filename st
return fixes, nil
}
otherFiles := parseOtherFiles(fset, srcDir, filename)
otherFiles, err := parseOtherFiles(ctx, fset, srcDir, filename)
if err != nil {
return nil, err
}
// Second pass: add information from other files in the same package,
// like their package vars and imports.
@@ -1192,7 +1196,7 @@ func addExternalCandidates(ctx context.Context, pass *pass, refs references, fil
if err != nil {
return err
}
if err = resolver.scan(context.Background(), callback); err != nil {
if err = resolver.scan(ctx, callback); err != nil {
return err
}
@@ -1203,7 +1207,7 @@ func addExternalCandidates(ctx context.Context, pass *pass, refs references, fil
}
results := make(chan result, len(refs))
ctx, cancel := context.WithCancel(context.TODO())
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
defer func() {
cancel()

View File

@@ -12,7 +12,7 @@ import (
"go/types"
)
// FileVersions returns a file's Go version.
// FileVersion returns a file's Go version.
// The reported version is an unknown Future version if a
// version cannot be determined.
func FileVersion(info *types.Info, file *ast.File) string {