Commit 5c31c130 authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

integration: Add test case for trailer part

parent b9d6fff9
...@@ -211,6 +211,41 @@ func TestH1H1HTTP10NoHostRewrite(t *testing.T) { ...@@ -211,6 +211,41 @@ func TestH1H1HTTP10NoHostRewrite(t *testing.T) {
} }
} }
// TestH1H1RequestTrailer tests request trailer part is forwarded to
// backend.
func TestH1H1RequestTrailer(t *testing.T) {
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
buf := make([]byte, 4096)
for {
_, err := r.Body.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("r.Body.Read() = %v", err)
}
}
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
}
})
defer st.Close()
res, err := st.http1(requestParam{
name: "TestH1H1RequestTrailer",
body: []byte("1"),
trailer: []hpack.HeaderField{
pair("foo", "bar"),
},
})
if err != nil {
t.Fatalf("Error st.http1() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH1H2ConnectFailure tests that server handles the situation that // TestH1H2ConnectFailure tests that server handles the situation that
// connection attempt to HTTP/2 backend failed. // connection attempt to HTTP/2 backend failed.
func TestH1H2ConnectFailure(t *testing.T) { func TestH1H2ConnectFailure(t *testing.T) {
......
...@@ -520,6 +520,41 @@ func TestH2H1ServerPush(t *testing.T) { ...@@ -520,6 +520,41 @@ func TestH2H1ServerPush(t *testing.T) {
} }
} }
// TestH2H1RequestTrailer tests request trailer part is forwarded to
// backend.
func TestH2H1RequestTrailer(t *testing.T) {
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
buf := make([]byte, 4096)
for {
_, err := r.Body.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("r.Body.Read() = %v", err)
}
}
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
}
})
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1RequestTrailer",
body: []byte("1"),
trailer: []hpack.HeaderField{
pair("foo", "bar"),
},
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1GracefulShutdown tests graceful shutdown. // TestH2H1GracefulShutdown tests graceful shutdown.
func TestH2H1GracefulShutdown(t *testing.T) { func TestH2H1GracefulShutdown(t *testing.T) {
st := newServerTester(nil, t, noopHandler) st := newServerTester(nil, t, noopHandler)
......
...@@ -255,6 +255,27 @@ type requestParam struct { ...@@ -255,6 +255,27 @@ type requestParam struct {
path string // path, defaults to / path string // path, defaults to /
header []hpack.HeaderField // additional request header fields header []hpack.HeaderField // additional request header fields
body []byte // request body body []byte // request body
trailer []hpack.HeaderField // trailer part
}
// wrapper for request body to set trailer part
type chunkedBodyReader struct {
trailer []hpack.HeaderField
trailerWritten bool
body io.Reader
req *http.Request
}
func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
// document says that we have to set http.Request.Trailer
// after request was sent and before body returns EOF.
if !cbr.trailerWritten {
cbr.trailerWritten = true
for _, h := range cbr.trailer {
cbr.req.Trailer.Set(h.Name, h.Value)
}
}
return cbr.body.Read(p)
} }
func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
...@@ -264,8 +285,16 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { ...@@ -264,8 +285,16 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
} }
var body io.Reader var body io.Reader
var cbr *chunkedBodyReader
if rp.body != nil { if rp.body != nil {
body = bytes.NewBuffer(rp.body) body = bytes.NewBuffer(rp.body)
if len(rp.trailer) != 0 {
cbr = &chunkedBodyReader{
trailer: rp.trailer,
body: body,
}
body = cbr
}
} }
req, err := http.NewRequest(method, st.url, body) req, err := http.NewRequest(method, st.url, body)
if err != nil { if err != nil {
...@@ -275,7 +304,15 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { ...@@ -275,7 +304,15 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
req.Header.Add(h.Name, h.Value) req.Header.Add(h.Name, h.Value)
} }
req.Header.Add("Test-Case", rp.name) req.Header.Add("Test-Case", rp.name)
if cbr != nil {
cbr.req = req
// this makes request use chunked encoding
req.ContentLength = -1
req.Trailer = make(http.Header)
for _, h := range cbr.trailer {
req.Trailer.Set(h.Name, "")
}
}
if err := req.Write(st.conn); err != nil { if err := req.Write(st.conn); err != nil {
return nil, err return nil, err
} }
...@@ -473,7 +510,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { ...@@ -473,7 +510,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
err := st.fr.WriteHeaders(http2.HeadersFrameParam{ err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id, StreamID: id,
EndStream: len(rp.body) == 0, EndStream: len(rp.body) == 0 && len(rp.trailer) == 0,
EndHeaders: true, EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(), BlockFragment: st.headerBlkBuf.Bytes(),
}) })
...@@ -483,7 +520,23 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { ...@@ -483,7 +520,23 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
if len(rp.body) != 0 { if len(rp.body) != 0 {
// TODO we assume rp.body fits in 1 frame // TODO we assume rp.body fits in 1 frame
if err := st.fr.WriteData(id, true, rp.body); err != nil { if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil {
return nil, err
}
}
if len(rp.trailer) != 0 {
st.headerBlkBuf.Reset()
for _, h := range rp.trailer {
_ = st.enc.WriteField(h)
}
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id,
EndStream: true,
EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(),
})
if err != nil {
return nil, err return nil, err
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment