# cache.py - caching layer for pynslcd # # Copyright (C) 2012-2019 Arthur de Jong # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA # 02110-1301 USA import datetime import os import sqlite3 import sys # TODO: probably create a config table # FIXME: have some way to remove stale entries from the cache if all items # from LDAP are queried (perhas use TTL from all request) class regroup(object): # noqa: N801 (this has an iterator name) """Regroup the results in the group column by the key columns. Get entries from a queryset that has multiple result rows per wanted entry by combining multiple values. E.g. 1, 2, 3 1, 2, 4 1, 2, 5 into 1, 2, [3, 4, 5] """ def __init__(self, results, group_by=None, group_column=None): self.group_by = tuple(group_by) self.group_column = group_column self.it = iter(results) self.tgtkey = self.currkey = self.currvalue = object() def keyfunc(self, row): return tuple(row[x] for x in self.group_by) def __iter__(self): return self def next(self): # find a start row while self.currkey == self.tgtkey: self.currvalue = next(self.it) # Exit on StopIteration self.currkey = self.keyfunc(self.currvalue) self.tgtkey = self.currkey # turn the result row into a list of columns row = list(self.currvalue) # replace the group column row[self.group_column] = list(self._grouper(self.tgtkey)) return row def __next__(self): return self.next() def _grouper(self, tgtkey): """Generate the group columns.""" try: while self.currkey == tgtkey: value = self.currvalue[self.group_column] if value is not None: yield value self.currvalue = next(self.it) self.currkey = self.keyfunc(self.currvalue) except StopIteration: pass class Query(object): """Helper class to build an SQL query for the cache.""" def __init__(self, query): self.query = query self.wheres = [] self.parameters = [] def add_where(self, where, parameters): self.wheres.append(where) self.parameters += parameters def execute(self, con): query = self.query if self.wheres: query += ' WHERE ' + ' AND '.join(self.wheres) cursor = con.cursor() return cursor.execute(query, self.parameters) class Cache(object): """The description of the cache.""" retrieve_sql = None retrieve_by = dict() group_by = () group_columns = () def __init__(self): self.con = _get_connection() self.db = sys.modules[self.__module__].__name__ if not hasattr(self, 'tables'): self.tables = ['%s_cache' % self.db] self.create() def create(self): """Create the needed tables if neccesary.""" self.con.executescript(self.create_sql) def store(self, *values): """Store the values in the cache for the specified table. The order of the values is the order returned by the Reques.convert() function. """ # split the values into simple (flat) values and one-to-many values simple_values = [] multi_values = [] for v in values: if isinstance(v, (list, tuple, set)): multi_values.append(v) else: simple_values.append(v) # insert the simple values simple_values.append(datetime.datetime.now()) args = ', '.join(len(simple_values) * ('?', )) self.con.execute(''' INSERT OR REPLACE INTO %s VALUES (%s) ''' % (self.tables[0], args), simple_values) # insert the one-to-many values for n, vlist in enumerate(multi_values): self.con.execute(''' DELETE FROM %s WHERE `%s` = ? ''' % (self.tables[n + 1], self.db), (values[0], )) self.con.executemany(''' INSERT INTO %s VALUES (?, ?) ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist)) def retrieve(self, parameters): """Retrieve all items from the cache based on the parameters supplied.""" query = Query(self.retrieve_sql or ''' SELECT * FROM %s ''' % self.tables[0]) if parameters: for k, v in parameters.items(): where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k)) query.add_where(where, where.count('?') * [v]) # group by # FIXME: find a nice way to turn group_by and group_columns into names results = query.execute(self.con) group_by = list(self.group_by + self.group_columns) for column in self.group_columns[::-1]: group_by.pop() results = regroup(results, group_by, column) # strip the mtime from the results return (list(x)[:-1] for x in results) def __enter__(self): return self.con.__enter__() def __exit__(self, *args): return self.con.__exit__(*args) # the connection to the sqlite database _connection = None # FIXME: make tread safe (is this needed the way the caches are initialised?) def _get_connection(): global _connection if _connection is None: filename = '/tmp/pynslcd_cache.sqlite' dirname = os.path.dirname(filename) if not os.path.isdir(dirname): os.mkdir(dirname) connection = sqlite3.connect( filename, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) connection.row_factory = sqlite3.Row # initialise connection properties connection.executescript(''' -- store temporary tables in memory PRAGMA temp_store = MEMORY; -- disable sync() on database (corruption on disk failure) PRAGMA synchronous = OFF; -- put journal in memory (corruption if crash during transaction) PRAGMA journal_mode = MEMORY; ''') _connection = connection return _connection