Browse Source

convert the Db class to singleton to avoid passing to all functions that need the object

Helmut Pozimski 6 years ago
parent
commit
508f7e5385
4 changed files with 86 additions and 102 deletions
  1. 7 9
      lib_stov/configuration.py
  2. 16 3
      lib_stov/database.py
  3. 15 17
      lib_stov/main.py
  4. 48 73
      lib_stov/program.py

+ 7 - 9
lib_stov/configuration.py

@@ -62,7 +62,7 @@ class Conf(object):
             "check_title": "no"
         }
 
-        self.__explanations = {
+        self._explanations = {
             "database": _("the name of your database file"),
             "downloaddir": _("the directory where downloaded videos are "
                              "saved"),
@@ -171,9 +171,8 @@ class Conf(object):
                 or int(self.values["config_version"]) < current_version:
             self.values["config_version"] = str(current_version)
             return False
-        else:
-            self.values["config_version"] = current_version
-            return True
+        self.values["config_version"] = current_version
+        return True
 
     def update_config(self):
         """Update the configuration to the latest version"""
@@ -192,9 +191,8 @@ class Conf(object):
                 int(currentdbversion):
             self.values["db_version"] = str(currentdbversion)
             return False
-        else:
-            self.values["db_version"] = str(currentdbversion)
-            return True
+        self.values["db_version"] = str(currentdbversion)
+        return True
 
     def assist(self):
         """ Ask the user to set all required configuration parameters """
@@ -203,8 +201,8 @@ class Conf(object):
                 "configuration of stov. \nThe default value will be "
                 "displayed in brackets.\n"
                 "Please specify now :\n"))
