Commit f6d2776a by Sheng

Let tornado parse xheaders

parent 88405edd
...@@ -9,25 +9,43 @@ from webssh.handler import MixinHandler, IndexHandler, InvalidValueError ...@@ -9,25 +9,43 @@ from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
class TestMixinHandler(unittest.TestCase): class TestMixinHandler(unittest.TestCase):
def test_get_real_client_addr(self): def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111
x_real_ip = '2.2.2.2'
x_real_port = 2222
fake_port = 65535
handler = MixinHandler() handler = MixinHandler()
handler.request = HTTPServerRequest(uri='/') handler.request = HTTPServerRequest(uri='/')
handler.request.remote_ip = x_forwarded_for
self.assertIsNone(handler.get_real_client_addr()) self.assertIsNone(handler.get_real_client_addr())
ip = '127.0.0.1' handler.request.headers.add('X-Forwarded-For', x_forwarded_for)
handler.request.headers.add('X-Real-Ip', ip) self.assertEqual(handler.get_real_client_addr(),
self.assertEqual(handler.get_real_client_addr(), False) (x_forwarded_for, fake_port))
handler.request.headers.add('X-Forwarded-Port', fake_port + 1)
self.assertEqual(handler.get_real_client_addr(),
(x_forwarded_for, fake_port))
handler.request.headers['X-Forwarded-Port'] = x_forwarded_port
self.assertEqual(handler.get_real_client_addr(),
(x_forwarded_for, x_forwarded_port))
handler.request.headers.add('X-Real-Port', '12345x') handler.request.remote_ip = x_real_ip
self.assertEqual(handler.get_real_client_addr(), False)
handler.request.headers.update({'X-Real-Port': '12345'}) handler.request.headers.add('X-Real-Ip', x_real_ip)
self.assertEqual(handler.get_real_client_addr(), (ip, 12345)) self.assertEqual(handler.get_real_client_addr(),
(x_real_ip, fake_port))
handler.request.headers.update({'X-Real-ip': None}) handler.request.headers.add('X-Real-Port', fake_port + 1)
self.assertEqual(handler.get_real_client_addr(), False) self.assertEqual(handler.get_real_client_addr(),
(x_real_ip, fake_port))
handler.request.headers.update({'X-Real-Port': '12345x'}) handler.request.headers['X-Real-Port'] = x_real_port
self.assertEqual(handler.get_real_client_addr(), False) self.assertEqual(handler.get_real_client_addr(),
(x_real_ip, x_real_port))
class TestIndexHandler(unittest.TestCase): class TestIndexHandler(unittest.TestCase):
......
...@@ -45,19 +45,22 @@ class MixinHandler(object): ...@@ -45,19 +45,22 @@ class MixinHandler(object):
return value return value
def get_real_client_addr(self): def get_real_client_addr(self):
ip = self.request.headers.get('X-Real-Ip') ip = self.request.remote_ip
port = self.request.headers.get('X-Real-Port')
if ip is None and port is None: if ip == self.request.headers.get('X-Real-Ip'):
return # suppose this app doesn't run after an nginx server port = self.request.headers.get('X-Real-Port')
elif ip in self.request.headers.get('X-Forwarded-For', ''):
port = self.request.headers.get('X-Forwarded-Port')
else:
# not running behind an nginx server
return
if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip):
port = to_int(port) port = to_int(port)
if port and is_valid_port(port): if port is None or not is_valid_port(port):
return (ip, port) # fake port
port = 65535
logging.warning('Bad nginx configuration.') return (ip, port)
return False
class IndexHandler(MixinHandler, tornado.web.RequestHandler): class IndexHandler(MixinHandler, tornado.web.RequestHandler):
...@@ -94,13 +97,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -94,13 +97,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_privatekey(self): def get_privatekey(self):
name = 'privatekey' name = 'privatekey'
lst = self.request.files.get(name) # multipart form lst = self.request.files.get(name)
if lst: if lst:
# multipart form
self.privatekey_filename = lst[0]['filename'] self.privatekey_filename = lst[0]['filename']
data = lst[0]['body'] data = lst[0]['body']
value = self.decode_argument(data, name=name).strip() value = self.decode_argument(data, name=name).strip()
else: else:
value = self.get_argument(name, u'') # urlencoded form # urlencoded form
value = self.get_argument(name, u'')
if len(value) > KEY_MAX_SIZE: if len(value) > KEY_MAX_SIZE:
raise InvalidValueError( raise InvalidValueError(
......
...@@ -28,7 +28,8 @@ def main(): ...@@ -28,7 +28,8 @@ def main():
options.parse_command_line() options.parse_command_line()
loop = tornado.ioloop.IOLoop.current() loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop, options), get_app_settings(options)) app = make_app(make_handlers(loop, options), get_app_settings(options))
app.listen(options.port, options.address, max_body_size=max_body_size) server_settings = dict(xheaders=True, max_body_size=max_body_size)
app.listen(options.port, options.address, **server_settings)
logging.info('Listening on {}:{}'.format(options.address, options.port)) logging.info('Listening on {}:{}'.format(options.address, options.port))
loop.start() loop.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment