diff --git a/requirements-test.txt b/requirements-test.txt index 131ea0e8..6c530e58 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -3,4 +3,5 @@ coveralls<=1.2.0 mock<4.0.0 # Pinned because version 4 dropped support for Python 2.7 codecov +testfixtures<7.0.0 pre-commit diff --git a/ssm/agents.py b/ssm/agents.py index 07ebeed7..9f40e7bf 100644 --- a/ssm/agents.py +++ b/ssm/agents.py @@ -300,6 +300,10 @@ def run_receiver(protocol, brokers, project, token, cp, log, dn_file): dns = get_dns(dn_file, log) ssm.set_dns(dns) + log.info('Fetching banned DNs.') + banned_dns = get_banned_dns(log, cp) + ssm.set_banned_dns(banned_dns) + except Exception as e: log.fatal('Failed to initialise SSM: %s', e) log.info(LOG_BREAK) @@ -392,3 +396,27 @@ def get_dns(dn_file, log): log.debug('%s DNs found.', len(dns)) return dns + + +def get_banned_dns(log, cp): + """Retrieve the list of banned dns""" + banned_dns = [] + try: + banned_dns_path = cp.get('auth', 'banned-dns') + banned_dns_file = os.path.normpath( + os.path.expandvars(banned_dns_path)) + except ConfigParser.NoOptionError: + banned_dns_file = None + + with open(banned_dns_file, 'r') as f: + lines = f.readlines() + for line in lines: + if line.isspace() or line.strip().startswith('#'): + continue + elif line.strip().startswith('/'): + banned_dns.append(line.strip()) + else: + log.warning('DN in banned dns list is not in ' + 'the correct format: %s', line) + + return banned_dns diff --git a/ssm/ssm2.py b/ssm/ssm2.py index 9cbc28e2..a70e16f9 100644 --- a/ssm/ssm2.py +++ b/ssm/ssm2.py @@ -90,6 +90,7 @@ def __init__(self, hosts_and_ports, qpath, cert, key, dest=None, listen=None, self._dest = dest self._valid_dns = [] + self._banned_dns = [] self._pidfile = pidfile # Used to differentiate between STOMP and AMS methods @@ -189,6 +190,10 @@ def set_dns(self, dn_list): """Set the list of DNs which are allowed to sign incoming messages.""" self._valid_dns = dn_list + def set_banned_dns(self, banned_dn_list): + """Set the list of banned dns, so their messages can be dropped.""" + self._banned_dns = banned_dn_list + ########################################################################## # Methods called by stomppy ########################################################################## @@ -283,8 +288,10 @@ def _handle_msg(self, text): Namely: - decrypt if necessary - verify signature + - send an error message if the message wasn't sent from a valid DN - Return plain-text message, signer's DN and an error/None. """ + if text is None or text == '': warning = 'Empty text passed to _handle_msg.' log.warning(warning) @@ -307,10 +314,23 @@ def _handle_msg(self, text): log.error(error) return None, None, error - if signer not in self._valid_dns: + # If the message has been sent from a banned DN, + # set a specific error message that can be + # checked for later. + if signer in self._banned_dns: + warning = 'Signer is in the banned DNs list' + log.warning(warning) + return None, signer, warning + + # Else, if the signer is not in valid DNs list, + # but also not a banned dn, + # set a specific error message + elif signer not in self._valid_dns: warning = 'Signer not in valid DNs list: %s' % signer log.warning(warning) return None, signer, warning + + # Else, the message has been sent from a valid DN else: log.info('Valid signer: %s', signer) @@ -320,9 +340,15 @@ def _save_msg_to_queue(self, body, empaid): """Extract message contents and add to the accept or reject queue.""" extracted_msg, signer, err_msg = self._handle_msg(body) try: + # If the warning states the message was sent from a banned DN, + # don't send the message to the reject queue. + # Instead, drop the message (don't send it to any queue) + if err_msg == "Signer is in the banned DNs list": + log.warning("Message dropped as was sent from a banned dn: %s", signer) + # If the message is empty or the error message is not empty # then reject the message. - if extracted_msg is None or err_msg is not None: + elif extracted_msg is None or err_msg is not None: if signer is None: # crypto failed signer = 'Not available.' elif extracted_msg is not None: diff --git a/test/test_ssm.py b/test/test_ssm.py index 5f96fd78..51773e32 100644 --- a/test/test_ssm.py +++ b/test/test_ssm.py @@ -4,11 +4,19 @@ import shutil import tempfile import unittest -from subprocess import call +from mock import patch +from subprocess import call, Popen, PIPE +import logging + +# For Python 2.7, make sure testfixtures version is < 7.0.0 +# testfixtures version 6.18.5 works fine +# run: pip install testfixtures==6.18.5 +from testfixtures import LogCapture from ssm.message_directory import MessageDirectory from ssm.ssm2 import Ssm2, Ssm2Exception +logging.basicConfig(level=logging.INFO) class TestSsm(unittest.TestCase): ''' @@ -134,6 +142,206 @@ def test_ssm_init_non_dirq(self): # Assert the outbound queue is of the expected type. self.assertTrue(isinstance(ssm._outq, MessageDirectory)) +class TestMsgToQueue(unittest.TestCase): + ''' + Class used for testing how messages are sent to queues based + upon the DN that sent them. + The _handle_msg() function called by _save_msg_to_queue() + (the function we are testing) is being mocked, as it fails + due to client signers and certificates not matching. + It is easier to mock out the failing function instead, as it is + not the function we are testing in this test. To test + _save_msg_to_queue, we are assuming the message's certificate and + signer match. + We then use the log output from ssm2.py to see if the messages + are sent to the queues we expect them to be. + ''' + + def setUp(self): + # Create temporary directory for message queues and pidfiles + self.dir_path = tempfile.mkdtemp() + + # Set up a test directory and certificates + self._tmp_dir = tempfile.mkdtemp(prefix='ssm') + + self._brokers = [('not.a.broker', 123)] + self._capath = '/not/a/path' + self._check_crls = False + self._pidfile = self._tmp_dir + '/pidfile' + + self._listen = '/topic/test' + self._dest = '/topic/test' + + self._msgdir = tempfile.mkdtemp(prefix='msgq') + + # Create a key/cert pair + call(['openssl', 'req', '-x509', '-nodes', '-days', '2', + '-newkey', 'rsa:2048', '-keyout', TEST_KEY_FILE, + '-out', TEST_CERT_FILE, '-subj', + '/C=UK/O=STFC/OU=SC/CN=Test Cert']) + + def tearDown(self): + # Remove test directory and all contents + try: + shutil.rmtree(self._msgdir) + shutil.rmtree(self._tmp_dir) + except OSError as e: + print('Error removing temporary directory %s' % self._tmp_dir) + print(e) + + + @patch.object(Ssm2, '_handle_msg') + def test_dns_saved_to_queue(self, mock_handle_msg): + ''' + Test that messages sent from different dns are sent + to the correct queue or dropped, depending on their dn. + The logging output is checked to tell us if the message + was sent to the correct queue, or dropped completely. + Therefore we don't need to create incoming or reject queues. + ''' + + # Create a list of fake valid dns that will send the messages + # These should be sent to the incoming queue + valid_dns = ("/C=UK/O=eScience/OU=CLRC/L=RAL/CN=valid-1.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=valid-2.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=valid-3.esc.rl.ac.uk") + + # Create a list of fake dns that will result in a rejected message + # These should be sent to the rejected queue + rejected_dns = ("/C=UK/O=eScience/OU=CLRC/L=RAL/CN=rejected-1.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=rejected-2.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=rejected-3.esc.rl.ac.uk") + + # Create a list of fake banned dns that feature in the banned dn list + # These should be dropped, and not sent to a queue + banned_dns = ("/C=UK/O=eScience/OU=CLRC/L=RAL/CN=banned-1.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=banned-2.esc.rl.ac.uk", + "/C=UK/O=eScience/OU=CLRC/L=RAL/CN=banned-3.esc.rl.ac.uk") + + # A custom empaid isn't necessary, and can just be 1 + empaid = "1" + + # Set up an openssl-style CA directory, containing the + # self-signed certificate as its own CA certificate, but with its + # name as .0. + p1 = Popen(['openssl', 'x509', '-subject_hash', '-noout'], + stdin=PIPE, stdout=PIPE, stderr=PIPE, + universal_newlines=True) + + with open(TEST_CERT_FILE, 'r') as test_cert: + cert_string = test_cert.read() + + hash_name, _unused_error = p1.communicate(cert_string) + + self.ca_certpath = os.path.join(TEST_CA_DIR, hash_name.strip() + '.0') + with open(self.ca_certpath, 'w') as ca_cert: + ca_cert.write(cert_string) + + # For each dn in the valid dns list, + # pass it and the message to ssm and use the log output to + # make sure it was dealt with correctly. + for dn in valid_dns: + # Capture the log output so we can use it in assertions + with LogCapture() as log: + message_valid = """APEL-summary-job-message: v0.2 + Site: RAL-LCG2 + Month: 3 + Year: 2010 + GlobalUserName: """ + dn + """ + VO: atlas + VOGroup: /atlas + VORole: Role=production + WallDuration: 234256 + CpuDuration: 2345 + NumberOfJobs: 100 + %%""" + + mock_handle_msg.return_value = message_valid, dn, None + + ssm = Ssm2(self._brokers, self._msgdir, TEST_CERT_FILE, + TEST_KEY_FILE, dest=self._dest, listen=self._listen, + capath=self.ca_certpath) + + ssm._save_msg_to_queue(message_valid, empaid) + + self.assertIn('Message saved to incoming queue', str(log)) + + # For each dn in the rejected dns list, + # pass it and the message to ssm and use the log output to + # make sure it was dealt with correctly. + # As there are several different ways messages can be rejected, + # keep a count to test a different method for each dn + dnCount = 1 + for dn in rejected_dns: + # Capture the log output so we can use it in assertions + with LogCapture() as log: + message_rejected = """APEL-summary-job-message: v0.2 + Site: RAL-LCG2 + Month: 3 + Year: 2010 + GlobalUserName: """ + dn + """ + VO: atlas + VOGroup: /atlas + VORole: Role=production + WallDuration: 234256 + CpuDuration: 2345 + NumberOfJobs: 100 + %%""" + + # Change the reason for method rejection for each dn + if dnCount == 1: + # Pass no message, which will also need an error message + mock_handle_msg.return_value = None, dn, "Empty text passed to _handle_msg" + elif dnCount == 2: + # Pass an error message + mock_handle_msg.return_value = message_rejected, dn, "Signer not in valid DNs list" + else: + # Pass a different error message + mock_handle_msg.return_value = message_rejected, dn, "Failed to verify message" + + ssm = Ssm2(self._brokers, self._msgdir, TEST_CERT_FILE, + TEST_KEY_FILE, dest=self._dest, listen=self._listen, + capath=self.ca_certpath) + + ssm._save_msg_to_queue(message_rejected, empaid) + + self.assertIn('Message saved to reject queue', str(log)) + + dnCount = dnCount + 1 + + # For each dn in the banned dns list, + # pass it and the message to ssm and use the log output to + # make sure it was dealt with correctly. + for dn in banned_dns: + # Capture the log output so we can use it in assertions + with LogCapture() as log: + message_banned = """APEL-summary-job-message: v0.2 + Site: RAL-LCG2 + Month: 3 + Year: 2010 + GlobalUserName: """ + dn + """ + VO: atlas + VOGroup: /atlas + VORole: Role=production + WallDuration: 234256 + CpuDuration: 2345 + NumberOfJobs: 100 + %%""" + + mock_handle_msg.return_value = message_banned, dn, "Signer is in the banned DNs list" + + ssm = Ssm2(self._brokers, self._msgdir, TEST_CERT_FILE, + TEST_KEY_FILE, dest=self._dest, listen=self._listen, + capath=self.ca_certpath) + + ssm._save_msg_to_queue(message_banned, empaid) + + self.assertIn('Message dropped as was sent from a banned dn', str(log)) + + +TEST_KEY_FILE = '/tmp/test.key' + +TEST_CA_DIR='/tmp' TEST_CERT_FILE = '/tmp/test.crt'