Commit 16e91746 authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

nghttpx: Return 400 error if multiple CLs are received in SPDY upstream

This change adds SPDY upstream tests.
parent b9a9a23b
......@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
"golang.org/x/net/spdy"
"io"
"io/ioutil"
"net/http"
......@@ -385,3 +386,87 @@ func TestH2H2InvalidResponseCL(t *testing.T) {
t.Errorf("status: %v; want %v", got, want)
}
}
func TestS3H1PlainGET(t *testing.T) {
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler)
defer st.Close()
res, err := st.spdy(requestParam{
name: "TestS3H1PlainGET",
})
if err != nil {
t.Fatalf("Error st.spdy() = %v", err)
}
want := 200
if got := res.status; got != want {
t.Errorf("status = %v; want %v", got, want)
}
}
func TestS3H1BadRequestCL(t *testing.T) {
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler)
defer st.Close()
// we set content-length: 1024, but the actual request body is
// 3 bytes.
res, err := st.spdy(requestParam{
name: "TestS3H1BadRequestCL",
method: "POST",
header: []hpack.HeaderField{
pair("content-length", "1024"),
},
body: []byte("foo"),
})
if err != nil {
t.Fatalf("Error st.spdy() = %v", err)
}
want := spdy.ProtocolError
if got := res.spdyRstErrCode; got != want {
t.Errorf("res.spdyRstErrCode = %v; want %v", got, want)
}
}
func TestS3H1MultipleRequestCL(t *testing.T) {
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("server should not forward bad request")
})
defer st.Close()
res, err := st.spdy(requestParam{
name: "TestS3H1MultipleRequestCL",
header: []hpack.HeaderField{
pair("content-length", "1"),
pair("content-length", "2"),
},
})
if err != nil {
t.Fatalf("Error st.spdy() = %v", err)
}
want := 400
if got := res.status; got != want {
t.Errorf("status: %v; want %v", got, want)
}
}
func TestS3H1InvalidRequestCL(t *testing.T) {
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("server should not forward bad request")
})
defer st.Close()
res, err := st.spdy(requestParam{
name: "TestS3H1InvalidRequestCL",
header: []hpack.HeaderField{
pair("content-length", ""),
},
})
if err != nil {
t.Fatalf("Error st.spdy() = %v", err)
}
want := 400
if got := res.status; got != want {
t.Errorf("status: %v; want %v", got, want)
}
}
......@@ -9,6 +9,7 @@ import (
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
"github.com/tatsuhiro-t/go-nghttp2"
"golang.org/x/net/spdy"
"io"
"io/ioutil"
"net"
......@@ -25,6 +26,7 @@ import (
const (
serverBin = buildDir + "/src/nghttpx"
serverPort = 3009
testDir = buildDir + "/integration-tests"
)
func pair(name, value string) hpack.HeaderField {
......@@ -43,24 +45,39 @@ type serverTester struct {
conn net.Conn // connection to frontend server
h2PrefaceSent bool // HTTP/2 preface was sent in conn
nextStreamID uint32 // next stream ID
fr *http2.Framer
headerBlkBuf bytes.Buffer // buffer to store encoded header block
enc *hpack.Encoder
header http.Header // received header fields
dec *hpack.Decoder
authority string // server's host:port
frCh chan http2.Frame
fr *http2.Framer // HTTP/2 framer
spdyFr *spdy.Framer // SPDY/3.1 framer
headerBlkBuf bytes.Buffer // buffer to store encoded header block
enc *hpack.Encoder // HTTP/2 HPACK encoder
header http.Header // received header fields
dec *hpack.Decoder // HTTP/2 HPACK decoder
authority string // server's host:port
frCh chan http2.Frame // used for incoming HTTP/2 frame
spdyFrCh chan spdy.Frame // used for incoming SPDY frame
errCh chan error
}
// newServerTester creates test context for plain TCP frontend
// connection.
func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
return newServerTesterInternal(args, t, handler, false)
}
// newServerTester creates test context for TLS frontend connection.
func newServerTesterTLS(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
return newServerTesterInternal(args, t, handler, true)
}
// newServerTesterInternal creates test context. If frontendTLS is
// true, set up TLS frontend connection.
func newServerTesterInternal(args []string, t *testing.T, handler http.HandlerFunc, frontendTLS bool) *serverTester {
ts := httptest.NewUnstartedServer(handler)
backendTLS := false
for _, k := range args {
if k == "--http2-bridge" {
switch k {
case "--http2-bridge":
backendTLS = true
break
}
}
if backendTLS {
......@@ -75,26 +92,36 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser
} else {
ts.Start()
}
u, err := url.Parse(ts.URL)
scheme := "http"
if frontendTLS {
scheme = "https"
args = append(args, testDir+"/server.key", testDir+"/server.crt")
} else {
args = append(args, "--frontend-no-tls")
}
backendURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("Error parsing URL from httptest.Server: %v", err)
}
// URL.Host looks like "127.0.0.1:8080", but we want
// "127.0.0.1,8080"
b := "-b" + strings.Replace(u.Host, ":", ",", -1)
b := "-b" + strings.Replace(backendURL.Host, ":", ",", -1)
args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b,
"--errorlog-file="+buildDir+"/integration-tests/log.txt",
"-LINFO", "--frontend-no-tls")
"--errorlog-file="+testDir+"/log.txt", "-LINFO")
authority := fmt.Sprintf("127.0.0.1:%v", serverPort)
st := &serverTester{
cmd: exec.Command(serverBin, args...),
t: t,
ts: ts,
url: fmt.Sprintf("http://127.0.0.1:%v", serverPort),
url: fmt.Sprintf("%v://%v", scheme, authority),
nextStreamID: 1,
authority: u.Host,
authority: authority,
frCh: make(chan http2.Frame),
spdyFrCh: make(chan spdy.Frame),
errCh: make(chan error),
}
......@@ -104,20 +131,45 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser
retry := 0
for {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", serverPort))
var conn net.Conn
var err error
if frontendTLS {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2-14", "spdy/3.1"},
}
conn, err = tls.Dial("tcp", authority, tlsConfig)
} else {
conn, err = net.Dial("tcp", authority)
}
if err != nil {
retry += 1
if retry >= 100 {
st.Close()
st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
}
time.Sleep(150 * time.Millisecond)
continue
}
if frontendTLS {
tlsConn := conn.(*tls.Conn)
cs := tlsConn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual {
st.Close()
st.t.Fatalf("Error negotiated next protocol is not mutual")
}
}
st.conn = conn
break
}
st.fr = http2.NewFramer(st.conn, st.conn)
spdyFr, err := spdy.NewFramer(st.conn, st.conn)
if err != nil {
st.Close()
st.t.Fatalf("Error spdy.NewFramer: %v", err)
}
st.spdyFr = spdyFr
st.enc = hpack.NewEncoder(&st.headerBlkBuf)
st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
st.header.Add(f.Name, f.Value)
......@@ -159,6 +211,26 @@ func (st *serverTester) readFrame() (http2.Frame, error) {
}
}
func (st *serverTester) readSpdyFrame() (spdy.Frame, error) {
go func() {
f, err := st.spdyFr.ReadFrame()
if err != nil {
st.errCh <- err
return
}
st.spdyFrCh <- f
}()
select {
case f := <-st.spdyFrCh:
return f, nil
case err := <-st.errCh:
return nil, err
case <-time.After(2 * time.Second):
return nil, errors.New("timeout waiting for frame")
}
}
type requestParam struct {
name string // name for this request to identify the request in log easily
streamID uint32 // stream ID, automatically assigned if 0
......@@ -211,6 +283,118 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
return res, nil
}
func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) {
res := &serverResponse{}
var id spdy.StreamId
if rp.streamID != 0 {
id = spdy.StreamId(rp.streamID)
if id >= spdy.StreamId(st.nextStreamID) && id%2 == 1 {
st.nextStreamID = uint32(id) + 2
}
} else {
id = spdy.StreamId(st.nextStreamID)
st.nextStreamID += 2
}
method := "GET"
if rp.method != "" {
method = rp.method
}
scheme := "http"
if rp.scheme != "" {
scheme = rp.scheme
}
host := st.authority
if rp.authority != "" {
host = rp.authority
}
path := "/"
if rp.path != "" {
path = rp.path
}
header := make(http.Header)
header.Add(":method", method)
header.Add(":scheme", scheme)
header.Add(":host", host)
header.Add(":path", path)
header.Add(":version", "HTTP/1.1")
header.Add("test-case", rp.name)
for _, h := range rp.header {
header.Add(h.Name, h.Value)
}
var synStreamFlags spdy.ControlFlags
if len(rp.body) == 0 {
synStreamFlags = spdy.ControlFlagFin
}
if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{
CFHeader: spdy.ControlFrameHeader{
Flags: synStreamFlags,
},
StreamId: id,
Headers: header,
}); err != nil {
return nil, err
}
if len(rp.body) != 0 {
if err := st.spdyFr.WriteFrame(&spdy.DataFrame{
StreamId: id,
Flags: spdy.DataFlagFin,
Data: rp.body,
}); err != nil {
return nil, err
}
}
loop:
for {
fr, err := st.readSpdyFrame()
if err != nil {
return res, err
}
switch f := fr.(type) {
case *spdy.SynReplyFrame:
if f.StreamId != id {
break
}
res.header = cloneHeader(f.Headers)
if _, err := fmt.Sscan(res.header.Get(":status"), &res.status); err != nil {
return res, fmt.Errorf("Error parsing status code: %v", err)
}
if f.CFHeader.Flags&spdy.ControlFlagFin != 0 {
break loop
}
case *spdy.DataFrame:
if f.StreamId != id {
break
}
res.body = append(res.body, f.Data...)
if f.Flags&spdy.DataFlagFin != 0 {
break loop
}
case *spdy.RstStreamFrame:
if f.StreamId != id {
break
}
res.spdyRstErrCode = f.Status
break loop
case *spdy.GoAwayFrame:
if f.Status == spdy.GoAwayOK {
break
}
res.spdyGoAwayErrCode = f.Status
break loop
}
}
return res, nil
}
func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
res := &serverResponse{}
st.headerBlkBuf.Reset()
......@@ -299,11 +483,12 @@ loop:
break
}
res.header = cloneHeader(st.header)
res.status, err = strconv.Atoi(res.header.Get(":status"))
var status int
status, err = strconv.Atoi(res.header.Get(":status"))
if err != nil {
return res, fmt.Errorf("Error parsing status code: %v", err)
}
res.status = status
if f.StreamEnded() {
break loop
}
......@@ -322,7 +507,7 @@ loop:
res.errCode = f.ErrCode
break loop
case *http2.GoAwayFrame:
if f.FrameHeader.StreamID != id || f.ErrCode == http2.ErrCodeNo {
if f.ErrCode == http2.ErrCodeNo {
break
}
res.errCode = f.ErrCode
......@@ -335,17 +520,20 @@ loop:
if err := st.fr.WriteSettingsAck(); err != nil {
return res, err
}
// TODO handle PUSH_PROMISE as well, since it alters HPACK context
}
}
return res, nil
}
type serverResponse struct {
status int // HTTP status code
header http.Header // response header fields
body []byte // response body
errCode http2.ErrCode // error code received in RST_STREAM or GOAWAY
connErr bool // true if connection error
status int // HTTP status code
header http.Header // response header fields
body []byte // response body
errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY
connErr bool // true if HTTP/2 connection error
spdyGoAwayErrCode spdy.GoAwayStatus // status code received in SPDY RST_STREAM
spdyRstErrCode spdy.RstStreamStatus // status code received in SPDY GOAWAY
}
func cloneHeader(h http.Header) http.Header {
......
......@@ -156,11 +156,23 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type,
auto nv = frame->syn_stream.nv;
if (LOG_ENABLED(INFO)) {
std::stringstream ss;
for (size_t i = 0; nv[i]; i += 2) {
ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n";
}
ULOG(INFO, upstream) << "HTTP request headers. stream_id="
<< downstream->get_stream_id() << "\n" << ss.str();
}
for (size_t i = 0; nv[i]; i += 2) {
downstream->add_request_header(nv[i], nv[i + 1]);
}
downstream->index_request_headers();
if (downstream->index_request_headers() != 0) {
upstream->error_reply(downstream, 400);
return;
}
auto path = downstream->get_request_header(http2::HD__PATH);
auto scheme = downstream->get_request_header(http2::HD__SCHEME);
......@@ -193,15 +205,6 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type,
downstream->inspect_http2_request();
if (LOG_ENABLED(INFO)) {
std::stringstream ss;
for (size_t i = 0; nv[i]; i += 2) {
ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n";
}
ULOG(INFO, upstream) << "HTTP request headers. stream_id="
<< downstream->get_stream_id() << "\n" << ss.str();
}
downstream->set_request_state(Downstream::HEADER_COMPLETE);
if (frame->syn_stream.hd.flags & SPDYLAY_CTRL_FLAG_FIN) {
if (!downstream->validate_request_bodylen()) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment