Commit e2083055 authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

python: make all callbacks BaseSPDYRequestHandler's instance method

ssl socket and ssctrl are also instance variable now.
parent 31204502
...@@ -1103,52 +1103,6 @@ try: ...@@ -1103,52 +1103,6 @@ try:
import time import time
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
def send_cb(session, data):
ssctrl = session.user_data
wlen = ssctrl.sock.send(data)
return wlen
def read_cb(session, stream_id, length, read_ctrl, source):
data = source.read(length)
if not data:
read_ctrl.flags = READ_EOF
return data
def on_ctrl_recv_cb(session, frame):
ssctrl = session.user_data
if frame.frame_type == SYN_STREAM:
stream = Stream(frame.stream_id)
ssctrl.streams[frame.stream_id] = stream
stream.process_headers(frame.nv)
elif frame.frame_type == HEADERS:
if frame.stream_id in ssctrl.streams:
stream = ssctrl.streams[frame.stream_id]
stream.process_headers(frame.nv)
def on_data_chunk_recv_cb(session, flags, stream_id, data):
ssctrl = session.user_data
if stream_id in ssctrl.streams:
stream = ssctrl.streams[stream_id]
if stream.method == 'POST':
if not stream.rfile:
stream.rfile = io.BytesIO()
stream.rfile.write(data)
else:
# We don't allow request body if method is not POST
session.submit_rst_stream(stream_id, PROTOCOL_ERROR)
def on_stream_close_cb(session, stream_id, status_code):
ssctrl = session.user_data
if stream_id in ssctrl.streams:
del ssctrl.streams[stream_id]
def on_request_recv_cb(session, stream_id):
ssctrl = session.user_data
if stream_id in ssctrl.streams:
stream = ssctrl.streams[stream_id]
ssctrl.handler.handle_one_request(stream)
class Stream: class Stream:
def __init__(self, stream_id): def __init__(self, stream_id):
self.stream_id = stream_id self.stream_id = stream_id
...@@ -1180,9 +1134,7 @@ try: ...@@ -1180,9 +1134,7 @@ try:
self.headers.append((k, v)) self.headers.append((k, v))
class SessionCtrl: class SessionCtrl:
def __init__(self, handler, sock): def __init__(self):
self.handler = handler
self.sock = sock
self.streams = {} self.streams = {}
class BaseSPDYRequestHandler(socketserver.BaseRequestHandler): class BaseSPDYRequestHandler(socketserver.BaseRequestHandler):
...@@ -1271,7 +1223,7 @@ try: ...@@ -1271,7 +1223,7 @@ try:
.format(stream.method)) .format(stream.method))
self.wfile.seek(0) self.wfile.seek(0)
data_prd = DataProvider(self.wfile, read_cb) data_prd = DataProvider(self.wfile, self.read_cb)
stream.data_prd = data_prd stream.data_prd = data_prd
self.send_header(':version', 'HTTP/1.1') self.send_header(':version', 'HTTP/1.1')
...@@ -1281,38 +1233,80 @@ try: ...@@ -1281,38 +1233,80 @@ try:
self.session.submit_response(stream.stream_id, self.session.submit_response(stream.stream_id,
self._response_headers, data_prd) self._response_headers, data_prd)
def send_cb(self, session, data):
return self.sslsock.send(data)
def read_cb(self, session, stream_id, length, read_ctrl, source):
data = source.read(length)
if not data:
read_ctrl.flags = READ_EOF
return data
def on_ctrl_recv_cb(self, session, frame):
if frame.frame_type == SYN_STREAM:
stream = Stream(frame.stream_id)
self.ssctrl.streams[frame.stream_id] = stream
stream.process_headers(frame.nv)
elif frame.frame_type == HEADERS:
if frame.stream_id in self.ssctrl.streams:
stream = self.ssctrl.streams[frame.stream_id]
stream.process_headers(frame.nv)
def on_data_chunk_recv_cb(self, session, flags, stream_id, data):
if stream_id in self.ssctrl.streams:
stream = self.ssctrl.streams[stream_id]
if stream.method == 'POST':
if not stream.rfile:
stream.rfile = io.BytesIO()
stream.rfile.write(data)
else:
# We don't allow request body if method is not POST
session.submit_rst_stream(stream_id, PROTOCOL_ERROR)
def on_stream_close_cb(self, session, stream_id, status_code):
if stream_id in self.ssctrl.streams:
del self.ssctrl.streams[stream_id]
def on_request_recv_cb(self, session, stream_id):
if stream_id in self.ssctrl.streams:
stream = self.ssctrl.streams[stream_id]
self.handle_one_request(stream)
def handle(self): def handle(self):
self.request.setsockopt(socket.IPPROTO_TCP, self.request.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, True) socket.TCP_NODELAY, True)
# TODO We need to call handshake manually because 3.3.0b2 # TODO We need to call handshake manually because 3.3.0b2
# crashes if do_handshake_on_connect=True # crashes if do_handshake_on_connect=True
sock = self.server.ctx.wrap_socket(self.request, server_side=True, self.sslsock = self.server.ctx.wrap_socket(\
do_handshake_on_connect=False) self.request,
sock.setblocking(False) server_side=True,
do_handshake_on_connect=False)
self.sslsock.setblocking(False)
while True: while True:
try: try:
sock.do_handshake() self.sslsock.do_handshake()
break break
except ssl.SSLWantReadError as e: except ssl.SSLWantReadError as e:
select.select([sock], [], []) select.select([self.sslsock], [], [])
except ssl.SSLWantWriteError as e: except ssl.SSLWantWriteError as e:
select.select([], [sock], []) select.select([], [self.sslsock], [])
version = npn_get_version(sock.selected_npn_protocol()) version = npn_get_version(self.sslsock.selected_npn_protocol())
if version == 0: if version == 0:
return return
ssctrl = SessionCtrl(self, sock) self.ssctrl = SessionCtrl()
self.session = Session(\ self.session = Session(\
SERVER, SERVER, version,
version, send_cb=self.send_cb,
send_cb=send_cb, on_ctrl_recv_cb=self.on_ctrl_recv_cb,
on_ctrl_recv_cb=on_ctrl_recv_cb, on_data_chunk_recv_cb=self.on_data_chunk_recv_cb,
on_data_chunk_recv_cb=on_data_chunk_recv_cb, on_stream_close_cb=self.on_stream_close_cb,
on_stream_close_cb=on_stream_close_cb, on_request_recv_cb=self.on_request_recv_cb)
on_request_recv_cb=on_request_recv_cb,
user_data=ssctrl)
self.session.submit_settings(\ self.session.submit_settings(\
FLAG_SETTINGS_NONE, FLAG_SETTINGS_NONE,
...@@ -1322,7 +1316,7 @@ try: ...@@ -1322,7 +1316,7 @@ try:
while self.session.want_read() or self.session.want_write(): while self.session.want_read() or self.session.want_write():
want_read = want_write = False want_read = want_write = False
try: try:
data = sock.recv(4096) data = self.sslsock.recv(4096)
if data: if data:
self.session.recv(data) self.session.recv(data)
else: else:
...@@ -1339,8 +1333,8 @@ try: ...@@ -1339,8 +1333,8 @@ try:
want_write = True want_write = True
if want_read or want_write: if want_read or want_write:
select.select([sock] if want_read else [], select.select([self.sslsock] if want_read else [],
[sock] if want_write else [], [self.sslsock] if want_write else [],
[]) [])
# The following methods and attributes are copied from # The following methods and attributes are copied from
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment