深度剖析:Go语言中的http.RoundTripper

发表时间: 2022-12-12 23:51

1、RoundTripper 接口

// RoundTripper is an interface representing the ability to execute a// single HTTP transaction, obtaining the Response for a given Request.//// A RoundTripper must be safe for concurrent use by multiple// goroutines.type RoundTripper interface {   // RoundTrip executes a single HTTP transaction, returning   // a Response for the provided Request.   //   // RoundTrip should not attempt to interpret the response. In   // particular, RoundTrip must return err == nil if it obtained   // a response, regardless of the response's HTTP status code.   // A non-nil err should be reserved for failure to obtain a   // response. Similarly, RoundTrip should not attempt to   // handle higher-level protocol details such as redirects,   // authentication, or cookies.   //   // RoundTrip should not modify the request, except for   // consuming and closing the Request's Body. RoundTrip may   // read fields of the request in a separate goroutine. Callers   // should not mutate or reuse the request until the Response's   // Body has been closed.   //   // RoundTrip must always close the body, including on errors,   // but depending on the implementation may do so in a separate   // goroutine even after RoundTrip returns. This means that   // callers wanting to reuse the body for subsequent requests   // must arrange to wait for the Close call before doing so.   //   // The Request's URL and Header fields must be initialized.   RoundTrip(*Request) (*Response, error)}

一句话概括:可以把RoundTripper看成是 http.Client 的中间件

2、场景

(1)缓存http responses。如果缓存中存在,则直接从缓存中取。如果缓存中不存在,则重新请求

(2)根据需要适当的设置HTTP headers

(3)Rate limiting

3、实例

使用HTTP.roundtripper的实现来缓存HTTP响应

(1)实现一个 HTTP Server

package mainimport (   "fmt"   "net/http")func main() {   mux := http.NewServeMux()   mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {      //This is here so wen can actually see what the responses that have been cached don't get here      fmt.Println("The request actually got here")      writer.Write([]byte("You got here"))   })   http.ListenAndServe(":8080", mux)}

(2)新建一个http.Transport并且实现 http.RoundTripper接口,同时让其具有缓存功能

package cacheTransportimport (   "bufio"   "bytes"   "errors"   "fmt"   "net"   "net/http"   "net/http/httputil"   "sync"   "time")func cachKey(r *http.Request) string {   return r.URL.String()}type CacheTransport struct {   data              map[string]string   mu                sync.RWMutex   originalTransport http.RoundTripper}func NewCacheTransport() *CacheTransport {   return &CacheTransport{      data: make(map[string]string, 20),      originalTransport: &http.Transport{         DialContext: (&net.Dialer{            Timeout:   2 * time.Second,            KeepAlive: 30 * time.Second,            DualStack: true,         }).DialContext,         ResponseHeaderTimeout: 5 * time.Second,         MaxIdleConns:          100,         IdleConnTimeout:       90 * time.Second,         TLSHandshakeTimeout:   10 * time.Second,      }}}func (c *CacheTransport) RoundTrip(r *http.Request) (*http.Response, error) {   //Check if we have the respose cached...   // if yes,we don't have to hit the server   // we just return it as is from the cache store.   if val, err := c.Get(r); err == nil {      fmt.Println("Fetching the response from the cache")      return cachedResponse([]byte(val), r)   }   // we don't have the response cached,the store was probably cleared.   // Make the request to the server   resp, err := c.originalTransport.RoundTrip(r)   if err != nil {      return nil, err   }   // Get the body of the response so we can save it int the cache for the next request.   buf, err := httputil.DumpResponse(resp, true)   if err != nil {      return nil, err   }   //Saving it to the store   c.Set(r, string(buf))   fmt.Println("Fetching the data form the real source")   return resp, nil}func cachedResponse(b []byte, r *http.Request) (*http.Response, error) {   buf := bytes.NewBuffer(b)   return http.ReadResponse(bufio.NewReader(buf), r)}func (c *CacheTransport) Set(r *http.Request, value string) {   c.mu.Lock()   defer c.mu.Unlock()   c.data[cachKey(r)] = value}func (c *CacheTransport) Get(r *http.Request) (string, error) {   c.mu.Lock()   defer c.mu.Unlock()   if val, ok := c.data[cachKey(r)]; ok {      return val, nil   }   return "", errors.New("key not found in cache")}func (c *CacheTransport) Clear() error {   c.mu.Lock()   c.mu.Unlock()   c.data = make(map[string]string)   return nil}

(3)实现客户端的引导程序main函数,同时具有一个timer来定时清理缓存

package mainimport (   "awesomeProject/interview/roundtripper/cacheTransport"   "fmt"   "io/ioutil"   "log"   "net/http"   "os"   "os/signal"   "strings"   "syscall"   "time")func main() {   cacheTransport := cacheTransport.NewCacheTransport()   // Create a custom client so we can make use of our RoundTripper   // If you make use of http.Get(),the default http client located at http.DefaultClient is used instead   // Since we have special needs,we have to make use of our own http.RoundTripper implementation   client := &http.Client{      Transport: cacheTransport,      Timeout:   time.Second * 5,   }   // Time to clear the cache store so we can make requet to the original server rather than fetch from the cache store   // This is to replicate real expiration of data in a cache store   cacheClearTicker := time.NewTicker(time.Second * 5)   // Make a new request every second   // This would help demonstrate if the response if coming from the real server or the cache   reqTicker := time.NewTicker(time.Second * 1)   terminateChannel := make(chan os.Signal, 1)   signal.Notify(terminateChannel, syscall.SIGTERM, syscall.SIGHUP)   req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", strings.NewReader(""))   if err != nil {      panic("Whoops")   }   for {      select {      case <-cacheClearTicker.C:         // Clear the cache so we can hit the original server         cacheTransport.Clear()      case <-terminateChannel:         cacheClearTicker.Stop()         reqTicker.Stop()         return      case <-reqTicker.C:         resp, err := client.Do(req)         if err != nil {            log.Printf("An error occurred.... %v", err)            continue         }         buf, err := ioutil.ReadAll(resp.Body)         if err != nil {            log.Printf("An error occurred...%v", err)            continue         }         fmt.Printf("The body of the response if \"%s\" \n\n", string(buf))      }   }}

(4)服务端执行结果如下

客户端执行结果如下: