diff --git a/openapi/generate.go b/openapi/generate.go index 23c2219..abd13d2 100644 --- a/openapi/generate.go +++ b/openapi/generate.go @@ -29,7 +29,7 @@ var ( func NewOpenApiDoc(optionFunc ...OptionFunc) *openapi3.T { t := &openapi3.T{ Extensions: map[string]any{}, - OpenAPI: "3.1.0", + OpenAPI: "3.0.3", Components: &openapi3.Components{ Extensions: map[string]any{}, Origin: &openapi3.Origin{ @@ -124,7 +124,13 @@ func NewOpenApiDoc(optionFunc ...OptionFunc) *openapi3.T { // Generate 生成 OpenApi 标准规范的文档 type Generate struct { - docTable map[string]*openapi3.T + docTable map[string]*openapi3.T + enableRedundantStorageComponents bool // 冗余存储 +} + +// EnableRedundantStorageComponents 开启冗余存储 +func (g *Generate) EnableRedundantStorageComponents() { + g.enableRedundantStorageComponents = true } // DocData 获取一个文档数据 @@ -141,10 +147,9 @@ func (g *Generate) NewOpenApiDoc(docFlag string, docOption ...OptionFunc) *opena return t } -// AddApiDoc 添加接口文档 -func (g *Generate) AddApiDoc(docFlag string, apiMeta define.UriConfig, request any, response any) error { +// formatType 输入输出数据类型, 转化为统一reflect.Type +func (g *Generate) formatType(request any, response any) (reflect.Type, reflect.Type) { var ( - err error requestType reflect.Type responseType reflect.Type ok bool @@ -164,63 +169,106 @@ func (g *Generate) AddApiDoc(docFlag string, apiMeta define.UriConfig, request a responseType = responseType.Elem() } } + return requestType, responseType +} - schemaData := GenerateOpenAPISchema(requestType) - apiOperate, isRead := g.initApiConfig(docFlag, apiMeta) - requestTypeStr := requestType.String() - if isRead { - for paramName, paramConfig := range schemaData.Value.Properties { - apiOperate.Parameters = append(apiOperate.Parameters, &openapi3.ParameterRef{ - Extensions: nil, - Origin: nil, - Ref: "", - Value: &openapi3.Parameter{ - Extensions: nil, - Origin: nil, - Name: paramName, - In: strings.ToLower(consts.RequestDataLocationQuery.String()), - Description: paramConfig.Value.Description, - Style: "", - Explode: nil, - AllowEmptyValue: paramConfig.Value.AllowEmptyValue, - AllowReserved: false, - Deprecated: false, - Required: op_array.ArrayType(paramConfig.Value.Required).Has(paramName) >= 0, - Schema: paramConfig, - Example: nil, - Examples: nil, - Content: nil, - }, - }) - } - } else { - apiOperate.RequestBody = &openapi3.RequestBodyRef{ +// getComponentsSchemaRef 获取组件 schema ref +func (g *Generate) getComponentsSchemaRef(structKindString string) string { + return "#/components/schemas/" + strings.TrimLeft(structKindString, "*") +} + +// setReadRequestParameter 设置读请求参数 +func (g *Generate) setReadRequestParameter(apiOperate *openapi3.Operation, schemaData *openapi3.SchemaRef) { + if nil == schemaData { + return + } + for paramName, paramConfig := range schemaData.Value.Properties { + apiOperate.Parameters = append(apiOperate.Parameters, &openapi3.ParameterRef{ Extensions: nil, Origin: nil, Ref: "", - Value: &openapi3.RequestBody{ - Extensions: nil, - Origin: nil, - Description: "", - Required: false, - Content: map[string]*openapi3.MediaType{ - consts.MimeTypeJson: { + Value: &openapi3.Parameter{ + Extensions: nil, + Origin: nil, + Name: paramName, + In: strings.ToLower(consts.RequestDataLocationQuery.String()), + Description: paramConfig.Value.Description, + Style: "", + Explode: nil, + AllowEmptyValue: paramConfig.Value.AllowEmptyValue, + AllowReserved: false, + Deprecated: false, + Required: op_array.ArrayType(paramConfig.Value.Required).Has(paramName) >= 0, + Schema: paramConfig, + Example: nil, + Examples: nil, + Content: nil, + }, + }) + } +} + +// setWriteRequestBody 设置写请求请求 Body +func (g *Generate) setWriteRequestBody(apiOperate *openapi3.Operation, schemaDataRef string) { + apiOperate.RequestBody = &openapi3.RequestBodyRef{ + Extensions: nil, + Origin: nil, + Ref: "", + Value: &openapi3.RequestBody{ + Extensions: nil, + Origin: nil, + Description: "接口请求数据", + Required: true, + Content: map[string]*openapi3.MediaType{ + consts.MimeTypeJson: { + Extensions: nil, + Origin: nil, + Schema: &openapi3.SchemaRef{ Extensions: nil, Origin: nil, - Schema: schemaData, - Example: nil, - Examples: nil, - Encoding: nil, + Ref: schemaDataRef, + Value: nil, }, + Example: nil, + Examples: nil, + Encoding: nil, }, }, + }, + } +} + +// AddApiDoc 添加接口文档 +func (g *Generate) AddApiDoc(docFlag string, apiMeta define.UriConfig, request any, response any) error { + var ( + err error + requestType reflect.Type + responseType reflect.Type + ) + + // 初始化请求数据与响应数据类型 + requestType, responseType = g.formatType(request, response) + + schemaData := GenerateOpenAPISchema(requestType) + apiOperate, isRead := g.initApiConfig(docFlag, apiMeta) + if isRead { + if g.enableRedundantStorageComponents { + // 此处是冗余 components 设置, 便于查看结构体, 不冗余文档也可正常解析 + requestTypeStr := requestType.String() + if _, exist := g.docTable[docFlag].Components.Schemas[requestTypeStr]; !exist { + g.docTable[docFlag].Components.Schemas[requestTypeStr] = schemaData + } + // 冗余处理结束 } + g.setReadRequestParameter(apiOperate, schemaData) + } else { + requestTypeStr := requestType.String() + if _, exist := g.docTable[docFlag].Components.Schemas[requestTypeStr]; !exist { + g.docTable[docFlag].Components.Schemas[requestTypeStr] = schemaData + } + g.setWriteRequestBody(apiOperate, g.getComponentsSchemaRef(requestType.String())) } - // 初始化接口配置 - if _, exist := g.docTable[docFlag].Components.Schemas[requestTypeStr]; !exist { - g.docTable[docFlag].Components.Schemas[requestTypeStr] = schemaData - } responseTypeStr := responseType.String() if _, exist := g.docTable[docFlag].Components.Schemas[responseTypeStr]; !exist { g.docTable[docFlag].Components.Schemas[responseTypeStr] = GenerateOpenAPISchema(responseType) @@ -231,18 +279,12 @@ func (g *Generate) AddApiDoc(docFlag string, apiMeta define.UriConfig, request a Origin: nil, Ref: "", Value: &openapi3.Response{ - Extensions: nil, - Origin: nil, Description: &desc, - Headers: nil, Content: map[string]*openapi3.MediaType{ consts.MimeTypeJson: { - Extensions: nil, - Origin: nil, - Schema: g.docTable[docFlag].Components.Schemas[responseTypeStr], - Example: nil, - Examples: nil, - Encoding: nil, + Schema: &openapi3.SchemaRef{ + Ref: g.getComponentsSchemaRef(responseTypeStr), + }, }, }, Links: nil, @@ -275,6 +317,12 @@ func (g *Generate) initApiConfig(docFlag string, apiMeta define.UriConfig) (*ope } newOperate := openapi3.NewOperation() newOperate.Parameters = make(openapi3.Parameters, 0) + // 合入公共的 请求参数 + if nil != g.docTable[docFlag].Components.Parameters { + for _, v := range g.docTable[docFlag].Components.Parameters { + newOperate.Parameters = append(newOperate.Parameters, v) + } + } newOperate.Responses = openapi3.NewResponses() newOperate.Summary = apiMeta.Desc newOperate.Description = apiMeta.Desc diff --git a/openapi/generate_test.go b/openapi/generate_test.go index 985a944..143efa3 100644 --- a/openapi/generate_test.go +++ b/openapi/generate_test.go @@ -11,10 +11,13 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "testing" "time" "git.zhangdeman.cn/zhangdeman/api-doc/define" + "git.zhangdeman.cn/zhangdeman/consts" + "github.com/getkin/kin-openapi/openapi3" ) func TestGenerate_AddApiDoc(t *testing.T) { @@ -34,8 +37,38 @@ func TestGenerate_AddApiDoc(t *testing.T) { UpdatedAt *time.Time `json:"updated_at,omitempty" description:"更新时间"` Category *Category `json:"category,omitempty" description:"分类"` } + DocManager.EnableRedundantStorageComponents() // 启用 components 冗余存储 docFlag := "demo" - DocManager.NewOpenApiDoc(docFlag) + DocManager.NewOpenApiDoc(docFlag, WithSecurity(&openapi3.SecuritySchemes{ + "Token-Auth": { + Extensions: nil, + Origin: nil, + Ref: "", + Value: &openapi3.SecurityScheme{ + Extensions: nil, + Origin: nil, + Type: "apiKey", + Description: "Token 身份认证", + Name: "token", + In: strings.ToLower(consts.RequestDataLocationHeader.String()), + }, + }, + }), WithCommonParameter(&openapi3.ParametersMap{ + "Token": { + Value: &openapi3.Parameter{ + Name: "Token", + In: strings.ToLower(consts.RequestDataLocationHeader.String()), + Description: "用户登录 Token", + }, + }, + "User-Agent": { + Value: &openapi3.Parameter{ + Name: "User-Agent", + In: strings.ToLower(consts.RequestDataLocationHeader.String()), + Description: "用户访问 UA", + }, + }, + })) DocManager.AddApiDoc(docFlag, define.UriConfig{ Path: "/a/b/c", RequestMethod: http.MethodGet, diff --git a/openapi/option.go b/openapi/option.go index 3603a99..76760a1 100644 --- a/openapi/option.go +++ b/openapi/option.go @@ -7,7 +7,11 @@ // Date : 2026-01-06 22:48 package openapi -import "github.com/getkin/kin-openapi/openapi3" +import ( + "sort" + + "github.com/getkin/kin-openapi/openapi3" +) // OptionFunc 设置文档选项 type OptionFunc func(t *openapi3.T) @@ -31,3 +35,53 @@ func WithInfo(info *openapi3.Info) OptionFunc { t.Info = info } } + +// WithSecurity 设置安全策略 +func WithSecurity(securityTable *openapi3.SecuritySchemes) OptionFunc { + return func(t *openapi3.T) { + if nil == securityTable { + return + } + if nil == t.Components { + t.Components = &openapi3.Components{} + } + if nil == t.Components.SecuritySchemes { + t.Components.SecuritySchemes = make(map[string]*openapi3.SecuritySchemeRef) + } + if nil == t.Security { + t.Security = make([]openapi3.SecurityRequirement, 0) + } + keyList := make([]string, 0) + for k, v := range *securityTable { + keyList = append(keyList, k) + t.Components.SecuritySchemes[k] = v + } + // 保证生成结果有序 + sort.Strings(keyList) + for _, k := range keyList { + t.Security = append(t.Security, map[string][]string{k: {}}) + } + } +} + +// WithCommonParameter 设置公共请求蚕食 +func WithCommonParameter(commonParameterTable *openapi3.ParametersMap) OptionFunc { + return func(t *openapi3.T) { + if nil == commonParameterTable { + return + } + if nil == t.Components { + t.Components = &openapi3.Components{} + } + if nil == t.Components.Parameters { + t.Components.Parameters = make(map[string]*openapi3.ParameterRef) + } + // 不要直接复制, 逐个设置, 可以重复调用, 后面的会覆盖前面的 + for k, v := range *commonParameterTable { + if nil == v { + continue + } + t.Components.Parameters[k] = v + } + } +}