validate/validate.go

245 lines
7.0 KiB
Go

// Package validate ...
//
// Description : validate ...
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2025-03-18 15:10
package validate
import (
"encoding/json"
"errors"
"fmt"
"git.zhangdeman.cn/zhangdeman/consts"
dynamicStructGenerate "git.zhangdeman.cn/zhangdeman/dynamic-struct"
"git.zhangdeman.cn/zhangdeman/json_filter/gjson_hack"
"git.zhangdeman.cn/zhangdeman/serialize"
"github.com/creasty/defaults"
"github.com/go-playground/validator/v10"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"strings"
)
var validatorInstance *validator.Validate
func init() {
validatorInstance = validator.New()
validatorInstance.SetTagName(TagValidate)
}
// Run 执行参数验证
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 15:12 2025/3/18
func Run(sourceData []byte, fieldList []StructField) ([]byte, error) {
handleInstance := &handle{
sourceData: sourceData,
fieldList: fieldList,
parentFieldTable: map[string]bool{},
}
tagTable := map[string]string{}
for _, item := range fieldList {
tagTable[item.TargetPath] = handleInstance.generateTag(item)
// 检测当前字段是否是某一个字段的父级字段
for _, field := range fieldList {
if field.TargetPath == item.TargetPath {
continue
}
// 当前字段以itemTarget开头
if strings.HasPrefix(field.TargetPath, item.TargetPath) {
// item.TargetPath是父级字段
handleInstance.parentFieldTable[item.TargetPath] = true
}
}
}
handleInstance.dynamicStruct = dynamicStructGenerate.NewStruct(tagTable)
return handleInstance.Run()
}
type handle struct {
sourceData []byte
fieldList []StructField
dynamicStruct dynamicStructGenerate.Builder
formatVal string
parentFieldTable map[string]bool // 父级字段路径表
}
// Run 执行验证
func (h *handle) Run() ([]byte, error) {
for _, field := range h.fieldList {
if h.parentFieldTable[field.TargetPath] {
// 中间层级字段, 无需额外处理
continue
}
if len(field.Errmsg) == 0 {
field.Errmsg = field.JsonTag + " : 参数校验不通过"
}
required, hasRequired := h.checkRequired(field)
field.Required = required
if field.Required && !hasRequired {
if nil == field.RuleList {
field.RuleList = make([]Rule, 0)
}
field.RuleList = append(field.RuleList, Rule{
Tag: consts.ValidatorRuleCommonRequired.String(),
Args: nil,
})
}
// 格式化数据
if _, err := h.formatDataValue(field); nil != err {
return nil, err
}
// 支持嵌套结构体
fieldTag := h.generateTag(field)
// 这里需要设置为对应类型的零值就行, 此处传入值的目的只是为了确认数据类型
h.dynamicStruct.AddField(field.JsonTag, "", consts.GetDataTypeDefaultValue(field.Type), fieldTag, false)
}
val := h.dynamicStruct.Build().New()
if err := serialize.JSON.UnmarshalWithNumber([]byte(h.formatVal), &val); nil != err {
return nil, err
}
if err := defaults.Set(&val); nil != err {
// 默认值设置失败
return nil, err
}
if err := validatorInstance.Struct(val); nil != err {
return nil, GetValidateErr(val, err, TagErrMsg)
}
targetByte, _ := json.Marshal(val)
return targetByte, nil
}
// checkRequired 格式化必传参数
// 返回值1, 是否必传
// 返回值2, 校验规则中是否存在必传校验
func (h *handle) checkRequired(field StructField) (bool, bool) {
required := field.Required
isHasRequiredRule := false
for _, rule := range field.RuleList {
if rule.Tag == consts.ValidatorRuleCommonRequired.String() {
isHasRequiredRule = true
break
}
}
if isHasRequiredRule {
required = true
}
if required {
if !isHasRequiredRule {
// 必传, 但是没有必传校验规则
return true, true
}
return true, false
}
return false, isHasRequiredRule
}
// getSourceDataValue 获取源数据值
func (h *handle) formatDataValue(field StructField) (any, error) {
sourceValue := gjson.GetBytes(h.sourceData, field.SourcePath)
if !sourceValue.Exists() {
if field.Required {
return nil, errors.New(field.SourcePath + " is required")
}
if len(field.DefaultValue) > 0 {
// 非必传, 且设置了默认值, 在数据源不存在时, 这只默认值
sourceValue = gjson.Result{
Type: gjson.String,
Raw: field.DefaultValue,
Str: field.DefaultValue,
Num: 0,
Index: 0,
Indexes: nil,
}
} else {
// 非必传, 且没有默认值, 就当做数据不存在
return nil, nil
}
}
var (
val any
err error
)
defer func() {
if nil == err && nil != val {
// 更新到格式化之后的数据结构中
h.formatVal, err = sjson.Set(h.formatVal, field.TargetPath, val)
}
}()
switch field.Type {
case consts.DataTypeInt: // Int类型
fallthrough
case consts.DataTypeIntPtr: // Uint类型
val, err = gjson_hack.Int(sourceValue)
return val, err
case consts.DataTypeUint:
fallthrough
case consts.DataTypeUintPtr: // Uint类型
val, err = gjson_hack.Uint(sourceValue)
return val, err
case consts.DataTypeFloat32:
fallthrough
case consts.DataTypeFloat32Ptr: // Float类型
val, err = gjson_hack.Float64(sourceValue)
return val, err
case consts.DataTypeString: // String类型
val = sourceValue.String()
return val, nil
case consts.DataTypeBool: // Bool类型
val = sourceValue.Bool()
return val, nil
case consts.DataTypeSliceFloat: // Float slice
val, err = gjson_hack.SliceFloat(sourceValue)
return val, err
case consts.DataTypeSliceInt: // Int slice
val, err = gjson_hack.SliceInt(sourceValue)
return val, err
case consts.DataTypeSliceUint: // Uint slice
val, err = gjson_hack.SliceUint(sourceValue)
return val, err
case consts.DataTypeSliceString: // String slice
val, err = gjson_hack.SliceString(sourceValue)
return val, err
case consts.DataTypeSliceBool: // Bool slice
val, err = gjson_hack.SliceBool(sourceValue)
return val, err
case consts.DataTypeMapStrAny: // Bool slice
val, err = gjson_hack.MapStrAny[any](sourceValue)
return val, err
}
val = sourceValue.Value()
return val, nil
}
// 生成结构体的tag标签
func (h *handle) generateTag(field StructField) string {
tagList := []string{
fmt.Sprintf(`json:"%s"`, field.JsonTag), // json tag
}
if len(field.Errmsg) == 0 {
field.Errmsg = field.JsonTag + ": 参数校验不通过"
}
tagList = append(tagList, fmt.Sprintf(`%s:"%s"`, TagErrMsg, field.Errmsg)) // 错误信息tag
validateRuleList := []string{}
for _, itemRule := range field.RuleList {
if len(itemRule.Args) == 0 {
validateRuleList = append(validateRuleList, itemRule.Tag)
} else {
validateRuleList = append(validateRuleList, fmt.Sprintf("%s=%s", itemRule.Tag, strings.Join(itemRule.Args, " ")))
}
}
// 验证规则tag
tagList = append(tagList, fmt.Sprintf(`%s:"%s"`, TagValidate, strings.Join(validateRuleList, ",")))
// 默认值
if field.DefaultValue == "-" && (field.Type == consts.DataTypeString || field.Type == consts.DataTypeStringPtr) {
field.DefaultValue = ""
}
tagList = append(tagList, fmt.Sprintf(`%s:"%s"`, TagDefaultValue, field.DefaultValue))
return strings.Join(tagList, " ")
}