Commit f4197f0e by Sheng

Use base_dir as the project root directory

parent d6de1340
...@@ -18,21 +18,23 @@ ...@@ -18,21 +18,23 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc., # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from binascii import hexlify import os.path
import random
import socket import socket
# import sys # import sys
import threading import threading
import random
# import traceback # import traceback
import paramiko import paramiko
from binascii import hexlify
from paramiko.py3compat import u, decodebytes from paramiko.py3compat import u, decodebytes
from webssh.settings import base_dir
# setup logging # setup logging
paramiko.util.log_to_file('tests/sshserver.log') paramiko.util.log_to_file(os.path.join(base_dir, 'tests', 'sshserver.log'))
host_key = paramiko.RSAKey(filename='tests/test_rsa.key') host_key = paramiko.RSAKey(filename=os.path.join(base_dir, 'tests', 'test_rsa.key')) # noqa
# host_key = paramiko.DSSKey(filename='test_dss.key') # host_key = paramiko.DSSKey(filename='test_dss.key')
print('Read key: ' + u(hexlify(host_key.get_fingerprint()))) print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
......
import json import json
import os.path
import random import random
import threading import threading
import tornado.websocket import tornado.websocket
...@@ -8,10 +9,10 @@ import webssh.handler as handler ...@@ -8,10 +9,10 @@ import webssh.handler as handler
from tornado.testing import AsyncHTTPTestCase from tornado.testing import AsyncHTTPTestCase
from tornado.httpclient import HTTPError from tornado.httpclient import HTTPError
from tornado.options import options from tornado.options import options
from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size
from tests.sshserver import run_ssh_server, banner from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata from tests.utils import encode_multipart_formdata, read_file
from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size, base_dir
handler.DELAY = 0.1 handler.DELAY = 0.1
...@@ -52,9 +53,6 @@ class TestApp(AsyncHTTPTestCase): ...@@ -52,9 +53,6 @@ class TestApp(AsyncHTTPTestCase):
cls.running.pop() cls.running.pop()
print('='*20) print('='*20)
def read_privatekey(self, filename):
return open(filename, 'rb').read().decode('utf-8')
def get_httpserver_options(self): def get_httpserver_options(self):
options = super(TestApp, self).get_httpserver_options() options = super(TestApp, self).get_httpserver_options()
options.update(max_body_size=max_body_size) options.update(max_body_size=max_body_size)
...@@ -126,7 +124,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -126,7 +124,7 @@ class TestApp(AsyncHTTPTestCase):
response = yield client.fetch(url) response = yield client.fetch(url)
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
privatekey = self.read_privatekey('tests/user_rsa_key') privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
files = [('privatekey', 'user_rsa_key', privatekey)] files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(), content_type, body = encode_multipart_formdata(self.body_dict.items(),
files) files)
...@@ -154,7 +152,7 @@ class TestApp(AsyncHTTPTestCase): ...@@ -154,7 +152,7 @@ class TestApp(AsyncHTTPTestCase):
response = yield client.fetch(url) response = yield client.fetch(url)
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
privatekey = self.read_privatekey('tests/user_rsa_key') privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
privatekey = privatekey[:100] + u'bad' + privatekey[100:] privatekey = privatekey[:100] + u'bad' + privatekey[100:]
files = [('privatekey', 'user_rsa_key', privatekey)] files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(), content_type, body = encode_multipart_formdata(self.body_dict.items(),
......
...@@ -3,7 +3,9 @@ import os.path ...@@ -3,7 +3,9 @@ import os.path
import paramiko import paramiko
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tests.utils import read_file
from webssh.handler import MixinHandler, IndexHandler, parse_encoding from webssh.handler import MixinHandler, IndexHandler, parse_encoding
from webssh.settings import base_dir
class TestHandler(unittest.TestCase): class TestHandler(unittest.TestCase):
...@@ -47,15 +49,11 @@ class TestMixinHandler(unittest.TestCase): ...@@ -47,15 +49,11 @@ class TestMixinHandler(unittest.TestCase):
class TestIndexHandler(unittest.TestCase): class TestIndexHandler(unittest.TestCase):
def read_privatekey(self, filename):
with open(filename, 'rb') as f:
return f.read().decode('utf-8')
def test_get_specific_pkey_with_plain_key(self): def test_get_specific_pkey_with_plain_key(self):
fname = 'test_rsa.key' fname = 'test_rsa.key'
cls = paramiko.RSAKey cls = paramiko.RSAKey
key = self.read_privatekey(os.path.join('tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, None) pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored') pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored')
...@@ -68,7 +66,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -68,7 +66,7 @@ class TestIndexHandler(unittest.TestCase):
cls = paramiko.RSAKey cls = paramiko.RSAKey
password = b'television' password = b'television'
key = self.read_privatekey(os.path.join('tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, password) pkey = IndexHandler.get_specific_pkey(cls, key, password)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None) pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
...@@ -80,7 +78,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -80,7 +78,7 @@ class TestIndexHandler(unittest.TestCase):
def test_get_pkey_obj_with_plain_key(self): def test_get_pkey_obj_with_plain_key(self):
fname = 'test_ed25519.key' fname = 'test_ed25519.key'
cls = paramiko.Ed25519Key cls = paramiko.Ed25519Key
key = self.read_privatekey(os.path.join('tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_pkey_obj(key, None) pkey = IndexHandler.get_pkey_obj(key, None)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_pkey_obj(key, u'iginored') pkey = IndexHandler.get_pkey_obj(key, u'iginored')
...@@ -92,7 +90,7 @@ class TestIndexHandler(unittest.TestCase): ...@@ -92,7 +90,7 @@ class TestIndexHandler(unittest.TestCase):
fname = 'test_ed25519_password.key' fname = 'test_ed25519_password.key'
password = 'abc123' password = 'abc123'
cls = paramiko.Ed25519Key cls = paramiko.Ed25519Key
key = self.read_privatekey(os.path.join('tests', fname)) key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_pkey_obj(key, password) pkey = IndexHandler.get_pkey_obj(key, password)
self.assertIsInstance(pkey, cls) self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -8,6 +8,7 @@ from webssh.policy import ( ...@@ -8,6 +8,7 @@ from webssh.policy import (
AutoAddPolicy, get_policy_dictionary, load_host_keys, AutoAddPolicy, get_policy_dictionary, load_host_keys,
get_policy_class, check_policy_setting get_policy_class, check_policy_setting
) )
from webssh.settings import base_dir
class TestPolicy(unittest.TestCase): class TestPolicy(unittest.TestCase):
...@@ -28,7 +29,7 @@ class TestPolicy(unittest.TestCase): ...@@ -28,7 +29,7 @@ class TestPolicy(unittest.TestCase):
host_keys = load_host_keys(path) host_keys = load_host_keys(path)
self.assertFalse(host_keys) self.assertFalse(host_keys)
path = 'tests/known_hosts_example' path = os.path.join(base_dir, 'tests', 'known_hosts_example')
host_keys = load_host_keys(path) host_keys = load_host_keys(path)
self.assertEqual(host_keys, paramiko.hostkeys.HostKeys(path)) self.assertEqual(host_keys, paramiko.hostkeys.HostKeys(path))
...@@ -44,7 +45,7 @@ class TestPolicy(unittest.TestCase): ...@@ -44,7 +45,7 @@ class TestPolicy(unittest.TestCase):
get_policy_class(key) get_policy_class(key)
def test_check_policy_setting(self): def test_check_policy_setting(self):
host_keys_filename = './tests/host_keys_test.db' host_keys_filename = os.path.join(base_dir, 'tests', 'host_keys_test.db') # noqa
host_keys_settings = dict( host_keys_settings = dict(
host_keys=paramiko.hostkeys.HostKeys(), host_keys=paramiko.hostkeys.HostKeys(),
system_host_keys=paramiko.hostkeys.HostKeys(), system_host_keys=paramiko.hostkeys.HostKeys(),
...@@ -63,8 +64,8 @@ class TestPolicy(unittest.TestCase): ...@@ -63,8 +64,8 @@ class TestPolicy(unittest.TestCase):
def test_is_missing_host_key(self): def test_is_missing_host_key(self):
client = paramiko.SSHClient() client = paramiko.SSHClient()
file1 = 'tests/known_hosts_example' file1 = os.path.join(base_dir, 'tests', 'known_hosts_example')
file2 = 'tests/known_hosts_example2' file2 = os.path.join(base_dir, 'tests', 'known_hosts_example2')
client.load_host_keys(file1) client.load_host_keys(file1)
client.load_system_host_keys(file2) client.load_system_host_keys(file2)
...@@ -85,7 +86,7 @@ class TestPolicy(unittest.TestCase): ...@@ -85,7 +86,7 @@ class TestPolicy(unittest.TestCase):
autoadd.is_missing_host_key(client, hostname, key) autoadd.is_missing_host_key(client, hostname, key)
) )
file3 = 'tests/known_hosts_example3' file3 = os.path.join(base_dir, 'tests', 'known_hosts_example3')
entry = paramiko.hostkeys.HostKeys(file3)._entries[0] entry = paramiko.hostkeys.HostKeys(file3)._entries[0]
hostname = entry.hostnames[0] hostname = entry.hostnames[0]
key = entry.key key = entry.key
...@@ -94,9 +95,9 @@ class TestPolicy(unittest.TestCase): ...@@ -94,9 +95,9 @@ class TestPolicy(unittest.TestCase):
def test_missing_host_key(self): def test_missing_host_key(self):
client = paramiko.SSHClient() client = paramiko.SSHClient()
file1 = 'tests/known_hosts_example' file1 = os.path.join(base_dir, 'tests', 'known_hosts_example')
file2 = 'tests/known_hosts_example2' file2 = os.path.join(base_dir, 'tests', 'known_hosts_example2')
filename = 'tests/known_hosts' filename = os.path.join(base_dir, 'tests', 'known_hosts')
copyfile(file1, filename) copyfile(file1, filename)
client.load_host_keys(filename) client.load_host_keys(filename)
n1 = len(client._host_keys) n1 = len(client._host_keys)
......
...@@ -3,10 +3,10 @@ import unittest ...@@ -3,10 +3,10 @@ import unittest
import paramiko import paramiko
import tornado.options as options import tornado.options as options
from webssh.policy import load_host_keys
from webssh.settings import ( from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version get_host_keys_settings, get_policy_setting, base_dir, print_version
) )
from webssh.policy import load_host_keys
from webssh._version import __version__ from webssh._version import __version__
...@@ -30,8 +30,8 @@ class TestSettings(unittest.TestCase): ...@@ -30,8 +30,8 @@ class TestSettings(unittest.TestCase):
load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
) )
options.hostFile = 'tests/known_hosts_example' options.hostFile = os.path.join(base_dir, 'tests', 'known_hosts_example') # noqa
options.sysHostFile = 'tests/known_hosts_example2' options.sysHostFile = os.path.join(base_dir, 'tests', 'known_hosts_example2') # noqa
dic2 = get_host_keys_settings(options) dic2 = get_host_keys_settings(options)
self.assertEqual(dic2['host_keys'], load_host_keys(options.hostFile)) self.assertEqual(dic2['host_keys'], load_host_keys(options.hostFile))
self.assertEqual(dic2['host_keys_filename'], options.hostFile) self.assertEqual(dic2['host_keys_filename'], options.hostFile)
......
...@@ -36,3 +36,7 @@ def encode_multipart_formdata(fields, files): ...@@ -36,3 +36,7 @@ def encode_multipart_formdata(fields, files):
def get_content_type(filename): def get_content_type(filename):
return mimetypes.guess_type(filename)[0] or 'application/octet-stream' return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
def read_file(path, encoding='utf-8'):
return open(path, 'rb').read().decode(encoding)
...@@ -28,14 +28,14 @@ define('version', type=bool, help='Show version information', ...@@ -28,14 +28,14 @@ define('version', type=bool, help='Show version information',
callback=print_version) callback=print_version)
base_dir = os.path.dirname(__file__) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024 max_body_size = 1 * 1024 * 1024
def get_app_settings(options): def get_app_settings(options):
settings = dict( settings = dict(
template_path=os.path.join(base_dir, 'templates'), template_path=os.path.join(base_dir, 'webssh', 'templates'),
static_path=os.path.join(base_dir, 'static'), static_path=os.path.join(base_dir, 'webssh', 'static'),
cookie_secret=uuid.uuid4().hex, cookie_secret=uuid.uuid4().hex,
websocket_ping_interval=options.wpIntvl, websocket_ping_interval=options.wpIntvl,
xsrf_cookies=(not options.debug), xsrf_cookies=(not options.debug),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment