diff --git a/router/common_param.go b/router/common_param.go index 29f03f9..87d965d 100644 --- a/router/common_param.go +++ b/router/common_param.go @@ -36,24 +36,42 @@ func (s *server) AddCommonParamRules(rules map[string]GetCommonParam) { func (s *server) injectCommonParam(ctx *gin.Context, formValue any) error { innerCtx := util.GinCtxToContext(ctx) var ( - val any - err error + val any + err error + reflectFormValue reflect.Value + reflectType reflect.Type + ok bool ) - reflectType := reflect.TypeOf(formValue) + if reflectFormValue, ok = formValue.(reflect.Value); !ok { + reflectFormValue = reflect.ValueOf(formValue) + reflectType = reflect.TypeOf(formValue) + } else { + reflectType = reflectFormValue.Type() + } + fieldTable := map[string]bool{} fieldNum := reflectType.Elem().NumField() for i := 0; i < fieldNum; i++ { - // 提取全部结构体字段 - fieldTable[reflectType.Elem().Field(i).Name] = true + if reflectType.Elem().Field(i).Anonymous && ((reflectType.Elem().Field(i).Type.Kind() == reflect.Ptr && reflectType.Elem().Field(i).Type.Kind() == reflect.Struct) || reflectType.Elem().Field(i).Type.Kind() == reflect.Struct) { + anonymousFieldType := reflectType.Elem().Field(i).Type + if anonymousFieldType.Kind() == reflect.Ptr { + anonymousFieldType = anonymousFieldType.Elem() + } + for j := 0; j < anonymousFieldType.NumField(); j++ { + fieldTable[anonymousFieldType.Field(j).Name] = true + } + } else { + // 提取全部结构体字段 + fieldTable[reflectType.Elem().Field(i).Name] = true + } } - reflectValue := reflect.ValueOf(formValue) for fieldName, getParamFunc := range s.commonParam { - if _, ok := fieldTable[fieldName]; !ok { + if _, ok = fieldTable[fieldName]; !ok { // 结构体字段未配置自动注入 logger.Instance.Debug("当前结构体不包含指定字段, 忽略执行", pkgLogger.NewLogData(innerCtx, logger.RecordType, logger.CodeInjectCommonParam, map[string]any{ "field_name": fieldName, - "struct": reflectValue.Elem().Type().String(), + "struct": reflectFormValue.Elem().Type().String(), }).ToFieldList()...) continue } @@ -64,7 +82,7 @@ func (s *server) injectCommonParam(ctx *gin.Context, formValue any) error { }).ToFieldList()...) return err } - fieldValue := reflectValue.Elem().FieldByName(fieldName) + fieldValue := reflectFormValue.Elem().FieldByName(fieldName) if !fieldValue.CanSet() { logDataList := pkgLogger.NewLogData(util.GinCtxToContext(ctx), logger.RecordType, logger.CodeInjectCommonParam, map[string]any{ "field_name": fieldName, diff --git a/router/handler.go b/router/handler.go index 7ab1ff0..084a817 100644 --- a/router/handler.go +++ b/router/handler.go @@ -46,10 +46,11 @@ func (s *server) getFormInitValue(ctx *gin.Context, uriCfg UriConfig) (any, erro func (s *server) RequestHandler(uriCfg UriConfig) gin.HandlerFunc { return func(ctx *gin.Context) { var ( - err error - ok bool - e exception.IException - formValue any + err error + ok bool + e exception.IException + formValue any + firstParam reflect.Value ) if formValue, err = s.getFormInitValue(ctx, uriCfg); nil != err { @@ -60,6 +61,17 @@ func (s *server) RequestHandler(uriCfg UriConfig) gin.HandlerFunc { ctx.Abort() return } + // 表单数据 + inputValue := reflect.ValueOf(formValue) + // 注入公共参数 + if err = s.injectCommonParam(ctx, inputValue); nil != err { + e = exception.NewFromError(500, err) + response.SendWithException(ctx, e, &define.ResponseOption{ + ContentType: consts.MimeTypeJson, + }) + ctx.Abort() + return + } isSuccess := false // 初始化响应之后logic @@ -94,11 +106,9 @@ func (s *server) RequestHandler(uriCfg UriConfig) gin.HandlerFunc { }() }() // 执行逻辑 - inputValue := reflect.ValueOf(formValue) if uriCfg.FormDataType.Kind() != reflect.Ptr { inputValue = inputValue.Elem() } - var firstParam reflect.Value if uriCfg.CtxType == CustomContextType { customCtx := ctx.MustGet(define.CustomContextKey) firstParam = reflect.ValueOf(customCtx) @@ -109,7 +119,7 @@ func (s *server) RequestHandler(uriCfg UriConfig) gin.HandlerFunc { if resList[1].IsNil() { // 请求成功 isSuccess = true - response.SuccessWithExtension(ctx, resList[0].Interface(), &define.ResponseOption{ContentType: "application/json;charset=utf-8"}) + response.SuccessWithExtension(ctx, resList[0].Interface(), &define.ResponseOption{ContentType: consts.MimeTypeJson}) return } // 请求失败 @@ -123,7 +133,7 @@ func (s *server) RequestHandler(uriCfg UriConfig) gin.HandlerFunc { }) } response.SendWithException(ctx, e, &define.ResponseOption{ - ContentType: "application/json;charset=utf-8", + ContentType: consts.MimeTypeJson, }) return } diff --git a/router/server.go b/router/server.go index 2fdf84c..0e58838 100644 --- a/router/server.go +++ b/router/server.go @@ -126,11 +126,12 @@ func NewServer(port int, optionList ...SetServerOptionFunc) *server { pprof.Register(r) } return &server{ - router: r, - uiInstance: apiDoc.NewSwaggerUI(option.serverInfo, option.serverList, apiDocEnum.SwaggerUITheme(option.swaggerUiTheme)), - port: port, - option: option, - lock: &sync.RWMutex{}, + router: r, + uiInstance: apiDoc.NewSwaggerUI(option.serverInfo, option.serverList, apiDocEnum.SwaggerUITheme(option.swaggerUiTheme)), + port: port, + option: option, + lock: &sync.RWMutex{}, + commonParam: map[string]GetCommonParam{}, } } diff --git a/router/server_test.go b/router/server_test.go index c1233d7..7151270 100644 --- a/router/server_test.go +++ b/router/server_test.go @@ -13,8 +13,28 @@ import ( "github.com/gin-gonic/gin" ) +type testCommon struct { + UserID uint `json:"user_id"` +} + +type testForm struct { + Meta `json:"-" method:"get" path:"test"` + testCommon + Name string `json:"name"` +} + func TestNewServer(t *testing.T) { s := NewServer(9087) - s.Router().GET("/ping", func(c *gin.Context) {}) + s.AddCommonParamRule("UserID", func(ctx *gin.Context) (any, error) { + return uint(123456), nil + }) + s.Group("", nil, testController{}) s.Start() } + +type testController struct { +} + +func (tc testController) Test(ctx *gin.Context, requestData *testForm) (*testCommon, error) { + return &testCommon{UserID: requestData.UserID}, nil +}