mirror of
https://github.com/fumiama/go-nd-portal.git
synced 2026-06-05 00:10:25 +08:00
refactor: auto get client ip from challenge response (#5)
* refactor: auto get client ip from challenge response Since `outip()` was not working properly on devices getting local IP addresses behind a router, we should refactor this. After analyzing the auth process, it is shown that the challenge response includes key `client_ip` which is the real public IP address with key `ip` not specified in request. - removed `outip()` - added rsp key `ClientIP` to get client ip from challenge rsp * style: trim code * style: fix spelling issues * refactor: create `portal_test.go` to handle portal tests separately * feature: add `ResolveLocalClientIP()` and its test case * optimize: resolve ClientIP locally when cant be get from challenge response
This commit is contained in:
39
cmd/main.go
39
cmd/main.go
@@ -4,7 +4,6 @@ package cmd
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -17,15 +16,6 @@ import (
|
||||
"github.com/fumiama/go-nd-portal/portal"
|
||||
)
|
||||
|
||||
func outip() (net.IP, error) {
|
||||
conn, err := net.Dial("udp", "8.8.8.8:53")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = conn.Close()
|
||||
return conn.LocalAddr().(*net.UDPAddr).IP.To4(), nil
|
||||
}
|
||||
|
||||
func line() int {
|
||||
_, _, fileLine, ok := runtime.Caller(1)
|
||||
if ok {
|
||||
@@ -38,14 +28,7 @@ const query = "query"
|
||||
|
||||
// Main cmd program
|
||||
func Main() {
|
||||
ip, err := outip()
|
||||
ipf := ""
|
||||
if err != nil {
|
||||
ipf = query
|
||||
} else {
|
||||
ipf = ip.String()
|
||||
}
|
||||
ips := flag.String("ip", ipf, "public IP")
|
||||
ip := flag.String("ip", "", "client IP, auto get from login host when empty")
|
||||
n := flag.String("n", query, "username")
|
||||
p := flag.String("p", query, "password")
|
||||
h := flag.Bool("h", false, "display this help")
|
||||
@@ -64,26 +47,18 @@ func Main() {
|
||||
} else if *w {
|
||||
logrus.SetLevel(logrus.WarnLevel)
|
||||
}
|
||||
if *ips == query {
|
||||
fmt.Printf("ip: ")
|
||||
_, err = fmt.Scanln(ips)
|
||||
if *ip != "" {
|
||||
// just validate IP here,
|
||||
// dont convert to net.IP because we need only its string later
|
||||
_, err := netip.ParseAddr(*ip)
|
||||
if err != nil {
|
||||
logrus.Errorln(err)
|
||||
os.Exit(line())
|
||||
}
|
||||
}
|
||||
if *ips != ip.String() {
|
||||
ipaddr, err := netip.ParseAddr(*ips)
|
||||
if err != nil {
|
||||
logrus.Errorln(err)
|
||||
os.Exit(line())
|
||||
}
|
||||
a4 := ipaddr.As4()
|
||||
copy(ip, a4[:])
|
||||
}
|
||||
if *n == query {
|
||||
fmt.Printf("username: ")
|
||||
_, err = fmt.Scanln(n)
|
||||
_, err := fmt.Scanln(n)
|
||||
if err != nil {
|
||||
logrus.Errorln(err)
|
||||
os.Exit(line())
|
||||
@@ -112,7 +87,7 @@ func Main() {
|
||||
// p: password
|
||||
// ip : public ip
|
||||
// *t : login type
|
||||
ptl, err := portal.NewPortal(*n, *p, *s, ip, portal.LoginType(*t))
|
||||
ptl, err := portal.NewPortal(*n, *p, *s, *ip, portal.LoginType(*t))
|
||||
if err != nil {
|
||||
logrus.Errorln(err)
|
||||
os.Exit(line())
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -16,12 +17,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrIllegalIPv4 is returned when an invalid IPv4 address is provided
|
||||
ErrIllegalIPv4 = errors.New("illegal ipv4")
|
||||
// ErrIllegalLoginType is returned when an invalid login type is provided
|
||||
ErrIllegalLoginType = errors.New("illegal login type")
|
||||
// ErrUnexpectedChallengeResponse is returned when challenge is shorter than expected
|
||||
ErrUnexpectedChallengeResponse = errors.New("unexpected challenge response")
|
||||
// ErrCannotDetermineClientIP is returned when client IP cant get from challenge or local resolution with cip not specified
|
||||
ErrCannotDetermineClientIP = errors.New("failed to determine client IP from challenge response or local resolution")
|
||||
// ErrUnexpectedLoginResponse is returned when login resp is shorter than expected
|
||||
ErrUnexpectedLoginResponse = errors.New("unexpected login response")
|
||||
)
|
||||
@@ -30,7 +31,7 @@ var (
|
||||
type Portal struct {
|
||||
name string
|
||||
pswd string
|
||||
cip net.IP
|
||||
cip string
|
||||
sip string
|
||||
domain string
|
||||
acid string
|
||||
@@ -89,18 +90,26 @@ func (lt LoginType) ToDomainAcID() (string, string, error) {
|
||||
return domain, acid, nil
|
||||
}
|
||||
|
||||
// ResolveLocalClientIP resolves Client IP locally
|
||||
func ResolveLocalClientIP() (string, error) {
|
||||
conn, err := net.Dial("udp", "8.8.8.8:53")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.LocalAddr().(*net.UDPAddr).IP.String(), nil
|
||||
}
|
||||
|
||||
// rsp struct for converting from raw response data to JSON
|
||||
type rsp struct {
|
||||
ClientIP string `json:"client_ip"`
|
||||
Challenge string `json:"challenge"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// NewPortal creates a new Portal instance
|
||||
func NewPortal(name, password, sIP string, cIP net.IP, loginType LoginType) (*Portal, error) {
|
||||
if len(cIP) != 4 {
|
||||
return nil, ErrIllegalIPv4
|
||||
}
|
||||
|
||||
func NewPortal(name, password, sIP string, cIP string, loginType LoginType) (*Portal, error) {
|
||||
domain, acid, err := loginType.ToDomainAcID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -157,6 +166,22 @@ func (p *Portal) GetChallenge() (string, error) {
|
||||
if r.Error != "ok" {
|
||||
return "", errors.New(r.Error)
|
||||
}
|
||||
// if cip was left empty, try get from challenge resp
|
||||
if p.cip == "" {
|
||||
logrus.Debugln("client ip is not specified, try get client ip from challenge resp")
|
||||
_, err = netip.ParseAddr(r.ClientIP)
|
||||
if err == nil {
|
||||
p.cip = r.ClientIP
|
||||
logrus.Debugln("get client ip from challenge resp:", r.ClientIP)
|
||||
} else {
|
||||
// if ClientIP is invalid, try resolve it locally
|
||||
p.cip, err = ResolveLocalClientIP()
|
||||
if err != nil {
|
||||
return "", ErrCannotDetermineClientIP
|
||||
}
|
||||
logrus.Debugln("failed to get client ip from challenge resp, using locally resolved ip:", p.cip)
|
||||
}
|
||||
}
|
||||
logrus.Debugln("get challenge:", r.Challenge)
|
||||
return r.Challenge, nil
|
||||
}
|
||||
@@ -211,6 +236,11 @@ func (p *Portal) Login(challenge string) error {
|
||||
return err
|
||||
}
|
||||
logrus.Debugln("login rsp:", &r)
|
||||
// compare local cip with response client_ip
|
||||
if p.cip != r.ClientIP {
|
||||
logrus.Warnln("client ip in login request does not match response! unexpected errors may occur")
|
||||
logrus.Warnf("request: %s, response: %s", p.cip, r.ClientIP)
|
||||
}
|
||||
if r.Error != "ok" {
|
||||
return errors.New(r.Error)
|
||||
}
|
||||
|
||||
24
portal/portal_test.go
Normal file
24
portal/portal_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package portal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAutoSelectServerIP(t *testing.T) {
|
||||
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(u.sip)
|
||||
assert.Equal(t, PortalServerIPQsh, u.sip)
|
||||
}
|
||||
|
||||
func TestResolveLocalClientIP(t *testing.T) {
|
||||
cip, err := ResolveLocalClientIP()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(cip)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-querystring/query"
|
||||
@@ -87,13 +86,13 @@ type GetPortalReq struct {
|
||||
func GetChallengeURL(
|
||||
sIP,
|
||||
callback,
|
||||
username, domain string,
|
||||
cIP net.IP,
|
||||
username, domain,
|
||||
cIP string,
|
||||
timestamp int64) (string, error) {
|
||||
v, err := query.Values(&GetChallengeReq{
|
||||
Callback: callback,
|
||||
Username: username + domain,
|
||||
IP: cIP.String(),
|
||||
IP: cIP,
|
||||
Timestamp: timestamp,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -109,8 +108,8 @@ func GetLoginURL(
|
||||
callback,
|
||||
username, domain,
|
||||
md5Password,
|
||||
acid string,
|
||||
cIP net.IP,
|
||||
acid,
|
||||
cIP,
|
||||
chksum,
|
||||
info string,
|
||||
timestamp int64) (string, error) {
|
||||
@@ -120,7 +119,7 @@ func GetLoginURL(
|
||||
Username: username + domain,
|
||||
EncryptedPassword: "{MD5}" + md5Password,
|
||||
AcID: acid,
|
||||
IP: cIP.String(),
|
||||
IP: cIP,
|
||||
Checksum: chksum,
|
||||
EncodedUserInfo: "{SRBX1}" + info,
|
||||
ConstantN: "200",
|
||||
@@ -155,14 +154,14 @@ type UserInfo struct {
|
||||
func GetUserInfo(
|
||||
username,
|
||||
domain,
|
||||
password string,
|
||||
cIP net.IP,
|
||||
password,
|
||||
cIP,
|
||||
acid string) (string, error) {
|
||||
var b strings.Builder
|
||||
err := json.NewEncoder(&b).Encode(&UserInfo{
|
||||
Username: username + domain,
|
||||
Password: password,
|
||||
IP: cIP.String(),
|
||||
IP: cIP,
|
||||
AcID: acid,
|
||||
EncVer: "srun_bx1",
|
||||
})
|
||||
@@ -234,8 +233,8 @@ func (p *Portal) CheckSum(
|
||||
username,
|
||||
domain,
|
||||
hmd5,
|
||||
acid string,
|
||||
cIP net.IP,
|
||||
acid,
|
||||
cIP,
|
||||
info string) string {
|
||||
var buf [20]byte
|
||||
h := sha1.New()
|
||||
@@ -247,7 +246,7 @@ func (p *Portal) CheckSum(
|
||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||
_, _ = h.Write([]byte(acid)) // acid
|
||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||
_, _ = h.Write(helper.StringToBytes(cIP.String()))
|
||||
_, _ = h.Write(helper.StringToBytes(cIP))
|
||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||
_, _ = h.Write([]byte("200")) // n
|
||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,17 +11,8 @@ import (
|
||||
"github.com/fumiama/go-nd-portal/helper"
|
||||
)
|
||||
|
||||
func TestAutoSelectServerIP(t *testing.T) {
|
||||
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(u.sip)
|
||||
assert.Equal(t, PortalServerIPQsh, u.sip)
|
||||
}
|
||||
|
||||
func TestGetUserInfo(t *testing.T) {
|
||||
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
|
||||
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -62,7 +52,7 @@ func TestDecodeKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEncodeUserInfo(t *testing.T) {
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -76,7 +66,7 @@ func TestEncodeUserInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHMd5(t *testing.T) {
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -90,7 +80,7 @@ func TestSha1(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckSum(t *testing.T) {
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
||||
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user