1
0
mirror of https://github.com/fumiama/paper-manager.git synced 2026-06-19 09:40:22 +08:00

finish register

This commit is contained in:
源文雨
2023-03-19 21:07:22 +08:00
parent bee5caaadc
commit f3757deecf
6 changed files with 92 additions and 26 deletions

View File

@@ -3,6 +3,7 @@ package backend
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings"
"github.com/fumiama/paper-manager/backend/utils" "github.com/fumiama/paper-manager/backend/utils"
) )
@@ -75,6 +76,39 @@ func init() {
writeresult(w, codeSuccess, nil, messageOk, typeSuccess) writeresult(w, codeSuccess, nil, messageOk, typeSuccess)
}} }}
apimap["/api/register"] = &apihandler{"POST", func(w http.ResponseWriter, r *http.Request) {
type registerbody struct {
Username string `json:"username"`
Mobile string `json:"mobile"`
Password string `json:"password"`
}
if r.Header.Get("Authorization") != "" {
writeresult(w, codeError, nil, errInvalidToken.Error(), typeError)
return
}
var body registerbody
defer r.Body.Close()
err := json.NewDecoder(r.Body).Decode(&body)
if err != nil {
writeresult(w, codeError, nil, err.Error(), typeError)
return
}
ip := r.RemoteAddr
i := strings.LastIndex(ip, ":")
if i >= 0 {
ip = ip[:i]
}
err = register(ip, body.Username, body.Mobile, body.Password)
if err != nil {
writeresult(w, codeError, nil, err.Error(), typeError)
return
}
type message struct {
M string `json:"msg"`
}
writeresult(w, codeSuccess, &message{M: "成功, 请耐心等待通知"}, messageOk, typeSuccess)
}}
apimap["/api/getUsersCount"] = &apihandler{"GET", func(w http.ResponseWriter, r *http.Request) { apimap["/api/getUsersCount"] = &apihandler{"GET", func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization") token := r.Header.Get("Authorization")
n, err := getUsersCount(token) n, err := getUsersCount(token)

View File

@@ -3,7 +3,6 @@ package global
import ( import (
"errors" "errors"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/fumiama/paper-manager/backend/utils" "github.com/fumiama/paper-manager/backend/utils"
@@ -65,7 +64,7 @@ var (
ErrEmptyName = errors.New("empty name") ErrEmptyName = errors.New("empty name")
ErrInvalidUsersCount = errors.New("invalid users count") ErrInvalidUsersCount = errors.New("invalid users count")
ErrEmptyUserID = errors.New("empty user ID") ErrEmptyUserID = errors.New("empty user ID")
ErrEmptyContect = errors.New("empty contact") ErrEmptyContact = errors.New("empty contact")
ErrUsernameExists = errors.New("username exists") ErrUsernameExists = errors.New("username exists")
ErrInvalidName = errors.New("invalid name") ErrInvalidName = errors.New("invalid name")
) )
@@ -201,7 +200,7 @@ func (u *UserDatabase) UpdateUserPassword(id int, npwd string) error {
// UpdateUserContact ... // UpdateUserContact ...
func (u *UserDatabase) UpdateUserContact(id int, ncont string) error { func (u *UserDatabase) UpdateUserContact(id int, ncont string) error {
if ncont == "" { if ncont == "" {
return ErrEmptyContect return ErrEmptyContact
} }
user, err := u.GetUserByID(id) user, err := u.GetUserByID(id)
if err != nil { if err != nil {
@@ -214,9 +213,8 @@ func (u *UserDatabase) UpdateUserContact(id int, ncont string) error {
return u.db.Insert(UserTableUser, &user) return u.db.Insert(UserTableUser, &user)
} }
// GetUserByName avoids sql injection by removing ; ' " = // GetUserByName avoids sql injection by limiting username to 0-9A-Za-z
func (u *UserDatabase) GetUserByName(username string) (user User, err error) { func (u *UserDatabase) GetUserByName(username string) (user User, err error) {
username = strings.NewReplacer(";", "", "'", "", `"`, "", "=", "").Replace(username)
for _, c := range username { for _, c := range username {
if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') { if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') {
err = ErrInvalidName err = ErrInvalidName
@@ -229,9 +227,8 @@ func (u *UserDatabase) GetUserByName(username string) (user User, err error) {
return return
} }
// IsNameExists avoids sql injection by removing ; ' " = // IsNameExists avoids sql injection by limiting username to 0-9A-Za-z
func (u *UserDatabase) IsNameExists(username string) bool { func (u *UserDatabase) IsNameExists(username string) bool {
username = strings.NewReplacer(";", "", "'", "", `"`, "", "=", "").Replace(username)
for _, c := range username { for _, c := range username {
if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') { if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') {
return false return false
@@ -354,17 +351,25 @@ func (u *UserDatabase) SendMessage(m *Message) error {
} }
// NotifyRegister will send register notification to all supers // NotifyRegister will send register notification to all supers
func (u *UserDatabase) NotifyRegister(name, cont, pswd string) error { func (u *UserDatabase) NotifyRegister(ip, name, cont, pswd string) error {
if name == "" { if name == "" {
return ErrEmptyName return ErrEmptyName
} }
if cont == "" {
return ErrEmptyContact
}
if pswd == "" {
return ErrEmptyPassword
}
for _, c := range name { for _, c := range name {
if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') { if !(c >= '0' && c <= '9') && !(c >= 'A' && c <= 'Z') && !(c >= 'a' && c <= 'z') {
return ErrInvalidName return ErrInvalidName
} }
} }
if pswd == "" {
return ErrEmptyPassword _, err := u.GetUserByName(name)
if err == nil {
return ErrInvalidName
} }
tos, err := u.GetSuperIDs() tos, err := u.GetSuperIDs()
@@ -374,7 +379,7 @@ func (u *UserDatabase) NotifyRegister(name, cont, pswd string) error {
m := Message{ m := Message{
Date: time.Now().Unix(), Date: time.Now().Unix(),
Text: "收到来自 " + name + " 的注册请求, 联系方式: " + cont, Text: "收到来自 " + ip + ", 用户名 " + name + " 的注册请求, 联系方式: " + cont,
Name: name, Name: name,
Cont: cont, Cont: cont,
Pswd: pswd, Pswd: pswd,
@@ -424,6 +429,9 @@ func (u *UserDatabase) notifyContactChange(name, cont string) error {
if name == "" { if name == "" {
return ErrEmptyName return ErrEmptyName
} }
if cont == "" {
return ErrEmptyContact
}
tos, err := u.GetSuperIDs() tos, err := u.GetSuperIDs()
if err != nil { if err != nil {

View File

@@ -27,11 +27,17 @@ func getMessageList(token string) ([]messageList, error) {
return nil, nil return nil, nil
} }
ml := make([]messageList, len(ms)) ml := make([]messageList, len(ms))
am := make(map[string]string, 64)
for i, m := range ms { for i, m := range ms {
avtr := "" avtr := ""
u, err := global.UserDB.GetUserByName(m.Name) if a, ok := am[m.Name]; ok {
if err == nil { avtr = a
avtr = u.Avtr } else {
u, err := global.UserDB.GetUserByName(m.Name)
if err == nil {
avtr = u.Avtr
am[m.Name] = u.Avtr
}
} }
ml[i].ID = *m.ID ml[i].ID = *m.ID
ml[i].Avatar = avtr ml[i].Avatar = avtr

27
backend/register.go Normal file
View File

@@ -0,0 +1,27 @@
package backend
import (
"errors"
"time"
"github.com/FloatTech/ttl"
"github.com/fumiama/paper-manager/backend/global"
)
var registerlimit = ttl.NewCache[string, bool](time.Minute * 10)
var (
errRegisterTooFast = errors.New("register too fast")
errInvalidIP = errors.New("invalid IP")
)
func register(ip, name, mobile, npwd string) error {
if registerlimit.Get(ip) {
return errRegisterTooFast
}
if ip == "" {
return errInvalidIP
}
registerlimit.Set(ip, true)
return global.UserDB.NotifyRegister(ip, name, mobile, npwd)
}

View File

@@ -6,19 +6,10 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// IP gets ip from r.Header's X-FORWARDED-FOR or r.RemoteAddr
func IP(r *http.Request) string {
forwarded := r.Header.Get("X-FORWARDED-FOR")
if forwarded != "" {
return forwarded
}
return r.RemoteAddr
}
// IsMethod check if the method meets the requirement // IsMethod check if the method meets the requirement
// and response 405 Method Not Allowed if not matched // and response 405 Method Not Allowed if not matched
func IsMethod(m string, w http.ResponseWriter, r *http.Request) bool { func IsMethod(m string, w http.ResponseWriter, r *http.Request) bool {
logrus.Infoln("[utils.IsMethod] accept", IP(r), r.Method, r.URL) logrus.Infoln("[utils.IsMethod] accept", r.RemoteAddr, r.Method, r.URL)
if r.Method != m { if r.Method != m {
http.Error(w, "405 Method Not Allowed", http.StatusMethodNotAllowed) http.Error(w, "405 Method Not Allowed", http.StatusMethodNotAllowed)
return false return false

View File

@@ -103,7 +103,7 @@ export default [
}, },
}, },
// mock register // mock register
{ /*{
url: '/api/register', url: '/api/register',
timeout: 200, timeout: 200,
method: 'post', method: 'post',
@@ -113,7 +113,7 @@ export default [
msg: '已将用户' + username + '电话' + mobile + '的注册请求上报, 请耐心等待!', msg: '已将用户' + username + '电话' + mobile + '的注册请求上报, 请耐心等待!',
}) })
}, },
}, },*/
/*{ /*{
url: '/api/getUserInfo', url: '/api/getUserInfo',
method: 'get', method: 'get',