helpers.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. """
  2. collection of helper functions used in other modules of acme-updater.
  3. """
  4. # -*- coding: utf8 -*-
  5. import logging
  6. import datetime
  7. import os
  8. import shutil
  9. import hashlib
  10. import OpenSSL
  11. import dns.tsigkeyring
  12. import dns.update
  13. import dns.query
  14. LOGGER = logging.getLogger("acme-updater")
  15. def parse_apache_vhost(file_obj):
  16. """
  17. Parses a given vhost file and extracts the main domain,
  18. the certificate file and the TLS key file.
  19. :param file_obj: file obj pointing to a vhost to parse
  20. :return: list of tuples with domains and found certificates
  21. :rtype: list
  22. """
  23. vhost_started = False
  24. parsed_info = []
  25. cert_path = ""
  26. key_path = ""
  27. main_domain = ""
  28. for line in file_obj:
  29. if "<VirtualHost" in line:
  30. vhost_started = True
  31. elif "</VirtualHost" in line and vhost_started:
  32. vhost_started = False
  33. if cert_path and key_path and main_domain:
  34. parsed_info.append((main_domain, cert_path, key_path))
  35. LOGGER.debug(
  36. "Found vhost with main domain %s, certificate %s and key "
  37. "file %s", main_domain, cert_path, key_path)
  38. cert_path = ""
  39. key_path = ""
  40. main_domain = ""
  41. elif "ServerName" in line:
  42. main_domain = line.strip().rsplit()[1]
  43. elif "SSLCertificateFile" in line:
  44. cert_path = line.strip().rsplit()[1]
  45. elif "SSLCertificateKeyFile" in line:
  46. key_path = line.strip().rsplit()[1]
  47. return parsed_info
  48. def parse_asn1_time(timestamp):
  49. """
  50. parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
  51. datetime object.
  52. :param timestamp: ASN1 timestamp
  53. :type timestamp: str
  54. :return: timestamp as datetime object
  55. :rtype: datetime
  56. """
  57. year = int(timestamp[:4])
  58. month = int(timestamp[4:6])
  59. day = int(timestamp[6:8])
  60. date = datetime.datetime(year, month, day)
  61. return date
  62. def copy_file(source, destination, backup=True):
  63. """
  64. Copies a file from the given source file to
  65. the given destination and optionally creates a copy of the
  66. destination file.
  67. :param source: source file path
  68. :type source: str
  69. :param destination: destination file path
  70. :type destination: str
  71. :param backup: whether to take a backup of the destination file before \
  72. overwriting it
  73. :type backup: bool
  74. :return: success
  75. :rtype: bool
  76. """
  77. backup_file = destination + ".bak_%s" % datetime.datetime.now().strftime(
  78. "%Y%m%d%H%M%S")
  79. if backup:
  80. try:
  81. shutil.copy(destination, backup_file)
  82. except IOError:
  83. LOGGER.error("Creating of backup file for %s failed!", destination)
  84. return False
  85. try:
  86. shutil.copy(source, destination)
  87. except IOError:
  88. LOGGER.error("Copying of file %s to %s failed!",
  89. source, destination)
  90. else:
  91. os.chmod(destination, 0o0644)
  92. os.chown(destination, 0, 0)
  93. return True
  94. def check_renewal(cert, cert_path):
  95. """
  96. Checks if the certificate has been renewed.
  97. :param cert: the certificate that needs to be checked
  98. :type cert: OpenSSL.crypto.X509
  99. :param cert_path: absolute path to the certificate file
  100. :type cert_path: str
  101. :return: renewal status
  102. :rtype: bool
  103. """
  104. try:
  105. with open(cert_path, "r") as acme_cert_file:
  106. acme_cert_text = acme_cert_file.read()
  107. except IOError:
  108. LOGGER.error("Could not open certificate %s in acme "
  109. "state directory", cert_path)
  110. else:
  111. x509_acme_cert = OpenSSL.crypto.load_certificate(
  112. OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
  113. )
  114. expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
  115. expiry_datetime = parse_asn1_time(expiry_date)
  116. if expiry_datetime < datetime.datetime.utcnow():
  117. LOGGER.warning("Certificate %s is expired and no newer "
  118. "one is available, bailing out!", cert_path)
  119. return False
  120. else:
  121. serial_current_cert = cert.get_serial_number()
  122. serial_acme_cert = x509_acme_cert.get_serial_number()
  123. if serial_current_cert == serial_acme_cert:
  124. LOGGER.debug("Cert %s matches with the one "
  125. "installed, nothing to do.", cert_path)
  126. return False
  127. else:
  128. return True
  129. def create_tlsa_hash(cert):
  130. """
  131. Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
  132. :param cert: certificate to be used
  133. :type cert: OpenSSL.crypto.X509
  134. :return: sha256 has of the public key
  135. :rtype: str
  136. """
  137. pubkey = cert.get_pubkey()
  138. pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
  139. pubkey)
  140. sha256 = hashlib.sha256()
  141. sha256.update(pubkey_der)
  142. hexdigest = sha256.hexdigest()
  143. return hexdigest
  144. def get_tsig_key(named_key_path):
  145. """
  146. Reads the named session key and generates a keyring object for it.
  147. :param named_key_path: Path to the named session key
  148. :type named_key_path: str
  149. :return: keyring, algorithm
  150. :rtype: tuple
  151. """
  152. key_name = None
  153. key_algorithm = None
  154. secret = None
  155. try:
  156. with open(named_key_path, "r") as bind_key:
  157. for line in bind_key:
  158. if "key" in line:
  159. key_name = line.split(" ")[1].strip("\"")
  160. elif "algorithm" in line:
  161. key_algorithm = line.strip().split(" ")[1].strip(";")
  162. elif "secret" in line:
  163. secret = line.strip().split(" ")[1].strip("\"").strip(";")
  164. except IOError:
  165. LOGGER.error("Error while opening the bind session key")
  166. return None, None
  167. else:
  168. if key_name and key_algorithm and secret:
  169. keyring = dns.tsigkeyring.from_text({
  170. key_name: secret
  171. })
  172. return keyring, key_algorithm
  173. else:
  174. return None, None
  175. def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
  176. subdomain="", ttl=300, protocol="tcp"):
  177. """
  178. Updates the tlsa record on the DNS server.
  179. :param zone: Zone of the (sub) domain
  180. :type zone: str
  181. :param tlsa_port: port for the tlsa record
  182. :type tlsa_port: str
  183. :param digest: cryptographic hash of the certificate public key
  184. :type digest: str
  185. :param keyring: keyring object
  186. :type keyring: dict
  187. :param keyalgorithm: algorithm used for the tsig key
  188. :type keyalgorithm: str
  189. :param subdomain: subdomain to create the tlsa record for
  190. :type subdomain: str
  191. :param ttl: TTL to use for the TLSA record
  192. :type ttl: int
  193. :param protocol: protocol for the TLSA record
  194. :type protocol: str
  195. :returns: response of the operation
  196. :rtype: dns.message.Message
  197. """
  198. update = dns.update.Update(zone, keyring=keyring,
  199. keyalgorithm=keyalgorithm)
  200. tlsa_content = "3 1 1 %s" % digest
  201. if subdomain:
  202. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
  203. else:
  204. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
  205. update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
  206. response = dns.query.tcp(update, 'localhost')
  207. return response