|
5 | 5 | package http
|
6 | 6 |
|
7 | 7 | import (
|
| 8 | + "bufio" |
8 | 9 | "bytes"
|
9 | 10 | "errors"
|
10 | 11 | "fmt"
|
11 | 12 | "io"
|
12 | 13 | "io/ioutil"
|
| 14 | + "net" |
13 | 15 | "net/url"
|
14 | 16 | "strings"
|
15 | 17 | "testing"
|
| 18 | + "time" |
16 | 19 | )
|
17 | 20 |
|
18 | 21 | type reqWriteTest struct {
|
@@ -566,6 +569,138 @@ func TestRequestWrite(t *testing.T) {
|
566 | 569 | }
|
567 | 570 | }
|
568 | 571 |
|
| 572 | +func TestRequestWriteTransport(t *testing.T) { |
| 573 | + t.Parallel() |
| 574 | + |
| 575 | + matchSubstr := func(substr string) func(string) error { |
| 576 | + return func(written string) error { |
| 577 | + if !strings.Contains(written, substr) { |
| 578 | + return fmt.Errorf("expected substring %q in request: %s", substr, written) |
| 579 | + } |
| 580 | + return nil |
| 581 | + } |
| 582 | + } |
| 583 | + |
| 584 | + noContentLengthOrTransferEncoding := func(req string) error { |
| 585 | + if strings.Contains(req, "Content-Length: ") { |
| 586 | + return fmt.Errorf("unexpected Content-Length in request: %s", req) |
| 587 | + } |
| 588 | + if strings.Contains(req, "Transfer-Encoding: ") { |
| 589 | + return fmt.Errorf("unexpected Transfer-Encoding in request: %s", req) |
| 590 | + } |
| 591 | + return nil |
| 592 | + } |
| 593 | + |
| 594 | + all := func(checks ...func(string) error) func(string) error { |
| 595 | + return func(req string) error { |
| 596 | + for _, c := range checks { |
| 597 | + if err := c(req); err != nil { |
| 598 | + return err |
| 599 | + } |
| 600 | + } |
| 601 | + return nil |
| 602 | + } |
| 603 | + } |
| 604 | + |
| 605 | + type testCase struct { |
| 606 | + method string |
| 607 | + clen int64 // ContentLength |
| 608 | + body io.ReadCloser |
| 609 | + want func(string) error |
| 610 | + |
| 611 | + // optional: |
| 612 | + init func(*testCase) |
| 613 | + afterReqRead func() |
| 614 | + } |
| 615 | + |
| 616 | + tests := []testCase{ |
| 617 | + { |
| 618 | + method: "GET", |
| 619 | + want: noContentLengthOrTransferEncoding, |
| 620 | + }, |
| 621 | + { |
| 622 | + method: "GET", |
| 623 | + body: ioutil.NopCloser(strings.NewReader("")), |
| 624 | + want: noContentLengthOrTransferEncoding, |
| 625 | + }, |
| 626 | + { |
| 627 | + method: "GET", |
| 628 | + clen: -1, |
| 629 | + body: ioutil.NopCloser(strings.NewReader("")), |
| 630 | + want: noContentLengthOrTransferEncoding, |
| 631 | + }, |
| 632 | + // A GET with a body, with explicit content length: |
| 633 | + { |
| 634 | + method: "GET", |
| 635 | + clen: 7, |
| 636 | + body: ioutil.NopCloser(strings.NewReader("foobody")), |
| 637 | + want: all(matchSubstr("Content-Length: 7"), |
| 638 | + matchSubstr("foobody")), |
| 639 | + }, |
| 640 | + // A GET with a body, sniffing the leading "f" from "foobody". |
| 641 | + { |
| 642 | + method: "GET", |
| 643 | + clen: -1, |
| 644 | + body: ioutil.NopCloser(strings.NewReader("foobody")), |
| 645 | + want: all(matchSubstr("Transfer-Encoding: chunked"), |
| 646 | + matchSubstr("\r\n1\r\nf\r\n"), |
| 647 | + matchSubstr("oobody")), |
| 648 | + }, |
| 649 | + // But a POST request is expected to have a body, so |
| 650 | + // no sniffing happens: |
| 651 | + { |
| 652 | + method: "POST", |
| 653 | + clen: -1, |
| 654 | + body: ioutil.NopCloser(strings.NewReader("foobody")), |
| 655 | + want: all(matchSubstr("Transfer-Encoding: chunked"), |
| 656 | + matchSubstr("foobody")), |
| 657 | + }, |
| 658 | + { |
| 659 | + method: "POST", |
| 660 | + clen: -1, |
| 661 | + body: ioutil.NopCloser(strings.NewReader("")), |
| 662 | + want: all(matchSubstr("Transfer-Encoding: chunked")), |
| 663 | + }, |
| 664 | + // Verify that a blocking Request.Body doesn't block forever. |
| 665 | + { |
| 666 | + method: "GET", |
| 667 | + clen: -1, |
| 668 | + init: func(tt *testCase) { |
| 669 | + pr, pw := io.Pipe() |
| 670 | + tt.afterReqRead = func() { |
| 671 | + pw.Close() |
| 672 | + } |
| 673 | + tt.body = ioutil.NopCloser(pr) |
| 674 | + }, |
| 675 | + want: matchSubstr("Transfer-Encoding: chunked"), |
| 676 | + }, |
| 677 | + } |
| 678 | + |
| 679 | + for i, tt := range tests { |
| 680 | + if tt.init != nil { |
| 681 | + tt.init(&tt) |
| 682 | + } |
| 683 | + req := &Request{ |
| 684 | + Method: tt.method, |
| 685 | + URL: &url.URL{ |
| 686 | + Scheme: "http", |
| 687 | + Host: "example.com", |
| 688 | + }, |
| 689 | + Header: make(Header), |
| 690 | + ContentLength: tt.clen, |
| 691 | + Body: tt.body, |
| 692 | + } |
| 693 | + got, err := dumpRequestOut(req, tt.afterReqRead) |
| 694 | + if err != nil { |
| 695 | + t.Errorf("test[%d]: %v", i, err) |
| 696 | + continue |
| 697 | + } |
| 698 | + if err := tt.want(string(got)); err != nil { |
| 699 | + t.Errorf("test[%d]: %v", i, err) |
| 700 | + } |
| 701 | + } |
| 702 | +} |
| 703 | + |
569 | 704 | type closeChecker struct {
|
570 | 705 | io.Reader
|
571 | 706 | closed bool
|
@@ -672,3 +807,76 @@ func TestRequestWriteError(t *testing.T) {
|
672 | 807 | t.Fatalf("writeCalls constant is outdated in test")
|
673 | 808 | }
|
674 | 809 | }
|
| 810 | + |
| 811 | +// dumpRequestOut is a modified copy of net/http/httputil.DumpRequestOut. |
| 812 | +// Unlike the original, this version doesn't mutate the req.Body and |
| 813 | +// try to restore it. It always dumps the whole body. |
| 814 | +// And it doesn't support https. |
| 815 | +func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) { |
| 816 | + |
| 817 | + // Use the actual Transport code to record what we would send |
| 818 | + // on the wire, but not using TCP. Use a Transport with a |
| 819 | + // custom dialer that returns a fake net.Conn that waits |
| 820 | + // for the full input (and recording it), and then responds |
| 821 | + // with a dummy response. |
| 822 | + var buf bytes.Buffer // records the output |
| 823 | + pr, pw := io.Pipe() |
| 824 | + defer pr.Close() |
| 825 | + defer pw.Close() |
| 826 | + dr := &delegateReader{c: make(chan io.Reader)} |
| 827 | + |
| 828 | + t := &Transport{ |
| 829 | + Dial: func(net, addr string) (net.Conn, error) { |
| 830 | + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil |
| 831 | + }, |
| 832 | + } |
| 833 | + defer t.CloseIdleConnections() |
| 834 | + |
| 835 | + // Wait for the request before replying with a dummy response: |
| 836 | + go func() { |
| 837 | + req, err := ReadRequest(bufio.NewReader(pr)) |
| 838 | + if err == nil { |
| 839 | + if onReadHeaders != nil { |
| 840 | + onReadHeaders() |
| 841 | + } |
| 842 | + // Ensure all the body is read; otherwise |
| 843 | + // we'll get a partial dump. |
| 844 | + io.Copy(ioutil.Discard, req.Body) |
| 845 | + req.Body.Close() |
| 846 | + } |
| 847 | + dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") |
| 848 | + }() |
| 849 | + |
| 850 | + _, err := t.RoundTrip(req) |
| 851 | + if err != nil { |
| 852 | + return nil, err |
| 853 | + } |
| 854 | + return buf.Bytes(), nil |
| 855 | +} |
| 856 | + |
| 857 | +// delegateReader is a reader that delegates to another reader, |
| 858 | +// once it arrives on a channel. |
| 859 | +type delegateReader struct { |
| 860 | + c chan io.Reader |
| 861 | + r io.Reader // nil until received from c |
| 862 | +} |
| 863 | + |
| 864 | +func (r *delegateReader) Read(p []byte) (int, error) { |
| 865 | + if r.r == nil { |
| 866 | + r.r = <-r.c |
| 867 | + } |
| 868 | + return r.r.Read(p) |
| 869 | +} |
| 870 | + |
| 871 | +// dumpConn is a net.Conn that writes to Writer and reads from Reader. |
| 872 | +type dumpConn struct { |
| 873 | + io.Writer |
| 874 | + io.Reader |
| 875 | +} |
| 876 | + |
| 877 | +func (c *dumpConn) Close() error { return nil } |
| 878 | +func (c *dumpConn) LocalAddr() net.Addr { return nil } |
| 879 | +func (c *dumpConn) RemoteAddr() net.Addr { return nil } |
| 880 | +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } |
| 881 | +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } |
| 882 | +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } |
0 commit comments