diff --git a/apps/user.go b/apps/user.go index c9ef1fd..4b9d96e 100644 --- a/apps/user.go +++ b/apps/user.go @@ -21,16 +21,19 @@ func InitUserGroup(g *echo.Group) { } ctx := context.WithValue(context.Background(), constant.DBKey, repository.Db) + platform := utils.GetPlatform(&c) + req.Platform = platform - err := service.Login.Register(ctx, *req) + err := service.UserBiz.Register(ctx, *req) if err != nil { return c.JSON(http.StatusOK, utils.Error(err)) } //登录 - token, _ := service.Login.Login(ctx, dto.LoginReq{ + token, _ := service.UserBiz.Login(ctx, dto.LoginReq{ Account: req.Account, Password: req.Password, + Platform: platform, }) return c.JSON(http.StatusOK, utils.Ok(token)) @@ -44,8 +47,9 @@ func InitUserGroup(g *echo.Group) { } ctx := context.WithValue(context.Background(), constant.DBKey, repository.Db) - - token, err := service.Login.Login(ctx, *req) + //设置platform + req.Platform = utils.GetPlatform(&c) + token, err := service.UserBiz.Login(ctx, *req) if err != nil { return c.JSON(http.StatusOK, utils.Error(err)) } @@ -60,7 +64,7 @@ func InitUserGroup(g *echo.Group) { return c.JSON(http.StatusOK, utils.Error(err)) } - err := service.Login.SendResetPwdCode(context.Background(), req.Account) + err := service.UserBiz.SendResetPwdCode(context.Background(), req.Email) if err != nil { return c.JSON(http.StatusOK, utils.Error(err)) } @@ -77,7 +81,7 @@ func InitUserGroup(g *echo.Group) { ctx := context.WithValue(context.Background(), constant.DBKey, repository.Db) - err := service.Login.ResetPwd(ctx, *req) + err := service.UserBiz.ResetPwd(ctx, *req) if err != nil { return c.JSON(http.StatusOK, utils.Error(err)) } @@ -88,7 +92,10 @@ func InitUserGroup(g *echo.Group) { //根据 token 获取 登录信息 g.POST("findLoginResult", func(c echo.Context) error { ctx := context.WithValue(context.Background(), constant.DBKey, repository.Db) - loginResult := service.Login.GetLoginResult(ctx, &c) + loginResult, err := service.UserBiz.GetLoginResult(ctx, &c) + if err != nil { + return c.JSON(http.StatusOK, utils.Error(err)) + } return c.JSON(http.StatusOK, utils.Ok(loginResult)) }) } diff --git a/common/config/config.dev.yaml b/common/config/config.dev.yaml index 07c7228..506abfb 100644 --- a/common/config/config.dev.yaml +++ b/common/config/config.dev.yaml @@ -1,6 +1,7 @@ # 本机配置 -fbServer: +myServer: appName: mylomen_server + jwtSigningKey: t2vr8fyqh5mvfvrjdszb6ev3eas4a9kw # fbConsul fbConsul: diff --git a/common/config/config.prod.yaml b/common/config/config.prod.yaml index 5c4bb4a..a97fc90 100644 --- a/common/config/config.prod.yaml +++ b/common/config/config.prod.yaml @@ -1,6 +1,7 @@ # 本机配置 -fbServer: +myServer: appName: mylomen_server + jwtSigningKey: t6p4n4g79mceqbe2b6syq4k5ahupy2h9 # fbConsul fbConsul: diff --git a/common/config/init.go b/common/config/init.go index 9f18646..3bfa07d 100644 --- a/common/config/init.go +++ b/common/config/init.go @@ -42,6 +42,11 @@ type Redis struct { WriteTimeout int64 `mapstructure:"writeTimeout" json:"writeTimeout" yaml:"writeTimeout"` } +type MyServer struct { + AppName string `mapstructure:"appName" json:"appName" yaml:"appName"` + JwtSigningKey string `mapstructure:"jwtSigningKey" json:"jwtSigningKey" yaml:"jwtSigningKey"` +} + type Conf struct { FbConsul *FbConsul `mapstructure:"fbConsul" json:"fbConsul" yaml:"fbConsul"` // fbConsul 配置 @@ -50,6 +55,8 @@ type Conf struct { PgSql *PgSql `mapstructure:"pgSql" json:"pgSql" yaml:"pgSql"` // pgSql 配置 Redis *Redis `mapstructure:"redis" json:"redis" yaml:"redis"` // redis 配置 + + MyServer *MyServer `mapstructure:"myServer" json:"myServer" yaml:"myServer"` // myServer 配置 } var Instance = initCf() diff --git a/common/constant/redis.go b/common/constant/redis.go index 78fbc61..d1625be 100644 --- a/common/constant/redis.go +++ b/common/constant/redis.go @@ -3,5 +3,7 @@ package constant const ( THIRD_LOGIN_TOKEN = ":login:token:" + THIRD_LOGIN_SN = ":login:sn:" + G_RESET_PWD_CODE = "g:reset:pwd:email:" ) diff --git a/common/dto/third_user_login.go b/common/dto/third_user_login.go index d834e35..bb049b2 100644 --- a/common/dto/third_user_login.go +++ b/common/dto/third_user_login.go @@ -8,7 +8,7 @@ type RegisterReq struct { } type SendResetPwdCodeReq struct { - Account string `json:"account"` // 帐号 + Email string `json:"email"` // 帐号 } type ResetPwdReq struct { @@ -20,9 +20,12 @@ type ResetPwdReq struct { } type LoginReq struct { - Account *string `json:"account"` // 帐号 - Email *string `json:"email"` // 邮箱 - Phone *string `json:"phone"` // 手机号码 + Account *string `json:"account"` // 帐号 + Password string `json:"password"` // 密码 + Platform string `json:"platform"` // 平台 + + Email *string `json:"email"` // 邮箱 + Phone *string `json:"phone"` // 手机号码 WxId *string `json:"wxId"` // 微信unionId QqId *string `json:"qqId"` // qqId @@ -30,7 +33,6 @@ type LoginReq struct { GoogleId *string `json:"googleId"` // googleId FacebookId *string `json:"facebookId"` // facebookId - Password string `json:"password"` // 密码 } type UserLoginToken struct { diff --git a/common/dto/user_vo.go b/common/dto/user_vo.go index cec2172..815e335 100644 --- a/common/dto/user_vo.go +++ b/common/dto/user_vo.go @@ -1,10 +1,51 @@ package dto -type UserVO struct { - Sn string `json:"sn"` - Token string `json:"token"` - Account *string `json:"Account"` +import ( + "github.com/golang-jwt/jwt/v5" + "time" +) +type UserVO struct { + Sn string `json:"sn"` + Token string `json:"token"` + Platform string `json:"platform"` + + Account *string `json:"Account"` Nickname *string `json:"nickname"` Avatar *string `json:"avatar"` } + +type LoginTokenVo struct { + jwt.RegisteredClaims + + Sn string `json:"sn"` + Token string `json:"token"` + Platform string `json:"platform"` +} + +func (vo *LoginTokenVo) SetExtByPlatform(platform string) *LoginTokenVo { + switch platform { + case "web": + { + vo.Platform = "web" + vo.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 7)) + } + case "android": + { + vo.Platform = "android" + vo.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 30)) + } + case "ios": + { + vo.Platform = "ios" + vo.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 30)) + } + default: + { + vo.Platform = "web" + vo.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 7 * 30)) + } + } + + return vo +} diff --git a/common/utils/httpHeaderUtil.go b/common/utils/httpHeaderUtil.go index 2c12b93..5340554 100644 --- a/common/utils/httpHeaderUtil.go +++ b/common/utils/httpHeaderUtil.go @@ -12,3 +12,7 @@ func GetAccessToken(c *echo.Context) string { return (*c).Request().Header.Get("udt") } + +func GetPlatform(c *echo.Context) string { + return (*c).Request().Header.Get("platform") +} diff --git a/common/xjwt/jwt.go b/common/xjwt/jwt.go new file mode 100644 index 0000000..84adaa0 --- /dev/null +++ b/common/xjwt/jwt.go @@ -0,0 +1,36 @@ +package xjwt + +import ( + "github.com/golang-jwt/jwt/v5" + "mylomen_server/common/config" + "mylomen_server/common/dto" + "time" +) + +func GenJwtToken(claims dto.LoginTokenVo) string { + mySigningKey := []byte(config.Instance.MyServer.JwtSigningKey) + + claims.Issuer = "mylomen.com" + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + + //加密 + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + ss, _ := token.SignedString(mySigningKey) + return ss +} + +func ParseJwtToken(jwtToken string) (*dto.LoginTokenVo, error) { + if jwtToken == "" || len(jwtToken) > 256 { + return nil, nil + } + //parse + parseToken, err := jwt.ParseWithClaims(jwtToken, &dto.LoginTokenVo{}, func(token *jwt.Token) (interface{}, error) { + return []byte(config.Instance.MyServer.JwtSigningKey), nil + }) + //check + if userClaims, ok := parseToken.Claims.(*dto.LoginTokenVo); ok && parseToken.Valid { + return userClaims, nil + } else { + return nil, err + } +} diff --git a/common/xjwt/jwt_test.go b/common/xjwt/jwt_test.go new file mode 100644 index 0000000..ed5b3fa --- /dev/null +++ b/common/xjwt/jwt_test.go @@ -0,0 +1,109 @@ +package xjwt + +import ( + "fmt" + "github.com/golang-jwt/jwt/v5" + "mylomen_server/common/dto" + "testing" + "time" +) + +func TestGenAndParse(t *testing.T) { + mySigningKey := []byte("AllYourBase") + + claims := dto.LoginTokenVo{ + Sn: "mySn", + Token: "myToken", + + RegisteredClaims: jwt.RegisteredClaims{ + Audience: []string{"mylomen.com"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "mylomen.com", + NotBefore: jwt.NewNumericDate(time.Now()), + Subject: "1", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + ss, err := token.SignedString(mySigningKey) + fmt.Println(ss, err) + + parseToken, err := jwt.ParseWithClaims(ss, &dto.LoginTokenVo{}, func(token *jwt.Token) (interface{}, error) { + return mySigningKey, nil + }) + + if userClaims, ok := parseToken.Claims.(*dto.LoginTokenVo); ok && parseToken.Valid { + t.Log(userClaims, userClaims.RegisteredClaims.Issuer) + } else { + fmt.Println(err) + t.Error("验证失败") + } +} + +func TestHs256(t *testing.T) { + type User struct { + Id int64 + Name string + } + type UserClaims struct { + User User + jwt.RegisteredClaims + } + // 1 jwt.NewWithClaims生成token + user := User{ + Id: 101, + Name: "hisheng", + } + userClaims := UserClaims{ + User: user, + RegisteredClaims: jwt.RegisteredClaims{}, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, userClaims) + + // 2 把token加密 + mySigningKey := []byte("ushjlwmwnwht") + ss, err := token.SignedString(mySigningKey) + t.Log(ss, err) +} + +func TestHs256Parse(t *testing.T) { + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJVc2VyIjp7IklkIjoxMDEsIk5hbWUiOiJoaXNoZW5nIn19.ij1kWID03f_CiELe0fPLZJ-Y64dkf2nDE-f6nGERBSE" + + type User struct { + Id int64 + Name string + } + type UserClaims struct { + User User + jwt.RegisteredClaims + } + + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte("ushjlwmwnwht"), nil + }) + + if userClaims, ok := token.Claims.(*UserClaims); ok && token.Valid { + t.Log(userClaims, userClaims.RegisteredClaims.Issuer) + } else { + t.Log(err) + } +} + +func TestGenJwtTokenThenParse(t *testing.T) { + claims := dto.LoginTokenVo{} + claims.Sn = "mySn" + claims.Token = "myToken" + claims.RegisteredClaims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(24 * time.Hour)) + jwtToken := GenJwtToken(claims) + + //parse + data, err := ParseJwtToken(jwtToken) + if err != nil { + t.Error(err) + return + } + if data.Sn != claims.Sn { + t.Error("Sn 不一致") + } +} diff --git a/go.mod b/go.mod index 6f8ff82..c34e4be 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect diff --git a/go.sum b/go.sum index 6e7bb69..216b5dc 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/infrastructure/repository/user_base.go b/infrastructure/repository/user_base.go index 53e689b..d71f653 100644 --- a/infrastructure/repository/user_base.go +++ b/infrastructure/repository/user_base.go @@ -8,7 +8,10 @@ import ( type gUserRepository interface { FindByReq(ctx context.Context, req dto.LoginReq) *UserDO + //FindById 根据 id 查询用户 FindById(ctx context.Context, id int64) *UserDO + FindBySn(ctx context.Context, sn string) *UserDO + //FindByAccount 根据 account 查询用户 FindByAccount(ctx context.Context, account string) *UserDO FindByEmail(ctx context.Context, email string) *UserDO @@ -25,4 +28,4 @@ type gUserRepository interface { UpdateById(ctx context.Context, data *UserDO) error } -var GUser gUserRepository = new(gUserRepositoryImpl) +var GUser gUserRepository = new(userRepositoryImpl) diff --git a/infrastructure/repository/user_base_repository.go b/infrastructure/repository/user_base_repository.go index 34e3e98..6dbf25a 100644 --- a/infrastructure/repository/user_base_repository.go +++ b/infrastructure/repository/user_base_repository.go @@ -43,10 +43,10 @@ func (m *UserDO) TableName() string { return "user_base" } -type gUserRepositoryImpl struct { +type userRepositoryImpl struct { } -func (rp *gUserRepositoryImpl) FindByReq(ctx context.Context, req dto.LoginReq) *UserDO { +func (rp *userRepositoryImpl) FindByReq(ctx context.Context, req dto.LoginReq) *UserDO { if req.Account != nil { return rp.FindByAccount(ctx, *req.Account) } @@ -73,7 +73,7 @@ func (rp *gUserRepositoryImpl) FindByReq(ctx context.Context, req dto.LoginReq) return nil } -func (rp *gUserRepositoryImpl) FindById(ctx context.Context, id int64) *UserDO { +func (rp *userRepositoryImpl) FindById(ctx context.Context, id int64) *UserDO { //验证参数 if id <= 0 { return nil @@ -90,7 +90,24 @@ func (rp *gUserRepositoryImpl) FindById(ctx context.Context, id int64) *UserDO { return &model } -func (rp *gUserRepositoryImpl) FindByAccount(ctx context.Context, account string) *UserDO { +func (rp *userRepositoryImpl) FindBySn(ctx context.Context, sn string) *UserDO { + //验证参数 + if sn == "" || len(sn) > 16 { + return nil + } + db, ok := ctx.Value("db").(*gorm.DB) + if !ok { + db = Db + } + var model UserDO + sqlErr := db.Model(&model).Where("sn=?", sn).First(&model).Error + if sqlErr != nil { + return nil + } + return &model +} + +func (rp *userRepositoryImpl) FindByAccount(ctx context.Context, account string) *UserDO { //验证参数 if account == "" || len(account) > 32 { return nil @@ -107,7 +124,7 @@ func (rp *gUserRepositoryImpl) FindByAccount(ctx context.Context, account string return &model } -func (rp *gUserRepositoryImpl) FindByEmail(ctx context.Context, email string) *UserDO { +func (rp *userRepositoryImpl) FindByEmail(ctx context.Context, email string) *UserDO { //验证参数 if email == "" || len(email) > 64 { return nil @@ -124,7 +141,7 @@ func (rp *gUserRepositoryImpl) FindByEmail(ctx context.Context, email string) *U return &model } -func (rp *gUserRepositoryImpl) FindByPhone(ctx context.Context, phone string) *UserDO { +func (rp *userRepositoryImpl) FindByPhone(ctx context.Context, phone string) *UserDO { //验证参数 if phone == "" || len(phone) > 16 { return nil @@ -141,7 +158,7 @@ func (rp *gUserRepositoryImpl) FindByPhone(ctx context.Context, phone string) *U return &model } -func (rp *gUserRepositoryImpl) FindByWxId(ctx context.Context, wxId string) *UserDO { +func (rp *userRepositoryImpl) FindByWxId(ctx context.Context, wxId string) *UserDO { //验证参数 if wxId == "" || len(wxId) > 16 { return nil @@ -158,7 +175,7 @@ func (rp *gUserRepositoryImpl) FindByWxId(ctx context.Context, wxId string) *Use return &model } -func (rp *gUserRepositoryImpl) FindByQqId(ctx context.Context, qqId string) *UserDO { +func (rp *userRepositoryImpl) FindByQqId(ctx context.Context, qqId string) *UserDO { //验证参数 if qqId == "" || len(qqId) > 32 { return nil @@ -175,7 +192,7 @@ func (rp *gUserRepositoryImpl) FindByQqId(ctx context.Context, qqId string) *Use return &model } -func (rp *gUserRepositoryImpl) FindByGoogleId(ctx context.Context, googleId string) *UserDO { +func (rp *userRepositoryImpl) FindByGoogleId(ctx context.Context, googleId string) *UserDO { //验证参数 if googleId == "" || len(googleId) > 32 { return nil @@ -192,7 +209,7 @@ func (rp *gUserRepositoryImpl) FindByGoogleId(ctx context.Context, googleId stri return &model } -func (rp *gUserRepositoryImpl) FindByFacebookId(ctx context.Context, facebookId string) *UserDO { +func (rp *userRepositoryImpl) FindByFacebookId(ctx context.Context, facebookId string) *UserDO { //验证参数 if facebookId == "" || len(facebookId) > 32 { return nil @@ -210,7 +227,7 @@ func (rp *gUserRepositoryImpl) FindByFacebookId(ctx context.Context, facebookId } // Create 创建用户 -func (rp *gUserRepositoryImpl) Create(ctx context.Context, data *UserDO) error { +func (rp *userRepositoryImpl) Create(ctx context.Context, data *UserDO) error { db, ok := ctx.Value("db").(*gorm.DB) if !ok { db = Db @@ -225,7 +242,7 @@ func (rp *gUserRepositoryImpl) Create(ctx context.Context, data *UserDO) error { } // UpdateById 根据ID更新用户 -func (rp *gUserRepositoryImpl) UpdateById(ctx context.Context, data *UserDO) error { +func (rp *userRepositoryImpl) UpdateById(ctx context.Context, data *UserDO) error { db, ok := ctx.Value("db").(*gorm.DB) if !ok { db = Db diff --git a/service/login.go b/service/user.go similarity index 61% rename from service/login.go rename to service/user.go index 391aae1..89a8d32 100644 --- a/service/login.go +++ b/service/user.go @@ -4,7 +4,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/json" "errors" "github.com/google/uuid" "github.com/labstack/echo/v4" @@ -14,6 +13,7 @@ import ( "mylomen_server/common/email" "mylomen_server/common/logs" "mylomen_server/common/utils" + "mylomen_server/common/xjwt" "mylomen_server/infrastructure/convert" "mylomen_server/infrastructure/redis" "mylomen_server/infrastructure/repository" @@ -21,22 +21,27 @@ import ( "time" ) -type login struct { +type user struct { } -var Login login +var UserBiz user // 令牌桶大小为 100, 以每秒 10 个 Token 的速率向桶中放置 Token var limiter = rate.NewLimiter(10, 10) // Register 注册 -func (l login) Register(ctx context.Context, req dto.RegisterReq) error { +func (l user) Register(ctx context.Context, req dto.RegisterReq) error { acUser := repository.GUser.FindByReq(ctx, req.LoginReq) if acUser != nil && acUser.Deleted == false { return errors.New("user exist") } + //fill data userDO := convert.UserReq2DO(req) + userDO.Deleted = false + userDO.UpdateTime = time.Now() + userDO.NickName = req.NickName + userDO.Avatar = req.Avatar //密码加密 h := sha256.Sum256([]byte(req.Password)) @@ -46,7 +51,7 @@ func (l login) Register(ctx context.Context, req dto.RegisterReq) error { return repository.GUser.Create(ctx, &userDO) } -func (l login) Login(ctx context.Context, req dto.LoginReq) (*dto.UserVO, error) { +func (l user) Login(ctx context.Context, req dto.LoginReq) (*dto.UserVO, error) { start := time.Now().UnixMilli() //1. 查询用户 @@ -70,20 +75,22 @@ func (l login) Login(ctx context.Context, req dto.LoginReq) (*dto.UserVO, error) //3. 组装数据 result := convert.UserDO2VO(*acUser) + result.Platform = req.Platform - //生成token todo 使用jwt - token := uuid.New().String() - token = strings.ReplaceAll(token, "-", "") - result.Token = token - + //生成token + token := strings.ReplaceAll(uuid.New().String(), "-", "") + claims := dto.LoginTokenVo{Sn: result.Sn, Token: token} + //设置超时时间 + claims.SetExtByPlatform(req.Platform) + result.Token = xjwt.GenJwtToken(claims) diff := time.Now().UnixMilli() - start logs.NewLog("").Infof("Login cost: %d", diff) return &result, nil } -func (l login) SendResetPwdCode(ctx context.Context, account string) error { +func (l user) SendResetPwdCode(ctx context.Context, emailP string) error { //1. 验证账号 - acUser := repository.GUser.FindByAccount(ctx, account) + acUser := repository.GUser.FindByEmail(ctx, emailP) if acUser == nil || acUser.Deleted == true { return errors.New("user not exist") } @@ -92,7 +99,7 @@ func (l login) SendResetPwdCode(ctx context.Context, account string) error { code := utils.GetPseudoRandomCode(6) //3. save into redis - if err := redis.Set(constant.G_RESET_PWD_CODE+account, code, time.Duration(30)*time.Minute); err != nil { + if err := redis.Set(constant.G_RESET_PWD_CODE+emailP, code, time.Duration(30)*time.Minute); err != nil { return errors.New("cache illegal") } @@ -102,7 +109,7 @@ func (l login) SendResetPwdCode(ctx context.Context, account string) error { } //4. send code - if err := email.SendEmailVerifyCodeByEmail(code, account); err != nil { + if err := email.SendEmailVerifyCodeByEmail(code, emailP); err != nil { return errors.New("send email code illegal") } @@ -110,7 +117,7 @@ func (l login) SendResetPwdCode(ctx context.Context, account string) error { } // ResetPwd 重置密码 -func (l login) ResetPwd(ctx context.Context, req dto.ResetPwdReq) error { +func (l user) ResetPwd(ctx context.Context, req dto.ResetPwdReq) error { //1. 验证code redisStr, err := redis.Get(constant.G_RESET_PWD_CODE + req.Account) if err != nil && redisStr != req.Code { @@ -127,29 +134,29 @@ func (l login) ResetPwd(ctx context.Context, req dto.ResetPwdReq) error { h := sha256.Sum256([]byte(req.Password)) passHash := hex.EncodeToString(h[:]) acUser.Pwd = &passHash - repository.GUser.UpdateById(ctx, acUser) - - //生成token - token := uuid.New().String() - token = strings.ReplaceAll(token, "-", "") + if updateErr := repository.GUser.UpdateById(ctx, acUser); updateErr != nil { + logs.NewLog("").Errorf("update user password req:%+v error: %+v", req, updateErr) + return errors.New("system error") + } return nil } -func (l login) GetLoginResult(ctx context.Context, c *echo.Context) *dto.UserVO { +func (l user) GetLoginResult(ctx context.Context, c *echo.Context) (*dto.UserVO, error) { accessToken := utils.GetAccessToken(c) - if accessToken == "" { - return nil + loginVo, err := xjwt.ParseJwtToken(accessToken) + if err != nil || loginVo.Sn == "" { + return nil, errors.New("token illegal") } - //redis - redisStr, err := redis.Get(constant.THIRD_LOGIN_TOKEN + accessToken) - if err != nil && redisStr != "" { - var loginInfo dto.UserLoginToken - if redisErr := json.Unmarshal([]byte(redisStr), &loginInfo); redisErr == nil { - return nil - } + //查询数据库 + userDO := repository.GUser.FindBySn(ctx, loginVo.Sn) + if err != nil || userDO == nil { + return nil, errors.New("user not exist") } - return nil + result := convert.UserDO2VO(*userDO) + result.Token = accessToken + result.Platform = loginVo.Platform + return utils.ToPtr(result), nil } diff --git a/service/login_test.go b/service/user_test.go similarity index 85% rename from service/login_test.go rename to service/user_test.go index 5a75ee1..8ce88b5 100644 --- a/service/login_test.go +++ b/service/user_test.go @@ -12,3 +12,7 @@ func Test1(t *testing.T) { passHash := hex.EncodeToString(h[:]) fmt.Print(passHash) } + +func TestLogin(t *testing.T) { + +}