-        for value in self.__explanations:
-            print(self.__explanations[value] + " [" + self.values[value] +
+        for value in self._explanations:
+            print(self._explanations[value] + " [" + self.values[value] +
                   "]:" +
                   " ")
             user_input = input()

+ 16 - 3
lib_stov/database.py

@@ -28,10 +28,12 @@ LOGGER = logging.getLogger("stov")
 
 
 class Db(object):
-    """This class is used to cosntruct the module which will take care of all
+    """This class is used to construct the module which will take care of all
     database related operations like opening the database, reading from and
     writing to it.
     """
+    _instance = None
+
     def __init__(self, path, version):
         """Constructor of the db class, populates the object with the relevant
         attributes, connects to the database and creates it if asked to.
@@ -52,6 +54,11 @@ class Db(object):
 
         self.__connection.close()
 
+    def __new__(cls, *args, **kwargs):
+        if not Db._instance:
+            Db._instance = super(Db, cls).__new__(cls)
+        return Db._instance
+
     def _execute_statement(self, statement, argument=None):
         """Executes a statement, works as a wrapper around cursor execute."""
 
@@ -70,6 +77,13 @@ class Db(object):
             self.__connection.commit()
             return result
 
+    @staticmethod
+    def get_instance():
+        """ Return the singleton instance of Db"""
+        if Db._instance:
+            return Db._instance
+        return None
+
     def populate(self):
         """Populates the database with the initial structure."""
 
@@ -206,8 +220,7 @@ class Db(object):
         result = cursor.fetchall()
         if not result:
             return False
-        else:
-            return True
+        return True
 
     def insert_video(self, video, subscription_id):
         """Inserts a video with the given data into the database"""

+ 15 - 17
lib_stov/main.py

@@ -34,40 +34,38 @@ def main():
     helpers.create_lock()
     conf = helpers.setup_configuration(arguments)
     logger = logging.getLogger("stov")
-    database = helpers.setup_database(conf)
+    helpers.setup_database(conf)
     helpers.find_youtubedl(conf)
-    program.initialize_sites(database)
+    program.initialize_sites()
     if arguments.add:
         if arguments.site:
-            program.add_subscription(conf, database,
-                                     arguments.channel,
+            program.add_subscription(conf, arguments.channel,
                                      arguments.searchparameter,
-                                     arguments.playlist,
-                                     arguments.site)
+                                     arguments.playlist, arguments.site)
         else:
-            program.add_subscription(conf, database, arguments.channel,
+            program.add_subscription(conf, arguments.channel,
                                      arguments.searchparameter,
                                      arguments.playlist)
     elif arguments.lssites:
-        program.list_sites(database)
+        program.list_sites()
     elif arguments.list:
-        program.list_subscriptions(conf, database)
+        program.list_subscriptions(conf)
     elif arguments.deleteid:
-        program.delete_subscription(database, arguments.deleteid)
+        program.delete_subscription(arguments.deleteid)
     elif arguments.update is not None:
-        program.update_subscriptions(database, conf, arguments.update)
+        program.update_subscriptions(conf, arguments.update)
     elif arguments.download is not None:
-        program.download_notify(database, conf, arguments.download)
+        program.download_notify(conf, arguments.download)
     elif arguments.subscriptionid:
-        program.list_videos(database, conf, arguments.subscriptionid)
+        program.list_videos(conf, arguments.subscriptionid)
     elif arguments.catchup:
-        program.catchup(database, arguments.catchup)
+        program.catchup(arguments.catchup)
     elif arguments.cleanup:
-        program.clean_database(database, conf)
+        program.clean_database(conf)
     elif arguments.enableid:
-        program.change_subscription_state(database, arguments.enableid, True)
+        program.change_subscription_state(arguments.enableid, True)
     elif arguments.disableid:
-        program.change_subscription_state(database, arguments.disableid, False)
+        program.change_subscription_state(arguments.disableid, False)
     elif arguments.license:
         program.print_license()
     elif arguments.version:

+ 48 - 73
lib_stov/program.py

@@ -29,19 +29,18 @@ from email.mime.multipart import MIMEMultipart
 
 from lib_stov import subscription
 from lib_stov import stov_exceptions
+from lib_stov.database import Db
 
 LOGGER = logging.getLogger("stov")
+DATABASE = Db.get_instance()
 
 
-def add_subscription(conf, database, channel="",
-                     search="", playlist="", site="youtube"):
+def add_subscription(conf, channel="", search="", playlist="", site="youtube"):
     """
     Takes care of adding a new subscription to the database.
 
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param site: site the subscription is about to be created for
     :type site: str
     :param channel: optional channel name
@@ -85,12 +84,12 @@ def add_subscription(conf, database, channel="",
         sys.exit(1)
     try:
         subscription_data = new_subscription.add_sub()
-        site_id = database.get_site_id(subscription_data[6])
+        site_id = DATABASE.get_site_id(subscription_data[6])
         new_sub_data = (subscription_data[0], subscription_data[1],
                         subscription_data[2], subscription_data[3],
                         subscription_data[4], subscription_data[5],
                         site_id)
-        subscription_id = database.insert_subscription(new_sub_data)
+        subscription_id = DATABASE.insert_subscription(new_sub_data)
         new_subscription.set_id(subscription_id)
     except stov_exceptions.DBWriteAccessFailedException as error:
         LOGGER.error(error)
@@ -107,10 +106,10 @@ def add_subscription(conf, database, channel="",
     except stov_exceptions.NoDataFromYoutubeAPIException as error:
         LOGGER.error(error)
     for video in new_subscription.parsed_response.videos:
-        if not database.video_in_database(video.video_id):
+        if not DATABASE.video_in_database(video.video_id):
             if new_subscription.check_string_match(video):
                 try:
-                    database.insert_video(video, new_subscription.get_id())
+                    DATABASE.insert_video(video, new_subscription.get_id())
                 except stov_exceptions.DBWriteAccessFailedException as error:
                     LOGGER.error(error)
                     sys.exit(1)
@@ -121,16 +120,14 @@ def add_subscription(conf, database, channel="",
                 _(" successfully added"))
 
 
-def list_subscriptions(conf, database):
+def list_subscriptions(conf):
     """
     Prints a list of subscriptions from the database.
 
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
-    :param database: database object
-    :type database: lib_stov.database.Db
     """
-    subscriptions_list = database.get_subscriptions(conf)
+    subscriptions_list = DATABASE.get_subscriptions(conf)
     sub_state = None
     if subscriptions_list:
         LOGGER.info(_("ID Title Site"))
@@ -145,17 +142,15 @@ def list_subscriptions(conf, database):
         LOGGER.info(_("No subscriptions added yet, add one!"))
 
 
-def delete_subscription(database, sub_id):
+def delete_subscription(sub_id):
     """
     Deletes a specified subscription from the database
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param sub_id: ID of the subscription to be deleted
     :type sub_id: int
     """
     try:
-        database.delete_subscription(sub_id)
+        DATABASE.delete_subscription(sub_id)
     except stov_exceptions.SubscriptionNotFoundException as error:
         LOGGER.error(error)
         sys.exit(1)
@@ -166,21 +161,19 @@ def delete_subscription(database, sub_id):
         LOGGER.info(_("Subscription deleted successfully!"))
 
 
-def update_subscriptions(database, conf, subscriptions=None):
+def update_subscriptions(conf, subscriptions=None):
     """
     Updates data about videos in a subscription.
 
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param subscriptions: list of subscriptions to update
     :type subscriptions: list
     """
-    subscriptions_list = get_subscriptions(conf, database, subscriptions)
+    subscriptions_list = get_subscriptions(conf, subscriptions)
     for element in subscriptions_list:
         LOGGER.debug(_("Updating subscription %s"), element.get_title())
-        videos = database.get_videos(element.get_id(), conf)
+        videos = DATABASE.get_videos(element.get_id(), conf)
         element.gather_videos(videos)
         try:
             element.update_data()
@@ -189,10 +182,10 @@ def update_subscriptions(database, conf, subscriptions=None):
         except stov_exceptions.NoDataFromYoutubeAPIException as error:
             LOGGER.error(error)
         for video in element.parsed_response.videos:
-            if not database.video_in_database(video.video_id):
+            if not DATABASE.video_in_database(video.video_id):
                 if element.check_string_match(video):
                     try:
-                        database.insert_video(video, element.get_id())
+                        DATABASE.insert_video(video, element.get_id())
                     except stov_exceptions.DBWriteAccessFailedException as \
                             error:
                         LOGGER.error(error)
@@ -202,14 +195,12 @@ def update_subscriptions(database, conf, subscriptions=None):
                                        "database."), video.title)
 
 
-def download_videos(database, conf, subscriptions=None):
+def download_videos(conf, subscriptions=None):
     """
     Downloads videos that haven't been previously downloaded.
 
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param subscriptions: list of subscriptions to consider for downloading
     :type subscriptions: list
     :return: tuple containing (in that order) downloaded videos, failed \
@@ -217,26 +208,26 @@ def download_videos(database, conf, subscriptions=None):
     :rtype: tuple
     """
     video_titles = []
-    subscriptions_list = get_subscriptions(conf, database, subscriptions)
+    subscriptions_list = get_subscriptions(conf, subscriptions)
     videos_downloaded = 0
     videos_failed = 0
     for sub in subscriptions_list:
-        videos = database.get_videos(sub.get_id(), conf)
+        videos = DATABASE.get_videos(sub.get_id(), conf)
         sub.gather_videos(videos)
         try:
             sub.download_videos()
         except stov_exceptions.SubscriptionDisabledException as error:
             LOGGER.debug(error)
         for entry in sub.downloaded_videos:
-            database.update_video_download_status(entry.get_id(), 1)
+            DATABASE.update_video_download_status(entry.get_id(), 1)
             video_titles.append(entry.title)
         videos_downloaded = len(video_titles)
         videos_failed = videos_failed + sub.failed_videos_count
         for video in sub.failed_videos:
             try:
-                database.update_video_fail_count(video.failcnt, video.get_id())
+                DATABASE.update_video_fail_count(video.failcnt, video.get_id())
                 if video.failcnt >= int(conf.values["maxfails"]):
-                    database.disable_failed_video(video.get_id())
+                    DATABASE.disable_failed_video(video.get_id())
             except stov_exceptions.DBWriteAccessFailedException as error:
                 LOGGER.error(error)
                 sys.exit(1)
@@ -320,19 +311,17 @@ def send_email(conf, msg):
         server_connection.quit()
 
 
-def list_videos(database, conf, sub_id):
+def list_videos(conf, sub_id):
     """
     Lists all videos in a specified subscription
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
     :param sub_id: ID of the subscription
     :type sub_id: int
     """
     try:
-        data = database.get_subscription(sub_id)
+        data = DATABASE.get_subscription(sub_id)
     except stov_exceptions.DBWriteAccessFailedException as error:
         LOGGER.error(error)
         sys.exit(1)
@@ -346,7 +335,7 @@ def list_videos(database, conf, sub_id):
                                    directory=data[0][5],
                                    disabled=data[0][6],
                                    site=data[0][7], conf=conf)
-            videos = database.get_videos(sub.get_id(), conf)
+            videos = DATABASE.get_videos(sub.get_id(), conf)
             sub.gather_videos(videos)
             videos_list = sub.print_videos()
             for video in videos_list:
@@ -356,24 +345,22 @@ def list_videos(database, conf, sub_id):
                            "try again."))
 
 
-def catchup(database, sub_id):
+def catchup(sub_id):
     """
     Marks all videos in a subscription as downloaded
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param sub_id: ID of the subscription
     :type sub_id: int
     """
     try:
-        sub_data = database.get_subscription_title(sub_id)
+        sub_data = DATABASE.get_subscription_title(sub_id)
     except stov_exceptions.DBWriteAccessFailedException as error:
         LOGGER.error(error)
         sys.exit(1)
     else:
         if sub_data:
             try:
-                database.mark_video_downloaded(sub_id)
+                DATABASE.mark_video_downloaded(sub_id)
             except stov_exceptions.DBWriteAccessFailedException as error:
                 LOGGER.error(error)
         else:
@@ -381,47 +368,43 @@ def catchup(database, sub_id):
                            "please check if the ID given is correct."))
 
 
-def clean_database(database, conf):
+def clean_database(conf):
     """
     Initiates a database cleanup, deleting all videos that are no longer
     in the scope of the query and vacuuming the database to free up space.
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
     """
-    subscription_list = database.get_subscriptions(conf)
+    subscription_list = DATABASE.get_subscriptions(conf)
     for element in subscription_list:
-        videos = database.get_videos(element.get_id(), conf)
+        videos = DATABASE.get_videos(element.get_id(), conf)
         element.check_and_delete(videos)
         for delete_video in element.to_delete:
             LOGGER.debug(_("Deleting video %s from "
                            "database"), delete_video.title)
             try:
-                database.delete_video(delete_video.get_id())
+                DATABASE.delete_video(delete_video.get_id())
             except stov_exceptions.DBWriteAccessFailedException as error:
                 LOGGER.error(error)
                 sys.exit(1)
     try:
-        database.vacuum()
+        DATABASE.vacuum()
     except stov_exceptions.DBWriteAccessFailedException as error:
         LOGGER.error(error)
         sys.exit(1)
 
 
-def change_subscription_state(database, sub_id, enable=False):
+def change_subscription_state(sub_id, enable=False):
     """
     Enables or disables a subscription.
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param sub_id: ID of the subscription
     :type sub_id: int
     :param enable: whether to enable or disable the subscription
     :type enable: bool
     """
-    subscription_state = database.get_subscription(sub_id)
+    subscription_state = DATABASE.get_subscription(sub_id)
     try:
         if enable:
             if int(subscription_state[0][6]) == 0:
@@ -429,7 +412,7 @@ def change_subscription_state(database, sub_id, enable=False):
                              sub_id)
             elif int(subscription_state[0][6]) == 1:
                 try:
-                    database.change_subscription_state(sub_id, 0)
+                    DATABASE.change_subscription_state(sub_id, 0)
                 except stov_exceptions.DBWriteAccessFailedException as error:
                     LOGGER.error(error)
                     sys.exit(1)
@@ -441,7 +424,7 @@ def change_subscription_state(database, sub_id, enable=False):
                              sub_id)
             elif int(subscription_state[0][6]) == 0:
                 try:
-                    database.change_subscription_state(sub_id, 1)
+                    DATABASE.change_subscription_state(sub_id, 1)
                 except stov_exceptions.DBWriteAccessFailedException as error:
                     LOGGER.error(error)
                     sys.exit(1)
@@ -471,19 +454,17 @@ def print_license():
         along with stov.  If not, see <http://www.gnu.org/licenses/>.""")
 
 
-def download_notify(database, conf, subscriptions=None):
+def download_notify(conf, subscriptions=None):
     """
     starts an update of not yet downloaded videos and notifies the user
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
     :param subscriptions: list of subscriptions to consider for downloading
     :type subscriptions: list
     """
     videos_downloaded, videos_failed, video_titles = \
-        download_videos(database, conf, subscriptions)
+        download_videos(conf, subscriptions)
     if videos_downloaded > 0 and conf.values["notify"] == "yes":
         msg = compose_email(conf, videos_downloaded, video_titles)
         send_email(conf, msg)
@@ -502,49 +483,43 @@ def download_notify(database, conf, subscriptions=None):
                            "parameter in your configuration."))
 
 
