helpers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # SPDX-FileCopyrightText: 2016-2023 Helmut Pozimski <helmut@pozimski.eu>
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. # -*- coding: utf8 -*-
  5. """
  6. collection of helper functions used in other modules of acme-updater.
  7. """
  8. import logging
  9. import datetime
  10. import os
  11. import shutil
  12. import hashlib
  13. import subprocess
  14. import OpenSSL
  15. import dns.tsigkeyring
  16. import dns.update
  17. import dns.query
  18. LOGGER = logging.getLogger("acme-updater")
  19. def parse_apache_vhost(file_obj):
  20. """
  21. Parses a given vhost file and extracts the main domain,
  22. the certificate file, the TLS key file and all domains contained
  23. within the vhost.
  24. :param file_obj: file obj pointing to a vhost to parse
  25. :return: list of tuples with domains and found certificates
  26. :rtype: list
  27. """
  28. vhost_started = False
  29. parsed_info = []
  30. cert_path = ""
  31. key_path = ""
  32. main_domain = ""
  33. domains = set()
  34. for line in file_obj:
  35. if "<VirtualHost" in line:
  36. vhost_started = True
  37. elif "</VirtualHost" in line and vhost_started:
  38. vhost_started = False
  39. if cert_path and key_path and main_domain and domains:
  40. parsed_info.append((main_domain, cert_path, key_path, domains))
  41. LOGGER.debug(
  42. "Found vhost with main domain %s, certificate %s and key "
  43. "file %s", main_domain, cert_path, key_path)
  44. cert_path = ""
  45. key_path = ""
  46. main_domain = ""
  47. elif "ServerName" in line:
  48. main_domain = line.strip().rsplit()[1]
  49. domains.add(line.strip().rsplit()[1])
  50. elif "SSLCertificateFile" in line:
  51. cert_path = line.strip().rsplit()[1]
  52. elif "SSLCertificateKeyFile" in line:
  53. key_path = line.strip().rsplit()[1]
  54. elif "ServerAlias" in line:
  55. for domain in line.strip().rsplit()[1].split(" "):
  56. domains.add(domain)
  57. return parsed_info
  58. def parse_asn1_time(timestamp):
  59. """
  60. parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
  61. datetime object.
  62. :param timestamp: ASN1 timestamp
  63. :type timestamp: str
  64. :return: timestamp as datetime object
  65. :rtype: datetime
  66. """
  67. year = int(timestamp[:4])
  68. month = int(timestamp[4:6])
  69. day = int(timestamp[6:8])
  70. date = datetime.datetime(year, month, day)
  71. return date
  72. def copy_file(source, destination, backup=True):
  73. """
  74. Copies a file from the given source file to
  75. the given destination and optionally creates a copy of the
  76. destination file.
  77. :param source: source file path
  78. :type source: str
  79. :param destination: destination file path
  80. :type destination: str
  81. :param backup: whether to take a backup of the destination file before \
  82. overwriting it
  83. :type backup: bool
  84. :return: success
  85. :rtype: bool
  86. """
  87. if backup:
  88. if not create_backup_copy(destination):
  89. return False
  90. try:
  91. shutil.copy(source, destination)
  92. except IOError:
  93. LOGGER.error("Copying of file %s to %s failed!",
  94. source, destination)
  95. else:
  96. os.chmod(destination, 0o0644)
  97. os.chown(destination, 0, 0)
  98. return True
  99. def create_backup_copy(source):
  100. """
  101. creates a backup file of a specified source file.
  102. :param source: source file path
  103. :type source: str
  104. :return: success
  105. :rtype: bool
  106. """
  107. backup_file = source + ".bak_%s" % datetime.datetime.now().strftime(
  108. "%Y%m%d%H%M%S")
  109. try:
  110. shutil.copy(source, backup_file)
  111. except IOError:
  112. LOGGER.error("Creating of backup file for %s failed!", source)
  113. return False
  114. else:
  115. return True
  116. def check_renewal(cert, cert_path):
  117. """
  118. Checks if the certificate has been renewed.
  119. :param cert: the certificate that needs to be checked
  120. :type cert: OpenSSL.crypto.X509
  121. :param cert_path: absolute path to the certificate file
  122. :type cert_path: str
  123. :return: renewal status
  124. :rtype: bool
  125. """
  126. try:
  127. with open(cert_path, "r") as acme_cert_file:
  128. acme_cert_text = acme_cert_file.read()
  129. except IOError:
  130. LOGGER.error("Could not open certificate %s in acme "
  131. "state directory", cert_path)
  132. else:
  133. x509_acme_cert = OpenSSL.crypto.load_certificate(
  134. OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
  135. )
  136. expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
  137. expiry_datetime = parse_asn1_time(expiry_date)
  138. if expiry_datetime < datetime.datetime.utcnow():
  139. LOGGER.warning("Certificate %s is expired and no newer "
  140. "one is available, bailing out!", cert_path)
  141. return False
  142. else:
  143. serial_current_cert = cert.get_serial_number()
  144. serial_acme_cert = x509_acme_cert.get_serial_number()
  145. if serial_current_cert == serial_acme_cert:
  146. LOGGER.debug("Cert %s matches with the one "
  147. "installed, nothing to do.", cert_path)
  148. return False
  149. else:
  150. return True
  151. def create_tlsa_hash(cert):
  152. """
  153. Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
  154. :param cert: certificate to be used
  155. :type cert: OpenSSL.crypto.X509
  156. :return: sha256 has of the public key
  157. :rtype: str
  158. """
  159. pubkey = cert.get_pubkey()
  160. pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
  161. pubkey)
  162. sha256 = hashlib.sha256()
  163. sha256.update(pubkey_der)
  164. hexdigest = sha256.hexdigest()
  165. return hexdigest
  166. def get_tsig_key(named_key_path):
  167. """
  168. Reads the named session key and generates a keyring object for it.
  169. :param named_key_path: Path to the named session key
  170. :type named_key_path: str
  171. :return: keyring, algorithm
  172. :rtype: tuple
  173. """
  174. key_name = None
  175. key_algorithm = None
  176. secret = None
  177. try:
  178. with open(named_key_path, "r") as bind_key:
  179. for line in bind_key:
  180. if "key" in line:
  181. key_name = line.split(" ")[1].strip("\"")
  182. elif "algorithm" in line:
  183. key_algorithm = line.strip().split(" ")[1].strip(";")
  184. elif "secret" in line:
  185. secret = line.strip().split(" ")[1].strip("\"").strip(";")
  186. except IOError:
  187. LOGGER.error("Error while opening the bind session key")
  188. return None, None
  189. else:
  190. if key_name and key_algorithm and secret:
  191. keyring = dns.tsigkeyring.from_text({
  192. key_name: secret
  193. })
  194. return keyring, key_algorithm
  195. else:
  196. return None, None
  197. def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
  198. subdomain="", ttl=300, protocol="tcp",
  199. dns_server="localhost"):
  200. """
  201. Updates the tlsa record on the DNS server.
  202. :param zone: Zone of the (sub) domain
  203. :type zone: str
  204. :param tlsa_port: port for the tlsa record
  205. :type tlsa_port: str
  206. :param digest: cryptographic hash of the certificate public key
  207. :type digest: str
  208. :param keyring: keyring object
  209. :type keyring: dict
  210. :param keyalgorithm: algorithm used for the tsig key
  211. :type keyalgorithm: str
  212. :param subdomain: subdomain to create the tlsa record for
  213. :type subdomain: str
  214. :param ttl: TTL to use for the TLSA record
  215. :type ttl: int
  216. :param protocol: protocol for the TLSA record
  217. :type protocol: str
  218. :param dns_server: DNS server to use to create TLSA records
  219. :type dns_server: str
  220. :returns: response of the operation
  221. :rtype: dns.message.Message
  222. """
  223. update = dns.update.Update(zone, keyring=keyring,
  224. keyalgorithm=keyalgorithm)
  225. tlsa_content = "3 1 1 %s" % digest
  226. if subdomain:
  227. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
  228. else:
  229. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
  230. update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
  231. response = dns.query.tcp(update, dns_server)
  232. return response
  233. def get_log_level(input_level=""):
  234. """
  235. Determines the log level to use based on a string.
  236. :param input_level: String representing the desired log level.
  237. :type input_level: str
  238. :return: corresponding log level of the logging module
  239. :rtype: int
  240. """
  241. if input_level.lower() == "debug":
  242. return logging.DEBUG
  243. elif input_level.lower() == "error":
  244. return logging.ERROR
  245. else:
  246. return logging.INFO
  247. def create_tlsa_records(domain, port, certificate, named_key_path,
  248. dns_server):
  249. """
  250. Creates tlsa records for the specified (sub-)domain
  251. :param domain: (sub-)domain the records are to be created for
  252. :type domain: str
  253. :param port: port to use for the record
  254. :type port: str
  255. :param certificate: certificate object used for record creation
  256. :type certificate: OpenSSL.crypto.X509
  257. :param named_key_path: path to the named session key
  258. :type named_key_path: str
  259. :param dns_server: DNS server to use to create TLSA records
  260. :type dns_server: str
  261. """
  262. hash_digest = create_tlsa_hash(certificate)
  263. zone = "%s.%s" % (domain.split(".")[-2], domain.split(".")[-1])
  264. tsig, keyalgo = get_tsig_key(named_key_path)
  265. update_tlsa_record(zone, port, hash_digest, tsig, keyalgo, domain,
  266. dns_server=dns_server)
  267. def get_subject_alt_name(certificate) -> list:
  268. """
  269. Extracts the subjectAltName entries from a X509 certficiate
  270. :param certificate: the certificate to extract the subjectAltName \
  271. entries from
  272. :type certificate: OpenSSL.crypto.X509
  273. :return: list of hostnames
  274. :rtype: list
  275. """
  276. alt_names = []
  277. for i in range(0, certificate.get_extension_count(), 1):
  278. if certificate.get_extension(i).get_short_name() == b"subjectAltName":
  279. extension_string = str(certificate.get_extension(i))
  280. for entry in extension_string.split(","):
  281. alt_names.append(entry.split(":")[1])
  282. break
  283. return alt_names
  284. def restart_service(service_name: str):
  285. if os.path.exists("/run/systemd/system"):
  286. subprocess.call(["/usr/bin/systemctl", "restart", service_name])
  287. else:
  288. subprocess.call(["/etc/init.d/%s" % service_name,
  289. "restart"])