helpers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # SPDX-FileCopyrightText: 2016-2023 Helmut Pozimski <helmut@pozimski.eu>
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. # -*- coding: utf8 -*-
  5. """
  6. collection of helper functions used in other modules of acme-updater.
  7. """
  8. import logging
  9. import datetime
  10. import os
  11. import shutil
  12. import hashlib
  13. import subprocess
  14. import OpenSSL
  15. import dns.tsigkeyring
  16. import dns.update
  17. import dns.query
  18. from typing import List
  19. from amulib.vhost_entry import ApacheVhostEntry
  20. LOGGER = logging.getLogger("acme-updater")
  21. def parse_apache_vhost(file_obj) -> List[ApacheVhostEntry]:
  22. """
  23. Parses a given vhost file and extracts the main domain,
  24. the certificate file, the TLS key file and all domains contained
  25. within the vhost.
  26. :param file_obj: file obj pointing to a vhost to parse
  27. :return: list of tuples with domains and found certificates
  28. :rtype: list
  29. """
  30. vhost_started = False
  31. parsed_info = []
  32. cert_path = ""
  33. key_path = ""
  34. main_domain = ""
  35. domains = set()
  36. for line in file_obj:
  37. if "<VirtualHost" in line:
  38. vhost_started = True
  39. elif "</VirtualHost" in line and vhost_started:
  40. vhost_started = False
  41. if cert_path and key_path and main_domain and domains:
  42. parsed_info.append(ApacheVhostEntry(main_domain, cert_path, key_path, domains))
  43. LOGGER.debug(
  44. "Found vhost with main domain %s, certificate %s and key "
  45. "file %s", main_domain, cert_path, key_path)
  46. cert_path = ""
  47. key_path = ""
  48. main_domain = ""
  49. elif "ServerName" in line:
  50. main_domain = line.strip().rsplit()[1]
  51. domains.add(line.strip().rsplit()[1])
  52. elif "SSLCertificateFile" in line:
  53. cert_path = line.strip().rsplit()[1]
  54. elif "SSLCertificateKeyFile" in line:
  55. key_path = line.strip().rsplit()[1]
  56. elif "ServerAlias" in line:
  57. for domain in line.strip().rsplit()[1].split(" "):
  58. domains.add(domain)
  59. return parsed_info
  60. def parse_asn1_time(timestamp):
  61. """
  62. parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
  63. datetime object.
  64. :param timestamp: ASN1 timestamp
  65. :type timestamp: str
  66. :return: timestamp as datetime object
  67. :rtype: datetime
  68. """
  69. year = int(timestamp[:4])
  70. month = int(timestamp[4:6])
  71. day = int(timestamp[6:8])
  72. date = datetime.datetime(year, month, day)
  73. return date
  74. def copy_file(source, destination, backup=True):
  75. """
  76. Copies a file from the given source file to
  77. the given destination and optionally creates a copy of the
  78. destination file.
  79. :param source: source file path
  80. :type source: str
  81. :param destination: destination file path
  82. :type destination: str
  83. :param backup: whether to take a backup of the destination file before \
  84. overwriting it
  85. :type backup: bool
  86. :return: success
  87. :rtype: bool
  88. """
  89. if backup:
  90. if not create_backup_copy(destination):
  91. return False
  92. try:
  93. shutil.copy(source, destination)
  94. except IOError:
  95. LOGGER.error("Copying of file %s to %s failed!",
  96. source, destination)
  97. else:
  98. os.chmod(destination, 0o0644)
  99. os.chown(destination, 0, 0)
  100. return True
  101. def create_backup_copy(source):
  102. """
  103. creates a backup file of a specified source file.
  104. :param source: source file path
  105. :type source: str
  106. :return: success
  107. :rtype: bool
  108. """
  109. backup_file = source + ".bak_%s" % datetime.datetime.now().strftime(
  110. "%Y%m%d%H%M%S")
  111. try:
  112. shutil.copy(source, backup_file)
  113. except IOError:
  114. LOGGER.error("Creating of backup file for %s failed!", source)
  115. return False
  116. else:
  117. return True
  118. def check_renewal(cert, cert_path):
  119. """
  120. Checks if the certificate has been renewed.
  121. :param cert: the certificate that needs to be checked
  122. :type cert: OpenSSL.crypto.X509
  123. :param cert_path: absolute path to the certificate file
  124. :type cert_path: str
  125. :return: renewal status
  126. :rtype: bool
  127. """
  128. try:
  129. with open(cert_path, "r") as acme_cert_file:
  130. acme_cert_text = acme_cert_file.read()
  131. except IOError:
  132. LOGGER.error("Could not open certificate %s in acme "
  133. "state directory", cert_path)
  134. else:
  135. x509_acme_cert = OpenSSL.crypto.load_certificate(
  136. OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
  137. )
  138. expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
  139. expiry_datetime = parse_asn1_time(expiry_date)
  140. if expiry_datetime < datetime.datetime.utcnow():
  141. LOGGER.warning("Certificate %s is expired and no newer "
  142. "one is available, bailing out!", cert_path)
  143. return False
  144. else:
  145. serial_current_cert = cert.get_serial_number()
  146. serial_acme_cert = x509_acme_cert.get_serial_number()
  147. if serial_current_cert == serial_acme_cert:
  148. LOGGER.debug("Cert %s matches with the one "
  149. "installed, nothing to do.", cert_path)
  150. return False
  151. else:
  152. return True
  153. def create_tlsa_hash(cert):
  154. """
  155. Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
  156. :param cert: certificate to be used
  157. :type cert: OpenSSL.crypto.X509
  158. :return: sha256 has of the public key
  159. :rtype: str
  160. """
  161. pubkey = cert.get_pubkey()
  162. pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
  163. pubkey)
  164. sha256 = hashlib.sha256()
  165. sha256.update(pubkey_der)
  166. hexdigest = sha256.hexdigest()
  167. return hexdigest
  168. def get_tsig_key(named_key_path):
  169. """
  170. Reads the named session key and generates a keyring object for it.
  171. :param named_key_path: Path to the named session key
  172. :type named_key_path: str
  173. :return: keyring, algorithm
  174. :rtype: tuple
  175. """
  176. key_name = None
  177. key_algorithm = None
  178. secret = None
  179. try:
  180. with open(named_key_path, "r") as bind_key:
  181. for line in bind_key:
  182. if "key" in line:
  183. key_name = line.split(" ")[1].strip("\"")
  184. elif "algorithm" in line:
  185. key_algorithm = line.strip().split(" ")[1].strip(";")
  186. elif "secret" in line:
  187. secret = line.strip().split(" ")[1].strip("\"").strip(";")
  188. except IOError:
  189. LOGGER.error("Error while opening the bind session key")
  190. return None, None
  191. else:
  192. if key_name and key_algorithm and secret:
  193. keyring = dns.tsigkeyring.from_text({
  194. key_name: secret
  195. })
  196. return keyring, key_algorithm
  197. else:
  198. return None, None
  199. def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
  200. subdomain="", ttl=300, protocol="tcp",
  201. dns_server="localhost"):
  202. """
  203. Updates the tlsa record on the DNS server.
  204. :param zone: Zone of the (sub) domain
  205. :type zone: str
  206. :param tlsa_port: port for the tlsa record
  207. :type tlsa_port: str
  208. :param digest: cryptographic hash of the certificate public key
  209. :type digest: str
  210. :param keyring: keyring object
  211. :type keyring: dict
  212. :param keyalgorithm: algorithm used for the tsig key
  213. :type keyalgorithm: str
  214. :param subdomain: subdomain to create the tlsa record for
  215. :type subdomain: str
  216. :param ttl: TTL to use for the TLSA record
  217. :type ttl: int
  218. :param protocol: protocol for the TLSA record
  219. :type protocol: str
  220. :param dns_server: DNS server to use to create TLSA records
  221. :type dns_server: str
  222. :returns: response of the operation
  223. :rtype: dns.message.Message
  224. """
  225. update = dns.update.Update(zone, keyring=keyring,
  226. keyalgorithm=keyalgorithm)
  227. tlsa_content = "3 1 1 %s" % digest
  228. if subdomain:
  229. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
  230. else:
  231. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
  232. update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
  233. response = dns.query.tcp(update, dns_server)
  234. return response
  235. def get_log_level(input_level=""):
  236. """
  237. Determines the log level to use based on a string.
  238. :param input_level: String representing the desired log level.
  239. :type input_level: str
  240. :return: corresponding log level of the logging module
  241. :rtype: int
  242. """
  243. if input_level.lower() == "debug":
  244. return logging.DEBUG
  245. elif input_level.lower() == "error":
  246. return logging.ERROR
  247. else:
  248. return logging.INFO
  249. def create_tlsa_records(domain, port, certificate, named_key_path,
  250. dns_server):
  251. """
  252. Creates tlsa records for the specified (sub-)domain
  253. :param domain: (sub-)domain the records are to be created for
  254. :type domain: str
  255. :param port: port to use for the record
  256. :type port: str
  257. :param certificate: certificate object used for record creation
  258. :type certificate: OpenSSL.crypto.X509
  259. :param named_key_path: path to the named session key
  260. :type named_key_path: str
  261. :param dns_server: DNS server to use to create TLSA records
  262. :type dns_server: str
  263. """
  264. hash_digest = create_tlsa_hash(certificate)
  265. zone = "%s.%s" % (domain.split(".")[-2], domain.split(".")[-1])
  266. tsig, keyalgo = get_tsig_key(named_key_path)
  267. update_tlsa_record(zone, port, hash_digest, tsig, keyalgo, domain,
  268. dns_server=dns_server)
  269. def get_subject_alt_name(certificate) -> list:
  270. """
  271. Extracts the subjectAltName entries from a X509 certficiate
  272. :param certificate: the certificate to extract the subjectAltName \
  273. entries from
  274. :type certificate: OpenSSL.crypto.X509
  275. :return: list of hostnames
  276. :rtype: list
  277. """
  278. alt_names = []
  279. for i in range(0, certificate.get_extension_count(), 1):
  280. if certificate.get_extension(i).get_short_name() == b"subjectAltName":
  281. extension_string = str(certificate.get_extension(i))
  282. for entry in extension_string.split(","):
  283. alt_names.append(entry.split(":")[1])
  284. break
  285. return alt_names
  286. def restart_service(service_name: str):
  287. if os.path.exists("/run/systemd/system"):
  288. subprocess.call(["/usr/bin/systemctl", "restart", service_name])
  289. else:
  290. subprocess.call(["/etc/init.d/%s" % service_name,
  291. "restart"])