Browse Source

add helper functions take from the existing acme_updater and acme-tlsa-mail scripts, create main function

Helmut Pozimski 7 years ago
parent
commit
08bca0f599
3 changed files with 247 additions and 1 deletions
  1. 6 1
      acme-updater
  2. 230 0
      amulib/helpers.py
  3. 11 0
      amulib/main.py

+ 6 - 1
acme-updater

@@ -1,2 +1,7 @@
 #! /usr/bin/env python3
-# -*- coding: utf8 -*-
+# -*- coding: utf8 -*-
+
+from amulib.main import main
+
+if __name__ == "__main__":
+    main()

+ 230 - 0
amulib/helpers.py

@@ -0,0 +1,230 @@
+"""
+collection of helper functions used in other modules of acme-updater.
+"""
+# -*- coding: utf8 -*-
+
+import logging
+import datetime
+import os
+import shutil
+import hashlib
+
+import OpenSSL
+import dns.tsigkeyring
+import dns.update
+import dns.query
+
+LOGGER = logging.getLogger("acme-updater")
+
+
+def parse_apache_vhost(file_obj):
+    """
+    Parses a given vhost file and extracts the main domain,
+    the certificate file and the TLS key file.
+
+    :param file_obj: file obj pointing to a vhost to parse
+    :return: list of tuples with domains and found certificates
+    :rtype: list
+    """
+    vhost_started = False
+    parsed_info = []
+    cert_path = ""
+    key_path = ""
+    main_domain = ""
+    for line in file_obj:
+        if "<VirtualHost" in line:
+            vhost_started = True
+        elif "</VirtualHost" in line and vhost_started:
+            vhost_started = False
+            if cert_path and key_path and main_domain:
+                parsed_info.append((main_domain, cert_path, key_path))
+                LOGGER.debug(
+                    "Found vhost with main domain %s, certificate %s and key "
+                    "file %s", main_domain, cert_path, key_path)
+                cert_path = ""
+                key_path = ""
+                main_domain = ""
+        elif "ServerName" in line:
+            main_domain = line.strip().rsplit()[1]
+        elif "SSLCertificateFile" in line:
+            cert_path = line.strip().rsplit()[1]
+        elif "SSLCertificateKeyFile" in line:
+            key_path = line.strip().rsplit()[1]
+    return parsed_info
+
+
+def parse_asn1_time(timestamp):
+    """
+    parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
+    datetime object.
+
+    :param timestamp: ASN1 timestamp
+    :type timestamp: str
+    :return: timestamp as datetime object
+    :rtype: datetime
+    """
+    year = int(timestamp[:4])
+    month = int(timestamp[4:6])
+    day = int(timestamp[6:8])
+    date = datetime.datetime(year, month, day)
+    return date
+
+
+def copy_file(source, destination, backup=True):
+    """
+    Copies a file from the given source file to
+    the given destination and optionally creates a copy of the
+    destination file.
+
+    :param source: source file path
+    :type source: str
+    :param destination: destination file path
+    :type destination: str
+    :param backup: whether to take a backup of the destination file before \
+    overwriting it
+    :type backup: bool
+    :return: success
+    :rtype: bool
+    """
+    backup_file = destination + ".bak_%s" % datetime.datetime.now().strftime(
+        "%Y%m%d%H%M%S")
+    if backup:
+        try:
+            shutil.copy(destination, backup_file)
+        except IOError:
+            LOGGER.error("Creating of backup file for %s failed!", destination)
+            return False
+    try:
+        shutil.copy(source, destination)
+    except IOError:
+        LOGGER.error("Copying of file %s to %s failed!",
+                     source, destination)
+    else:
+        os.chmod(destination, 0o0644)
+        os.chown(destination, 0, 0)
+        return True
+
+
+def check_renewal(cert, cert_path):
+    """
+    Checks if the certificate has been renewed.
+
+    :param cert: the certificate that needs to be checked
+    :type cert: OpenSSL.crypto.X509
+    :param cert_path: absolute path to the certificate file
+    :type cert_path: str
+    :return: renewal status
+    :rtype: bool
+    """
+    try:
+        with open(cert_path, "r") as acme_cert_file:
+            acme_cert_text = acme_cert_file.read()
+    except IOError:
+        LOGGER.error("Could not open certificate %s in acme "
+                     "state directory", cert_path)
+    else:
+        x509_acme_cert = OpenSSL.crypto.load_certificate(
+            OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
+        )
+        expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
+        expiry_datetime = parse_asn1_time(expiry_date)
+        if expiry_datetime < datetime.datetime.utcnow():
+            LOGGER.warning("Certificate %s is expired and no newer "
+                           "one is available, bailing out!", cert_path)
+            return False
+        else:
+            serial_current_cert = cert.get_serial_number()
+            serial_acme_cert = x509_acme_cert.get_serial_number()
+            if serial_current_cert == serial_acme_cert:
+                LOGGER.debug("Cert %s matches with the one "
+                             "installed, nothing to do.", cert_path)
+                return False
+            else:
+                return True
+
+
+def create_tlsa_hash(cert):
+    """
+    Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
+    :param cert: certificate to be used
+    :type cert: OpenSSL.crypto.X509
+    :return: sha256 has of the public key
+    :rtype: str
+    """
+    pubkey = cert.get_pubkey()
+    pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
+                                               pubkey)
+    sha256 = hashlib.sha256()
+    sha256.update(pubkey_der)
+    hexdigest = sha256.hexdigest()
+    return hexdigest
+
+
+def get_tsig_key(named_key_path):
+    """
+    Reads the named session key and generates a keyring object for it.
+
+    :param named_key_path: Path to the named session key
+    :type named_key_path: str
+    :return: keyring, algorithm
+    :rtype: tuple
+    """
+    key_name = None
+    key_algorithm = None
+    secret = None
+    try:
+        with open(named_key_path, "r") as bind_key:
+            for line in bind_key:
+                if "key" in line:
+                    key_name = line.split(" ")[1].strip("\"")
+                elif "algorithm" in line:
+                    key_algorithm = line.strip().split(" ")[1].strip(";")
+                elif "secret" in line:
+                    secret = line.strip().split(" ")[1].strip("\"").strip(";")
+    except IOError:
+        LOGGER.error("Error while opening the bind session key")
+        return None, None
+    else:
+        if key_name and key_algorithm and secret:
+            keyring = dns.tsigkeyring.from_text({
+                key_name: secret
+            })
+            return keyring, key_algorithm
+        else:
+            return None, None
+
+
+def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
+                       subdomain="", ttl=300, protocol="tcp"):
+    """
+    Updates the tlsa record on the DNS server.
+
+    :param zone: Zone of the (sub) domain
+    :type zone: str
+    :param tlsa_port: port for the tlsa record
+    :type tlsa_port: str
+    :param digest: cryptographic hash of the certificate public key
+    :type digest: str
+    :param keyring: keyring object
+    :type keyring: dict
+    :param keyalgorithm: algorithm used for the tsig key
+    :type keyalgorithm: str
+    :param subdomain: subdomain to create the tlsa record for
+    :type subdomain: str
+    :param ttl: TTL to use for the TLSA record
+    :type ttl: int
+    :param protocol: protocol for the TLSA record
+    :type protocol: str
+    :returns: response of the operation
+    :rtype: dns.message.Message
+    """
+    update = dns.update.Update(zone, keyring=keyring,
+                               keyalgorithm=keyalgorithm)
+    tlsa_content = "3 1 1 %s" % digest
+    if subdomain:
+        tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
+    else:
+        tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
+    update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
+    response = dns.query.tcp(update, 'localhost')
+    return response

+ 11 - 0
amulib/main.py

@@ -0,0 +1,11 @@
+# -*- coding: utf8 -*-
+
+import argparse
+
+
+def main():
+    """
+    Main function of acme-updater.
+    """
+    parser = argparse.ArgumentParser()
+    parser.parse_args()