|
- # SPDX-FileCopyrightText: 2016-2023 Helmut Pozimski <helmut@pozimski.eu>
- #
- # SPDX-License-Identifier: GPL-2.0-only
- # -*- coding: utf8 -*-
- """
- collection of helper functions used in other modules of acme-updater.
- """
- import logging
- import datetime
- import os
- import shutil
- import hashlib
- import subprocess
- import OpenSSL
- import dns.tsigkeyring
- import dns.update
- import dns.query
- from typing import List
- from amulib.vhost_entry import ApacheVhostEntry
- LOGGER = logging.getLogger("acme-updater")
- def parse_apache_vhost(file_obj) -> List[ApacheVhostEntry]:
- """
- Parses a given vhost file and extracts the main domain,
- the certificate file, the TLS key file and all domains contained
- within the vhost.
- :param file_obj: file obj pointing to a vhost to parse
- :return: list of tuples with domains and found certificates
- :rtype: list
- """
- vhost_started = False
- parsed_info = []
- cert_path = ""
- key_path = ""
- main_domain = ""
- domains = set()
- for line in file_obj:
- if "<VirtualHost" in line:
- vhost_started = True
- elif "</VirtualHost" in line and vhost_started:
- vhost_started = False
- if cert_path and key_path and main_domain and domains:
- parsed_info.append(ApacheVhostEntry(main_domain, cert_path, key_path, domains))
- LOGGER.debug(
- "Found vhost with main domain %s, certificate %s and key "
- "file %s", main_domain, cert_path, key_path)
- cert_path = ""
- key_path = ""
- main_domain = ""
- elif "ServerName" in line:
- main_domain = line.strip().rsplit()[1]
- domains.add(line.strip().rsplit()[1])
- elif "SSLCertificateFile" in line:
- cert_path = line.strip().rsplit()[1]
- elif "SSLCertificateKeyFile" in line:
- key_path = line.strip().rsplit()[1]
- elif "ServerAlias" in line:
- for domain in line.strip().rsplit()[1].split(" "):
- domains.add(domain)
- return parsed_info
- def parse_asn1_time(timestamp):
- """
- parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
- datetime object.
- :param timestamp: ASN1 timestamp
- :type timestamp: str
- :return: timestamp as datetime object
- :rtype: datetime
- """
- year = int(timestamp[:4])
- month = int(timestamp[4:6])
- day = int(timestamp[6:8])
- date = datetime.datetime(year, month, day)
- return date
- def copy_file(source, destination, backup=True):
- """
- Copies a file from the given source file to
- the given destination and optionally creates a copy of the
- destination file.
- :param source: source file path
- :type source: str
- :param destination: destination file path
- :type destination: str
- :param backup: whether to take a backup of the destination file before \
- overwriting it
- :type backup: bool
- :return: success
- :rtype: bool
- """
- if backup:
- if not create_backup_copy(destination):
- return False
- try:
- shutil.copy(source, destination)
- except IOError:
- LOGGER.error("Copying of file %s to %s failed!",
- source, destination)
- else:
- os.chmod(destination, 0o0644)
- os.chown(destination, 0, 0)
- return True
- def create_backup_copy(source):
- """
- creates a backup file of a specified source file.
- :param source: source file path
- :type source: str
- :return: success
- :rtype: bool
- """
- backup_file = source + ".bak_%s" % datetime.datetime.now().strftime(
- "%Y%m%d%H%M%S")
- try:
- shutil.copy(source, backup_file)
- except IOError:
- LOGGER.error("Creating of backup file for %s failed!", source)
- return False
- else:
- return True
- def check_renewal(cert, cert_path):
- """
- Checks if the certificate has been renewed.
- :param cert: the certificate that needs to be checked
- :type cert: OpenSSL.crypto.X509
- :param cert_path: absolute path to the certificate file
- :type cert_path: str
- :return: renewal status
- :rtype: bool
- """
- try:
- with open(cert_path, "r") as acme_cert_file:
- acme_cert_text = acme_cert_file.read()
- except IOError:
- LOGGER.error("Could not open certificate %s in acme "
- "state directory", cert_path)
- else:
- x509_acme_cert = OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
- )
- expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
- expiry_datetime = parse_asn1_time(expiry_date)
- if expiry_datetime < datetime.datetime.utcnow():
- LOGGER.warning("Certificate %s is expired and no newer "
- "one is available, bailing out!", cert_path)
- return False
- else:
- serial_current_cert = cert.get_serial_number()
- serial_acme_cert = x509_acme_cert.get_serial_number()
- if serial_current_cert == serial_acme_cert:
- LOGGER.debug("Cert %s matches with the one "
- "installed, nothing to do.", cert_path)
- return False
- else:
- return True
- def create_tlsa_hash(cert):
- """
- Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
- :param cert: certificate to be used
- :type cert: OpenSSL.crypto.X509
- :return: sha256 has of the public key
- :rtype: str
- """
- pubkey = cert.get_pubkey()
- pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
- pubkey)
- sha256 = hashlib.sha256()
- sha256.update(pubkey_der)
- hexdigest = sha256.hexdigest()
- return hexdigest
- def get_tsig_key(named_key_path):
- """
- Reads the named session key and generates a keyring object for it.
- :param named_key_path: Path to the named session key
- :type named_key_path: str
- :return: keyring, algorithm
- :rtype: tuple
- """
- key_name = None
- key_algorithm = None
- secret = None
- try:
- with open(named_key_path, "r") as bind_key:
- for line in bind_key:
- if "key" in line:
- key_name = line.split(" ")[1].strip("\"")
- elif "algorithm" in line:
- key_algorithm = line.strip().split(" ")[1].strip(";")
- elif "secret" in line:
- secret = line.strip().split(" ")[1].strip("\"").strip(";")
- except IOError:
- LOGGER.error("Error while opening the bind session key")
- return None, None
- else:
- if key_name and key_algorithm and secret:
- keyring = dns.tsigkeyring.from_text({
- key_name: secret
- })
- return keyring, key_algorithm
- else:
- return None, None
- def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
- subdomain="", ttl=300, protocol="tcp",
- dns_server="localhost"):
- """
- Updates the tlsa record on the DNS server.
- :param zone: Zone of the (sub) domain
- :type zone: str
- :param tlsa_port: port for the tlsa record
- :type tlsa_port: str
- :param digest: cryptographic hash of the certificate public key
- :type digest: str
- :param keyring: keyring object
- :type keyring: dict
- :param keyalgorithm: algorithm used for the tsig key
- :type keyalgorithm: str
- :param subdomain: subdomain to create the tlsa record for
- :type subdomain: str
- :param ttl: TTL to use for the TLSA record
- :type ttl: int
- :param protocol: protocol for the TLSA record
- :type protocol: str
- :param dns_server: DNS server to use to create TLSA records
- :type dns_server: str
- :returns: response of the operation
- :rtype: dns.message.Message
- """
- update = dns.update.Update(zone, keyring=keyring,
- keyalgorithm=keyalgorithm)
- tlsa_content = "3 1 1 %s" % digest
- if subdomain:
- tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
- else:
- tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
- update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
- response = dns.query.tcp(update, dns_server)
- return response
- def get_log_level(input_level=""):
- """
- Determines the log level to use based on a string.
- :param input_level: String representing the desired log level.
- :type input_level: str
- :return: corresponding log level of the logging module
- :rtype: int
- """
- if input_level.lower() == "debug":
- return logging.DEBUG
- elif input_level.lower() == "error":
- return logging.ERROR
- else:
- return logging.INFO
- def create_tlsa_records(domain, port, certificate, named_key_path,
- dns_server):
- """
- Creates tlsa records for the specified (sub-)domain
- :param domain: (sub-)domain the records are to be created for
- :type domain: str
- :param port: port to use for the record
- :type port: str
- :param certificate: certificate object used for record creation
- :type certificate: OpenSSL.crypto.X509
- :param named_key_path: path to the named session key
- :type named_key_path: str
- :param dns_server: DNS server to use to create TLSA records
- :type dns_server: str
- """
- hash_digest = create_tlsa_hash(certificate)
- zone = "%s.%s" % (domain.split(".")[-2], domain.split(".")[-1])
- tsig, keyalgo = get_tsig_key(named_key_path)
- update_tlsa_record(zone, port, hash_digest, tsig, keyalgo, domain,
- dns_server=dns_server)
- def get_subject_alt_name(certificate) -> list:
- """
- Extracts the subjectAltName entries from a X509 certficiate
- :param certificate: the certificate to extract the subjectAltName \
- entries from
- :type certificate: OpenSSL.crypto.X509
- :return: list of hostnames
- :rtype: list
- """
- alt_names = []
- for i in range(0, certificate.get_extension_count(), 1):
- if certificate.get_extension(i).get_short_name() == b"subjectAltName":
- extension_string = str(certificate.get_extension(i))
- for entry in extension_string.split(","):
- alt_names.append(entry.split(":")[1])
- break
- return alt_names
- def restart_service(service_name: str):
- if os.path.exists("/run/systemd/system"):
- subprocess.call(["/usr/bin/systemctl", "restart", service_name])
- else:
- subprocess.call(["/etc/init.d/%s" % service_name,
- "restart"])
|