Commit 221bd815 by Sheng

Refactored code

parent aa442b05
...@@ -197,7 +197,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -197,7 +197,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def ssh_connect(self): def ssh_connect(self):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.load_host_keys(self.settings['host_file']) if isinstance(self.settings['policy'], paramiko.client.AutoAddPolicy):
ssh.load_host_keys(self.settings['host_file'])
else:
ssh._host_keys = self.settings.get('host_keys')
ssh.set_missing_host_key_policy(self.settings['policy']) ssh.set_missing_host_key_policy(self.settings['policy'])
args = self.get_args() args = self.get_args()
dst_addr = (args[0], args[1]) dst_addr = (args[0], args[1])
...@@ -284,13 +287,15 @@ def recycle(worker): ...@@ -284,13 +287,15 @@ def recycle(worker):
def get_host_keys(path): def get_host_keys(path):
if os.path.exists(path) and os.path.isfile(path): if os.path.exists(path) and os.path.isfile(path):
return paramiko.hostkeys.HostKeys(filename=path) return paramiko.hostkeys.HostKeys(filename=path)
return paramiko.hostkeys.HostKeys()
def create_host_file(host_file): def create_host_file(host_file):
host_keys = get_host_keys(host_file) host_keys = get_host_keys(host_file)
if not host_keys: if not host_keys:
host_keys = get_host_keys(os.path.expanduser("~/.ssh/known_hosts")) host_keys = get_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
host_keys.save(host_file) host_keys.save(host_file)
return host_keys
def get_policy_class(policy): def get_policy_class(policy):
...@@ -311,7 +316,7 @@ def get_policy_class(policy): ...@@ -311,7 +316,7 @@ def get_policy_class(policy):
def main(): def main():
base_dir = os.path.dirname(__file__) base_dir = os.path.dirname(__file__)
host_file = os.path.join(base_dir, 'known_hosts') host_file = os.path.join(base_dir, 'known_hosts')
create_host_file(host_file) host_keys = create_host_file(host_file)
settings = { settings = {
'template_path': os.path.join(base_dir, 'templates'), 'template_path': os.path.join(base_dir, 'templates'),
...@@ -329,6 +334,7 @@ def main(): ...@@ -329,6 +334,7 @@ def main():
settings.update( settings.update(
debug=options.debug, debug=options.debug,
host_file=host_file, host_file=host_file,
host_keys=host_keys,
policy=get_policy_class(options.policy)() policy=get_policy_class(options.policy)()
) )
app = tornado.web.Application(handlers, **settings) app = tornado.web.Application(handlers, **settings)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment