From bc5a8afd6c28b152115b09b7eb8b592b850e0d38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8C=B6=E6=B8=85=E6=AC=A2?= Date: Thu, 13 Feb 2025 16:03:31 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=A2=9E=E5=8A=A0=E5=BF=85?= =?UTF-8?q?=E4=BC=A0=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate.go | 25 +++++++++++----------- parser_test.go | 6 +++--- struct_field.go | 15 ++++++++++++++ validateRule.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 validateRule.go diff --git a/generate.go b/generate.go index d29c6e7..10c0ff7 100644 --- a/generate.go +++ b/generate.go @@ -278,19 +278,13 @@ func (g *Generate) AddComponentsSchema(rootSchemaName string, pkgPath string, in schemaName := strings.ReplaceAll(pkgPath+"."+inputName, "/", "-") if _, exist := g.docData.Components.Schemas[schemaName]; !exist { s := &define.Schema{ - Nullable: false, - Discriminator: nil, - ReadOnly: false, - WriteOnly: false, - Xml: nil, - ExternalDocs: nil, - Example: "", - Deprecated: false, - Properties: make(map[string]*define.Property), - Required: make([]string, 0), - Enum: make([]any, 0), - Type: consts.SwaggerDataTypeObject, - Ref: g.getSchemaRef(schemaName), + Nullable: false, + Deprecated: false, + Properties: make(map[string]*define.Property), + Required: make([]string, 0), + Enum: make([]any, 0), + Type: consts.SwaggerDataTypeObject, // TODO : 区分数组 + Ref: g.getSchemaRef(schemaName), } if len(rootSchemaName) == 0 || inputType.Kind() == reflect.Struct { s.Ref = "" @@ -328,6 +322,11 @@ func (g *Generate) AddComponentsSchema(rootSchemaName string, pkgPath string, in } // g.docData.Components.Schemas[schemaName].Ref = consts.SwaggerDataTypeObject for i := 0; i < inputType.NumField(); i++ { + if ValidateRule.IsRequired(inputType.Field(i)) { + // 必传字段 + g.docData.Components.Schemas[schemaName].Required = append(g.docData.Components.Schemas[schemaName].Required, ParseStructField.GetParamName(inputType.Field(i))) + } + if inputType.Field(i).Type.Kind() == reflect.Ptr || inputType.Field(i).Type.Kind() == reflect.Struct || inputType.Field(i).Type.Kind() == reflect.Map || diff --git a/parser_test.go b/parser_test.go index 1caf209..25c293d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -24,12 +24,12 @@ import ( // Date : 17:55 2024/7/19 func Test_parser_Openapi3(t *testing.T) { type User struct { - Name string `json:"name" d:"zhang" desc:"用户姓名"` - Age int `json:"age" d:"18" desc:"年龄"` + Name string `json:"name" d:"zhang" desc:"用户姓名" binding:"required"` + Age int `json:"age" d:"18" desc:"年龄" binding:"required"` } type List struct { Total int64 `json:"total"` - UserList []User `json:"user_list"` + UserList []User `json:"user_list" binding:"required"` } var l List g := NewOpenapiDoc(nil, nil) diff --git a/struct_field.go b/struct_field.go index e672641..1d4b977 100644 --- a/struct_field.go +++ b/struct_field.go @@ -71,3 +71,18 @@ func (psf parseStructField) GetDefaultValue(structField reflect.StructField) str } return "" } + +// GetValidateRule 获取验证规则 +// +// Author : go_developer@163.com<白茶清欢> +// +// Date : 15:30 2025/2/13 +func (psf parseStructField) GetValidateRule(structField reflect.StructField) string { + defaultTagList := []string{define.TagValidate, define.TagBinding} + for _, tag := range defaultTagList { + if tagVal, exist := structField.Tag.Lookup(tag); exist && len(tagVal) > 0 { + return tagVal + } + } + return "" +} diff --git a/validateRule.go b/validateRule.go new file mode 100644 index 0000000..92bef26 --- /dev/null +++ b/validateRule.go @@ -0,0 +1,55 @@ +// Package api_doc ... +// +// Description : api_doc ... +// +// Author : go_developer@163.com<白茶清欢> +// +// Date : 2025-02-13 15:26 +package api_doc + +import ( + "git.zhangdeman.cn/zhangdeman/consts" + "reflect" + "strings" +) + +var ( + ValidateRule = validateRule{} +) + +type validateRule struct{} + +// IsRequired 判断是否必传 +// +// Author : go_developer@163.com<白茶清欢> +// +// Date : 15:32 2025/2/13 +func (r validateRule) IsRequired(structField reflect.StructField) bool { + ruleTable := r.getValidateRuleTable(structField) + _, exist := ruleTable[consts.ValidatorRuleCommonRequired.String()] + // 存在即为必传 + return exist +} + +// getValidateRuleTable 解析验证规则表 +// +// Author : go_developer@163.com<白茶清欢> +// +// Date : 15:29 2025/2/13 +func (r validateRule) getValidateRuleTable(structField reflect.StructField) map[string]string { + res := map[string]string{} + ruleStr := ParseStructField.GetValidateRule(structField) + if len(ruleStr) == 0 { + return res + } + expressList := strings.Split(ruleStr, ",") + for _, item := range expressList { + if strings.Contains(item, "=") { + arr := strings.Split(item, "=") + res[strings.TrimSpace(arr[0])] = strings.Join(arr[1:], "=") + } else { + res[strings.TrimSpace(item)] = "" + } + } + return res +}