diff --git a/proxy/forward.go b/proxy/forward.go index ce4cd51..7389902 100644 --- a/proxy/forward.go +++ b/proxy/forward.go @@ -8,13 +8,11 @@ package proxy import ( - "bytes" "compress/gzip" - "fmt" "io" "io/ioutil" - "net" "net/http" + "net/http/httputil" "strings" ) @@ -24,72 +22,30 @@ import ( // // Date : 2:08 下午 2021/8/6 func Forward(rw http.ResponseWriter, req *http.Request, serverConfig *Server) { - fmt.Printf("Received request %s %s %s\n", req.Method, req.Host, req.RemoteAddr) - transport := http.DefaultTransport - - // step 1 - outReq := new(http.Request) - *outReq = *req // this only does shallow copies of maps - - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - if prior, ok := outReq.Header["X-Forwarded-For"]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - outReq.Header.Set("X-Forwarded-For", clientIP) + if !strings.HasPrefix(serverConfig.URI, "/") { + serverConfig.URI = "/" + serverConfig.URI } - - // 写入重写的请求Header - for k, v := range serverConfig.RewriteRequestHeader { - outReq.Header.Set(k, v) - } - - // 重写请求地址 - outReq.Host = serverConfig.Host - outReq.URL.Path = serverConfig.URI - outReq.URL.Scheme = serverConfig.Scheme - outReq.URL.Host = serverConfig.Host - - // step 2 - res, err := transport.RoundTrip(outReq) - if err != nil { - rw.WriteHeader(http.StatusBadGateway) - return - } - - // step 3 - for key, value := range res.Header { - for _, v := range value { - if strings.ToLower(key) == "content-encoding" { - continue - } - rw.Header().Add(key, v) + // 请求重写方法 + director := func(req *http.Request) { + req.URL.Scheme = serverConfig.Scheme + // req.URL.Host = projectDetail.GetProjectDetail().Domain + ":" + fmt.Sprintf("%v", projectDetail.GetProjectDetail().Port) + // req.Host = projectDetail.GetProjectDetail().Domain + ":" + fmt.Sprintf("%v", projectDetail.GetProjectDetail().Port) + req.Host = serverConfig.Host + req.URL.Host = serverConfig.Host + req.URL.Path = serverConfig.URI + req.RequestURI = serverConfig.URI + // 写入重写的请求Header + for k, v := range serverConfig.RewriteRequestHeader { + req.Header.Set(k, v) } } - - rw.WriteHeader(res.StatusCode) - - // 重写请求header - for k, v := range serverConfig.RewriteResponseHeader { - rw.Header().Set(k, v) + // TODO : 重写响应数据 + modifyResponseFunc := func(rep *http.Response) error { + return nil } - - defer res.Body.Close() - - // 重写响应数据 - if !strings.Contains(strings.ToLower(res.Header.Get("Content-Type")), "application/json") || nil == serverConfig.RewriteResponseData || len(serverConfig.RewriteResponseData) == 0 { - _, _ = io.Copy(rw, res.Body) - return - } - var ( - responseData []byte - ) - - responseData, err = getResponseData(res) - fmt.Println(string(responseData), err) - - bytesBuffer := bytes.NewReader([]byte(`{"data":{"permission":true}}`)) - _, _ = io.Copy(rw, bytesBuffer) + p := &httputil.ReverseProxy{Director: director, ModifyResponse: modifyResponseFunc} + p.ServeHTTP(rw, req) } // getResultCompressType 获取返回结果的压缩方式