-def initialize_sites(database):
+def initialize_sites():
     """
     Adds sites to the database if they are not in there yet.
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     """
     supported_sites = ["youtube", "zdf_mediathek", "twitch"]
-    sites = database.get_sites()
+    sites = DATABASE.get_sites()
     for site in supported_sites:
         site_found = False
         for result in sites:
             if site in result:
                 site_found = True
         if not site_found:
-            database.add_site(site)
+            DATABASE.add_site(site)
     for site in sites:
         if site[1] not in supported_sites:
-            database.remove_site(site[1])
+            DATABASE.remove_site(site[1])
 
 
-def list_sites(database):
+def list_sites():
     """
     Lists the currently supported sites.
 
-    :param database: database object
-    :type database: lib_stov.database.Db
     """
-    sites = database.get_sites()
+    sites = DATABASE.get_sites()
     LOGGER.info(_("Sites currently supported by stov:"))
     for entry in sites:
         LOGGER.info(entry[1])
 
 
-def get_subscriptions(conf, database, subscriptions=None):
+def get_subscriptions(conf, subscriptions=None):
     """
     Retrieves all or only specific subscriptions from the database and
     returns them as a list of subscription objects.
 
     :param conf: configuration object
     :type conf: lib_stov.configuration.Conf
-    :param database: database object
-    :type database: lib_stov.database.Db
     :param subscriptions: list of subscriptions to retrieve
     :type subscriptions: list
     :return: list of subscription objects
@@ -553,7 +528,7 @@ def get_subscriptions(conf, database, subscriptions=None):
     if subscriptions:
         subscriptions_list = []
         for element in subscriptions:
-            data = database.get_subscription(element)
+            data = DATABASE.get_subscription(element)
             if data:
                 sub = subscription.Sub(subscription_id=data[0][0],
                                        title=data[0][1],
@@ -569,5 +544,5 @@ def get_subscriptions(conf, database, subscriptions=None):
                     _("Invalid subscription, please check the list and "
                       "try again."))
     else:
-        subscriptions_list = database.get_subscriptions(conf)
+        subscriptions_list = DATABASE.get_subscriptions(conf)
     return subscriptions_list