mirror of
https://github.com/fumiama/go-nd-portal.git
synced 2026-06-28 15:10:26 +08:00
refactor: auto get client ip from challenge response (#5)
* refactor: auto get client ip from challenge response Since `outip()` was not working properly on devices getting local IP addresses behind a router, we should refactor this. After analyzing the auth process, it is shown that the challenge response includes key `client_ip` which is the real public IP address with key `ip` not specified in request. - removed `outip()` - added rsp key `ClientIP` to get client ip from challenge rsp * style: trim code * style: fix spelling issues * refactor: create `portal_test.go` to handle portal tests separately * feature: add `ResolveLocalClientIP()` and its test case * optimize: resolve ClientIP locally when cant be get from challenge response
This commit is contained in:
39
cmd/main.go
39
cmd/main.go
@@ -4,7 +4,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -17,15 +16,6 @@ import (
|
|||||||
"github.com/fumiama/go-nd-portal/portal"
|
"github.com/fumiama/go-nd-portal/portal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func outip() (net.IP, error) {
|
|
||||||
conn, err := net.Dial("udp", "8.8.8.8:53")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_ = conn.Close()
|
|
||||||
return conn.LocalAddr().(*net.UDPAddr).IP.To4(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func line() int {
|
func line() int {
|
||||||
_, _, fileLine, ok := runtime.Caller(1)
|
_, _, fileLine, ok := runtime.Caller(1)
|
||||||
if ok {
|
if ok {
|
||||||
@@ -38,14 +28,7 @@ const query = "query"
|
|||||||
|
|
||||||
// Main cmd program
|
// Main cmd program
|
||||||
func Main() {
|
func Main() {
|
||||||
ip, err := outip()
|
ip := flag.String("ip", "", "client IP, auto get from login host when empty")
|
||||||
ipf := ""
|
|
||||||
if err != nil {
|
|
||||||
ipf = query
|
|
||||||
} else {
|
|
||||||
ipf = ip.String()
|
|
||||||
}
|
|
||||||
ips := flag.String("ip", ipf, "public IP")
|
|
||||||
n := flag.String("n", query, "username")
|
n := flag.String("n", query, "username")
|
||||||
p := flag.String("p", query, "password")
|
p := flag.String("p", query, "password")
|
||||||
h := flag.Bool("h", false, "display this help")
|
h := flag.Bool("h", false, "display this help")
|
||||||
@@ -64,26 +47,18 @@ func Main() {
|
|||||||
} else if *w {
|
} else if *w {
|
||||||
logrus.SetLevel(logrus.WarnLevel)
|
logrus.SetLevel(logrus.WarnLevel)
|
||||||
}
|
}
|
||||||
if *ips == query {
|
if *ip != "" {
|
||||||
fmt.Printf("ip: ")
|
// just validate IP here,
|
||||||
_, err = fmt.Scanln(ips)
|
// dont convert to net.IP because we need only its string later
|
||||||
|
_, err := netip.ParseAddr(*ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln(err)
|
logrus.Errorln(err)
|
||||||
os.Exit(line())
|
os.Exit(line())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if *ips != ip.String() {
|
|
||||||
ipaddr, err := netip.ParseAddr(*ips)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Errorln(err)
|
|
||||||
os.Exit(line())
|
|
||||||
}
|
|
||||||
a4 := ipaddr.As4()
|
|
||||||
copy(ip, a4[:])
|
|
||||||
}
|
|
||||||
if *n == query {
|
if *n == query {
|
||||||
fmt.Printf("username: ")
|
fmt.Printf("username: ")
|
||||||
_, err = fmt.Scanln(n)
|
_, err := fmt.Scanln(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln(err)
|
logrus.Errorln(err)
|
||||||
os.Exit(line())
|
os.Exit(line())
|
||||||
@@ -112,7 +87,7 @@ func Main() {
|
|||||||
// p: password
|
// p: password
|
||||||
// ip : public ip
|
// ip : public ip
|
||||||
// *t : login type
|
// *t : login type
|
||||||
ptl, err := portal.NewPortal(*n, *p, *s, ip, portal.LoginType(*t))
|
ptl, err := portal.NewPortal(*n, *p, *s, *ip, portal.LoginType(*t))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln(err)
|
logrus.Errorln(err)
|
||||||
os.Exit(line())
|
os.Exit(line())
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -16,12 +17,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrIllegalIPv4 is returned when an invalid IPv4 address is provided
|
|
||||||
ErrIllegalIPv4 = errors.New("illegal ipv4")
|
|
||||||
// ErrIllegalLoginType is returned when an invalid login type is provided
|
// ErrIllegalLoginType is returned when an invalid login type is provided
|
||||||
ErrIllegalLoginType = errors.New("illegal login type")
|
ErrIllegalLoginType = errors.New("illegal login type")
|
||||||
// ErrUnexpectedChallengeResponse is returned when challenge is shorter than expected
|
// ErrUnexpectedChallengeResponse is returned when challenge is shorter than expected
|
||||||
ErrUnexpectedChallengeResponse = errors.New("unexpected challenge response")
|
ErrUnexpectedChallengeResponse = errors.New("unexpected challenge response")
|
||||||
|
// ErrCannotDetermineClientIP is returned when client IP cant get from challenge or local resolution with cip not specified
|
||||||
|
ErrCannotDetermineClientIP = errors.New("failed to determine client IP from challenge response or local resolution")
|
||||||
// ErrUnexpectedLoginResponse is returned when login resp is shorter than expected
|
// ErrUnexpectedLoginResponse is returned when login resp is shorter than expected
|
||||||
ErrUnexpectedLoginResponse = errors.New("unexpected login response")
|
ErrUnexpectedLoginResponse = errors.New("unexpected login response")
|
||||||
)
|
)
|
||||||
@@ -30,7 +31,7 @@ var (
|
|||||||
type Portal struct {
|
type Portal struct {
|
||||||
name string
|
name string
|
||||||
pswd string
|
pswd string
|
||||||
cip net.IP
|
cip string
|
||||||
sip string
|
sip string
|
||||||
domain string
|
domain string
|
||||||
acid string
|
acid string
|
||||||
@@ -89,18 +90,26 @@ func (lt LoginType) ToDomainAcID() (string, string, error) {
|
|||||||
return domain, acid, nil
|
return domain, acid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveLocalClientIP resolves Client IP locally
|
||||||
|
func ResolveLocalClientIP() (string, error) {
|
||||||
|
conn, err := net.Dial("udp", "8.8.8.8:53")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return conn.LocalAddr().(*net.UDPAddr).IP.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// rsp struct for converting from raw response data to JSON
|
// rsp struct for converting from raw response data to JSON
|
||||||
type rsp struct {
|
type rsp struct {
|
||||||
|
ClientIP string `json:"client_ip"`
|
||||||
Challenge string `json:"challenge"`
|
Challenge string `json:"challenge"`
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPortal creates a new Portal instance
|
// NewPortal creates a new Portal instance
|
||||||
func NewPortal(name, password, sIP string, cIP net.IP, loginType LoginType) (*Portal, error) {
|
func NewPortal(name, password, sIP string, cIP string, loginType LoginType) (*Portal, error) {
|
||||||
if len(cIP) != 4 {
|
|
||||||
return nil, ErrIllegalIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
domain, acid, err := loginType.ToDomainAcID()
|
domain, acid, err := loginType.ToDomainAcID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -157,6 +166,22 @@ func (p *Portal) GetChallenge() (string, error) {
|
|||||||
if r.Error != "ok" {
|
if r.Error != "ok" {
|
||||||
return "", errors.New(r.Error)
|
return "", errors.New(r.Error)
|
||||||
}
|
}
|
||||||
|
// if cip was left empty, try get from challenge resp
|
||||||
|
if p.cip == "" {
|
||||||
|
logrus.Debugln("client ip is not specified, try get client ip from challenge resp")
|
||||||
|
_, err = netip.ParseAddr(r.ClientIP)
|
||||||
|
if err == nil {
|
||||||
|
p.cip = r.ClientIP
|
||||||
|
logrus.Debugln("get client ip from challenge resp:", r.ClientIP)
|
||||||
|
} else {
|
||||||
|
// if ClientIP is invalid, try resolve it locally
|
||||||
|
p.cip, err = ResolveLocalClientIP()
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrCannotDetermineClientIP
|
||||||
|
}
|
||||||
|
logrus.Debugln("failed to get client ip from challenge resp, using locally resolved ip:", p.cip)
|
||||||
|
}
|
||||||
|
}
|
||||||
logrus.Debugln("get challenge:", r.Challenge)
|
logrus.Debugln("get challenge:", r.Challenge)
|
||||||
return r.Challenge, nil
|
return r.Challenge, nil
|
||||||
}
|
}
|
||||||
@@ -211,6 +236,11 @@ func (p *Portal) Login(challenge string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logrus.Debugln("login rsp:", &r)
|
logrus.Debugln("login rsp:", &r)
|
||||||
|
// compare local cip with response client_ip
|
||||||
|
if p.cip != r.ClientIP {
|
||||||
|
logrus.Warnln("client ip in login request does not match response! unexpected errors may occur")
|
||||||
|
logrus.Warnf("request: %s, response: %s", p.cip, r.ClientIP)
|
||||||
|
}
|
||||||
if r.Error != "ok" {
|
if r.Error != "ok" {
|
||||||
return errors.New(r.Error)
|
return errors.New(r.Error)
|
||||||
}
|
}
|
||||||
|
|||||||
24
portal/portal_test.go
Normal file
24
portal/portal_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package portal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAutoSelectServerIP(t *testing.T) {
|
||||||
|
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(u.sip)
|
||||||
|
assert.Equal(t, PortalServerIPQsh, u.sip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveLocalClientIP(t *testing.T) {
|
||||||
|
cip, err := ResolveLocalClientIP()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(cip)
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -87,13 +86,13 @@ type GetPortalReq struct {
|
|||||||
func GetChallengeURL(
|
func GetChallengeURL(
|
||||||
sIP,
|
sIP,
|
||||||
callback,
|
callback,
|
||||||
username, domain string,
|
username, domain,
|
||||||
cIP net.IP,
|
cIP string,
|
||||||
timestamp int64) (string, error) {
|
timestamp int64) (string, error) {
|
||||||
v, err := query.Values(&GetChallengeReq{
|
v, err := query.Values(&GetChallengeReq{
|
||||||
Callback: callback,
|
Callback: callback,
|
||||||
Username: username + domain,
|
Username: username + domain,
|
||||||
IP: cIP.String(),
|
IP: cIP,
|
||||||
Timestamp: timestamp,
|
Timestamp: timestamp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -109,8 +108,8 @@ func GetLoginURL(
|
|||||||
callback,
|
callback,
|
||||||
username, domain,
|
username, domain,
|
||||||
md5Password,
|
md5Password,
|
||||||
acid string,
|
acid,
|
||||||
cIP net.IP,
|
cIP,
|
||||||
chksum,
|
chksum,
|
||||||
info string,
|
info string,
|
||||||
timestamp int64) (string, error) {
|
timestamp int64) (string, error) {
|
||||||
@@ -120,7 +119,7 @@ func GetLoginURL(
|
|||||||
Username: username + domain,
|
Username: username + domain,
|
||||||
EncryptedPassword: "{MD5}" + md5Password,
|
EncryptedPassword: "{MD5}" + md5Password,
|
||||||
AcID: acid,
|
AcID: acid,
|
||||||
IP: cIP.String(),
|
IP: cIP,
|
||||||
Checksum: chksum,
|
Checksum: chksum,
|
||||||
EncodedUserInfo: "{SRBX1}" + info,
|
EncodedUserInfo: "{SRBX1}" + info,
|
||||||
ConstantN: "200",
|
ConstantN: "200",
|
||||||
@@ -155,14 +154,14 @@ type UserInfo struct {
|
|||||||
func GetUserInfo(
|
func GetUserInfo(
|
||||||
username,
|
username,
|
||||||
domain,
|
domain,
|
||||||
password string,
|
password,
|
||||||
cIP net.IP,
|
cIP,
|
||||||
acid string) (string, error) {
|
acid string) (string, error) {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
err := json.NewEncoder(&b).Encode(&UserInfo{
|
err := json.NewEncoder(&b).Encode(&UserInfo{
|
||||||
Username: username + domain,
|
Username: username + domain,
|
||||||
Password: password,
|
Password: password,
|
||||||
IP: cIP.String(),
|
IP: cIP,
|
||||||
AcID: acid,
|
AcID: acid,
|
||||||
EncVer: "srun_bx1",
|
EncVer: "srun_bx1",
|
||||||
})
|
})
|
||||||
@@ -234,8 +233,8 @@ func (p *Portal) CheckSum(
|
|||||||
username,
|
username,
|
||||||
domain,
|
domain,
|
||||||
hmd5,
|
hmd5,
|
||||||
acid string,
|
acid,
|
||||||
cIP net.IP,
|
cIP,
|
||||||
info string) string {
|
info string) string {
|
||||||
var buf [20]byte
|
var buf [20]byte
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
@@ -247,7 +246,7 @@ func (p *Portal) CheckSum(
|
|||||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||||
_, _ = h.Write([]byte(acid)) // acid
|
_, _ = h.Write([]byte(acid)) // acid
|
||||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||||
_, _ = h.Write(helper.StringToBytes(cIP.String()))
|
_, _ = h.Write(helper.StringToBytes(cIP))
|
||||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||||
_, _ = h.Write([]byte("200")) // n
|
_, _ = h.Write([]byte("200")) // n
|
||||||
_, _ = h.Write(helper.StringToBytes(challenge))
|
_, _ = h.Write(helper.StringToBytes(challenge))
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -12,17 +11,8 @@ import (
|
|||||||
"github.com/fumiama/go-nd-portal/helper"
|
"github.com/fumiama/go-nd-portal/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAutoSelectServerIP(t *testing.T) {
|
|
||||||
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
t.Log(u.sip)
|
|
||||||
assert.Equal(t, PortalServerIPQsh, u.sip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUserInfo(t *testing.T) {
|
func TestGetUserInfo(t *testing.T) {
|
||||||
u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu)
|
u, err := NewPortal("2000010101001", "12345678", "", "1.2.3.4", LoginTypeQshEdu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -62,7 +52,7 @@ func TestDecodeKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeUserInfo(t *testing.T) {
|
func TestEncodeUserInfo(t *testing.T) {
|
||||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -76,7 +66,7 @@ func TestEncodeUserInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHMd5(t *testing.T) {
|
func TestHMd5(t *testing.T) {
|
||||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -90,7 +80,7 @@ func TestSha1(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckSum(t *testing.T) {
|
func TestCheckSum(t *testing.T) {
|
||||||
u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu)
|
u, err := NewPortal("2001010101001", "1234567890", "", "113.54.148.243", LoginTypeQshEdu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user