diff options
-rw-r--r-- | pynslcd/alias.py | 16 | ||||
-rw-r--r-- | pynslcd/cache.py | 130 | ||||
-rw-r--r-- | pynslcd/group.py | 18 | ||||
-rw-r--r-- | pynslcd/host.py | 51 | ||||
-rw-r--r-- | pynslcd/netgroup.py | 15 | ||||
-rw-r--r-- | pynslcd/network.py | 51 | ||||
-rw-r--r-- | pynslcd/protocol.py | 24 | ||||
-rw-r--r-- | pynslcd/rpc.py | 24 | ||||
-rw-r--r-- | pynslcd/service.py | 60 |
9 files changed, 229 insertions, 160 deletions
diff --git a/pynslcd/alias.py b/pynslcd/alias.py index d5ae390..371ac2e 100644 --- a/pynslcd/alias.py +++ b/pynslcd/alias.py @@ -60,11 +60,17 @@ class Cache(cache.Cache): ON `alias_member_cache`.`alias` = `alias_cache`.`cn` ''' - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results, returning the members as a list - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('rfc822MailMember', )): - yield row['cn'], row['rfc822MailMember'] + retrieve_by = dict( + rfc822MailMember=''' + `cn` IN ( + SELECT `a`.`alias` + FROM `alias_member_cache` `a` + WHERE `a`.`rfc822MailMember` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # rfc822MailMember class AliasRequest(common.Request): diff --git a/pynslcd/cache.py b/pynslcd/cache.py index 796bef4..3974fef 100644 --- a/pynslcd/cache.py +++ b/pynslcd/cache.py @@ -30,18 +30,50 @@ import sqlite3 # FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request) +class regroup(object): + + def __init__(self, results, group_by=None, group_column=None): + """Regroup the results in the group column by the key columns.""" + self.group_by = tuple(group_by) + self.group_column = group_column + self.it = iter(results) + self.tgtkey = self.currkey = self.currvalue = object() + + def keyfunc(self, row): + return tuple(row[x] for x in self.group_by) + + def __iter__(self): + return self + + def next(self): + # find a start row + while self.currkey == self.tgtkey: + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + self.tgtkey = self.currkey + # turn the result row into a list of columns + row = list(self.currvalue) + # replace the group column + row[self.group_column] = list(self._grouper(self.tgtkey)) + return row + + def _grouper(self, tgtkey): + """Generate the group columns.""" + while self.currkey == tgtkey: + value = self.currvalue[self.group_column] + if value is not None: + yield value + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + + class Query(object): + """Helper class to build an SQL query for the cache.""" - def __init__(self, query, parameters=None): + def __init__(self, query): self.query = query self.wheres = [] self.parameters = [] - if parameters: - for k, v in parameters.items(): - self.add_where('`%s` = ?' % k, [v]) - - def add_query(self, query): - self.query += ' ' + query def add_where(self, where, parameters): self.wheres.append(where) @@ -51,64 +83,17 @@ class Query(object): query = self.query if self.wheres: query += ' WHERE ' + ' AND '.join(self.wheres) - c = con.cursor() - return c.execute(query, self.parameters) - - -class CnAliasedQuery(Query): - - sql = ''' - SELECT `%(table)s_cache`.*, - `%(table)s_alias_cache`.`cn` AS `alias` - FROM `%(table)s_cache` - LEFT JOIN `%(table)s_alias_cache` - ON `%(table)s_alias_cache`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - cn_join = ''' - LEFT JOIN `%(table)s_alias_cache` `cn_alias` - ON `cn_alias`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - def __init__(self, table, parameters): - args = dict(table=table) - super(CnAliasedQuery, self).__init__(self.sql % args) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join % args) - self.add_where('(`%(table)s_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)' % args, [v, v]) - else: - self.add_where('`%s` = ?' % k, [v]) - - -class RowGrouper(object): - """Pass in query results and group the results by a certain specified - list of columns.""" - - def __init__(self, results, groupby, columns): - self.groupby = groupby - self.columns = columns - self.results = itertools.groupby(results, key=self.keyfunc) - - def __iter__(self): - return self - - def keyfunc(self, row): - return tuple(row[x] for x in self.groupby) - - def next(self): - groupcols, rows = self.results.next() - tmp = dict((x, list()) for x in self.columns) - for row in rows: - for col in self.columns: - if row[col] is not None: - tmp[col].append(row[col]) - result = dict(row) - result.update(tmp) - return result + cursor = con.cursor() + return cursor.execute(query, self.parameters) class Cache(object): + """The description of the cache.""" + + retrieve_sql = None + retrieve_by = dict() + group_by = () + group_columns = () def __init__(self): self.con = _get_connection() @@ -154,12 +139,25 @@ class Cache(object): ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist)) def retrieve(self, parameters): - """Retrieve all items from the cache based on the parameters supplied.""" - query = Query(''' + """Retrieve all items from the cache based on the parameters + supplied.""" + query = Query(self.retrieve_sql or ''' SELECT * FROM %s - ''' % self.tables[0], parameters) - return (list(x)[:-1] for x in query.execute(self.con)) + ''' % self.tables[0]) + if parameters: + for k, v in parameters.items(): + where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k)) + query.add_where(where, where.count('?') * [v]) + # group by + # FIXME: find a nice way to turn group_by and group_columns into names + results = query.execute(self.con) + group_by = list(self.group_by + self.group_columns) + for column in self.group_columns[::-1]: + group_by.pop() + results = regroup(results, group_by, column) + # strip the mtime from the results + return (list(x)[:-1] for x in results) def __enter__(self): return self.con.__enter__(); diff --git a/pynslcd/group.py b/pynslcd/group.py index 2028f1e..10e3423 100644 --- a/pynslcd/group.py +++ b/pynslcd/group.py @@ -99,13 +99,17 @@ class Cache(cache.Cache): ON `group_member_cache`.`group` = `group_cache`.`cn` ''' - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results returning the members as a set - q = itertools.groupby(query.execute(self.con), - key=lambda x: (x['cn'], x['userPassword'], x['gidNumber'])) - for k, v in q: - yield k + (set(x['memberUid'] for x in v if x['memberUid'] is not None), ) + retrieve_by = dict( + memberUid=''' + `cn` IN ( + SELECT `a`.`group` + FROM `group_member_cache` `a` + WHERE `a`.`memberUid` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (3, ) # memberUid class GroupRequest(common.Request): diff --git a/pynslcd/host.py b/pynslcd/host.py index 91c3fa0..04f5337 100644 --- a/pynslcd/host.py +++ b/pynslcd/host.py @@ -34,23 +34,6 @@ class Search(search.LDAPSearch): required = ('cn', ) -class HostQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `host_cache`.`cn` AS `cn`, - `host_alias_cache`.`cn` AS `alias`, - `host_address_cache`.`ipHostNumber` AS `ipHostNumber` - FROM `host_cache` - LEFT JOIN `host_alias_cache` - ON `host_alias_cache`.`host` = `host_cache`.`cn` - LEFT JOIN `host_address_cache` - ON `host_address_cache`.`host` = `host_cache`.`cn` - ''' - - def __init__(self, parameters): - super(HostQuery, self).__init__('host', parameters) - - class Cache(cache.Cache): tables = ('host_cache', 'host_alias_cache', 'host_address_cache') @@ -73,10 +56,36 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `host_address_idx` ON `host_address_cache`(`host`); ''' - def retrieve(self, parameters): - query = HostQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipHostNumber', )): - yield row['cn'], row['alias'], row['ipHostNumber'] + retrieve_sql = ''' + SELECT `host_cache`.`cn` AS `cn`, + `host_alias_cache`.`cn` AS `alias`, + `host_address_cache`.`ipHostNumber` AS `ipHostNumber`, + `host_cache`.`mtime` AS `mtime` + FROM `host_cache` + LEFT JOIN `host_alias_cache` + ON `host_alias_cache`.`host` = `host_cache`.`cn` + LEFT JOIN `host_address_cache` + ON `host_address_cache`.`host` = `host_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `host_cache`.`cn` = ? OR + `host_cache`.`cn` IN ( + SELECT `by_alias`.`host` + FROM `host_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipHostNumber=''' + `host_cache`.`cn` IN ( + SELECT `by_ipHostNumber`.`host` + FROM `host_address_cache` `by_ipHostNumber` + WHERE `by_ipHostNumber`.`ipHostNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipHostNumber class HostRequest(common.Request): diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py index 1de60bf..d86e38c 100644 --- a/pynslcd/netgroup.py +++ b/pynslcd/netgroup.py @@ -63,6 +63,21 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `netgroup_membe_idx` ON `netgroup_member_cache`(`netgroup`); ''' + retrieve_sql = ''' + SELECT `netgroup_cache`.`cn` AS `cn`, + `netgroup_triple_cache`.`nisNetgroupTriple` AS `nisNetgroupTriple`, + `netgroup_member_cache`.`memberNisNetgroup` AS `memberNisNetgroup`, + `netgroup_cache`.`mtime` AS `mtime` + FROM `netgroup_cache` + LEFT JOIN `netgroup_triple_cache` + ON `netgroup_triple_cache`.`netgroup` = `netgroup_cache`.`cn` + LEFT JOIN `netgroup_member_cache` + ON `netgroup_member_cache`.`netgroup` = `netgroup_cache`.`cn` + ''' + + group_by = (0, ) # cn + group_columns = (1, 2) # nisNetgroupTriple, memberNisNetgroup + class NetgroupRequest(common.Request): diff --git a/pynslcd/network.py b/pynslcd/network.py index bf49b4d..01bf6c2 100644 --- a/pynslcd/network.py +++ b/pynslcd/network.py @@ -35,23 +35,6 @@ class Search(search.LDAPSearch): required = ('cn', ) -class NetworkQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `network_cache`.`cn` AS `cn`, - `network_alias_cache`.`cn` AS `alias`, - `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber` - FROM `network_cache` - LEFT JOIN `network_alias_cache` - ON `network_alias_cache`.`network` = `network_cache`.`cn` - LEFT JOIN `network_address_cache` - ON `network_address_cache`.`network` = `network_cache`.`cn` - ''' - - def __init__(self, parameters): - super(NetworkQuery, self).__init__('network', parameters) - - class Cache(cache.Cache): tables = ('network_cache', 'network_alias_cache', 'network_address_cache') @@ -74,10 +57,36 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `network_address_idx` ON `network_address_cache`(`network`); ''' - def retrieve(self, parameters): - query = NetworkQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipNetworkNumber', )): - yield row['cn'], row['alias'], row['ipNetworkNumber'] + retrieve_sql = ''' + SELECT `network_cache`.`cn` AS `cn`, + `network_alias_cache`.`cn` AS `alias`, + `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`, + `network_cache`.`mtime` AS `mtime` + FROM `network_cache` + LEFT JOIN `network_alias_cache` + ON `network_alias_cache`.`network` = `network_cache`.`cn` + LEFT JOIN `network_address_cache` + ON `network_address_cache`.`network` = `network_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `network_cache`.`cn` = ? OR + `network_cache`.`cn` IN ( + SELECT `by_alias`.`network` + FROM `network_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipNetworkNumber=''' + `network_cache`.`cn` IN ( + SELECT `by_ipNetworkNumber`.`network` + FROM `network_address_cache` `by_ipNetworkNumber` + WHERE `by_ipNetworkNumber`.`ipNetworkNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipNetworkNumber class NetworkRequest(common.Request): diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py index 122673d..1472c04 100644 --- a/pynslcd/protocol.py +++ b/pynslcd/protocol.py @@ -52,10 +52,26 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `protocol_alias_idx` ON `protocol_alias_cache`(`protocol`); ''' - def retrieve(self, parameters): - query = cache.CnAliasedQuery('protocol', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['ipProtocolNumber'] + retrieve_sql = ''' + SELECT `protocol_cache`.`cn` AS `cn`, `protocol_alias_cache`.`cn` AS `alias`, + `ipProtocolNumber`, `mtime` + FROM `protocol_cache` + LEFT JOIN `protocol_alias_cache` + ON `protocol_alias_cache`.`protocol` = `protocol_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `protocol_cache`.`cn` = ? OR + `protocol_cache`.`cn` IN ( + SELECT `by_alias`.`protocol` + FROM `protocol_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class ProtocolRequest(common.Request): diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py index 98a0ecc..2a241fd 100644 --- a/pynslcd/rpc.py +++ b/pynslcd/rpc.py @@ -52,10 +52,26 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `rpc_alias_idx` ON `rpc_alias_cache`(`rpc`); ''' - def retrieve(self, parameters): - query = cache.CnAliasedQuery('rpc', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['oncRpcNumber'] + retrieve_sql = ''' + SELECT `rpc_cache`.`cn` AS `cn`, `rpc_alias_cache`.`cn` AS `alias`, + `oncRpcNumber`, `mtime` + FROM `rpc_cache` + LEFT JOIN `rpc_alias_cache` + ON `rpc_alias_cache`.`rpc` = `rpc_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `rpc_cache`.`cn` = ? OR + `rpc_cache`.`cn` IN ( + SELECT `by_alias`.`rpc` + FROM `rpc_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class RpcRequest(common.Request): diff --git a/pynslcd/service.py b/pynslcd/service.py index 6f55cc1..c27f485 100644 --- a/pynslcd/service.py +++ b/pynslcd/service.py @@ -40,33 +40,6 @@ class Search(search.LDAPSearch): required = ('cn', 'ipServicePort', 'ipServiceProtocol') -class ServiceQuery(cache.CnAliasedQuery): - - sql = ''' - SELECT `service_cache`.*, - `service_alias_cache`.`cn` AS `alias` - FROM `service_cache` - LEFT JOIN `service_alias_cache` - ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - cn_join = ''' - LEFT JOIN `service_alias_cache` `cn_alias` - ON `cn_alias`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `cn_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - def __init__(self, parameters): - super(ServiceQuery, self).__init__('service', {}) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join) - self.add_where('(`service_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)', [v, v]) - else: - self.add_where('`service_cache`.`%s` = ?' % k, [v]) - - class Cache(cache.Cache): tables = ('service_cache', 'service_alias_cache') @@ -90,6 +63,34 @@ class Cache(cache.Cache): CREATE INDEX IF NOT EXISTS `service_alias_idx2` ON `service_alias_cache`(`ipServiceProtocol`); ''' + retrieve_sql = ''' + SELECT `service_cache`.`cn` AS `cn`, + `service_alias_cache`.`cn` AS `alias`, + `service_cache`.`ipServicePort`, + `service_cache`.`ipServiceProtocol`, + `mtime` + FROM `service_cache` + LEFT JOIN `service_alias_cache` + ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + ''' + + retrieve_by = dict( + cn=''' + ( `service_cache`.`cn` = ? OR + 0 < ( + SELECT COUNT(*) + FROM `service_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ? + AND `by_alias`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `by_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + )) + ''', + ) + + group_by = (0, 2, 3) # cn, ipServicePort, ipServiceProtocol + group_columns = (1, ) # alias + def store(self, name, aliases, port, protocol): self.con.execute(''' INSERT OR REPLACE INTO `service_cache` @@ -107,11 +108,6 @@ class Cache(cache.Cache): (?, ?, ?) ''', ((port, protocol, alias) for alias in aliases)) - def retrieve(self, parameters): - query = ServiceQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', 'ipServicePort', 'ipServiceProtocol'), ('alias', )): - yield row['cn'], row['alias'], row['ipServicePort'], row['ipServiceProtocol'] - class ServiceRequest(common.Request): |