helpers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # SPDX-FileCopyrightText: 2016-2017 Helmut Pozimski <helmut@pozimski.eu>
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. # -*- coding: utf8 -*-
  5. """
  6. collection of helper functions used in other modules of acme-updater.
  7. """
  8. import logging
  9. import datetime
  10. import os
  11. import shutil
  12. import hashlib
  13. import OpenSSL
  14. import dns.tsigkeyring
  15. import dns.update
  16. import dns.query
  17. LOGGER = logging.getLogger("acme-updater")
  18. def parse_apache_vhost(file_obj):
  19. """
  20. Parses a given vhost file and extracts the main domain,
  21. the certificate file, the TLS key file and all domains contained
  22. within the vhost.
  23. :param file_obj: file obj pointing to a vhost to parse
  24. :return: list of tuples with domains and found certificates
  25. :rtype: list
  26. """
  27. vhost_started = False
  28. parsed_info = []
  29. cert_path = ""
  30. key_path = ""
  31. main_domain = ""
  32. domains = set()
  33. for line in file_obj:
  34. if "<VirtualHost" in line:
  35. vhost_started = True
  36. elif "</VirtualHost" in line and vhost_started:
  37. vhost_started = False
  38. if cert_path and key_path and main_domain and domains:
  39. parsed_info.append((main_domain, cert_path, key_path, domains))
  40. LOGGER.debug(
  41. "Found vhost with main domain %s, certificate %s and key "
  42. "file %s", main_domain, cert_path, key_path)
  43. cert_path = ""
  44. key_path = ""
  45. main_domain = ""
  46. elif "ServerName" in line:
  47. main_domain = line.strip().rsplit()[1]
  48. domains.add(line.strip().rsplit()[1])
  49. elif "SSLCertificateFile" in line:
  50. cert_path = line.strip().rsplit()[1]
  51. elif "SSLCertificateKeyFile" in line:
  52. key_path = line.strip().rsplit()[1]
  53. elif "ServerAlias" in line:
  54. for domain in line.strip().rsplit()[1].split(" "):
  55. domains.add(domain)
  56. return parsed_info
  57. def parse_asn1_time(timestamp):
  58. """
  59. parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
  60. datetime object.
  61. :param timestamp: ASN1 timestamp
  62. :type timestamp: str
  63. :return: timestamp as datetime object
  64. :rtype: datetime
  65. """
  66. year = int(timestamp[:4])
  67. month = int(timestamp[4:6])
  68. day = int(timestamp[6:8])
  69. date = datetime.datetime(year, month, day)
  70. return date
  71. def copy_file(source, destination, backup=True):
  72. """
  73. Copies a file from the given source file to
  74. the given destination and optionally creates a copy of the
  75. destination file.
  76. :param source: source file path
  77. :type source: str
  78. :param destination: destination file path
  79. :type destination: str
  80. :param backup: whether to take a backup of the destination file before \
  81. overwriting it
  82. :type backup: bool
  83. :return: success
  84. :rtype: bool
  85. """
  86. if backup:
  87. if not create_backup_copy(destination):
  88. return False
  89. try:
  90. shutil.copy(source, destination)
  91. except IOError:
  92. LOGGER.error("Copying of file %s to %s failed!",
  93. source, destination)
  94. else:
  95. os.chmod(destination, 0o0644)
  96. os.chown(destination, 0, 0)
  97. return True
  98. def create_backup_copy(source):
  99. """
  100. creates a backup file of a specified source file.
  101. :param source: source file path
  102. :type source: str
  103. :return: success
  104. :rtype: bool
  105. """
  106. backup_file = source + ".bak_%s" % datetime.datetime.now().strftime(
  107. "%Y%m%d%H%M%S")
  108. try:
  109. shutil.copy(source, backup_file)
  110. except IOError:
  111. LOGGER.error("Creating of backup file for %s failed!", source)
  112. return False
  113. else:
  114. return True
  115. def check_renewal(cert, cert_path):
  116. """
  117. Checks if the certificate has been renewed.
  118. :param cert: the certificate that needs to be checked
  119. :type cert: OpenSSL.crypto.X509
  120. :param cert_path: absolute path to the certificate file
  121. :type cert_path: str
  122. :return: renewal status
  123. :rtype: bool
  124. """
  125. try:
  126. with open(cert_path, "r") as acme_cert_file:
  127. acme_cert_text = acme_cert_file.read()
  128. except IOError:
  129. LOGGER.error("Could not open certificate %s in acme "
  130. "state directory", cert_path)
  131. else:
  132. x509_acme_cert = OpenSSL.crypto.load_certificate(
  133. OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
  134. )
  135. expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
  136. expiry_datetime = parse_asn1_time(expiry_date)
  137. if expiry_datetime < datetime.datetime.utcnow():
  138. LOGGER.warning("Certificate %s is expired and no newer "
  139. "one is available, bailing out!", cert_path)
  140. return False
  141. else:
  142. serial_current_cert = cert.get_serial_number()
  143. serial_acme_cert = x509_acme_cert.get_serial_number()
  144. if serial_current_cert == serial_acme_cert:
  145. LOGGER.debug("Cert %s matches with the one "
  146. "installed, nothing to do.", cert_path)
  147. return False
  148. else:
  149. return True
  150. def create_tlsa_hash(cert):
  151. """
  152. Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
  153. :param cert: certificate to be used
  154. :type cert: OpenSSL.crypto.X509
  155. :return: sha256 has of the public key
  156. :rtype: str
  157. """
  158. pubkey = cert.get_pubkey()
  159. pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
  160. pubkey)
  161. sha256 = hashlib.sha256()
  162. sha256.update(pubkey_der)
  163. hexdigest = sha256.hexdigest()
  164. return hexdigest
  165. def get_tsig_key(named_key_path):
  166. """
  167. Reads the named session key and generates a keyring object for it.
  168. :param named_key_path: Path to the named session key
  169. :type named_key_path: str
  170. :return: keyring, algorithm
  171. :rtype: tuple
  172. """
  173. key_name = None
  174. key_algorithm = None
  175. secret = None
  176. try:
  177. with open(named_key_path, "r") as bind_key:
  178. for line in bind_key:
  179. if "key" in line:
  180. key_name = line.split(" ")[1].strip("\"")
  181. elif "algorithm" in line:
  182. key_algorithm = line.strip().split(" ")[1].strip(";")
  183. elif "secret" in line:
  184. secret = line.strip().split(" ")[1].strip("\"").strip(";")
  185. except IOError:
  186. LOGGER.error("Error while opening the bind session key")
  187. return None, None
  188. else:
  189. if key_name and key_algorithm and secret:
  190. keyring = dns.tsigkeyring.from_text({
  191. key_name: secret
  192. })
  193. return keyring, key_algorithm
  194. else:
  195. return None, None
  196. def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
  197. subdomain="", ttl=300, protocol="tcp",
  198. dns_server="localhost"):
  199. """
  200. Updates the tlsa record on the DNS server.
  201. :param zone: Zone of the (sub) domain
  202. :type zone: str
  203. :param tlsa_port: port for the tlsa record
  204. :type tlsa_port: str
  205. :param digest: cryptographic hash of the certificate public key
  206. :type digest: str
  207. :param keyring: keyring object
  208. :type keyring: dict
  209. :param keyalgorithm: algorithm used for the tsig key
  210. :type keyalgorithm: str
  211. :param subdomain: subdomain to create the tlsa record for
  212. :type subdomain: str
  213. :param ttl: TTL to use for the TLSA record
  214. :type ttl: int
  215. :param protocol: protocol for the TLSA record
  216. :type protocol: str
  217. :param dns_server: DNS server to use to create TLSA records
  218. :type dns_server: str
  219. :returns: response of the operation
  220. :rtype: dns.message.Message
  221. """
  222. update = dns.update.Update(zone, keyring=keyring,
  223. keyalgorithm=keyalgorithm)
  224. tlsa_content = "3 1 1 %s" % digest
  225. if subdomain:
  226. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
  227. else:
  228. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
  229. update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
  230. response = dns.query.tcp(update, dns_server)
  231. return response
  232. def get_log_level(input_level=""):
  233. """
  234. Determines the log level to use based on a string.
  235. :param input_level: String representing the desired log level.
  236. :type input_level: str
  237. :return: corresponding log level of the logging module
  238. :rtype: int
  239. """
  240. if input_level.lower() == "debug":
  241. return logging.DEBUG
  242. elif input_level.lower() == "error":
  243. return logging.ERROR
  244. else:
  245. return logging.INFO
  246. def create_tlsa_records(domain, port, certificate, named_key_path,
  247. dns_server):
  248. """
  249. Creates tlsa records for the specified (sub-)domain
  250. :param domain: (sub-)domain the records are to be created for
  251. :type domain: str
  252. :param port: port to use for the record
  253. :type port: str
  254. :param certificate: certificate object used for record creation
  255. :type certificate: OpenSSL.crypto.X509
  256. :param named_key_path: path to the named session key
  257. :type named_key_path: str
  258. :param dns_server: DNS server to use to create TLSA records
  259. :type dns_server: str
  260. """
  261. hash_digest = create_tlsa_hash(certificate)
  262. zone = "%s.%s" % (domain.split(".")[-2], domain.split(".")[-1])
  263. tsig, keyalgo = get_tsig_key(named_key_path)
  264. update_tlsa_record(zone, port, hash_digest, tsig, keyalgo, domain,
  265. dns_server=dns_server)
  266. def get_subject_alt_name(certificate):
  267. """
  268. Extracts the subjectAltName entries from a X509 certficiate
  269. :param certificate: the certificate to extract the subjectAltName \
  270. entries from
  271. :type certificate: OpenSSL.crypto.X509
  272. :return: list of hostnames
  273. :rtype: list
  274. """
  275. list = []
  276. for i in range(0, certificate.get_extension_count(), 1):
  277. if certificate.get_extension(i).get_short_name() == b"subjectAltName":
  278. extension_string = str(certificate.get_extension(i))
  279. for entry in extension_string.split(","):
  280. list.append(entry.split(":")[1])
  281. break
  282. return list