Commit c06bf531 by Sheng

Added an option for blocking public non-https requests

parent 746982b0
...@@ -3,7 +3,6 @@ import paramiko ...@@ -3,7 +3,6 @@ import paramiko
from tornado.httpclient import HTTPRequest from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tornado.web import HTTPError
from tests.utils import read_file, make_tests_data_path from tests.utils import read_file, make_tests_data_path
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
...@@ -17,6 +16,8 @@ class TestMixinHandler(unittest.TestCase): ...@@ -17,6 +16,8 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self): def test_is_forbidden(self):
handler = MixinHandler() handler = MixinHandler()
handler.is_open_to_public = True
handler.forbid_public_http = True
request = HTTPRequest('http://example.com/') request = HTTPRequest('http://example.com/')
handler.request = request handler.request = request
......
...@@ -7,10 +7,11 @@ import paramiko ...@@ -7,10 +7,11 @@ import paramiko
import tornado.options as options import tornado.options as options
from tests.utils import make_tests_data_path from tests.utils import make_tests_data_path
from webssh import settings
from webssh.policy import load_host_keys from webssh.policy import load_host_keys
from webssh.settings import ( from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version, get_host_keys_settings, get_policy_setting, base_dir, print_version,
get_ssl_context, get_trusted_downstream get_ssl_context, get_trusted_downstream, detect_is_open_to_public,
) )
from webssh.utils import UnicodeType from webssh.utils import UnicodeType
from webssh._version import __version__ from webssh._version import __version__
...@@ -137,3 +138,29 @@ class TestSettings(unittest.TestCase): ...@@ -137,3 +138,29 @@ class TestSettings(unittest.TestCase):
options.tdstream = '1.1.1.1, 2.2.2.' options.tdstream = '1.1.1.1, 2.2.2.'
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_trusted_downstream(options), tdstream get_trusted_downstream(options), tdstream
def test_detect_is_open_to_public(self):
options.fbidhttp = True
options.address = 'localhost'
detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public)
options.address = '127.0.0.1'
detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public)
options.address = '192.168.1.1'
detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public)
options.address = ''
detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public)
options.address = '0.0.0.0'
detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public)
options.address = '::'
detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public)
import unittest import unittest
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_str, to_bytes, to_int to_int, on_public_network_interface, on_public_network_interfaces,
get_ips_by_name
) )
...@@ -51,3 +52,25 @@ class TestUitls(unittest.TestCase): ...@@ -51,3 +52,25 @@ class TestUitls(unittest.TestCase):
self.assertFalse(is_valid_hostname('https://www.google.com')) self.assertFalse(is_valid_hostname('https://www.google.com'))
self.assertFalse(is_valid_hostname('127.0.0.1')) self.assertFalse(is_valid_hostname('127.0.0.1'))
self.assertFalse(is_valid_hostname('::1')) self.assertFalse(is_valid_hostname('::1'))
def test_get_ips_by_name(self):
self.assertTrue(get_ips_by_name(''), {'0.0.0.0', '::'})
self.assertTrue(get_ips_by_name('localhost'), {'127.0.0.1'})
self.assertTrue(get_ips_by_name('192.68.1.1'), {'192.168.1.1'})
self.assertTrue(get_ips_by_name('2.2.2.2'), {'2.2.2.2'})
def test_on_public_network_interface(self):
self.assertTrue(on_public_network_interface('0.0.0.0'))
self.assertTrue(on_public_network_interface('::'))
self.assertTrue(on_public_network_interface('0:0:0:0:0:0:0:0'))
self.assertTrue(on_public_network_interface('2.2.2.2'))
self.assertTrue(on_public_network_interface('2:2:2:2:2:2:2:2'))
self.assertIsNone(on_public_network_interface('127.0.0.1'))
def test_on_public_network_interfaces(self):
self.assertTrue(
on_public_network_interfaces(['0.0.0.0', '127.0.0.1'])
)
self.assertIsNone(
on_public_network_interfaces(['192.168.1.1', '127.0.0.1'])
)
...@@ -10,7 +10,8 @@ import paramiko ...@@ -10,7 +10,8 @@ import paramiko
import tornado.web import tornado.web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from webssh.settings import swallow_http_errors from tornado.options import options
from webssh import settings
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, is_valid_ip_address, is_valid_port, is_valid_hostname,
to_bytes, to_str, to_int, to_ip_address, UnicodeType to_bytes, to_str, to_int, to_ip_address, UnicodeType
...@@ -39,11 +40,20 @@ class InvalidValueError(Exception): ...@@ -39,11 +40,20 @@ class InvalidValueError(Exception):
class MixinHandler(object): class MixinHandler(object):
is_open_to_public = None
forbid_public_http = None
custom_headers = { custom_headers = {
'Server': 'TornadoServer' 'Server': 'TornadoServer'
} }
def initialize(self): def initialize(self):
if self.is_open_to_public is None:
MixinHandler.is_open_to_public = settings.is_open_to_public
if self.forbid_public_http is None:
MixinHandler.forbid_public_http = options.fbidhttp
if self.is_forbidden(): if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
self.request.connection.stream.write(to_bytes(result)) self.request.connection.stream.write(to_bytes(result))
...@@ -66,11 +76,12 @@ class MixinHandler(object): ...@@ -66,11 +76,12 @@ class MixinHandler(object):
) )
return True return True
if context._orig_protocol == 'http': if self.is_open_to_public and self.forbid_public_http:
ipaddr = to_ip_address(ip) if context._orig_protocol == 'http':
if not ipaddr.is_private: ipaddr = to_ip_address(ip)
logging.warning('Public non-https request is forbidden.') if not ipaddr.is_private:
return True logging.warning('Public non-https request is forbidden.')
return True
def set_default_headers(self): def set_default_headers(self):
for header in self.custom_headers.items(): for header in self.custom_headers.items():
...@@ -127,7 +138,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -127,7 +138,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
super(IndexHandler, self).initialize() super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not swallow_http_errors: if self.request.method != 'POST' or not settings.swallow_http_errors:
super(IndexHandler, self).write_error(status_code, **kwargs) super(IndexHandler, self).write_error(status_code, **kwargs)
else: else:
exc_info = kwargs.get('exc_info') exc_info = kwargs.get('exc_info')
......
...@@ -6,7 +6,7 @@ from tornado.options import options ...@@ -6,7 +6,7 @@ from tornado.options import options
from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler
from webssh.settings import ( from webssh.settings import (
get_app_settings, get_host_keys_settings, get_policy_setting, get_app_settings, get_host_keys_settings, get_policy_setting,
get_ssl_context, get_server_settings get_ssl_context, get_server_settings, detect_is_open_to_public
) )
...@@ -40,6 +40,7 @@ def main(): ...@@ -40,6 +40,7 @@ def main():
app.listen(options.sslport, options.ssladdress, **server_settings) app.listen(options.sslport, options.ssladdress, **server_settings)
logging.info('Listening on ssl {}:{}'.format(options.ssladdress, logging.info('Listening on ssl {}:{}'.format(options.ssladdress,
options.sslport)) options.sslport))
detect_is_open_to_public(options)
loop.start() loop.start()
......
...@@ -7,7 +7,9 @@ from tornado.options import define ...@@ -7,7 +7,9 @@ from tornado.options import define
from webssh.policy import ( from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting load_host_keys, get_policy_class, check_policy_setting
) )
from webssh.utils import to_ip_address from webssh.utils import (
to_ip_address, get_ips_by_name, on_public_network_interfaces
)
from webssh._version import __version__ from webssh._version import __version__
...@@ -29,6 +31,7 @@ define('policy', default='warning', ...@@ -29,6 +31,7 @@ define('policy', default='warning',
define('hostfile', default='', help='User defined host keys file') define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file') define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='trusted downstream, separated by comma') define('tdstream', default='', help='trusted downstream, separated by comma')
define('fbidhttp', type=bool, default=True, help='forbid public http request')
define('wpintvl', type=int, default=0, help='Websocket ping interval') define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('version', type=bool, help='Show version information', define('version', type=bool, help='Show version information',
callback=print_version) callback=print_version)
...@@ -38,6 +41,7 @@ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ...@@ -38,6 +41,7 @@ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024 max_body_size = 1 * 1024 * 1024
swallow_http_errors = True swallow_http_errors = True
xheaders = True xheaders = True
is_open_to_public = False
def get_app_settings(options): def get_app_settings(options):
...@@ -113,3 +117,12 @@ def get_trusted_downstream(options): ...@@ -113,3 +117,12 @@ def get_trusted_downstream(options):
to_ip_address(ip) to_ip_address(ip)
tdstream.add(ip) tdstream.add(ip)
return tdstream return tdstream
def detect_is_open_to_public(options):
global is_open_to_public
if on_public_network_interfaces(get_ips_by_name(options.address)):
is_open_to_public = True
logging.info('Forbid public http: {}'.format(options.fbidhttp))
else:
is_open_to_public = False
import ipaddress import ipaddress
import re import re
import socket
try: try:
from types import UnicodeType from types import UnicodeType
...@@ -10,6 +11,9 @@ except ImportError: ...@@ -10,6 +11,9 @@ except ImportError:
numeric = re.compile(r'[0-9]+$') numeric = re.compile(r'[0-9]+$')
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE) allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
default_public_ipv4addr = ipaddress.ip_address(u'0.0.0.0')
default_public_ipv6addr = ipaddress.ip_address(u'::')
def to_str(bstr, encoding='utf-8'): def to_str(bstr, encoding='utf-8'):
if isinstance(bstr, bytes): if isinstance(bstr, bytes):
...@@ -60,3 +64,25 @@ def is_valid_hostname(hostname): ...@@ -60,3 +64,25 @@ def is_valid_hostname(hostname):
return False return False
return all(allowed.match(label) for label in labels) return all(allowed.match(label) for label in labels)
def get_ips_by_name(name):
if name == '':
return {'0.0.0.0', '::'}
ret = socket.getaddrinfo(name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)
return {t[4][0] for t in ret}
def on_public_network_interface(ip):
ipaddr = to_ip_address(ip)
if ipaddr == default_public_ipv4addr or ipaddr == default_public_ipv6addr:
return True
if not ipaddr.is_private:
return True
def on_public_network_interfaces(ips):
for ip in ips:
if on_public_network_interface(ip):
return True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment