diff --git a/cmd/main.go b/cmd/main.go index 4bfe27b..1daaccf 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,7 +4,6 @@ package cmd import ( "flag" "fmt" - "net" "net/netip" "os" "runtime" @@ -17,15 +16,6 @@ import ( "github.com/fumiama/go-nd-portal/portal" ) -func outip() (net.IP, error) { - conn, err := net.Dial("udp", "8.8.8.8:53") - if err != nil { - return nil, err - } - _ = conn.Close() - return conn.LocalAddr().(*net.UDPAddr).IP.To4(), nil -} - func line() int { _, _, fileLine, ok := runtime.Caller(1) if ok { @@ -38,14 +28,7 @@ const query = "query" // Main cmd program func Main() { - ip, err := outip() - ipf := "" - if err != nil { - ipf = query - } else { - ipf = ip.String() - } - ips := flag.String("ip", ipf, "public IP") + ip := flag.String("ip", "", "client IP, auto get from login host when empty") n := flag.String("n", query, "username") p := flag.String("p", query, "password") h := flag.Bool("h", false, "display this help") @@ -64,26 +47,18 @@ func Main() { } else if *w { logrus.SetLevel(logrus.WarnLevel) } - if *ips == query { - fmt.Printf("ip: ") - _, err = fmt.Scanln(ips) + if *ip != "" { + // just validate IP here, + // dont convert to net.IP because we need only its string later + _, err := netip.ParseAddr(*ip) if err != nil { logrus.Errorln(err) os.Exit(line()) } } - if *ips != ip.String() { - ipaddr, err := netip.ParseAddr(*ips) - if err != nil { - logrus.Errorln(err) - os.Exit(line()) - } - a4 := ipaddr.As4() - copy(ip, a4[:]) - } if *n == query { fmt.Printf("username: ") - _, err = fmt.Scanln(n) + _, err := fmt.Scanln(n) if err != nil { logrus.Errorln(err) os.Exit(line()) @@ -112,7 +87,7 @@ func Main() { // p: password // ip : public ip // *t : login type - ptl, err := portal.NewPortal(*n, *p, *s, ip, portal.LoginType(*t)) + ptl, err := portal.NewPortal(*n, *p, *s, *ip, portal.LoginType(*t)) if err != nil { logrus.Errorln(err) os.Exit(line()) diff --git a/portal/portal.go b/portal/portal.go index 20a201e..34bed00 100644 --- a/portal/portal.go +++ b/portal/portal.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -16,12 +17,12 @@ import ( ) var ( - // ErrIllegalIPv4 is returned when an invalid IPv4 address is provided - ErrIllegalIPv4 = errors.New("illegal ipv4") // ErrIllegalLoginType is returned when an invalid login type is provided ErrIllegalLoginType = errors.New("illegal login type") // ErrUnexpectedChallengeResponse is returned when challenge is shorter than expected ErrUnexpectedChallengeResponse = errors.New("unexpected challenge response") + // ErrCannotDetermineClientIP is returned when client IP cant get from challenge or local resolution with cip not specified + ErrCannotDetermineClientIP = errors.New("failed to determine client IP from challenge response or local resolution") // ErrUnexpectedLoginResponse is returned when login resp is shorter than expected ErrUnexpectedLoginResponse = errors.New("unexpected login response") ) @@ -30,7 +31,7 @@ var ( type Portal struct { name string pswd string - cip net.IP + cip string sip string domain string acid string @@ -89,18 +90,26 @@ func (lt LoginType) ToDomainAcID() (string, string, error) { return domain, acid, nil } +// ResolveLocalClientIP resolves Client IP locally +func ResolveLocalClientIP() (string, error) { + conn, err := net.Dial("udp", "8.8.8.8:53") + if err != nil { + return "", err + } + defer conn.Close() + + return conn.LocalAddr().(*net.UDPAddr).IP.String(), nil +} + // rsp struct for converting from raw response data to JSON type rsp struct { + ClientIP string `json:"client_ip"` Challenge string `json:"challenge"` Error string `json:"error"` } // NewPortal creates a new Portal instance -func NewPortal(name, password, sIP string, cIP net.IP, loginType LoginType) (*Portal, error) { - if len(cIP) != 4 { - return nil, ErrIllegalIPv4 - } - +func NewPortal(name, password, sIP string, cIP string, loginType LoginType) (*Portal, error) { domain, acid, err := loginType.ToDomainAcID() if err != nil { return nil, err @@ -157,6 +166,22 @@ func (p *Portal) GetChallenge() (string, error) { if r.Error != "ok" { return "", errors.New(r.Error) } + // if cip was left empty, try get from challenge resp + if p.cip == "" { + logrus.Debugln("client ip is not specified, try get client ip from challenge resp") + _, err = netip.ParseAddr(r.ClientIP) + if err == nil { + p.cip = r.ClientIP + logrus.Debugln("get client ip from challenge resp:", r.ClientIP) + } else { + // if ClientIP is invalid, try resolve it locally + p.cip, err = ResolveLocalClientIP() + if err != nil { + return "", ErrCannotDetermineClientIP + } + logrus.Debugln("failed to get client ip from challenge resp, using locally resolved ip:", p.cip) + } + } logrus.Debugln("get challenge:", r.Challenge) return r.Challenge, nil } @@ -211,6 +236,11 @@ func (p *Portal) Login(challenge string) error { return err } logrus.Debugln("login rsp:", &r) + // compare local cip with response client_ip + if p.cip != r.ClientIP { + logrus.Warnln("client ip in login request does not match response! unexpected errors may occur") + logrus.Warnf("request: %s, response: %s", p.cip, r.ClientIP) + } if r.Error != "ok" { return errors.New(r.Error) } diff --git a/portal/portal_test.go b/portal/portal_test.go new file mode 100644 index 0000000..a83209b --- /dev/null +++ b/portal/portal_test.go @@ -0,0 +1,24 @@ +package portal + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAutoSelectServerIP(t *testing.T) { + u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu) + if err != nil { + t.Fatal(err) + } + t.Log(u.sip) + assert.Equal(t, PortalServerIPQsh, u.sip) +} + +func TestResolveLocalClientIP(t *testing.T) { + cip, err := ResolveLocalClientIP() + if err != nil { + t.Fatal(err) + } + t.Log(cip) +} diff --git a/portal/server.go b/portal/server.go index fd0066d..c31b00c 100644 --- a/portal/server.go +++ b/portal/server.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "strings" "github.com/google/go-querystring/query" @@ -87,13 +86,13 @@ type GetPortalReq struct { func GetChallengeURL( sIP, callback, - username, domain string, - cIP net.IP, + username, domain, + cIP string, timestamp int64) (string, error) { v, err := query.Values(&GetChallengeReq{ Callback: callback, Username: username + domain, - IP: cIP.String(), + IP: cIP, Timestamp: timestamp, }) if err != nil { @@ -109,8 +108,8 @@ func GetLoginURL( callback, username, domain, md5Password, - acid string, - cIP net.IP, + acid, + cIP, chksum, info string, timestamp int64) (string, error) { @@ -120,7 +119,7 @@ func GetLoginURL( Username: username + domain, EncryptedPassword: "{MD5}" + md5Password, AcID: acid, - IP: cIP.String(), + IP: cIP, Checksum: chksum, EncodedUserInfo: "{SRBX1}" + info, ConstantN: "200", @@ -155,14 +154,14 @@ type UserInfo struct { func GetUserInfo( username, domain, - password string, - cIP net.IP, + password, + cIP, acid string) (string, error) { var b strings.Builder err := json.NewEncoder(&b).Encode(&UserInfo{ Username: username + domain, Password: password, - IP: cIP.String(), + IP: cIP, AcID: acid, EncVer: "srun_bx1", }) @@ -234,8 +233,8 @@ func (p *Portal) CheckSum( username, domain, hmd5, - acid string, - cIP net.IP, + acid, + cIP, info string) string { var buf [20]byte h := sha1.New() @@ -247,7 +246,7 @@ func (p *Portal) CheckSum( _, _ = h.Write(helper.StringToBytes(challenge)) _, _ = h.Write([]byte(acid)) // acid _, _ = h.Write(helper.StringToBytes(challenge)) - _, _ = h.Write(helper.StringToBytes(cIP.String())) + _, _ = h.Write(helper.StringToBytes(cIP)) _, _ = h.Write(helper.StringToBytes(challenge)) _, _ = h.Write([]byte("200")) // n _, _ = h.Write(helper.StringToBytes(challenge)) diff --git a/portal/server_test.go b/portal/server_test.go index 0b4c82e..3930300 100644 --- a/portal/server_test.go +++ b/portal/server_test.go @@ -4,7 +4,6 @@ import ( "crypto/sha1" "encoding/binary" "encoding/hex" - "net" "testing" "github.com/stretchr/testify/assert" @@ -12,17 +11,8 @@ import ( "github.com/fumiama/go-nd-portal/helper" ) -func TestAutoSelectServerIP(t *testing.T) { - u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu) - if err != nil { - t.Fatal(err) - } - t.Log(u.sip) - assert.Equal(t, PortalServerIPQsh, u.sip) -} - func TestGetUserInfo(t *testing.T) { - u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu) + u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu) if err != nil { t.Fatal(err) } @@ -62,7 +52,7 @@ func TestDecodeKey(t *testing.T) { } func TestEncodeUserInfo(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) + u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu) if err != nil { t.Fatal(err) } @@ -76,7 +66,7 @@ func TestEncodeUserInfo(t *testing.T) { } func TestHMd5(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) + u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu) if err != nil { t.Fatal(err) } @@ -90,7 +80,7 @@ func TestSha1(t *testing.T) { } func TestCheckSum(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) + u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu) if err != nil { t.Fatal(err) }