gopkg/proxy/forward.go

127 lines
2.9 KiB
Go
Raw Normal View History

2021-08-06 16:02:59 +08:00
// Package proxy...
//
// Description : 正向代理的实现
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2021-08-06 2:07 下午
package proxy
import (
2021-08-06 20:49:42 +08:00
"bytes"
2021-08-06 20:13:02 +08:00
"compress/gzip"
2021-08-06 16:02:59 +08:00
"fmt"
"io"
2021-08-06 20:13:02 +08:00
"io/ioutil"
2021-08-06 16:02:59 +08:00
"net"
"net/http"
"strings"
)
// Forward 正向代理的实现
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2:08 下午 2021/8/6
func Forward(rw http.ResponseWriter, req *http.Request, serverConfig *Server) {
fmt.Printf("Received request %s %s %s\n", req.Method, req.Host, req.RemoteAddr)
transport := http.DefaultTransport
// step 1
outReq := new(http.Request)
*outReq = *req // this only does shallow copies of maps
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if prior, ok := outReq.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
outReq.Header.Set("X-Forwarded-For", clientIP)
}
// 写入重写的请求Header
2021-08-06 20:13:02 +08:00
for k, v := range serverConfig.RewriteRequestHeader {
2021-08-06 16:02:59 +08:00
outReq.Header.Set(k, v)
}
// 重写请求地址
outReq.Host = serverConfig.Host
outReq.URL.Path = serverConfig.URI
outReq.URL.Scheme = serverConfig.Scheme
outReq.URL.Host = serverConfig.Host
2021-08-06 20:13:02 +08:00
2021-08-06 16:02:59 +08:00
// step 2
res, err := transport.RoundTrip(outReq)
if err != nil {
rw.WriteHeader(http.StatusBadGateway)
return
}
// step 3
for key, value := range res.Header {
for _, v := range value {
2021-08-06 20:49:42 +08:00
if strings.ToLower(key) == "content-encoding" {
continue
}
2021-08-06 16:02:59 +08:00
rw.Header().Add(key, v)
}
}
rw.WriteHeader(res.StatusCode)
2021-08-06 20:13:02 +08:00
// 重写请求header
for k, v := range serverConfig.RewriteResponseHeader {
rw.Header().Set(k, v)
}
2021-09-11 22:38:45 +08:00
defer res.Body.Close()
2021-08-06 20:13:02 +08:00
// 重写响应数据
if !strings.Contains(strings.ToLower(res.Header.Get("Content-Type")), "application/json") || nil == serverConfig.RewriteResponseData || len(serverConfig.RewriteResponseData) == 0 {
2021-09-11 22:38:45 +08:00
_, _ = io.Copy(rw, res.Body)
return
}
var (
responseData []byte
)
responseData, err = getResponseData(res)
fmt.Println(string(responseData), err)
2021-08-06 20:13:02 +08:00
2021-09-11 22:38:45 +08:00
bytesBuffer := bytes.NewReader([]byte(`{"data":{"permission":true}}`))
_, _ = io.Copy(rw, bytesBuffer)
}
// getResultCompressType 获取返回结果的压缩方式
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 9:11 下午 2021/9/11
func getResultCompressType(res *http.Response) string {
return strings.ToLower(res.Header.Get("Content-Encoding"))
}
// getResponseData 解析响应数据
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 9:15 下午 2021/9/11
func getResponseData(res *http.Response) ([]byte, error) {
var (
responseData []byte
err error
)
switch getResultCompressType(res) {
case "gzip":
var gzipData io.Reader
if gzipData, err = gzip.NewReader(res.Body); nil == err {
// gzip 处理过的数据
responseData, err = ioutil.ReadAll(gzipData)
}
default:
// 默认没有任何压缩
responseData, err = io.ReadAll(res.Body)
2021-08-06 20:49:42 +08:00
}
2021-09-11 22:38:45 +08:00
return responseData, err
2021-08-06 16:02:59 +08:00
}