Commit 77b6fbfd by Sheng

Block requests not come from trusted_downstream and public non-https requests

parent db3ee2b7
...@@ -15,6 +15,7 @@ matrix: ...@@ -15,6 +15,7 @@ matrix:
install: install:
- pip install -r requirements.txt - pip install -r requirements.txt
- pip install pytest pytest-cov codecov flake8 - pip install pytest pytest-cov codecov flake8
- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
script: script:
- pytest --cov=webssh - pytest --cov=webssh
......
import unittest import unittest
import paramiko import paramiko
from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tornado.web import HTTPError
from tests.utils import read_file, make_tests_data_path from tests.utils import read_file, make_tests_data_path
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
try:
from unittest.mock import Mock
except ImportError:
from mock import Mock
class TestMixinHandler(unittest.TestCase): class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
request = HTTPRequest('http://example.com/')
handler.request = request
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
def test_get_real_client_addr(self): def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1' x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111 x_forwarded_port = 1111
......
...@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop ...@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from webssh.settings import swallow_http_errors from webssh.settings import swallow_http_errors
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, is_valid_ip_address, is_valid_port, is_valid_hostname,
to_bytes, to_str, to_int, UnicodeType to_bytes, to_str, to_int, to_ip_address, UnicodeType
) )
from webssh.worker import Worker, recycle_worker, workers from webssh.worker import Worker, recycle_worker, workers
...@@ -39,6 +39,28 @@ class InvalidValueError(Exception): ...@@ -39,6 +39,28 @@ class InvalidValueError(Exception):
class MixinHandler(object): class MixinHandler(object):
def prepare(self):
if self.is_forbidden():
raise tornado.web.HTTPError(403)
def is_forbidden(self):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
* non-https requests from a public network.
"""
context = self.request.connection.context
ip = context.address[0]
lst = context.trusted_downstream
if lst and ip not in lst:
return True
if context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if ipaddr.is_global:
return True
def set_default_headers(self): def set_default_headers(self):
self.set_header('Server', 'TornadoServer') self.set_header('Server', 'TornadoServer')
......
...@@ -6,7 +6,7 @@ from tornado.options import options ...@@ -6,7 +6,7 @@ from tornado.options import options
from webssh.handler import IndexHandler, WsockHandler from webssh.handler import IndexHandler, WsockHandler
from webssh.settings import ( from webssh.settings import (
get_app_settings, get_host_keys_settings, get_policy_setting, get_app_settings, get_host_keys_settings, get_policy_setting,
get_ssl_context, max_body_size, xheaders get_ssl_context, get_server_settings
) )
...@@ -31,12 +31,12 @@ def main(): ...@@ -31,12 +31,12 @@ def main():
loop = tornado.ioloop.IOLoop.current() loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop, options), get_app_settings(options)) app = make_app(make_handlers(loop, options), get_app_settings(options))
ssl_ctx = get_ssl_context(options) ssl_ctx = get_ssl_context(options)
kwargs = dict(xheaders=xheaders, max_body_size=max_body_size) server_settings = get_server_settings(options)
app.listen(options.port, options.address, **kwargs) app.listen(options.port, options.address, **server_settings)
logging.info('Listening on {}:{}'.format(options.address, options.port)) logging.info('Listening on {}:{}'.format(options.address, options.port))
if ssl_ctx: if ssl_ctx:
kwargs.update(ssl_options=ssl_ctx) server_settings.update(ssl_options=ssl_ctx)
app.listen(options.sslPort, options.sslAddress, **kwargs) app.listen(options.sslPort, options.sslAddress, **server_settings)
logging.info('Listening on ssl {}:{}'.format(options.sslAddress, logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
options.sslPort)) options.sslPort))
loop.start() loop.start()
......
...@@ -51,6 +51,15 @@ def get_app_settings(options): ...@@ -51,6 +51,15 @@ def get_app_settings(options):
return settings return settings
def get_server_settings(options):
settings = dict(
xheaders=xheaders,
max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options)
)
return settings
def get_host_keys_settings(options): def get_host_keys_settings(options):
if not options.hostFile: if not options.hostFile:
host_keys_filename = os.path.join(base_dir, 'known_hosts') host_keys_filename = os.path.join(base_dir, 'known_hosts')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment