1
0
mirror of https://github.com/moparisthebest/wget synced 2024-07-03 16:38:41 -04:00

PEP8'ify the Python Test Suite

* testenv/conf/{__init__,authentication,files_crawled,
      hook_sample,reject_header,server_files}.py: Aesthetic changes to
      meet Python PEP8 guidelines
    * testenv/exc/{server_error,test_failed}.py: Same
    * testenv/misc/{colour_terminal,wget_file}.py: Same
    * testenv/server/http/http_server.py: Same
    * testenv/test/base_test.py: Same
This commit is contained in:
Darshit Shah 2015-04-14 10:36:20 +05:30
parent c6af2fddee
commit 8e0dd0d870
12 changed files with 245 additions and 221 deletions

View File

@ -3,6 +3,7 @@ import os
# this file implements the mechanism of conf class auto-registration, # this file implements the mechanism of conf class auto-registration,
# don't modify this file if you have no idea what you're doing # don't modify this file if you have no idea what you're doing
def gen_hook(): def gen_hook():
hook_table = {} hook_table = {}

View File

@ -16,7 +16,7 @@ that.
@rule() @rule()
class Authentication: class Authentication:
def __init__ (self, auth_obj): def __init__(self, auth_obj):
self.auth_type = auth_obj['Type'] self.auth_type = auth_obj['Type']
self.auth_user = auth_obj['User'] self.auth_user = auth_obj['User']
self.auth_pass = auth_obj['Pass'] self.auth_pass = auth_obj['Pass']

View File

@ -23,5 +23,5 @@ class FilesCrawled:
diff = headers.symmetric_difference(remaining) diff = headers.symmetric_difference(remaining)
if diff: if diff:
print_red (str(diff)) print_red(str(diff))
raise TestFailed('Not all files were crawled correctly.') raise TestFailed('Not all files were crawled correctly.')

View File

@ -18,5 +18,5 @@ class SampleHook:
# implement hook here # implement hook here
# if you need the test case instance, refer to test_obj # if you need the test case instance, refer to test_obj
if False: if False:
raise TestFailed ("Reason") raise TestFailed("Reason")
pass pass

View File

@ -9,5 +9,5 @@ requests.
@rule() @rule()
class RejectHeader: class RejectHeader:
def __init__ (self, header_obj): def __init__(self, header_obj):
self.headers = header_obj self.headers = header_obj

View File

@ -20,7 +20,7 @@ class ServerFiles:
def __call__(self, test_obj): def __call__(self, test_obj):
for server, files in zip(test_obj.servers, self.server_files): for server, files in zip(test_obj.servers, self.server_files):
files_content = {f.name: test_obj._replace_substring(f.content) files_content = {f.name: test_obj._replace_substring(f.content)
for f in files} for f in files}
files_rules = {f.name: test_obj.get_server_rules(f) files_rules = {f.name: test_obj.get_server_rules(f)
for f in files} for f in files}
server.server_conf(files_content, files_rules) server.server_conf(files_content, files_rules)

View File

@ -3,8 +3,11 @@ class ServerError (Exception):
""" A custom exception which is raised by the test servers. Often used to """ A custom exception which is raised by the test servers. Often used to
handle control flow. """ handle control flow. """
def __init__ (self, err_message): def __init__(self, err_message):
self.err_message = err_message self.err_message = err_message
class AuthError (ServerError): class AuthError (ServerError):
""" A custom exception raised byt he servers when authentication of the
request fails. """
pass pass

View File

@ -3,5 +3,5 @@ class TestFailed(Exception):
""" A Custom Exception raised by the Test Environment. """ """ A Custom Exception raised by the Test Environment. """
def __init__ (self, error): def __init__(self, error):
self.error = error self.error = error

View File

