Commit 80bdddc2 by Sheng

Added setting swallow_http_errors

parent e1fbc417
...@@ -11,7 +11,9 @@ from tornado.options import options ...@@ -11,7 +11,9 @@ from tornado.options import options
from tests.sshserver import run_ssh_server, banner from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata, read_file, make_tests_data_path # noqa from tests.utils import encode_multipart_formdata, read_file, make_tests_data_path # noqa
from webssh.main import make_app, make_handlers from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size from webssh.settings import (
get_app_settings, max_body_size, swallow_http_errors
)
from webssh.utils import to_str from webssh.utils import to_str
try: try:
...@@ -65,59 +67,56 @@ class TestApp(AsyncHTTPTestCase): ...@@ -65,59 +67,56 @@ class TestApp(AsyncHTTPTestCase):
options.update(max_body_size=max_body_size) options.update(max_body_size=max_body_size)
return options return options
def my_assertIn(self, part, whole):
if swallow_http_errors:
self.assertIn(part, whole)
else:
self.assertIn(b'Bad Request', whole)
def test_app_with_invalid_form_for_missing_argument(self): def test_app_with_invalid_form_for_missing_argument(self):
response = self.fetch('/') response = self.fetch('/')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = 'port=7000&username=admin&password' body = 'port=7000&username=admin&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Missing argument hostname', response.body)
self.assertIn(b'Missing argument hostname', response.body)
body = 'hostname=127.0.0.1&username=admin&password' body = 'hostname=127.0.0.1&username=admin&password'
self.assertEqual(response.code, 400)
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertIn(b'Missing argument port', response.body) self.my_assertIn(b'Missing argument port', response.body)
body = 'hostname=127.0.0.1&port=7000&password' body = 'hostname=127.0.0.1&port=7000&password'
self.assertEqual(response.code, 400)
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertIn(b'Missing argument username', response.body) self.my_assertIn(b'Missing argument username', response.body)
body = 'hostname=&port=&username=&password' body = 'hostname=&port=&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Missing value hostname', response.body)
self.assertIn(b'Missing value hostname', response.body)
body = 'hostname=127.0.0.1&port=&username=&password' body = 'hostname=127.0.0.1&port=&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Missing value port', response.body)
self.assertIn(b'Missing value port', response.body)
body = 'hostname=127.0.0.1&port=7000&username=&password' body = 'hostname=127.0.0.1&port=7000&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Missing value username', response.body)
self.assertIn(b'Missing value username', response.body)
def test_app_with_invalid_form_for_invalid_value(self): def test_app_with_invalid_form_for_invalid_value(self):
body = 'hostname=127.0.0&port=22&username=&password' body = 'hostname=127.0.0&port=22&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertIn(b'Invalid hostname', response.body) self.my_assertIn(b'Invalid hostname', response.body)
body = 'hostname=http://www.googe.com&port=22&username=&password' body = 'hostname=http://www.googe.com&port=22&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Invalid hostname', response.body)
self.assertIn(b'Invalid hostname', response.body)
body = 'hostname=127.0.0.1&port=port&username=&password' body = 'hostname=127.0.0.1&port=port&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Invalid port', response.body)
self.assertIn(b'Invalid port', response.body)
body = 'hostname=127.0.0.1&port=70000&username=&password' body = 'hostname=127.0.0.1&port=70000&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertEqual(response.code, 400) self.my_assertIn(b'Invalid port', response.body)
self.assertIn(b'Invalid port', response.body)
def test_app_with_wrong_hostname_ip(self): def test_app_with_wrong_hostname_ip(self):
body = 'hostname=127.0.0.1&port=7000&username=admin' body = 'hostname=127.0.0.1&port=7000&username=admin'
...@@ -370,10 +369,16 @@ class TestApp(AsyncHTTPTestCase): ...@@ -370,10 +369,16 @@ class TestApp(AsyncHTTPTestCase):
headers = { headers = {
'Content-Type': content_type, 'content-length': str(len(body)) 'Content-Type': content_type, 'content-length': str(len(body))
} }
if swallow_http_errors:
response = yield client.fetch(url, method='POST', headers=headers,
body=body)
self.assertIn(b'Invalid private key', response.body)
else:
with self.assertRaises(HTTPError) as ctx: with self.assertRaises(HTTPError) as ctx:
yield client.fetch(url, method='POST', headers=headers, body=body) yield client.fetch(url, method='POST', headers=headers,
self.assertEqual(ctx.exception.code, 400) body=body)
self.assertIn('Invalid private key', ctx.exception.message) self.assertIn('Bad Request', ctx.exception.message)
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_auth_with_pubkey_exceeds_key_max_size(self): def test_app_auth_with_pubkey_exceeds_key_max_size(self):
...@@ -389,10 +394,15 @@ class TestApp(AsyncHTTPTestCase): ...@@ -389,10 +394,15 @@ class TestApp(AsyncHTTPTestCase):
headers = { headers = {
'Content-Type': content_type, 'content-length': str(len(body)) 'Content-Type': content_type, 'content-length': str(len(body))
} }
if swallow_http_errors:
response = yield client.fetch(url, method='POST', headers=headers,
body=body)
self.assertIn(b'Invalid private key', response.body)
else:
with self.assertRaises(HTTPError) as ctx: with self.assertRaises(HTTPError) as ctx:
yield client.fetch(url, method='POST', headers=headers, body=body) yield client.fetch(url, method='POST', headers=headers,
self.assertEqual(ctx.exception.code, 400) body=body)
self.assertIn('Invalid private key', ctx.exception.message) self.assertIn('Bad Request', ctx.exception.message)
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_auth_with_pubkey_cannot_be_decoded_by_multipart_form(self): def test_app_auth_with_pubkey_cannot_be_decoded_by_multipart_form(self):
...@@ -411,10 +421,15 @@ class TestApp(AsyncHTTPTestCase): ...@@ -411,10 +421,15 @@ class TestApp(AsyncHTTPTestCase):
headers = { headers = {
'Content-Type': content_type, 'content-length': str(len(body)) 'Content-Type': content_type, 'content-length': str(len(body))
} }
if swallow_http_errors:
response = yield client.fetch(url, method='POST', headers=headers,
body=body)
self.assertIn(b'Invalid unicode', response.body)
else:
with self.assertRaises(HTTPError) as ctx: with self.assertRaises(HTTPError) as ctx:
yield client.fetch(url, method='POST', headers=headers, body=body) yield client.fetch(url, method='POST', headers=headers,
self.assertEqual(ctx.exception.code, 400) body=body)
self.assertIn('Invalid unicode', ctx.exception.message) self.assertIn('Bad Request', ctx.exception.message)
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_post_form_with_large_body_size_by_multipart_form(self): def test_app_post_form_with_large_body_size_by_multipart_form(self):
...@@ -432,8 +447,8 @@ class TestApp(AsyncHTTPTestCase): ...@@ -432,8 +447,8 @@ class TestApp(AsyncHTTPTestCase):
} }
with self.assertRaises(HTTPError) as ctx: with self.assertRaises(HTTPError) as ctx:
yield client.fetch(url, method='POST', headers=headers, body=body) yield client.fetch(url, method='POST', headers=headers,
self.assertEqual(ctx.exception.code, 400) body=body)
self.assertIn('Bad Request', ctx.exception.message) self.assertIn('Bad Request', ctx.exception.message)
@tornado.testing.gen_test @tornado.testing.gen_test
...@@ -447,7 +462,6 @@ class TestApp(AsyncHTTPTestCase): ...@@ -447,7 +462,6 @@ class TestApp(AsyncHTTPTestCase):
body = self.body + '&privatekey=' + privatekey body = self.body + '&privatekey=' + privatekey
with self.assertRaises(HTTPError) as ctx: with self.assertRaises(HTTPError) as ctx:
yield client.fetch(url, method='POST', body=body) yield client.fetch(url, method='POST', body=body)
self.assertEqual(ctx.exception.code, 400)
self.assertIn('Bad Request', ctx.exception.message) self.assertIn('Bad Request', ctx.exception.message)
@tornado.testing.gen_test @tornado.testing.gen_test
......
...@@ -4,7 +4,7 @@ import paramiko ...@@ -4,7 +4,7 @@ import paramiko
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tests.utils import read_file, make_tests_data_path from tests.utils import read_file, make_tests_data_path
from webssh.handler import ( from webssh.handler import (
MixinHandler, IndexHandler, parse_encoding, InvalidException MixinHandler, IndexHandler, parse_encoding, InvalidValueError
) )
...@@ -83,7 +83,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestIndexHandler(unittest.TestCase):
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname) pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
with self.assertRaises(InvalidException) as exc: with self.assertRaises(InvalidValueError) as exc:
pkey = IndexHandler.get_pkey_obj('x'+key, None, fname) pkey = IndexHandler.get_pkey_obj('x'+key, None, fname)
self.assertIn('Invalid private key', str(exc)) self.assertIn('Invalid private key', str(exc))
...@@ -94,9 +94,9 @@ class TestIndexHandler(unittest.TestCase): ...@@ -94,9 +94,9 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(make_tests_data_path(fname)) key = read_file(make_tests_data_path(fname))
pkey = IndexHandler.get_pkey_obj(key, password, fname) pkey = IndexHandler.get_pkey_obj(key, password, fname)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
with self.assertRaises(InvalidException) as exc: with self.assertRaises(InvalidValueError) as exc:
pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname) pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname)
self.assertIn('Wrong password', str(exc)) self.assertIn('Wrong password', str(exc))
with self.assertRaises(InvalidException) as exc: with self.assertRaises(InvalidValueError) as exc:
pkey = IndexHandler.get_pkey_obj('x'+key, password, fname) pkey = IndexHandler.get_pkey_obj('x'+key, password, fname)
self.assertIn('Invalid private key', str(exc)) self.assertIn('Invalid private key', str(exc))
...@@ -10,11 +10,12 @@ import paramiko ...@@ -10,11 +10,12 @@ import paramiko
import tornado.web import tornado.web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from webssh.worker import Worker, recycle_worker, workers from webssh.settings import swallow_http_errors
from webssh.utils import ( from webssh.utils import (
is_valid_ipv4_address, is_valid_ipv6_address, is_valid_port, is_valid_ipv4_address, is_valid_ipv6_address, is_valid_port,
is_valid_hostname, to_bytes, to_str, UnicodeType is_valid_hostname, to_bytes, to_str, UnicodeType
) )
from webssh.worker import Worker, recycle_worker, workers
try: try:
from concurrent.futures import Future from concurrent.futures import Future
...@@ -38,34 +39,24 @@ def parse_encoding(data): ...@@ -38,34 +39,24 @@ def parse_encoding(data):
return s.strip('"').split('.')[-1] return s.strip('"').split('.')[-1]
class InvalidException(Exception): class InvalidValueError(Exception):
pass pass
class MixinHandler(object): class MixinHandler(object):
formater = 'Missing value {}'
def write_error(self, status_code, **kwargs):
exc_info = kwargs.get('exc_info')
if exc_info and len(exc_info) > 1:
info = str(exc_info[1])
if info:
self._reason = info.split(':', 1)[-1].strip()
super(MixinHandler, self).write_error(status_code, **kwargs)
def get_value(self, name): def get_value(self, name):
value = self.get_argument(name) value = self.get_argument(name)
if not value: if not value:
raise InvalidException(self.formater.format(name)) raise InvalidValueError('Missing value {}'.format(name))
return value return value
def get_real_client_addr(self): def get_real_client_addr(self):
ip = self.request.headers.get('X-Real-Ip') ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port') port = self.request.headers.get('X-Real-Port')
if ip is None and port is None: # suppose the server doesn't use nginx if ip is None and port is None:
return return # suppose this app doesn't run after an nginx server
if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip): if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip):
try: try:
...@@ -87,19 +78,33 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -87,19 +78,33 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.policy = policy self.policy = policy
self.host_keys_settings = host_keys_settings self.host_keys_settings = host_keys_settings
self.filename = None self.filename = None
self.result = dict(id=None, status=None, encoding=None)
def write_error(self, status_code, **kwargs):
if self.settings.get('serve_traceback') or status_code == 500 or \
not swallow_http_errors:
super(MixinHandler, self).write_error(status_code, **kwargs)
else:
exc_info = kwargs.get('exc_info')
if exc_info:
self._reason = exc_info[1].log_message
self.result.update(status=self._reason)
self.set_status(200)
self.finish(self.result)
def get_privatekey(self): def get_privatekey(self):
lst = self.request.files.get('privatekey') # multipart form name = 'privatekey'
lst = self.request.files.get(name) # multipart form
if not lst: if not lst:
return self.get_argument('privatekey', u'') # urlencoded form return self.get_argument(name, u'') # urlencoded form
else: else:
self.filename = lst[0]['filename'] self.filename = lst[0]['filename']
data = lst[0]['body'] data = lst[0]['body']
if len(data) > KEY_MAX_SIZE: if len(data) > KEY_MAX_SIZE:
raise InvalidException( raise InvalidValueError(
'Invalid private key: {}'.format(self.filename) 'Invalid private key: {}'.format(self.filename)
) )
return self.decode_argument(data, name=self.filename) return self.decode_argument(data, name=name)
@classmethod @classmethod
def get_specific_pkey(cls, pkeycls, privatekey, password): def get_specific_pkey(cls, pkeycls, privatekey, password):
...@@ -130,7 +135,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -130,7 +135,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
error = ( error = (
'Wrong password {!r} for decrypting the private key.' 'Wrong password {!r} for decrypting the private key.'
) .format(password) ) .format(password)
raise InvalidException(error) raise InvalidValueError(error)
return pkey return pkey
...@@ -138,7 +143,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -138,7 +143,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
value = self.get_value('hostname') value = self.get_value('hostname')
if not (is_valid_hostname(value) | is_valid_ipv4_address(value) | if not (is_valid_hostname(value) | is_valid_ipv4_address(value) |
is_valid_ipv6_address(value)): is_valid_ipv6_address(value)):
raise InvalidException('Invalid hostname: {}'.format(value)) raise InvalidValueError('Invalid hostname: {}'.format(value))
return value return value
def get_port(self): def get_port(self):
...@@ -151,7 +156,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -151,7 +156,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
if is_valid_port(port): if is_valid_port(port):
return port return port
raise InvalidException('Invalid port: {}'.format(value)) raise InvalidValueError('Invalid port: {}'.format(value))
def get_args(self): def get_args(self):
hostname = self.get_hostname() hostname = self.get_hostname()
...@@ -189,7 +194,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -189,7 +194,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
try: try:
args = self.get_args() args = self.get_args()
except InvalidException as exc: except InvalidValueError as exc:
raise tornado.web.HTTPError(400, str(exc)) raise tornado.web.HTTPError(400, str(exc))
dst_addr = (args[0], args[1]) dst_addr = (args[0], args[1])
...@@ -227,10 +232,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -227,10 +232,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
@tornado.gen.coroutine @tornado.gen.coroutine
def post(self): def post(self):
worker_id = None
status = None
encoding = None
future = Future() future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
t.setDaemon(True) t.setDaemon(True)
...@@ -239,20 +240,17 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -239,20 +240,17 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
try: try:
worker = yield future worker = yield future
except (ValueError, paramiko.SSHException) as exc: except (ValueError, paramiko.SSHException) as exc:
status = str(exc) self.result.update(status=str(exc))
else: else:
worker_id = worker.id workers[worker.id] = worker
workers[worker_id] = worker
self.loop.call_later(DELAY, recycle_worker, worker) self.loop.call_later(DELAY, recycle_worker, worker)
encoding = worker.encoding self.result.update(id=worker.id, encoding=worker.encoding)
self.write(dict(id=worker_id, status=status, encoding=encoding)) self.write(self.result)
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
formater = 'Bad Request (Missing value {})'
def initialize(self, loop): def initialize(self, loop):
self.loop = loop self.loop = loop
self.worker_ref = None self.worker_ref = None
...@@ -265,8 +263,8 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): ...@@ -265,8 +263,8 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
logging.info('Connected from {}:{}'.format(*self.src_addr)) logging.info('Connected from {}:{}'.format(*self.src_addr))
try: try:
worker_id = self.get_value('id') worker_id = self.get_value('id')
except (tornado.web.MissingArgumentError, InvalidException) as exc: except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
self.close(reason=str(exc).split(':', 1)[-1].strip()) self.close(reason=str(exc))
else: else:
worker = workers.get(worker_id) worker = workers.get(worker_id)
if worker and worker.src_addr[0] == self.src_addr[0]: if worker and worker.src_addr[0] == self.src_addr[0]:
......
...@@ -29,6 +29,7 @@ define('version', type=bool, help='Show version information', ...@@ -29,6 +29,7 @@ define('version', type=bool, help='Show version information',
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024 max_body_size = 1 * 1024 * 1024
swallow_http_errors = True
def get_app_settings(options): def get_app_settings(options):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment