helpers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # This file is part of acme-updater, written by Helmut Pozimski 2016-2017.
  2. #
  3. # stov is free software: you can redistribute it and/or modify
  4. # it under the terms of the GNU General Public License as published by
  5. # the Free Software Foundation, version 2 of the License.
  6. #
  7. # stov is distributed in the hope that it will be useful,
  8. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. # GNU General Public License for more details.
  11. #
  12. # You should have received a copy of the GNU General Public License
  13. # along with stov. If not, see <http://www.gnu.org/licenses/>.
  14. # -*- coding: utf8 -*-
  15. """
  16. collection of helper functions used in other modules of acme-updater.
  17. """
  18. import logging
  19. import datetime
  20. import os
  21. import shutil
  22. import hashlib
  23. import OpenSSL
  24. import dns.tsigkeyring
  25. import dns.update
  26. import dns.query
  27. LOGGER = logging.getLogger("acme-updater")
  28. def parse_apache_vhost(file_obj):
  29. """
  30. Parses a given vhost file and extracts the main domain,
  31. the certificate file, the TLS key file and all domains contained
  32. within the vhost.
  33. :param file_obj: file obj pointing to a vhost to parse
  34. :return: list of tuples with domains and found certificates
  35. :rtype: list
  36. """
  37. vhost_started = False
  38. parsed_info = []
  39. cert_path = ""
  40. key_path = ""
  41. main_domain = ""
  42. domains = set()
  43. for line in file_obj:
  44. if "<VirtualHost" in line:
  45. vhost_started = True
  46. elif "</VirtualHost" in line and vhost_started:
  47. vhost_started = False
  48. if cert_path and key_path and main_domain and domains:
  49. parsed_info.append((main_domain, cert_path, key_path, domains))
  50. LOGGER.debug(
  51. "Found vhost with main domain %s, certificate %s and key "
  52. "file %s", main_domain, cert_path, key_path)
  53. cert_path = ""
  54. key_path = ""
  55. main_domain = ""
  56. elif "ServerName" in line:
  57. main_domain = line.strip().rsplit()[1]
  58. domains.add(line.strip().rsplit()[1])
  59. elif "SSLCertificateFile" in line:
  60. cert_path = line.strip().rsplit()[1]
  61. elif "SSLCertificateKeyFile" in line:
  62. key_path = line.strip().rsplit()[1]
  63. elif "ServerAlias" in line:
  64. for domain in line.strip().rsplit()[1].split(" "):
  65. domains.add(domain)
  66. return parsed_info
  67. def parse_asn1_time(timestamp):
  68. """
  69. parses an ANS1 timestamp as returned by OpenSSL and turns it into a python
  70. datetime object.
  71. :param timestamp: ASN1 timestamp
  72. :type timestamp: str
  73. :return: timestamp as datetime object
  74. :rtype: datetime
  75. """
  76. year = int(timestamp[:4])
  77. month = int(timestamp[4:6])
  78. day = int(timestamp[6:8])
  79. date = datetime.datetime(year, month, day)
  80. return date
  81. def copy_file(source, destination, backup=True):
  82. """
  83. Copies a file from the given source file to
  84. the given destination and optionally creates a copy of the
  85. destination file.
  86. :param source: source file path
  87. :type source: str
  88. :param destination: destination file path
  89. :type destination: str
  90. :param backup: whether to take a backup of the destination file before \
  91. overwriting it
  92. :type backup: bool
  93. :return: success
  94. :rtype: bool
  95. """
  96. if backup:
  97. if not create_backup_copy(destination):
  98. return False
  99. try:
  100. shutil.copy(source, destination)
  101. except IOError:
  102. LOGGER.error("Copying of file %s to %s failed!",
  103. source, destination)
  104. else:
  105. os.chmod(destination, 0o0644)
  106. os.chown(destination, 0, 0)
  107. return True
  108. def create_backup_copy(source):
  109. """
  110. creates a backup file of a specified source file.
  111. :param source: source file path
  112. :type source: str
  113. :return: success
  114. :rtype: bool
  115. """
  116. backup_file = source + ".bak_%s" % datetime.datetime.now().strftime(
  117. "%Y%m%d%H%M%S")
  118. try:
  119. shutil.copy(source, backup_file)
  120. except IOError:
  121. LOGGER.error("Creating of backup file for %s failed!", source)
  122. return False
  123. else:
  124. return True
  125. def check_renewal(cert, cert_path):
  126. """
  127. Checks if the certificate has been renewed.
  128. :param cert: the certificate that needs to be checked
  129. :type cert: OpenSSL.crypto.X509
  130. :param cert_path: absolute path to the certificate file
  131. :type cert_path: str
  132. :return: renewal status
  133. :rtype: bool
  134. """
  135. try:
  136. with open(cert_path, "r") as acme_cert_file:
  137. acme_cert_text = acme_cert_file.read()
  138. except IOError:
  139. LOGGER.error("Could not open certificate %s in acme "
  140. "state directory", cert_path)
  141. else:
  142. x509_acme_cert = OpenSSL.crypto.load_certificate(
  143. OpenSSL.crypto.FILETYPE_PEM, acme_cert_text
  144. )
  145. expiry_date = x509_acme_cert.get_notAfter().decode("utf-8")
  146. expiry_datetime = parse_asn1_time(expiry_date)
  147. if expiry_datetime < datetime.datetime.utcnow():
  148. LOGGER.warning("Certificate %s is expired and no newer "
  149. "one is available, bailing out!", cert_path)
  150. return False
  151. else:
  152. serial_current_cert = cert.get_serial_number()
  153. serial_acme_cert = x509_acme_cert.get_serial_number()
  154. if serial_current_cert == serial_acme_cert:
  155. LOGGER.debug("Cert %s matches with the one "
  156. "installed, nothing to do.", cert_path)
  157. return False
  158. else:
  159. return True
  160. def create_tlsa_hash(cert):
  161. """
  162. Creates an tlsa 3 1 1 hash to create TLSA records for a given certificate
  163. :param cert: certificate to be used
  164. :type cert: OpenSSL.crypto.X509
  165. :return: sha256 has of the public key
  166. :rtype: str
  167. """
  168. pubkey = cert.get_pubkey()
  169. pubkey_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
  170. pubkey)
  171. sha256 = hashlib.sha256()
  172. sha256.update(pubkey_der)
  173. hexdigest = sha256.hexdigest()
  174. return hexdigest
  175. def get_tsig_key(named_key_path):
  176. """
  177. Reads the named session key and generates a keyring object for it.
  178. :param named_key_path: Path to the named session key
  179. :type named_key_path: str
  180. :return: keyring, algorithm
  181. :rtype: tuple
  182. """
  183. key_name = None
  184. key_algorithm = None
  185. secret = None
  186. try:
  187. with open(named_key_path, "r") as bind_key:
  188. for line in bind_key:
  189. if "key" in line:
  190. key_name = line.split(" ")[1].strip("\"")
  191. elif "algorithm" in line:
  192. key_algorithm = line.strip().split(" ")[1].strip(";")
  193. elif "secret" in line:
  194. secret = line.strip().split(" ")[1].strip("\"").strip(";")
  195. except IOError:
  196. LOGGER.error("Error while opening the bind session key")
  197. return None, None
  198. else:
  199. if key_name and key_algorithm and secret:
  200. keyring = dns.tsigkeyring.from_text({
  201. key_name: secret
  202. })
  203. return keyring, key_algorithm
  204. else:
  205. return None, None
  206. def update_tlsa_record(zone, tlsa_port, digest, keyring, keyalgorithm,
  207. subdomain="", ttl=300, protocol="tcp"):
  208. """
  209. Updates the tlsa record on the DNS server.
  210. :param zone: Zone of the (sub) domain
  211. :type zone: str
  212. :param tlsa_port: port for the tlsa record
  213. :type tlsa_port: str
  214. :param digest: cryptographic hash of the certificate public key
  215. :type digest: str
  216. :param keyring: keyring object
  217. :type keyring: dict
  218. :param keyalgorithm: algorithm used for the tsig key
  219. :type keyalgorithm: str
  220. :param subdomain: subdomain to create the tlsa record for
  221. :type subdomain: str
  222. :param ttl: TTL to use for the TLSA record
  223. :type ttl: int
  224. :param protocol: protocol for the TLSA record
  225. :type protocol: str
  226. :returns: response of the operation
  227. :rtype: dns.message.Message
  228. """
  229. update = dns.update.Update(zone, keyring=keyring,
  230. keyalgorithm=keyalgorithm)
  231. tlsa_content = "3 1 1 %s" % digest
  232. if subdomain:
  233. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, subdomain)
  234. else:
  235. tlsa_record = "_%s._%s.%s." % (tlsa_port, protocol, zone)
  236. update.replace(tlsa_record, ttl, "tlsa", tlsa_content)
  237. response = dns.query.tcp(update, 'localhost')
  238. return response
  239. def get_log_level(input_level=""):
  240. """
  241. Determines the log level to use based on a string.
  242. :param input_level: String representing the desired log level.
  243. :type input_level: str
  244. :return: corresponding log level of the logging module
  245. :rtype: int
  246. """
  247. if input_level.lower() == "debug":
  248. return logging.DEBUG
  249. elif input_level.lower() == "error":
  250. return logging.ERROR
  251. else:
  252. return logging.INFO
  253. def create_tlsa_records(domain, port, certificate, named_key_path):
  254. """
  255. Creates tlsa records for the specified (sub-)domain
  256. :param domain: (sub-)domain the records are to be created for
  257. :type domain: str
  258. :param port: port to use for the record
  259. :type port: str
  260. :param certificate: certificate object used for record creation
  261. :type certificate: OpenSSL.crypto.X509
  262. :param named_key_path: path to the named session key
  263. :type named_key_path: str
  264. """
  265. hash_digest = create_tlsa_hash(certificate)
  266. zone = "%s.%s" % (domain.split(".")[-2], domain.split(".")[-1])
  267. tsig, keyalgo = get_tsig_key(named_key_path)
  268. update_tlsa_record(zone, port, hash_digest, tsig, keyalgo, domain)
  269. def get_subject_alt_name(certificate):
  270. """
  271. Extracts the subjectAltName entries from a X509 certficiate
  272. :param certificate: the certificate to extract the subjectAltName \
  273. entries from
  274. :type certificate: OpenSSL.crypto.X509
  275. :return: list of hostnames
  276. :rtype: list
  277. """
  278. list = []
  279. for i in range(0, certificate.get_extension_count(), 1):
  280. if certificate.get_extension(i).get_short_name() == b"subjectAltName":
  281. extension_string = str(certificate.get_extension(i))
  282. for entry in extension_string.split(","):
  283. list.append(entry.split(":")[1])
  284. break
  285. return list