@ -18,22 +18,23 @@ codes on;y add clutter. """
T_COLORS = { T_COLORS = {
'PURPLE' : '\033[95m', 'PURPLE': '\033[95m',
'BLUE' : '\033[94m', 'BLUE': '\033[94m',
'GREEN' : '\033[92m', 'GREEN': '\033[92m',
'YELLOW' : '\033[93m', 'YELLOW': '\033[93m',
'RED' : '\033[91m', 'RED': '\033[91m',
'ENDC' : '\033[0m' 'ENDC': '\033[0m'
} }
system = True if platform.system() in ( 'Linux', 'Darwin' ) else False system = True if platform.system() in ('Linux', 'Darwin') else False
check = False if getenv("MAKE_CHECK") == 'True' else True check = False if getenv("MAKE_CHECK") == 'True' else True
def printer (color, string):
if sys.stdout.isatty() and system and check: def printer(color, string):
print (T_COLORS.get (color) + string + T_COLORS.get ('ENDC')) if sys.stdout.isatty() and system and check:
print(T_COLORS.get(color) + string + T_COLORS.get('ENDC'))
else: else:
print (string) print(string)
print_blue = partial(printer, 'BLUE') print_blue = partial(printer, 'BLUE')

View File

@ -3,7 +3,7 @@ class WgetFile:
""" WgetFile is a File Data Container object """ """ WgetFile is a File Data Container object """
def __init__ ( def __init__(
self, self,
name, name,
content="Test Contents", content="Test Contents",

View File

@ -10,50 +10,55 @@ import socket
import os import os
class StoppableHTTPServer (HTTPServer): class StoppableHTTPServer(HTTPServer):
""" This class extends the HTTPServer class from default http.server library """ This class extends the HTTPServer class from default http.server library
in Python 3. The StoppableHTTPServer class is capable of starting an HTTP in Python 3. The StoppableHTTPServer class is capable of starting an HTTP
server that serves a virtual set of files made by the WgetFile class and server that serves a virtual set of files made by the WgetFile class and
has most of its properties configurable through the server_conf() has most of its properties configurable through the server_conf()
method. """ method. """
request_headers = list () request_headers = list()
""" Define methods for configuring the Server. """ """ Define methods for configuring the Server. """
def server_conf (self, filelist, conf_dict): def server_conf(self, filelist, conf_dict):
""" Set Server Rules and File System for this instance. """ """ Set Server Rules and File System for this instance. """
self.server_configs = conf_dict self.server_configs = conf_dict
self.fileSys = filelist self.fileSys = filelist
def get_req_headers (self): def get_req_headers(self):
return self.request_headers return self.request_headers
class HTTPSServer (StoppableHTTPServer): class HTTPSServer(StoppableHTTPServer):
""" The HTTPSServer class extends the StoppableHTTPServer class with """ The HTTPSServer class extends the StoppableHTTPServer class with
additional support for secure connections through SSL. """ additional support for secure connections through SSL. """
def __init__ (self, address, handler): def __init__(self, address, handler):
import ssl import ssl
BaseServer.__init__ (self, address, handler) BaseServer.__init__(self, address, handler)
# step one up because test suite change directory away from $srcdir (don't do that !!!) # step one up because test suite change directory away from $srcdir
CERTFILE = os.path.abspath(os.path.join('..', os.getenv('srcdir', '.'), 'certs', 'server-cert.pem')) # (don't do that !!!)
KEYFILE = os.path.abspath(os.path.join('..', os.getenv('srcdir', '.'), 'certs', 'server-key.pem')) CERTFILE = os.path.abspath(os.path.join('..',
fop = open (CERTFILE) os.getenv('srcdir', '.'),
print (fop.readline()) 'certs',
self.socket = ssl.wrap_socket ( 'server-cert.pem'))
sock = socket.socket (self.address_family, self.socket_type), KEYFILE = os.path.abspath(os.path.join('..',
ssl_version = ssl.PROTOCOL_TLSv1, os.getenv('srcdir', '.'),
certfile = CERTFILE, 'certs',
keyfile = KEYFILE, 'server-key.pem'))
server_side = True self.socket = ssl.wrap_socket(
sock=socket.socket(self.address_family, self.socket_type),
ssl_version=ssl.PROTOCOL_TLSv1,
certfile=CERTFILE,
keyfile=KEYFILE,
server_side=True
) )
self.server_bind() self.server_bind()
self.server_activate() self.server_activate()
class _Handler (BaseHTTPRequestHandler): class _Handler(BaseHTTPRequestHandler):
""" This is a private class which tells the server *HOW* to handle each """ This is a private class which tells the server *HOW* to handle each
request. For each HTTP Request Command that the server should be capable of request. For each HTTP Request Command that the server should be capable of
responding to, there must exist a do_REQUESTNAME() method which details the responding to, there must exist a do_REQUESTNAME() method which details the
@ -61,7 +66,7 @@ class _Handler (BaseHTTPRequestHandler):
in this class are auxilliary methods created to help in processing certain in this class are auxilliary methods created to help in processing certain
requests. """ requests. """
def get_rule_list (self, name): def get_rule_list(self, name):
return self.rules.get(name) return self.rules.get(name)
# The defailt protocol version of the server we run is HTTP/1.1 not # The defailt protocol version of the server we run is HTTP/1.1 not
@ -70,23 +75,23 @@ class _Handler (BaseHTTPRequestHandler):
""" Define functions for various HTTP Requests. """ """ Define functions for various HTTP Requests. """
def do_HEAD (self): def do_HEAD(self):
self.send_head ("HEAD") self.send_head("HEAD")
def do_GET (self): def do_GET(self):
""" Process HTTP GET requests. This is the same as processing HEAD """ Process HTTP GET requests. This is the same as processing HEAD
requests and then actually transmitting the data to the client. If requests and then actually transmitting the data to the client. If
send_head() does not specify any "start" offset, we send the complete send_head() does not specify any "start" offset, we send the complete
data, else transmit only partial data. """ data, else transmit only partial data. """
content, start = self.send_head ("GET") content, start = self.send_head("GET")
if content: if content:
if start is None: if start is None:
self.wfile.write (content.encode ('utf-8')) self.wfile.write(content.encode('utf-8'))
else: else:
self.wfile.write (content.encode ('utf-8')[start:]) self.wfile.write(content.encode('utf-8')[start:])
def do_POST (self): def do_POST(self):
""" According to RFC 7231 sec 4.3.3, if the resource requested in a POST """ According to RFC 7231 sec 4.3.3, if the resource requested in a POST
request does not exist on the server, the first POST request should request does not exist on the server, the first POST request should
create that resource. PUT requests are otherwise used to create a create that resource. PUT requests are otherwise used to create a
@ -100,70 +105,70 @@ class _Handler (BaseHTTPRequestHandler):
path = self.path[1:] path = self.path[1:]
if path in self.server.fileSys: if path in self.server.fileSys:
self.rules = self.server.server_configs.get (path) self.rules = self.server.server_configs.get(path)
if not self.rules: if not self.rules:
self.rules = dict () self.rules = dict()
if not self.custom_response (): if not self.custom_response():
return (None, None) return(None, None)
body_data = self.get_body_data () body_data = self.get_body_data()
self.send_response (200) self.send_response(200)
self.add_header ("Content-type", "text/plain") self.add_header("Content-type", "text/plain")
content = self.server.fileSys.pop (path) + "\n" + body_data content = self.server.fileSys.pop(path) + "\n" + body_data
total_length = len (content) total_length = len(content)
self.server.fileSys[path] = content self.server.fileSys[path] = content
self.add_header ("Content-Length", total_length) self.add_header("Content-Length", total_length)
self.add_header ("Location", self.path) self.add_header("Location", self.path)
self.finish_headers () self.finish_headers()
try: try:
self.wfile.write (content.encode ('utf-8')) self.wfile.write(content.encode('utf-8'))
except Exception: except Exception:
pass pass
else: else:
self.send_put (path) self.send_put(path)
def do_PUT (self): def do_PUT(self):
path = self.path[1:] path = self.path[1:]
self.rules = self.server.server_configs.get (path) self.rules = self.server.server_configs.get(path)
if not self.custom_response (): if not self.custom_response():
return (None, None) return(None, None)
self.send_put (path) self.send_put(path)
""" End of HTTP Request Method Handlers. """ """ End of HTTP Request Method Handlers. """
""" Helper functions for the Handlers. """ """ Helper functions for the Handlers. """
def parse_range_header (self, header_line, length): def parse_range_header(self, header_line, length):
import re import re
if header_line is None: if header_line is None:
return None return None
if not header_line.startswith ("bytes="): if not header_line.startswith("bytes="):
raise ServerError ("Cannot parse header Range: %s" % raise ServerError("Cannot parse header Range: %s" %
(header_line)) (header_line))
regex = re.match (r"^bytes=(\d*)\-$", header_line) regex = re.match(r"^bytes=(\d*)\-$", header_line)
range_start = int (regex.group (1)) range_start = int(regex.group(1))
if range_start >= length: if range_start >= length:
raise ServerError ("Range Overflow") raise ServerError("Range Overflow")
return range_start return range_start
def get_body_data (self): def get_body_data(self):
cLength_header = self.headers.get ("Content-Length") cLength_header = self.headers.get("Content-Length")
cLength = int (cLength_header) if cLength_header is not None else 0 cLength = int(cLength_header) if cLength_header is not None else 0
body_data = self.rfile.read (cLength).decode ('utf-8') body_data = self.rfile.read(cLength).decode('utf-8')
return body_data return body_data
def send_put (self, path): def send_put(self, path):
if path in self.server.fileSys: if path in self.server.fileSys:
self.server.fileSys.pop (path, None) self.server.fileSys.pop(path, None)
self.send_response (204) self.send_response(204)
else: else:
self.rules = dict () self.rules = dict()
self.send_response (201) self.send_response(201)
body_data = self.get_body_data () body_data = self.get_body_data()
self.server.fileSys[path] = body_data self.server.fileSys[path] = body_data
self.add_header ("Location", self.path) self.add_header("Location", self.path)
self.finish_headers () self.finish_headers()
""" This empty method is called automatically when all the rules are """ This empty method is called automatically when all the rules are
processed for a given request. However, send_header() should only be called processed for a given request. However, send_header() should only be called
@ -173,17 +178,17 @@ class _Handler (BaseHTTPRequestHandler):
finish_headers() instead of end_headers(). The finish_headers() method finish_headers() instead of end_headers(). The finish_headers() method
takes care of sending the appropriate headers before completing the takes care of sending the appropriate headers before completing the
response. """ response. """
def SendHeader (self, header_obj): def SendHeader(self, header_obj):
pass pass
def send_cust_headers (self): def send_cust_headers(self):
header_obj = self.get_rule_list ('SendHeader') header_obj = self.get_rule_list('SendHeader')
if header_obj: if header_obj:
for header in header_obj.headers: for header in header_obj.headers:
self.add_header (header, header_obj.headers[header]) self.add_header(header, header_obj.headers[header])
def finish_headers (self): def finish_headers(self):
self.send_cust_headers () self.send_cust_headers()
try: try:
for keyword, value in self._headers_dict.items(): for keyword, value in self._headers_dict.items():
self.send_header(keyword, value) self.send_header(keyword, value)
@ -191,46 +196,46 @@ class _Handler (BaseHTTPRequestHandler):
self._headers_dict.clear() self._headers_dict.clear()
except AttributeError: except AttributeError:
pass pass
self.end_headers () self.end_headers()
def Response (self, resp_obj): def Response(self, resp_obj):
self.send_response (resp_obj.response_code) self.send_response(resp_obj.response_code)
self.finish_headers () self.finish_headers()
raise ServerError ("Custom Response code sent.") raise ServerError("Custom Response code sent.")
def custom_response (self): def custom_response(self):
codes = self.get_rule_list ('Response') codes = self.get_rule_list('Response')
if codes: if codes:
self.send_response (codes.response_code) self.send_response(codes.response_code)
self.finish_headers () self.finish_headers()
return False return False
else: else:
return True return True
def add_header (self, keyword, value): def add_header(self, keyword, value):
if not hasattr (self, "_headers_dict"): if not hasattr(self, "_headers_dict"):
self._headers_dict = dict() self._headers_dict = dict()
self._headers_dict[keyword.lower()] = value self._headers_dict[keyword.lower()] = value
def base64 (self, data): def base64(self, data):
string = b64encode (data.encode ('utf-8')) string = b64encode(data.encode('utf-8'))
return string.decode ('utf-8') return string.decode('utf-8')
""" Send an authentication challenge. """ Send an authentication challenge.
This method calls self.send_header() directly instead of using the This method calls self.send_header() directly instead of using the
add_header() method because sending multiple WWW-Authenticate headers add_header() method because sending multiple WWW-Authenticate headers
actually makes sense and we do use that feature in some tests. """ actually makes sense and we do use that feature in some tests. """
def send_challenge (self, auth_type): def send_challenge(self, auth_type):
auth_type = auth_type.lower() auth_type = auth_type.lower()
if auth_type == "both": if auth_type == "both":
self.send_challenge ("basic") self.send_challenge("basic")
self.send_challenge ("digest") self.send_challenge("digest")
return return
if auth_type == "basic": if auth_type == "basic":
challenge_str = 'BasIc realm="Wget-Test"' challenge_str = 'BasIc realm="Wget-Test"'
elif auth_type == "digest" or auth_type == "both_inline": elif auth_type == "digest" or auth_type == "both_inline":
self.nonce = md5 (str (random ()).encode ('utf-8')).hexdigest() self.nonce = md5(str(random()).encode('utf-8')).hexdigest()
self.opaque = md5 (str (random ()).encode ('utf-8')).hexdigest() self.opaque = md5(str(random()).encode('utf-8')).hexdigest()
# 'DIgest' to provoke a Wget failure with turkish locales # 'DIgest' to provoke a Wget failure with turkish locales
challenge_str = 'DIgest realm="Test", nonce="%s", opaque="%s"' % ( challenge_str = 'DIgest realm="Test", nonce="%s", opaque="%s"' % (
self.nonce, self.nonce,
@ -239,18 +244,18 @@ class _Handler (BaseHTTPRequestHandler):
if auth_type == "both_inline": if auth_type == "both_inline":
# 'BasIc' to provoke a Wget failure with turkish locales # 'BasIc' to provoke a Wget failure with turkish locales
challenge_str = 'BasIc realm="Wget-Test", ' + challenge_str challenge_str = 'BasIc realm="Wget-Test", ' + challenge_str
self.send_header ("WWW-Authenticate", challenge_str) self.send_header("WWW-Authenticate", challenge_str)
def authorize_basic (self, auth_header, auth_rule): def authorize_basic(self, auth_header, auth_rule):
if auth_header is None or auth_header.split(' ')[0].lower() != 'basic': if auth_header is None or auth_header.split(' ')[0].lower() != 'basic':
return False return False
else: else:
self.user = auth_rule.auth_user self.user = auth_rule.auth_user
self.passw = auth_rule.auth_pass self.passw = auth_rule.auth_pass
auth_str = "basic " + self.base64 (self.user + ":" + self.passw) auth_str = "basic " + self.base64(self.user + ":" + self.passw)
return True if auth_str.lower() == auth_header.lower() else False return True if auth_str.lower() == auth_header.lower() else False
def parse_auth_header (self, auth_header): def parse_auth_header(self, auth_header):
n = len("digest ") n = len("digest ")
auth_header = auth_header[n:].strip() auth_header = auth_header[n:].strip()
items = auth_header.split(", ") items = auth_header.split(", ")
@ -258,38 +263,39 @@ class _Handler (BaseHTTPRequestHandler):
keyvals = [(k.strip(), v.strip().replace('"', '')) for k, v in keyvals] keyvals = [(k.strip(), v.strip().replace('"', '')) for k, v in keyvals]
return dict(keyvals) return dict(keyvals)
def KD (self, secret, data): def KD(self, secret, data):
return self.H (secret + ":" + data) return self.H(secret + ":" + data)
def H (self, data): def H(self, data):
return md5 (data.encode ('utf-8')).hexdigest () return md5(data.encode('utf-8')).hexdigest()
def A1 (self): def A1(self):
return "%s:%s:%s" % (self.user, "Test", self.passw) return "%s:%s:%s" % (self.user, "Test", self.passw)
def A2 (self, params): def A2(self, params):
return "%s:%s" % (self.command, params["uri"]) return "%s:%s" % (self.command, params["uri"])
def check_response (self, params): def check_response(self, params):
if "qop" in params: if "qop" in params:
data_str = params['nonce'] \ data_str = params['nonce'] \
+ ":" + params['nc'] \ + ":" + params['nc'] \
+ ":" + params['cnonce'] \ + ":" + params['cnonce'] \
+ ":" + params['qop'] \ + ":" + params['qop'] \
+ ":" + self.H (self.A2 (params)) + ":" + self.H(self.A2(params))
else: else:
data_str = params['nonce'] + ":" + self.H (self.A2 (params)) data_str = params['nonce'] + ":" + self.H(self.A2(params))
resp = self.KD (self.H (self.A1 ()), data_str) resp = self.KD(self.H(self.A1()), data_str)
return True if resp == params['response'] else False return True if resp == params['response'] else False
def authorize_digest (self, auth_header, auth_rule): def authorize_digest(self, auth_header, auth_rule):
if auth_header is None or auth_header.split(' ')[0].lower() != 'digest': if auth_header is None or \
auth_header.split(' ')[0].lower() != 'digest':
return False return False
else: else:
self.user = auth_rule.auth_user self.user = auth_rule.auth_user
self.passw = auth_rule.auth_pass self.passw = auth_rule.auth_pass
params = self.parse_auth_header (auth_header) params = self.parse_auth_header(auth_header)
if self.user != params['username'] or \ if self.user != params['username'] or \
self.nonce != params['nonce'] or \ self.nonce != params['nonce'] or \
self.opaque != params['opaque']: self.opaque != params['opaque']:
@ -298,67 +304,72 @@ class _Handler (BaseHTTPRequestHandler):
for attrib in req_attribs: for attrib in req_attribs:
if attrib not in params: if attrib not in params:
return False return False
if not self.check_response (params): if not self.check_response(params):
return False return False
def authorize_both (self, auth_header, auth_rule): def authorize_both(self, auth_header, auth_rule):
return False return False
def authorize_both_inline (self, auth_header, auth_rule): def authorize_both_inline(self, auth_header, auth_rule):
return False return False
def Authentication (self, auth_rule): def Authentication(self, auth_rule):
try: try:
self.handle_auth (auth_rule) self.handle_auth(auth_rule)
except AuthError as se: except AuthError as se:
self.send_response (401, "Authorization Required") self.send_response(401, "Authorization Required")
self.send_challenge (auth_rule.auth_type) self.send_challenge(auth_rule.auth_type)
self.finish_headers () self.finish_headers()
raise se raise se
def handle_auth (self, auth_rule): def handle_auth(self, auth_rule):
is_auth = True is_auth = True
auth_header = self.headers.get ("Authorization") auth_header = self.headers.get("Authorization")
required_auth = auth_rule.auth_type.lower() required_auth = auth_rule.auth_type.lower()
if required_auth == "both" or required_auth == "both_inline": if required_auth == "both" or required_auth == "both_inline":
auth_type = auth_header.split(' ')[0].lower() if auth_header else required_auth if auth_header:
auth_type = auth_header.split(' ')[0].lower()
else:
auth_type = required_auth
else: else:
auth_type = required_auth auth_type = required_auth
try: try:
assert hasattr (self, "authorize_" + auth_type) assert hasattr(self, "authorize_" + auth_type)
is_auth = getattr (self, "authorize_" + auth_type) (auth_header, auth_rule) is_auth = getattr(self, "authorize_" + auth_type)(auth_header,
auth_rule)
except AssertionError: except AssertionError:
raise AuthError ("Authentication Mechanism " + auth_type + " not supported") raise AuthError("Authentication Mechanism %s not supported" %
auth_type)
except AttributeError as ae: except AttributeError as ae:
raise AuthError (ae.__str__()) raise AuthError(ae.__str__())
if is_auth is False: if is_auth is False:
raise AuthError ("Unable to Authenticate") raise AuthError("Unable to Authenticate")
def ExpectHeader(self, header_obj):
def ExpectHeader (self, header_obj):
exp_headers = header_obj.headers exp_headers = header_obj.headers
for header_line in exp_headers: for header_line in exp_headers:
header_recd = self.headers.get (header_line) header_recd = self.headers.get(header_line)
if header_recd is None or header_recd != exp_headers[header_line]: if header_recd is None or header_recd != exp_headers[header_line]:
self.send_error (400, "Expected Header " + header_line + " not found") self.send_error(400, "Expected Header %s not found" %
self.finish_headers () header_line)
raise ServerError ("Header " + header_line + " not found") self.finish_headers()
raise ServerError("Header " + header_line + " not found")
def RejectHeader(self, header_obj):
def RejectHeader (self, header_obj):
rej_headers = header_obj.headers rej_headers = header_obj.headers
for header_line in rej_headers: for header_line in rej_headers:
header_recd = self.headers.get (header_line) header_recd = self.headers.get(header_line)
if header_recd is not None and header_recd == rej_headers[header_line]: if not header_recd and header_recd == rej_headers[header_line]:
self.send_error (400, 'Blacklisted Header ' + header_line + ' received') self.send_error(400, 'Blacklisted Header %s received' %
self.finish_headers () header_line)
raise ServerError ("Header " + header_line + ' received') self.finish_headers()
raise ServerError("Header " + header_line + ' received')
def __log_request (self, method): def __log_request(self, method):
req = method + " " + self.path req = method + " " + self.path
self.server.request_headers.append (req) self.server.request_headers.append(req)
def send_head (self, method): def send_head(self, method):
""" Common code for GET and HEAD Commands. """ Common code for GET and HEAD Commands.
This method is overriden to use the fileSys dict. This method is overriden to use the fileSys dict.
@ -372,88 +383,89 @@ class _Handler (BaseHTTPRequestHandler):
else: else:
path = self.path[1:] path = self.path[1:]
self.__log_request (method) self.__log_request(method)
if path in self.server.fileSys: if path in self.server.fileSys:
self.rules = self.server.server_configs.get (path) self.rules = self.server.server_configs.get(path)
content = self.server.fileSys.get (path) content = self.server.fileSys.get(path)
content_length = len (content) content_length = len(content)
for rule_name in self.rules: for rule_name in self.rules:
try: try:
assert hasattr (self, rule_name) assert hasattr(self, rule_name)
getattr (self, rule_name) (self.rules [rule_name]) getattr(self, rule_name)(self.rules[rule_name])
except AssertionError as ae: except AssertionError as ae:
msg = "Rule " + rule_name + " not defined" msg = "Rule " + rule_name + " not defined"
self.send_error (500, msg) self.send_error(500, msg)
return (None, None) return(None, None)
except AuthError as ae: except AuthError as ae:
print (ae.__str__()) print(ae.__str__())
return (None, None) return(None, None)
except ServerError as se: except ServerError as se:
print (se.__str__()) print(se.__str__())
return (content, None) return(content, None)
try: try:
self.range_begin = self.parse_range_header ( self.range_begin = self.parse_range_header(
self.headers.get ("Range"), content_length) self.headers.get("Range"), content_length)
except ServerError as ae: except ServerError as ae:
# self.log_error("%s", ae.err_message) # self.log_error("%s", ae.err_message)
if ae.err_message == "Range Overflow": if ae.err_message == "Range Overflow":
self.send_response (416) self.send_response(416)
self.finish_headers () self.finish_headers()
return (None, None) return(None, None)
else: else:
self.range_begin = None self.range_begin = None
if self.range_begin is None: if self.range_begin is None:
self.send_response (200) self.send_response(200)
else: else:
self.send_response (206) self.send_response(206)
self.add_header ("Accept-Ranges", "bytes") self.add_header("Accept-Ranges", "bytes")
self.add_header ("Content-Range", self.add_header("Content-Range",
"bytes %d-%d/%d" % (self.range_begin, "bytes %d-%d/%d" % (self.range_begin,
content_length - 1, content_length - 1,
content_length)) content_length))
content_length -= self.range_begin content_length -= self.range_begin
cont_type = self.guess_type (path) cont_type = self.guess_type(path)
self.add_header ("Content-Type", cont_type) self.add_header("Content-Type", cont_type)
self.add_header ("Content-Length", content_length) self.add_header("Content-Length", content_length)
self.finish_headers () self.finish_headers()
return (content, self.range_begin) return(content, self.range_begin)
else: else:
self.send_error (404, "Not Found") self.send_error(404, "Not Found")
return (None, None) return(None, None)
def guess_type (self, path): def guess_type(self, path):
base_name = basename ("/" + path) base_name = basename("/" + path)
name, ext = splitext (base_name) name, ext = splitext(base_name)
extension_map = { extension_map = {
".txt" : "text/plain", ".txt": "text/plain",
".css" : "text/css", ".css": "text/css",
".html" : "text/html" ".html": "text/html"
} }
return extension_map.get(ext, "text/plain") return extension_map.get(ext, "text/plain")
class HTTPd (threading.Thread):
class HTTPd(threading.Thread):
server_class = StoppableHTTPServer server_class = StoppableHTTPServer
handler = _Handler handler = _Handler
def __init__ (self, addr=None): def __init__(self, addr=None):
threading.Thread.__init__ (self) threading.Thread.__init__(self)
if addr is None: if addr is None:
addr = ('localhost', 0) addr = ('localhost', 0)
self.server_inst = self.server_class (addr, self.handler) self.server_inst = self.server_class(addr, self.handler)
self.server_address = self.server_inst.socket.getsockname()[:2] self.server_address = self.server_inst.socket.getsockname()[:2]
def run (self): def run(self):
self.server_inst.serve_forever () self.server_inst.serve_forever()
def server_conf (self, file_list, server_rules): def server_conf(self, file_list, server_rules):
self.server_inst.server_conf (file_list, server_rules) self.server_inst.server_conf(file_list, server_rules)
class HTTPSd (HTTPd): class HTTPSd(HTTPd):
server_class = HTTPSServer server_class = HTTPSServer
# vim: set ts=4 sts=4 sw=4 tw=80 et : # vim: set ts=4 sts=4 sw=4 tw=79 et :

View File

@ -28,9 +28,10 @@ class BaseTest:
Attributes should not be defined outside __init__. Attributes should not be defined outside __init__.
""" """
self.name = name self.name = name
self.pre_configs = pre_hook or {} # if pre_hook == None, then # if pre_hook == None, then {} (an empty dict object) is passed to
# {} (an empty dict object) is # self.pre_configs
# passed to self.pre_configs self.pre_configs = pre_hook or {}
self.test_params = test_params or {} self.test_params = test_params or {}
self.post_configs = post_hook or {} self.post_configs = post_hook or {}
self.protocols = protocols self.protocols = protocols
@ -109,11 +110,16 @@ class BaseTest:
if gdb == "1": if gdb == "1":
cmd_line = 'gdb --args %s %s ' % (wget_path, wget_options) cmd_line = 'gdb --args %s %s ' % (wget_path, wget_options)
elif valgrind == "1": elif valgrind == "1":
cmd_line = 'valgrind --error-exitcode=301 --leak-check=yes --track-origins=yes %s %s ' % (wget_path, wget_options) cmd_line = 'valgrind --error-exitcode=301 ' \
'--leak-check=yes ' \
'--track-origins=yes ' \
'%s %s ' % (wget_path, wget_options)
elif valgrind not in ("", "0"): elif valgrind not in ("", "0"):
cmd_line = '%s %s %s ' % (os.getenv("VALGRIND_TESTS", ""), wget_path, wget_options) cmd_line = '%s %s %s ' % (os.getenv("VALGRIND_TESTS", ""),
wget_path,
wget_options)
else: else:
cmd_line = '%s %s ' % (wget_path, wget_options) cmd_line = '%s %s ' % (wget_path, wget_options)
for protocol, urls, domain in zip(self.protocols, for protocol, urls, domain in zip(self.protocols,
self.urls, self.urls,
@ -139,12 +145,12 @@ class BaseTest:
if not os.getenv("NO_CLEANUP"): if not os.getenv("NO_CLEANUP"):
shutil.rmtree(self.get_test_dir()) shutil.rmtree(self.get_test_dir())
except: except:
print ("Unknown Exception while trying to remove Test Environment.") print("Unknown Exception while trying to remove Test Environment.")
def _exit_test (self): def _exit_test(self):
self.__test_cleanup() self.__test_cleanup()
def begin (self): def begin(self):
return 0 if self.tests_passed else 100 return 0 if self.tests_passed else 100
def call_test(self): def call_test(self):
@ -181,16 +187,17 @@ class BaseTest:
def post_hook_call(self): def post_hook_call(self):
self.hook_call(self.post_configs, 'Post Test Function') self.hook_call(self.post_configs, 'Post Test Function')
def _replace_substring (self, string): def _replace_substring(self, string):
""" """
Replace first occurrence of "{{name}}" in @string with "getattr(self, name)". Replace first occurrence of "{{name}}" in @string with
"getattr(self, name)".
""" """
pattern = re.compile (r'\{\{\w+\}\}') pattern = re.compile(r'\{\{\w+\}\}')
match_obj = pattern.search (string) match_obj = pattern.search(string)
if match_obj is not None: if match_obj is not None:
rep = match_obj.group() rep = match_obj.group()
temp = getattr (self, rep.strip ('{}')) temp = getattr(self, rep.strip('{}'))
string = string.replace (rep, temp) string = string.replace(rep, temp)
return string return string
def instantiate_server_by(self, protocol): def instantiate_server_by(self, protocol):