Commit 6c9af890 by Sheng

Use method initialize to pass settings

parent fdf0f88e
...@@ -21,10 +21,6 @@ DELAY = 3 ...@@ -21,10 +21,6 @@ DELAY = 3
class MixinHandler(object): class MixinHandler(object):
def __init__(self, *args, **kwargs):
self.loop = args[0]._loop
super(MixinHandler, self).__init__(*args, **kwargs)
def get_client_addr(self): def get_client_addr(self):
ip = self.request.headers.get('X-Real-Ip') ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port') port = self.request.headers.get('X-Real-Port')
...@@ -40,6 +36,11 @@ class MixinHandler(object): ...@@ -40,6 +36,11 @@ class MixinHandler(object):
class IndexHandler(MixinHandler, tornado.web.RequestHandler): class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def initialize(self, loop, policy, host_keys_settings):
self.loop = loop
self.policy = policy
self.host_keys_settings = host_keys_settings
def get_privatekey(self): def get_privatekey(self):
try: try:
data = self.request.files.get('privatekey')[0]['body'] data = self.request.files.get('privatekey')[0]['body']
...@@ -107,10 +108,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -107,10 +108,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def ssh_connect(self): def ssh_connect(self):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh._system_host_keys = self.settings['system_host_keys'] ssh._system_host_keys = self.host_keys_settings['system_host_keys']
ssh._host_keys = self.settings['host_keys'] ssh._host_keys = self.host_keys_settings['host_keys']
ssh._host_keys_filename = self.settings['host_keys_filename'] ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
ssh.set_missing_host_key_policy(self.settings['policy']) ssh.set_missing_host_key_policy(self.policy)
args = self.get_args() args = self.get_args()
dst_addr = (args[0], args[1]) dst_addr = (args[0], args[1])
...@@ -167,9 +168,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -167,9 +168,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def __init__(self, *args, **kwargs): def initialize(self, loop):
self.loop = loop
self.worker_ref = None self.worker_ref = None
super(WsockHandler, self).__init__(*args, **kwargs)
def get_client_addr(self): def get_client_addr(self):
return super(WsockHandler, self).get_client_addr() or self.stream.\ return super(WsockHandler, self).get_client_addr() or self.stream.\
......
...@@ -2,23 +2,26 @@ import logging ...@@ -2,23 +2,26 @@ import logging
import tornado.web import tornado.web
import tornado.ioloop import tornado.ioloop
from tornado.options import parse_command_line, options from tornado.options import define, parse_command_line, options
from handler import IndexHandler, WsockHandler from handler import IndexHandler, WsockHandler
from settings import get_application_settings from settings import (get_app_settings, get_host_keys_settings,
get_policy_setting)
def main(): def main():
parse_command_line() parse_command_line()
settings = get_application_settings() app_settings = get_app_settings(options)
host_keys_settings = get_host_keys_settings(options)
policy = get_policy_setting(options, host_keys_settings)
loop = tornado.ioloop.IOLoop.current()
handlers = [ handlers = [
(r'/', IndexHandler), (r'/', IndexHandler, dict(loop=loop, policy=policy,
(r'/ws', WsockHandler) host_keys_settings=host_keys_settings)),
(r'/ws', WsockHandler, dict(loop=loop))
] ]
loop = tornado.ioloop.IOLoop.current() app = tornado.web.Application(handlers, **app_settings)
app = tornado.web.Application(handlers, **settings)
app._loop = loop
app.listen(options.port, options.address) app.listen(options.port, options.address)
logging.info('Listening on {}:{}'.format(options.address, options.port)) logging.info('Listening on {}:{}'.format(options.address, options.port))
loop.start() loop.start()
......
...@@ -35,6 +35,20 @@ def get_policy_class(policy): ...@@ -35,6 +35,20 @@ def get_policy_class(policy):
return cls return cls
def check_policy_setting(policy_class, host_keys_settings):
host_keys = host_keys_settings['host_keys']
host_keys_filename = host_keys_settings['host_keys_filename']
system_host_keys = host_keys_settings['system_host_keys']
if policy_class is paramiko.client.AutoAddPolicy:
host_keys.save(host_keys_filename) # for permission test
elif policy_class is paramiko.client.RejectPolicy:
if not host_keys and not system_host_keys:
raise ValueError(
'Reject policy could not be used without host keys.'
)
class AutoAddPolicy(paramiko.client.MissingHostKeyPolicy): class AutoAddPolicy(paramiko.client.MissingHostKeyPolicy):
""" """
thread-safe AutoAddPolicy thread-safe AutoAddPolicy
......
import logging import logging
import os.path import os.path
import uuid import uuid
import paramiko
from tornado.options import define, options from tornado.options import define
from policy import get_host_keys, get_policy_class from policy import get_host_keys, get_policy_class, check_policy_setting
define('address', default='127.0.0.1', help='listen address') define('address', default='127.0.0.1', help='listen address')
...@@ -12,32 +11,47 @@ define('port', default=8888, help='listen port', type=int) ...@@ -12,32 +11,47 @@ define('port', default=8888, help='listen port', type=int)
define('debug', default=False, help='debug mode', type=bool) define('debug', default=False, help='debug mode', type=bool)
define('policy', default='warning', define('policy', default='warning',
help='missing host key policy, reject|autoadd|warning') help='missing host key policy, reject|autoadd|warning')
define('hostFile', default='', help='User-defined host keys file')
define('sysHostFile', default='', help='System-wide host keys File')
def get_application_settings(): base_dir = os.path.dirname(__file__)
base_dir = os.path.dirname(__file__)
filename = os.path.join(base_dir, 'known_hosts')
host_keys = get_host_keys(filename)
system_host_keys = get_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
policy_class = get_policy_class(options.policy)
logging.info(policy_class.__name__)
if policy_class is paramiko.client.AutoAddPolicy:
host_keys.save(filename) # for permission test
elif policy_class is paramiko.client.RejectPolicy:
if not host_keys and not system_host_keys:
raise ValueError('Empty known_hosts with reject policy?')
def get_app_settings(options):
settings = dict( settings = dict(
template_path=os.path.join(base_dir, 'templates'), template_path=os.path.join(base_dir, 'templates'),
static_path=os.path.join(base_dir, 'static'), static_path=os.path.join(base_dir, 'static'),
cookie_secret=uuid.uuid4().hex, cookie_secret=uuid.uuid4().hex,
xsrf_cookies=True, xsrf_cookies=True,
host_keys=host_keys,
host_keys_filename=filename,
system_host_keys=system_host_keys,
policy=policy_class(),
debug=options.debug debug=options.debug
) )
return settings
def get_host_keys_settings(options):
if not options.hostFile:
host_keys_filename = os.path.join(base_dir, 'known_hosts')
else:
host_keys_filename = options.hostFile
host_keys = get_host_keys(host_keys_filename)
if not options.sysHostFile:
filename = os.path.expanduser('~/.ssh/known_hosts')
else:
filename = options.sysHostFile
system_host_keys = get_host_keys(filename)
settings = dict(
host_keys=host_keys,
system_host_keys=system_host_keys,
host_keys_filename=host_keys_filename
)
return settings return settings
def get_policy_setting(options, host_keys_settings):
policy_class = get_policy_class(options.policy)
logging.info(policy_class.__name__)
check_policy_setting(policy_class, host_keys_settings)
return policy_class()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment