diff --git a/proxy/config.go b/proxy/config.go index 925b206..a59b8ec 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -13,8 +13,10 @@ package proxy // // Date : 3:34 下午 2021/8/6 type Server struct { - Scheme string // 转发scheme - Host string // 服务器地址 - URI string // 转发接口 - RewriteHeader map[string]string // 重写header + Scheme string // 转发scheme + Host string // 服务器地址 + URI string // 转发接口 + RewriteRequestHeader map[string]string // 重写请求header + RewriteResponseHeader map[string]string // 重写响应header + RewriteResponseData map[string]string // 重写响应数据 } diff --git a/proxy/forward.go b/proxy/forward.go index e0a531f..b7273d1 100644 --- a/proxy/forward.go +++ b/proxy/forward.go @@ -8,11 +8,15 @@ package proxy import ( + "compress/gzip" "fmt" "io" + "io/ioutil" "net" "net/http" "strings" + + "git.zhangdeman.cn/zhangdeman/gopkg/safe" ) // Forward 正向代理的实现 @@ -37,7 +41,7 @@ func Forward(rw http.ResponseWriter, req *http.Request, serverConfig *Server) { } // 写入重写的请求Header - for k, v := range serverConfig.RewriteHeader { + for k, v := range serverConfig.RewriteRequestHeader { outReq.Header.Set(k, v) } @@ -46,6 +50,7 @@ func Forward(rw http.ResponseWriter, req *http.Request, serverConfig *Server) { outReq.URL.Path = serverConfig.URI outReq.URL.Scheme = serverConfig.Scheme outReq.URL.Host = serverConfig.Host + // step 2 res, err := transport.RoundTrip(outReq) if err != nil { @@ -61,6 +66,24 @@ func Forward(rw http.ResponseWriter, req *http.Request, serverConfig *Server) { } rw.WriteHeader(res.StatusCode) - io.Copy(rw, res.Body) + + // 重写请求header + for k, v := range serverConfig.RewriteResponseHeader { + rw.Header().Set(k, v) + } + + // 重写响应数据 + if !strings.Contains(strings.ToLower(res.Header.Get("Content-Type")), "application/json") || nil == serverConfig.RewriteResponseData || len(serverConfig.RewriteResponseData) == 0 { + io.Copy(rw, res.Body) + } else { + var responseData []byte + var gzipData io.Reader + gzipData, err = gzip.NewReader(res.Body) + responseData, err = ioutil.ReadAll(gzipData) + fmt.Println(string(responseData), err) + formatData, _ := safe.Filter(responseData, serverConfig.RewriteResponseData) + rw.Write(formatData) + } + res.Body.Close() } diff --git a/safe/data.go b/safe/data.go index 91d2a16..5ef1750 100644 --- a/safe/data.go +++ b/safe/data.go @@ -18,18 +18,18 @@ import ( // Author : go_developer@163.com<白茶清欢> // // Date : 6:40 下午 2021/3/10 -func Filter(source []byte, filter []string) ([]byte, error) { +func Filter(source []byte, filter map[string]string) ([]byte, error) { var ( bt []byte setErr error ) - for _, item := range filter { + for result, item := range filter { fieldList := strings.Split(item, ".") val, _, _, err := jsonparser.Get(source, fieldList...) if nil != err { return nil, err } - if bt, setErr = jsonparser.Set(bt, val, fieldList...); nil != setErr { + if bt, setErr = jsonparser.Set(bt, val, strings.Split(result, ".")...); nil != setErr { return nil, setErr } }