Commit e85ae169 by Sheng

Added to_bytes function to utils

parent cb866825
...@@ -13,6 +13,7 @@ from tests.sshserver import run_ssh_server, banner ...@@ -13,6 +13,7 @@ from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata, read_file from tests.utils import encode_multipart_formdata, read_file
from webssh.main import make_app, make_handlers from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size, base_dir from webssh.settings import get_app_settings, max_body_size, base_dir
from webssh.utils import to_str
handler.DELAY = 0.1 handler.DELAY = 0.1
...@@ -22,7 +23,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -22,7 +23,7 @@ class TestApp(AsyncHTTPTestCase):
running = [True] running = [True]
sshserver_port = 2200 sshserver_port = 2200
body = u'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa body = 'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
body_dict = { body_dict = {
'hostname': '127.0.0.1', 'hostname': '127.0.0.1',
'port': str(sshserver_port), 'port': str(sshserver_port),
...@@ -61,37 +62,37 @@ class TestApp(AsyncHTTPTestCase): ...@@ -61,37 +62,37 @@ class TestApp(AsyncHTTPTestCase):
def test_app_with_invalid_form(self): def test_app_with_invalid_form(self):
response = self.fetch('/') response = self.fetch('/')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = u'hostname=&port=&username=&password' body = 'hostname=&port=&username=&password'
response = self.fetch('/', method="POST", body=body) response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty hostname"', response.body) self.assertIn(b'"status": "Empty hostname"', response.body)
body = u'hostname=127.0.0.1&port=&username=&password' body = 'hostname=127.0.0.1&port=&username=&password'
response = self.fetch('/', method="POST", body=body) response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty port"', response.body) self.assertIn(b'"status": "Empty port"', response.body)
body = u'hostname=127.0.0.1&port=port&username=&password' body = 'hostname=127.0.0.1&port=port&username=&password'
response = self.fetch('/', method="POST", body=body) response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body) self.assertIn(b'"status": "Invalid port', response.body)
body = u'hostname=127.0.0.1&port=70000&username=&password' body = 'hostname=127.0.0.1&port=70000&username=&password'
response = self.fetch('/', method="POST", body=body) response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body) self.assertIn(b'"status": "Invalid port', response.body)
body = u'hostname=127.0.0.1&port=7000&username=&password' body = 'hostname=127.0.0.1&port=7000&username=&password'
response = self.fetch('/', method="POST", body=body) response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty username"', response.body) self.assertIn(b'"status": "Empty username"', response.body)
def test_app_with_wrong_credentials(self): def test_app_with_wrong_credentials(self):
response = self.fetch('/') response = self.fetch('/')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
response = self.fetch('/', method="POST", body=self.body + u's') response = self.fetch('/', method="POST", body=self.body + 's')
self.assertIn(b'Authentication failed.', response.body) self.assertIn(b'Authentication failed.', response.body)
def test_app_with_correct_credentials(self): def test_app_with_correct_credentials(self):
response = self.fetch('/') response = self.fetch('/')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
response = self.fetch('/', method="POST", body=self.body) response = self.fetch('/', method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
...@@ -104,7 +105,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -104,7 +105,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body) response = yield client.fetch(url, method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
...@@ -133,7 +134,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -133,7 +134,7 @@ class TestApp(AsyncHTTPTestCase):
} }
response = yield client.fetch(url, method="POST", headers=headers, response = yield client.fetch(url, method="POST", headers=headers,
body=body) body=body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
...@@ -142,7 +143,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -142,7 +143,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id'] ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url) ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message() msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner) self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close() ws.close()
@tornado.testing.gen_test @tornado.testing.gen_test
...@@ -153,7 +154,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -153,7 +154,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key')) privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
privatekey = privatekey[:100] + u'bad' + privatekey[100:] privatekey = privatekey[:100] + 'bad' + privatekey[100:]
files = [('privatekey', 'user_rsa_key', privatekey)] files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(), content_type, body = encode_multipart_formdata(self.body_dict.items(),
files) files)
...@@ -162,7 +163,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -162,7 +163,7 @@ class TestApp(AsyncHTTPTestCase):
} }
response = yield client.fetch(url, method="POST", headers=headers, response = yield client.fetch(url, method="POST", headers=headers,
body=body) body=body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNotNone(data['status']) self.assertIsNotNone(data['status'])
self.assertIsNone(data['id']) self.assertIsNone(data['id'])
self.assertIsNone(data['encoding']) self.assertIsNone(data['encoding'])
...@@ -174,7 +175,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -174,7 +175,7 @@ class TestApp(AsyncHTTPTestCase):
response = yield client.fetch(url) response = yield client.fetch(url)
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
privatekey = u'h' * (2 * max_body_size) privatekey = 'h' * (2 * max_body_size)
files = [('privatekey', 'user_rsa_key', privatekey)] files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(), content_type, body = encode_multipart_formdata(self.body_dict.items(),
files) files)
...@@ -193,7 +194,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -193,7 +194,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body) response = yield client.fetch(url, method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
...@@ -202,7 +203,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -202,7 +203,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id'] ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url) ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message() msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner) self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close() ws.close()
@tornado.testing.gen_test @tornado.testing.gen_test
...@@ -214,7 +215,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -214,7 +215,7 @@ class TestApp(AsyncHTTPTestCase):
body = self.body.replace('robey', 'bar') body = self.body.replace('robey', 'bar')
response = yield client.fetch(url, method="POST", body=body) response = yield client.fetch(url, method="POST", body=body)
data = json.loads(response.body.decode('utf-8')) data = json.loads(to_str(response.body))
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
...@@ -223,7 +224,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -223,7 +224,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id'] ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url) ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message() msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner) self.assertEqual(to_str(msg, data['encoding']), banner)
# messages below will be ignored silently # messages below will be ignored silently
yield ws.write_message('hello') yield ws.write_message('hello')
......
...@@ -56,7 +56,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, None) pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored') pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None) pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
self.assertIsNone(pkey) self.assertIsNone(pkey)
...@@ -64,7 +64,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestIndexHandler(unittest.TestCase):
def test_get_specific_pkey_with_encrypted_key(self): def test_get_specific_pkey_with_encrypted_key(self):
fname = 'test_rsa_password.key' fname = 'test_rsa_password.key'
cls = paramiko.RSAKey cls = paramiko.RSAKey
password = b'television' password = 'television'
key = read_file(os.path.join(base_dir, 'tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, password) pkey = IndexHandler.get_specific_pkey(cls, key, password)
...@@ -81,7 +81,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_pkey_obj(key, None) pkey = IndexHandler.get_pkey_obj(key, None)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_pkey_obj(key, u'iginored') pkey = IndexHandler.get_pkey_obj(key, 'iginored')
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, None) pkey = IndexHandler.get_pkey_obj('x'+key, None)
...@@ -94,6 +94,6 @@ class TestIndexHandler(unittest.TestCase): ...@@ -94,6 +94,6 @@ class TestIndexHandler(unittest.TestCase):
pkey = IndexHandler.get_pkey_obj(key, password) pkey = IndexHandler.get_pkey_obj(key, password)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj(key, u'wrongpass') pkey = IndexHandler.get_pkey_obj(key, 'wrongpass')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, password) pkey = IndexHandler.get_pkey_obj('x'+key, password)
import unittest import unittest
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address, from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
is_valid_port, to_str) is_valid_port, to_str, to_bytes)
class TestUitls(unittest.TestCase): class TestUitls(unittest.TestCase):
...@@ -12,6 +12,12 @@ class TestUitls(unittest.TestCase): ...@@ -12,6 +12,12 @@ class TestUitls(unittest.TestCase):
self.assertEqual(to_str(b), u) self.assertEqual(to_str(b), u)
self.assertEqual(to_str(u), u) self.assertEqual(to_str(u), u)
def test_to_bytes(self):
b = b'hello'
u = u'hello'
self.assertEqual(to_bytes(b), b)
self.assertEqual(to_bytes(u), b)
def test_is_valid_ipv4_address(self): def test_is_valid_ipv4_address(self):
self.assertFalse(is_valid_ipv4_address('127.0.0')) self.assertFalse(is_valid_ipv4_address('127.0.0'))
self.assertFalse(is_valid_ipv4_address(b'127.0.0')) self.assertFalse(is_valid_ipv4_address(b'127.0.0'))
......
...@@ -10,10 +10,9 @@ import paramiko ...@@ -10,10 +10,9 @@ import paramiko
import tornado.web import tornado.web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.util import basestring_type
from webssh.worker import Worker, recycle_worker, workers from webssh.worker import Worker, recycle_worker, workers
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address, from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
is_valid_port) is_valid_port, to_bytes, to_str, UnicodeType)
try: try:
from concurrent.futures import Future from concurrent.futures import Future
...@@ -70,7 +69,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -70,7 +69,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
data = self.request.files.get('privatekey')[0]['body'] data = self.request.files.get('privatekey')[0]['body']
except TypeError: except TypeError:
return return
return data.decode('utf-8') return to_str(data)
@classmethod @classmethod
def get_specific_pkey(cls, pkeycls, privatekey, password): def get_specific_pkey(cls, pkeycls, privatekey, password):
...@@ -87,7 +86,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -87,7 +86,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
@classmethod @classmethod
def get_pkey_obj(cls, privatekey, password): def get_pkey_obj(cls, privatekey, password):
password = password.encode('utf-8') if password else None password = to_bytes(password)
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, password)\ pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, password)\
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, password)\ or cls.get_specific_pkey(paramiko.DSSKey, privatekey, password)\
...@@ -138,8 +137,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -138,8 +137,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except paramiko.SSHException: except paramiko.SSHException:
result = None result = None
else: else:
data = stdout.read().decode('utf-8') data = stdout.read()
result = parse_encoding(data) result = parse_encoding(to_str(data))
return result if result else 'utf-8' return result if result else 'utf-8'
...@@ -247,7 +246,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): ...@@ -247,7 +246,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
pass pass
data = msg.get('data') data = msg.get('data')
if data and isinstance(data, basestring_type): if data and isinstance(data, UnicodeType):
worker.data_to_dst.append(data) worker.data_to_dst.append(data)
worker.on_write() worker.on_write()
......
import ipaddress import ipaddress
try:
from types import UnicodeType
except ImportError:
UnicodeType = str
def to_str(s):
if isinstance(s, bytes): def to_str(bstr, encoding='utf-8'):
return s.decode('utf-8') if isinstance(bstr, bytes):
return s return bstr.decode(encoding)
return bstr
def to_bytes(ustr, encoding='utf-8'):
if isinstance(ustr, UnicodeType):
return ustr.encode(encoding)
return ustr
def is_valid_ipv4_address(ipstr): def is_valid_ipv4_address(ipstr):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment