diff options
author | Arthur de Jong <arthur@arthurdejong.org> | 2013-08-12 22:42:56 +0200 |
---|---|---|
committer | Arthur de Jong <arthur@arthurdejong.org> | 2013-08-17 12:31:36 +0200 |
commit | d66162ad308878d2f3fb505a05742798283a8854 (patch) | |
tree | a7736f6cc1a1c065f69fb02ee49dfc8054d32886 | |
parent | bfe22cc93563d86c8c35cd068810d7f8ac2dee33 (diff) |
Use retrieve_by, group_by and group_columns in the cache
This removes custom retrieve() functions and Query classes from the
database modules and uses retrieve_sql retrieve_by, group_by and
group_columns to make a custom retrieval query.
In the cache module this completely replaces how the query grouping is
done. The Query class is now only used inside the cache and the
CnAliasedQuery, RowGrouper and related classed have been removed.
-rw-r--r-- | pynslcd/alias.py | 16 | ||||
-rw-r--r-- | pynslcd/cache.py | 130 | ||||
-rw-r--r-- | pynslcd/group.py | 18 | ||||
-rw-r--r-- | pynslcd/host.py | 51 | ||||
-rw-r--r-- | pynslcd/netgroup.py | 15 | ||||
-rw-r--r-- | pynslcd/network.py | 51 | ||||
-rw-r--r-- | pynslcd/protocol.py | 24 | ||||
-rw-r--r-- | pynslcd/rpc.py | 24 | ||||
-rw-r--r-- | pynslcd/service.py | 60 |
9 files changed, 229 insertions, 160 deletions
diff --git a/pynslcd/alias.py b/pynslcd/alias.py index d5ae390..371ac2e 100644 --- a/pynslcd/alias.py +++ b/pynslcd/alias.py @@ -60,11 +60,17 @@ class Cache(cache.Cache): ON `alias_member_cache`.`alias` = `alias_cache`.`cn` ''' - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results, returning the members as a list - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('rfc822MailMember', )): - yield row['cn'], row['rfc822MailMember'] + retrieve_by = dict( + rfc822MailMember=''' + `cn` IN ( + SELECT `a`.`alias` + FROM `alias_member_cache` `a` + WHERE `a`.`rfc822MailMember` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # rfc822MailMember class AliasRequest(common.Request): diff --git a/pynslcd/cache.py b/pynslcd/cache.py index 796bef4..3974fef 100644 --- a/pynslcd/cache.py +++ b/pynslcd/cache.py @@ -30,18 +30,50 @@ import sqlite3 # FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request) +class regroup(object): + + def __init__(self, results, group_by=None, group_column=None): + """Regroup the results in the group column by the key columns.""" + self.group_by = tuple(group_by) + self.group_column = group_column + self.it = iter(results) + self.tgtkey = self.currkey = self.currvalue = object() + + def keyfunc(self, row): + return tuple(row[x] for x in self.group_by) + + def __iter__(self): + return self + + def next(self): + # find a start row + while self.currkey == self.tgtkey: + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + self.tgtkey = self.currkey + # turn the result row into a list of columns + row = list(self.currvalue) + # replace the group column + row[self.group_column] = list(self._grouper(self.tgtkey)) + return row + + def _grouper(self, tgtkey): + """Generate the group columns.""" + while self.currkey == tgtkey: + value = self.currvalue[self.group_column] + if value is not None: + yield value + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + + class Query(object): + """Helper class to build an SQL query for the cache.""" - def __init__(self, query, parameters=None): + def __init__(self, query): self.query = query self.wheres = [] self.parameters = [] - if parameters: - for k, v in parameters.items(): - self.add_where('`%s` = ?' % k, [v]) - - def add_query(self, query): - self.query += ' ' + query def add_where(self, where, parameters): self.wheres.append(where) @@ -51,64 +83,17 @@ class Query(object): query = self.query if self.wheres: query += ' WHERE ' + ' AND '.join(self.wheres) - c = con.cursor() - return c.execute(query, self.parameters) - - -class CnAliasedQuery(Query): - - sql = ''' - SELECT `%(table)s_cache`.*, - `%(table)s_alias_cache`.`cn` AS `alias` - FROM `%(table)s_cache` - LEFT JOIN `%(table)s_alias_cache` - ON `%(table)s_alias_cache`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - cn_join = ''' - LEFT JOIN `%(table)s_alias_cache` `cn_alias` - ON `cn_alias`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - def __init__(self, table, parameters): - args = dict(table=table) - super(CnAliasedQuery, self).__init__(self.sql % args) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join % args) - self.add_where('(`%(table)s_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)' % args, [v, v]) - else: - self.add_where('`%s` = ?' % k, [v]) - - -class RowGrouper(object): - """Pass in query results and group the results by a certain specified - list of columns.""" - - def __init__(self, results, groupby, columns): - self.groupby = groupby - self.columns = columns - self.results = itertools.groupby(results, key=self.keyfunc) - - def __iter__(self): - return self - - def keyfunc(self, row): - return tuple(row[x] for x in self.groupby) - - def next(self): - groupcols, rows = self.results.next() - tmp = dict((x, list()) for x in self.columns) - for row in rows: - for col in self.columns: - if row[col] is not None: - tmp[col].append(row[col]) - result = dict(row) - result.update(tmp) - return result + cursor = con.cursor() + return cursor.execute(query, self.parameters) class Cache(object): + """The description of the cache.""" + + retrieve_sql = None + retrieve_by = dict() + group_by = () + group_columns = () def __init__(self): self.con = _get_connection() @@ -154,12 +139,25 @@ class Cache(object): ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist)) def retrieve(self, parameters): - """Retrieve all items from the cache based on the parameters supplied.""" - query = Query(''' + """Retrieve all items from the cache based on the parameters + supplied.""" + query = Query(self.retrieve_sql or ''' SELECT * FROM %s - ''' % self.tables[0], parameters) - return (list(x)[:-1] for x in query.execute(self.con)) + ''' % self.tables[0]) + if parameters: + for k, v in parameters.items(): + where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k)) + query.add_where(where, where.count('?') * [v]) + # group by + # FIXME: find a nice way to turn group_by and group_columns into names + results = query.execute(self.con) + group_by = list(self.group_by + self.group_columns) + for column in self.group_columns[::-1]: + group_by.pop() + results = regroup(results, group_by, column) + # strip the mtime from the results + return (list(x)[:-1] for x in results) def __enter__(self): return self.con.__enter__(); diff --git a/pynslcd/group.py b/pynslcd/group.py index 2028f1e..10e3423 100644 --- a/pynslcd/group.py +++ b/pynslcd/group.py @@ -99,13 +99,17 @@ class Cache(cache.Cache): ON `group_member_cache`.`group` = `group_cache`.`cn` ''' - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results returning the members as a set - q = itertools.groupby(query.execute(self.con), - key=lambda x: (x['cn'], x['userPassword'], x['gidNumber'])) - for k, v in q: - yield k + (set(x['memberUid'] for x in v if x['memberUid'] is not None), ) + retrieve_by = dict( + memberUid=''' + `cn` IN ( + SELECT `a`.`group` + FROM `group_member_cache` `a` + WHERE `a`.`memberUid` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (3, ) # memberUid class GroupRequest(common.Request): diff --git a/pynslcd/host.py b/pynslcd/host.py index 91c3fa0..04f5337 100644 --- a/pynslcd/host.py +++ b/pynslcd/host.py @@ -34,23 +34,6 @@ class Search(search.LDAPSearch): required = ('cn', ) -class HostQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `host_cache`.`cn` AS `cn`, - `host_alias_cache`.`cn` AS `alias`, - `host_address_cache`.`ipHostNumber` AS `ipHostNumber` - FROM `host_cache` - LEFT JOIN `host_alias_cache` - ON `host_alias_cache`.`host` = `host_cache`.`cn` - LEFT JOIN `host_address_cache` - ON `host_address_cache`.`host` = `host_cache`.`cn` - ''' - - def __init__(self, parameters): - super(HostQuery, self).__init__('host', parameters) - - class Cache(cache.Cache): tables = ('host_cache', 'host_alias_cache', 'host_address_cache') @@ -73,10 +56,36 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `host_address_idx` ON `host_address_cache`(`host`); ''' - def retrieve(self, parameters): - query = HostQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipHostNumber', )): - yield row['cn'], row['alias'], row['ipHostNumber'] + retrieve_sql = ''' + SELECT `host_cache`.`cn` AS `cn`, + `host_alias_cache`.`cn` AS `alias`, + `host_address_cache`.`ipHostNumber` AS `ipHostNumber`, + `host_cache`.`mtime` AS `mtime` + FROM `host_cache` + LEFT JOIN `host_alias_cache` + ON `host_alias_cache`.`host` = `host_cache`.`cn` + LEFT JOIN `host_address_cache` + ON `host_address_cache`.`host` = `host_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `host_cache`.`cn` = ? OR + `host_cache`.`cn` IN ( + SELECT `by_alias`.`host` + FROM `host_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipHostNumber=''' + `host_cache`.`cn` IN ( + SELECT `by_ipHostNumber`.`host` + FROM `host_address_cache` `by_ipHostNumber` + WHERE `by_ipHostNumber`.`ipHostNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipHostNumber class HostRequest(common.Request): diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py index 1de60bf..d86e38c 100644 --- a/pynslcd/netgroup.py +++ b/pynslcd/netgroup.py @@ -63,6 +63,21 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `netgroup_membe_idx` ON `netgroup_member_cache`(`netgroup`); ''' + retrieve_sql = ''' + SELECT `netgroup_cache`.`cn` AS `cn`, + `netgroup_triple_cache`.`nisNetgroupTriple` AS `nisNetgroupTriple`, + `netgroup_member_cache`.`memberNisNetgroup` AS `memberNisNetgroup`, + `netgroup_cache`.`mtime` AS `mtime` + FROM `netgroup_cache` + LEFT JOIN `netgroup_triple_cache` + ON `netgroup_triple_cache`.`netgroup` = `netgroup_cache`.`cn` + LEFT JOIN `netgroup_member_cache` + ON `netgroup_member_cache`.`netgroup` = `netgroup_cache`.`cn` + ''' + + group_by = (0, ) # cn + group_columns = (1, 2) # nisNetgroupTriple, memberNisNetgroup + class NetgroupRequest(common.Request): diff --git a/pynslcd/network.py b/pynslcd/network.py index bf49b4d..01bf6c2 100644 --- a/pynslcd/network.py +++ b/pynslcd/network.py @@ -35,23 +35,6 @@ class Search(search.LDAPSearch): required = ('cn', ) -class NetworkQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `network_cache`.`cn` AS `cn`, - `network_alias_cache`.`cn` AS `alias`, - `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber` - FROM `network_cache` - LEFT JOIN `network_alias_cache` - ON `network_alias_cache`.`network` = `network_cache`.`cn` - LEFT JOIN `network_address_cache` - ON `network_address_cache`.`network` = `network_cache`.`cn` - ''' - - def __init__(self, parameters): - super(NetworkQuery, self).__init__('network', parameters) - - class Cache(cache.Cache): tables = ('network_cache', 'network_alias_cache', 'network_address_cache') @@ -74,10 +57,36 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `network_address_idx` ON `network_address_cache`(`network`); ''' - def retrieve(self, parameters): - query = NetworkQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipNetworkNumber', )): - yield row['cn'], row['alias'], row['ipNetworkNumber'] + retrieve_sql = ''' + SELECT `network_cache`.`cn` AS `cn`, + `network_alias_cache`.`cn` AS `alias`, + `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`, + `network_cache`.`mtime` AS `mtime` + FROM `network_cache` + LEFT JOIN `network_alias_cache` + ON `network_alias_cache`.`network` = `network_cache`.`cn` + LEFT JOIN `network_address_cache` + ON `network_address_cache`.`network` = `network_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `network_cache`.`cn` = ? OR + `network_cache`.`cn` IN ( + SELECT `by_alias`.`network` + FROM `network_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipNetworkNumber=''' + `network_cache`.`cn` IN ( + SELECT `by_ipNetworkNumber`.`network` + FROM `network_address_cache` `by_ipNetworkNumber` + WHERE `by_ipNetworkNumber`.`ipNetworkNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipNetworkNumber class NetworkRequest(common.Request): diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py index 122673d..1472c04 100644 --- a/pynslcd/protocol.py +++ b/pynslcd/protocol.py @@ -52,10 +52,26 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `protocol_alias_idx` ON `protocol_alias_cache`(`protocol`); ''' - def retrieve(self, parameters): - query = cache.CnAliasedQuery('protocol', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['ipProtocolNumber'] + retrieve_sql = ''' + SELECT `protocol_cache`.`cn` AS `cn`, `protocol_alias_cache`.`cn` AS `alias`, + `ipProtocolNumber`, `mtime` + FROM `protocol_cache` + LEFT JOIN `protocol_alias_cache` + ON `protocol_alias_cache`.`protocol` = `protocol_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `protocol_cache`.`cn` = ? OR + `protocol_cache`.`cn` IN ( + SELECT `by_alias`.`protocol` + FROM `protocol_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class ProtocolRequest(common.Request): diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py index 98a0ecc..2a241fd 100644 --- a/pynslcd/rpc.py +++ b/pynslcd/rpc.py @@ -52,10 +52,26 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `rpc_alias_idx` ON `rpc_alias_cache`(`rpc`); ''' - def retrieve(self, parameters): - query = cache.CnAliasedQuery('rpc', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['oncRpcNumber'] + retrieve_sql = ''' + SELECT `rpc_cache`.`cn` AS `cn`, `rpc_alias_cache`.`cn` AS `alias`, + `oncRpcNumber`, `mtime` + FROM `rpc_cache` + LEFT JOIN `rpc_alias_cache` + ON `rpc_alias_cache`.`rpc` = `rpc_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `rpc_cache`.`cn` = ? OR + `rpc_cache`.`cn` IN ( + SELECT `by_alias`.`rpc` + FROM `rpc_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class RpcRequest(common.Request): diff --git a/pynslcd/service.py b/pynslcd/service.py index 6f55cc1..c27f485 100644 --- a/pynslcd/service.py +++ b/pynslcd/service.py @@ -40,33 +40,6 @@ class Search(search.LDAPSearch): required = ('cn', 'ipServicePort', 'ipServiceProtocol') -class ServiceQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `service_cache`.*, - `service_alias_cache`.`cn` AS `alias` - FROM `service_cache` - LEFT JOIN `service_alias_cache` - ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - cn_join = ''' - LEFT JOIN `service_alias_cache` `cn_alias` - ON `cn_alias`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `cn_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - def __init__(self, parameters): - super(ServiceQuery, self).__init__('service', {}) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join) - self.add_where('(`service_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)', [v, v]) - else: - self.add_where('`service_cache`.`%s` = ?' % k, [v]) - - class Cache(cache.Cache): tables = ('service_cache', 'service_alias_cache') @@ -90,6 +63,34 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `service_alias_idx2` ON `service_alias_cache`(`ipServiceProtocol`); ''' + retrieve_sql = ''' + SELECT `service_cache`.`cn` AS `cn`, + `service_alias_cache`.`cn` AS `alias`, + `service_cache`.`ipServicePort`, + `service_cache`.`ipServiceProtocol`, + `mtime` + FROM `service_cache` + LEFT JOIN `service_alias_cache` + ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + ''' + + retrieve_by = dict( + cn=''' + ( `service_cache`.`cn` = ? OR + 0 < ( + SELECT COUNT(*) + FROM `service_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ? + AND `by_alias`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `by_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + )) + ''', + ) + + group_by = (0, 2, 3) # cn, ipServicePort, ipServiceProtocol + group_columns = (1, ) # alias + def store(self, name, aliases, port, protocol): self.con.execute(''' INSERT OR REPLACE INTO `service_cache` @@ -107,11 +108,6 @@ class Cache(cache.Cache): (?, ?, ?) ''', ((port, protocol, alias) for alias in aliases)) - def retrieve(self, parameters): - query = ServiceQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', 'ipServicePort', 'ipServiceProtocol'), ('alias', )): - yield row['cn'], row['alias'], row['ipServicePort'], row['ipServiceProtocol'] - class ServiceRequest(common.Request): |