Commit c2c81aae by Sheng

Use method initialize to deny forbidden acesss

parent 7e5a1703
...@@ -569,19 +569,19 @@ class TestAppWithTrustedStream(OtherTestBase): ...@@ -569,19 +569,19 @@ class TestAppWithTrustedStream(OtherTestBase):
def test_with_forbidden_get_request(self): def test_with_forbidden_get_request(self):
response = self.fetch('/', method='GET') response = self.fetch('/', method='GET')
self.assertEqual(response.code, 403) self.assertEqual(response.code, 403)
self.assertIn(b'403: Forbidden', response.body) self.assertIn('Forbidden', response.error.message)
def test_with_forbidden_post_request(self): def test_with_forbidden_post_request(self):
response = self.fetch('/', method='POST', body=urlencode(self.body), response = self.fetch('/', method='POST', body=urlencode(self.body),
headers=self.headers) headers=self.headers)
self.assertEqual(response.code, 200) self.assertEqual(response.code, 403)
self.assertIn(b'"status": "Forbidden"', response.body) self.assertIn('Forbidden', response.error.message)
def test_with_forbidden_put_request(self): def test_with_forbidden_put_request(self):
response = self.fetch('/', method='PUT', body=urlencode(self.body), response = self.fetch('/', method='PUT', body=urlencode(self.body),
headers=self.headers) headers=self.headers)
self.assertEqual(response.code, 403) self.assertEqual(response.code, 403)
self.assertIn(b'403: Forbidden', response.body) self.assertIn('Forbidden', response.error.message)
class TestAppNotFoundHandler(OtherTestBase): class TestAppNotFoundHandler(OtherTestBase):
......
...@@ -43,9 +43,13 @@ class MixinHandler(object): ...@@ -43,9 +43,13 @@ class MixinHandler(object):
'Server': 'TornadoServer' 'Server': 'TornadoServer'
} }
def prepare(self): def initialize(self):
if self.is_forbidden(): if self.is_forbidden():
raise tornado.web.HTTPError(403) self.request.connection.stream.write(
b'%s 403 Forbidden\r\n\r\n' % to_bytes(self.request.version)
)
self.request.connection.close()
raise ValueError('Accesss denied')
def is_forbidden(self): def is_forbidden(self):
""" """
...@@ -105,10 +109,9 @@ class MixinHandler(object): ...@@ -105,10 +109,9 @@ class MixinHandler(object):
class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler): class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
def initialize(self): def initialize(self):
pass super(NotFoundHandler, self).initialize()
def prepare(self): def prepare(self):
super(NotFoundHandler, self).prepare()
raise tornado.web.HTTPError(404) raise tornado.web.HTTPError(404)
...@@ -122,6 +125,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ...@@ -122,6 +125,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.privatekey_filename = None self.privatekey_filename = None
self.debug = self.settings.get('debug', False) self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None) self.result = dict(id=None, status=None, encoding=None)
super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not swallow_http_errors: if self.request.method != 'POST' or not swallow_http_errors:
...@@ -322,6 +326,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): ...@@ -322,6 +326,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop): def initialize(self, loop):
self.loop = loop self.loop = loop
self.worker_ref = None self.worker_ref = None
super(WsockHandler, self).initialize()
def open(self): def open(self):
self.src_addr = self.get_client_addr() self.src_addr = self.get_client_addr()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment