Commit 37299468 by Sheng

Added is_valid_hostname to utils

parent f6100207
...@@ -70,6 +70,14 @@ class TestApp(AsyncHTTPTestCase): ...@@ -70,6 +70,14 @@ class TestApp(AsyncHTTPTestCase):
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertIn(b'"status": "The port field is required"', response.body) self.assertIn(b'"status": "The port field is required"', response.body)
body = 'hostname=127.0.0&port=22&username=&password'
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'"status": "Invalid hostname', response.body)
body = 'hostname=http://www.googe.com&port=22&username=&password'
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'"status": "Invalid hostname', response.body)
body = 'hostname=127.0.0.1&port=port&username=&password' body = 'hostname=127.0.0.1&port=port&username=&password'
response = self.fetch('/', method='POST', body=body) response = self.fetch('/', method='POST', body=body)
self.assertIn(b'"status": "Invalid port', response.body) self.assertIn(b'"status": "Invalid port', response.body)
......
import unittest import unittest
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address, from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
is_valid_port, to_str, to_bytes) is_valid_port, is_valid_hostname, to_str, to_bytes)
class TestUitls(unittest.TestCase): class TestUitls(unittest.TestCase):
...@@ -34,3 +34,14 @@ class TestUitls(unittest.TestCase): ...@@ -34,3 +34,14 @@ class TestUitls(unittest.TestCase):
self.assertTrue(is_valid_port(80)) self.assertTrue(is_valid_port(80))
self.assertFalse(is_valid_port(0)) self.assertFalse(is_valid_port(0))
self.assertFalse(is_valid_port(65536)) self.assertFalse(is_valid_port(65536))
def test_is_valid_hostname(self):
self.assertTrue(is_valid_hostname('google.com'))
self.assertTrue(is_valid_hostname('google.com.'))
self.assertTrue(is_valid_hostname('www.google.com'))
self.assertTrue(is_valid_hostname('www.google.com.'))
self.assertFalse(is_valid_hostname('.www.google.com'))
self.assertFalse(is_valid_hostname('http://www.google.com'))
self.assertFalse(is_valid_hostname('https://www.google.com'))
self.assertFalse(is_valid_hostname('127.0.0.1'))
self.assertFalse(is_valid_hostname('::1'))
...@@ -11,8 +11,10 @@ import tornado.web ...@@ -11,8 +11,10 @@ import tornado.web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from webssh.worker import Worker, recycle_worker, workers from webssh.worker import Worker, recycle_worker, workers
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address, from webssh.utils import (
is_valid_port, to_bytes, to_str, UnicodeType) is_valid_ipv4_address, is_valid_ipv6_address, is_valid_port,
is_valid_hostname, to_bytes, to_str, UnicodeType
)
try: try:
from concurrent.futures import Future from concurrent.futures import Future
...@@ -98,6 +100,13 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -98,6 +100,13 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
'wrong password for decrypting the private key.') 'wrong password for decrypting the private key.')
return pkey return pkey
def get_hostname(self):
value = self.get_value('hostname')
if not (is_valid_hostname(value) | is_valid_ipv4_address(value) |
is_valid_ipv6_address(value)):
raise ValueError('Invalid hostname {}'.format(value))
return value
def get_port(self): def get_port(self):
value = self.get_value('port') value = self.get_value('port')
try: try:
...@@ -117,7 +126,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -117,7 +126,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
return value return value
def get_args(self): def get_args(self):
hostname = self.get_value('hostname') hostname = self.get_hostname()
port = self.get_port() port = self.get_port()
username = self.get_value('username') username = self.get_value('username')
password = self.get_argument('password') password = self.get_argument('password')
......
import ipaddress import ipaddress
import re
try: try:
from types import UnicodeType from types import UnicodeType
...@@ -38,3 +40,20 @@ def is_valid_ipv6_address(ipstr): ...@@ -38,3 +40,20 @@ def is_valid_ipv6_address(ipstr):
def is_valid_port(port): def is_valid_port(port):
return 0 < port < 65536 return 0 < port < 65536
def is_valid_hostname(hostname):
if hostname[-1] == ".":
# strip exactly one dot from the right, if present
hostname = hostname[:-1]
if len(hostname) > 253:
return False
labels = hostname.split(".")
# the TLD must be not all-numeric
if re.match(r"[0-9]+$", labels[-1]):
return False
allowed = re.compile(r"(?!-)[a-z0-9-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(label) for label in labels)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment