Commit e85ae169 by Sheng

Added to_bytes function to utils

parent cb866825
......@@ -13,6 +13,7 @@ from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata, read_file
from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size, base_dir
from webssh.utils import to_str
handler.DELAY = 0.1
......@@ -22,7 +23,7 @@ class TestApp(AsyncHTTPTestCase):
running = [True]
sshserver_port = 2200
body = u'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
body = 'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
body_dict = {
'hostname': '127.0.0.1',
'port': str(sshserver_port),
......@@ -61,37 +62,37 @@ class TestApp(AsyncHTTPTestCase):
def test_app_with_invalid_form(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
body = u'hostname=&port=&username=&password'
body = 'hostname=&port=&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty hostname"', response.body)
body = u'hostname=127.0.0.1&port=&username=&password'
body = 'hostname=127.0.0.1&port=&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty port"', response.body)
body = u'hostname=127.0.0.1&port=port&username=&password'
body = 'hostname=127.0.0.1&port=port&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body)
body = u'hostname=127.0.0.1&port=70000&username=&password'
body = 'hostname=127.0.0.1&port=70000&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body)
body = u'hostname=127.0.0.1&port=7000&username=&password'
body = 'hostname=127.0.0.1&port=7000&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty username"', response.body)
def test_app_with_wrong_credentials(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
response = self.fetch('/', method="POST", body=self.body + u's')
response = self.fetch('/', method="POST", body=self.body + 's')
self.assertIn(b'Authentication failed.', response.body)
def test_app_with_correct_credentials(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
response = self.fetch('/', method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
......@@ -104,7 +105,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
......@@ -133,7 +134,7 @@ class TestApp(AsyncHTTPTestCase):
}
response = yield client.fetch(url, method="POST", headers=headers,
body=body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
......@@ -142,7 +143,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner)
self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close()
@tornado.testing.gen_test
......@@ -153,7 +154,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
privatekey = privatekey[:100] + u'bad' + privatekey[100:]
privatekey = privatekey[:100] + 'bad' + privatekey[100:]
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
......@@ -162,7 +163,7 @@ class TestApp(AsyncHTTPTestCase):
}
response = yield client.fetch(url, method="POST", headers=headers,
body=body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNotNone(data['status'])
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
......@@ -174,7 +175,7 @@ class TestApp(AsyncHTTPTestCase):
response = yield client.fetch(url)
self.assertEqual(response.code, 200)
privatekey = u'h' * (2 * max_body_size)
privatekey = 'h' * (2 * max_body_size)
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
......@@ -193,7 +194,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
......@@ -202,7 +203,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner)
self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close()
@tornado.testing.gen_test
......@@ -214,7 +215,7 @@ class TestApp(AsyncHTTPTestCase):
body = self.body.replace('robey', 'bar')
response = yield client.fetch(url, method="POST", body=body)
data = json.loads(response.body.decode('utf-8'))
data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
......@@ -223,7 +224,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
self.assertEqual(msg.decode(data['encoding']), banner)
self.assertEqual(to_str(msg, data['encoding']), banner)
# messages below will be ignored silently
yield ws.write_message('hello')
......
......@@ -56,7 +56,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored')
pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
self.assertIsNone(pkey)
......@@ -64,7 +64,7 @@ class TestIndexHandler(unittest.TestCase):
def test_get_specific_pkey_with_encrypted_key(self):
fname = 'test_rsa_password.key'
cls = paramiko.RSAKey
password = b'television'
password = 'television'
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, password)
......@@ -81,7 +81,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_pkey_obj(key, None)
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_pkey_obj(key, u'iginored')
pkey = IndexHandler.get_pkey_obj(key, 'iginored')
self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, None)
......@@ -94,6 +94,6 @@ class TestIndexHandler(unittest.TestCase):
pkey = IndexHandler.get_pkey_obj(key, password)
self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj(key, u'wrongpass')
pkey = IndexHandler.get_pkey_obj(key, 'wrongpass')
with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, password)
import unittest
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
is_valid_port, to_str)
is_valid_port, to_str, to_bytes)
class TestUitls(unittest.TestCase):
......@@ -12,6 +12,12 @@ class TestUitls(unittest.TestCase):
self.assertEqual(to_str(b), u)
self.assertEqual(to_str(u), u)
def test_to_bytes(self):
b = b'hello'
u = u'hello'
self.assertEqual(to_bytes(b), b)
self.assertEqual(to_bytes(u), b)
def test_is_valid_ipv4_address(self):
self.assertFalse(is_valid_ipv4_address('127.0.0'))
self.assertFalse(is_valid_ipv4_address(b'127.0.0'))
......
......@@ -10,10 +10,9 @@ import paramiko
import tornado.web
from tornado.ioloop import IOLoop
from tornado.util import basestring_type
from webssh.worker import Worker, recycle_worker, workers
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
is_valid_port)
is_valid_port, to_bytes, to_str, UnicodeType)
try:
from concurrent.futures import Future
......@@ -70,7 +69,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
data = self.request.files.get('privatekey')[0]['body']
except TypeError:
return
return data.decode('utf-8')
return to_str(data)
@classmethod
def get_specific_pkey(cls, pkeycls, privatekey, password):
......@@ -87,7 +86,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
@classmethod
def get_pkey_obj(cls, privatekey, password):
password = password.encode('utf-8') if password else None
password = to_bytes(password)
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, password)\
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, password)\
......@@ -138,8 +137,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except paramiko.SSHException:
result = None
else:
data = stdout.read().decode('utf-8')
result = parse_encoding(data)
data = stdout.read()
result = parse_encoding(to_str(data))
return result if result else 'utf-8'
......@@ -247,7 +246,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
pass
data = msg.get('data')
if data and isinstance(data, basestring_type):
if data and isinstance(data, UnicodeType):
worker.data_to_dst.append(data)
worker.on_write()
......
import ipaddress
try:
from types import UnicodeType
except ImportError:
UnicodeType = str
def to_str(s):
if isinstance(s, bytes):
return s.decode('utf-8')
return s
def to_str(bstr, encoding='utf-8'):
if isinstance(bstr, bytes):
return bstr.decode(encoding)
return bstr
def to_bytes(ustr, encoding='utf-8'):
if isinstance(ustr, UnicodeType):
return ustr.encode(encoding)
return ustr
def is_valid_ipv4_address(ipstr):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment