1
0
mirror of https://github.com/fumiama/go-nd-portal.git synced 2026-06-05 00:10:25 +08:00

refactor: auto get client ip from challenge response (#5)

* refactor: auto get client ip from challenge response
Since `outip()` was not working properly on devices getting local IP addresses behind a router, we should refactor this.
After analyzing the auth process, it is shown that  the challenge response includes key `client_ip` which is the real public IP address with key `ip` not specified in request.
- removed `outip()`
- added rsp key `ClientIP` to get client ip from challenge rsp

* style: trim code

* style: fix spelling issues

* refactor: create `portal_test.go` to handle portal tests separately

* feature: add `ResolveLocalClientIP()` and its test case

* optimize: resolve ClientIP locally when cant be get from challenge response
This commit is contained in:
chasey-dev
2025-09-01 22:33:57 +08:00
committed by GitHub
parent 32fdf3ae90
commit f2459dd8d9
5 changed files with 85 additions and 67 deletions

View File

@@ -4,7 +4,6 @@ package cmd
import (
"flag"
"fmt"
"net"
"net/netip"
"os"
"runtime"
@@ -17,15 +16,6 @@ import (
"github.com/fumiama/go-nd-portal/portal"
)
func outip() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:53")
if err != nil {
return nil, err
}
_ = conn.Close()
return conn.LocalAddr().(*net.UDPAddr).IP.To4(), nil
}
func line() int {
_, _, fileLine, ok := runtime.Caller(1)
if ok {
@@ -38,14 +28,7 @@ const query = "query"
// Main cmd program
func Main() {
ip, err := outip()
ipf := ""
if err != nil {
ipf = query
} else {
ipf = ip.String()
}
ips := flag.String("ip", ipf, "public IP")
ip := flag.String("ip", "", "client IP, auto get from login host when empty")
n := flag.String("n", query, "username")
p := flag.String("p", query, "password")
h := flag.Bool("h", false, "display this help")
@@ -64,26 +47,18 @@ func Main() {
} else if *w {
logrus.SetLevel(logrus.WarnLevel)
}
if *ips == query {
fmt.Printf("ip: ")
_, err = fmt.Scanln(ips)
if *ip != "" {
// just validate IP here,
// dont convert to net.IP because we need only its string later
_, err := netip.ParseAddr(*ip)
if err != nil {
logrus.Errorln(err)
os.Exit(line())
}
}
if *ips != ip.String() {
ipaddr, err := netip.ParseAddr(*ips)
if err != nil {
logrus.Errorln(err)
os.Exit(line())
}
a4 := ipaddr.As4()
copy(ip, a4[:])
}
if *n == query {
fmt.Printf("username: ")
_, err = fmt.Scanln(n)
_, err := fmt.Scanln(n)
if err != nil {
logrus.Errorln(err)
os.Exit(line())
@@ -112,7 +87,7 @@ func Main() {
// p: password
// ip : public ip
// *t : login type
ptl, err := portal.NewPortal(*n, *p, *s, ip, portal.LoginType(*t))
ptl, err := portal.NewPortal(*n, *p, *s, *ip, portal.LoginType(*t))
if err != nil {
logrus.Errorln(err)
os.Exit(line())

View File

@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"net"
"net/netip"
"time"
"github.com/sirupsen/logrus"
@@ -16,12 +17,12 @@ import (
)
var (
// ErrIllegalIPv4 is returned when an invalid IPv4 address is provided
ErrIllegalIPv4 = errors.New("illegal ipv4")
// ErrIllegalLoginType is returned when an invalid login type is provided
ErrIllegalLoginType = errors.New("illegal login type")
// ErrUnexpectedChallengeResponse is returned when challenge is shorter than expected
ErrUnexpectedChallengeResponse = errors.New("unexpected challenge response")
// ErrCannotDetermineClientIP is returned when client IP cant get from challenge or local resolution with cip not specified
ErrCannotDetermineClientIP = errors.New("failed to determine client IP from challenge response or local resolution")
// ErrUnexpectedLoginResponse is returned when login resp is shorter than expected
ErrUnexpectedLoginResponse = errors.New("unexpected login response")
)
@@ -30,7 +31,7 @@ var (
type Portal struct {
name string
pswd string
cip net.IP
cip string
sip string
domain string
acid string
@@ -89,18 +90,26 @@ func (lt LoginType) ToDomainAcID() (string, string, error) {
return domain, acid, nil
}
// ResolveLocalClientIP resolves Client IP locally
func ResolveLocalClientIP() (string, error) {
conn, err := net.Dial("udp", "8.8.8.8:53")
if err != nil {
return "", err
}
defer conn.Close()
return conn.LocalAddr().(*net.UDPAddr).IP.String(), nil
}
// rsp struct for converting from raw response data to JSON
type rsp struct {
ClientIP string `json:"client_ip"`
Challenge string `json:"challenge"`
Error string `json:"error"`
}
// NewPortal creates a new Portal instance
func NewPortal(name, password, sIP string, cIP net.IP, loginType LoginType) (*Portal, error) {
if len(cIP) != 4 {
return nil, ErrIllegalIPv4
}
func NewPortal(name, password, sIP string, cIP string, loginType LoginType) (*Portal, error) {
domain, acid, err := loginType.ToDomainAcID()
if err != nil {
return nil, err
@@ -157,6 +166,22 @@ func (p *Portal) GetChallenge() (string, error) {
if r.Error != "ok" {
return "", errors.New(r.Error)
}
// if cip was left empty, try get from challenge resp
if p.cip == "" {
logrus.Debugln("client ip is not specified, try get client ip from challenge resp")
_, err = netip.ParseAddr(r.ClientIP)
if err == nil {
p.cip = r.ClientIP
logrus.Debugln("get client ip from challenge resp:", r.ClientIP)
} else {
// if ClientIP is invalid, try resolve it locally
p.cip, err = ResolveLocalClientIP()
if err != nil {
return "", ErrCannotDetermineClientIP
}
logrus.Debugln("failed to get client ip from challenge resp, using locally resolved ip:", p.cip)
}
}
logrus.Debugln("get challenge:", r.Challenge)
return r.Challenge, nil
}
@@ -211,6 +236,11 @@ func (p *Portal) Login(challenge string) error {
return err
}
logrus.Debugln("login rsp:", &r)
// compare local cip with response client_ip
if p.cip != r.ClientIP {
logrus.Warnln("client ip in login request does not match response! unexpected errors may occur")
logrus.Warnf("request: %s, response: %s", p.cip, r.ClientIP)
}
if r.Error != "ok" {
return errors.New(r.Error)
}

24
portal/portal_test.go Normal file
View File

@@ -0,0 +1,24 @@
package portal
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAutoSelectServerIP(t *testing.T) {
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}
t.Log(u.sip)
assert.Equal(t, PortalServerIPQsh, u.sip)
}
func TestResolveLocalClientIP(t *testing.T) {
cip, err := ResolveLocalClientIP()
if err != nil {
t.Fatal(err)
}
t.Log(cip)
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"net"
"strings"
"github.com/google/go-querystring/query"
@@ -87,13 +86,13 @@ type GetPortalReq struct {
func GetChallengeURL(
sIP,
callback,
username, domain string,
cIP net.IP,
username, domain,
cIP string,
timestamp int64) (string, error) {
v, err := query.Values(&GetChallengeReq{
Callback: callback,
Username: username + domain,
IP: cIP.String(),
IP: cIP,
Timestamp: timestamp,
})
if err != nil {
@@ -109,8 +108,8 @@ func GetLoginURL(
callback,
username, domain,
md5Password,
acid string,
cIP net.IP,
acid,
cIP,
chksum,
info string,
timestamp int64) (string, error) {
@@ -120,7 +119,7 @@ func GetLoginURL(
Username: username + domain,
EncryptedPassword: "{MD5}" + md5Password,
AcID: acid,
IP: cIP.String(),
IP: cIP,
Checksum: chksum,
EncodedUserInfo: "{SRBX1}" + info,
ConstantN: "200",
@@ -155,14 +154,14 @@ type UserInfo struct {
func GetUserInfo(
username,
domain,
password string,
cIP net.IP,
password,
cIP,
acid string) (string, error) {
var b strings.Builder
err := json.NewEncoder(&b).Encode(&UserInfo{
Username: username + domain,
Password: password,
IP: cIP.String(),
IP: cIP,
AcID: acid,
EncVer: "srun_bx1",
})
@@ -234,8 +233,8 @@ func (p *Portal) CheckSum(
username,
domain,
hmd5,
acid string,
cIP net.IP,
acid,
cIP,
info string) string {
var buf [20]byte
h := sha1.New()
@@ -247,7 +246,7 @@ func (p *Portal) CheckSum(
_, _ = h.Write(helper.StringToBytes(challenge))
_, _ = h.Write([]byte(acid)) // acid
_, _ = h.Write(helper.StringToBytes(challenge))
_, _ = h.Write(helper.StringToBytes(cIP.String()))
_, _ = h.Write(helper.StringToBytes(cIP))
_, _ = h.Write(helper.StringToBytes(challenge))
_, _ = h.Write([]byte("200")) // n
_, _ = h.Write(helper.StringToBytes(challenge))

View File

@@ -4,7 +4,6 @@ import (
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"net"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,17 +11,8 @@ import (
"github.com/fumiama/go-nd-portal/helper"
)
func TestAutoSelectServerIP(t *testing.T) {
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}
t.Log(u.sip)
assert.Equal(t, PortalServerIPQsh, u.sip)
}
func TestGetUserInfo(t *testing.T) {
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}
@@ -62,7 +52,7 @@ func TestDecodeKey(t *testing.T) {
}
func TestEncodeUserInfo(t *testing.T) {
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}
@@ -76,7 +66,7 @@ func TestEncodeUserInfo(t *testing.T) {
}
func TestHMd5(t *testing.T) {
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}
@@ -90,7 +80,7 @@ func TestSha1(t *testing.T) {
}
func TestCheckSum(t *testing.T) {
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
if err != nil {
t.Fatal(err)
}