Arthur de Jong

Open Source / Free Software developer

summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pynslcd/alias.py16
-rw-r--r--pynslcd/cache.py130
-rw-r--r--pynslcd/group.py18
-rw-r--r--pynslcd/host.py51
-rw-r--r--pynslcd/netgroup.py15
-rw-r--r--pynslcd/network.py51
-rw-r--r--pynslcd/protocol.py24
-rw-r--r--pynslcd/rpc.py24
-rw-r--r--pynslcd/service.py60
9 files changed, 229 insertions, 160 deletions
diff --git a/pynslcd/alias.py b/pynslcd/alias.py
index d5ae390..371ac2e 100644
--- a/pynslcd/alias.py
+++ b/pynslcd/alias.py
@@ -60,11 +60,17 @@ class Cache(cache.Cache):
ON `alias_member_cache`.`alias` = `alias_cache`.`cn`
'''
- def retrieve(self, parameters):
- query = cache.Query(self.retrieve_sql, parameters)
- # return results, returning the members as a list
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('rfc822MailMember', )):
- yield row['cn'], row['rfc822MailMember']
+ retrieve_by = dict(
+ rfc822MailMember='''
+ `cn` IN (
+ SELECT `a`.`alias`
+ FROM `alias_member_cache` `a`
+ WHERE `a`.`rfc822MailMember` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # rfc822MailMember
class AliasRequest(common.Request):
diff --git a/pynslcd/cache.py b/pynslcd/cache.py
index 796bef4..3974fef 100644
--- a/pynslcd/cache.py
+++ b/pynslcd/cache.py
@@ -30,18 +30,50 @@ import sqlite3
# FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request)
+class regroup(object):
+
+ def __init__(self, results, group_by=None, group_column=None):
+ """Regroup the results in the group column by the key columns."""
+ self.group_by = tuple(group_by)
+ self.group_column = group_column
+ self.it = iter(results)
+ self.tgtkey = self.currkey = self.currvalue = object()
+
+ def keyfunc(self, row):
+ return tuple(row[x] for x in self.group_by)
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ # find a start row
+ while self.currkey == self.tgtkey:
+ self.currvalue = next(self.it) # Exit on StopIteration
+ self.currkey = self.keyfunc(self.currvalue)
+ self.tgtkey = self.currkey
+ # turn the result row into a list of columns
+ row = list(self.currvalue)
+ # replace the group column
+ row[self.group_column] = list(self._grouper(self.tgtkey))
+ return row
+
+ def _grouper(self, tgtkey):
+ """Generate the group columns."""
+ while self.currkey == tgtkey:
+ value = self.currvalue[self.group_column]
+ if value is not None:
+ yield value
+ self.currvalue = next(self.it) # Exit on StopIteration
+ self.currkey = self.keyfunc(self.currvalue)
+
+
class Query(object):
+ """Helper class to build an SQL query for the cache."""
- def __init__(self, query, parameters=None):
+ def __init__(self, query):
self.query = query
self.wheres = []
self.parameters = []
- if parameters:
- for k, v in parameters.items():
- self.add_where('`%s` = ?' % k, [v])
-
- def add_query(self, query):
- self.query += ' ' + query
def add_where(self, where, parameters):
self.wheres.append(where)
@@ -51,64 +83,17 @@ class Query(object):
query = self.query
if self.wheres:
query += ' WHERE ' + ' AND '.join(self.wheres)
- c = con.cursor()
- return c.execute(query, self.parameters)
-
-
-class CnAliasedQuery(Query):
-
- sql = '''
- SELECT `%(table)s_cache`.*,
- `%(table)s_alias_cache`.`cn` AS `alias`
- FROM `%(table)s_cache`
- LEFT JOIN `%(table)s_alias_cache`
- ON `%(table)s_alias_cache`.`%(table)s` = `%(table)s_cache`.`cn`
- '''
-
- cn_join = '''
- LEFT JOIN `%(table)s_alias_cache` `cn_alias`
- ON `cn_alias`.`%(table)s` = `%(table)s_cache`.`cn`
- '''
-
- def __init__(self, table, parameters):
- args = dict(table=table)
- super(CnAliasedQuery, self).__init__(self.sql % args)
- for k, v in parameters.items():
- if k == 'cn':
- self.add_query(self.cn_join % args)
- self.add_where('(`%(table)s_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)' % args, [v, v])
- else:
- self.add_where('`%s` = ?' % k, [v])
-
-
-class RowGrouper(object):
- """Pass in query results and group the results by a certain specified
- list of columns."""
-
- def __init__(self, results, groupby, columns):
- self.groupby = groupby
- self.columns = columns
- self.results = itertools.groupby(results, key=self.keyfunc)
-
- def __iter__(self):
- return self
-
- def keyfunc(self, row):
- return tuple(row[x] for x in self.groupby)
-
- def next(self):
- groupcols, rows = self.results.next()
- tmp = dict((x, list()) for x in self.columns)
- for row in rows:
- for col in self.columns:
- if row[col] is not None:
- tmp[col].append(row[col])
- result = dict(row)
- result.update(tmp)
- return result
+ cursor = con.cursor()
+ return cursor.execute(query, self.parameters)
class Cache(object):
+ """The description of the cache."""
+
+ retrieve_sql = None
+ retrieve_by = dict()
+ group_by = ()
+ group_columns = ()
def __init__(self):
self.con = _get_connection()
@@ -154,12 +139,25 @@ class Cache(object):
''' % (self.tables[n + 1]), ((values[0], x) for x in vlist))
def retrieve(self, parameters):
- """Retrieve all items from the cache based on the parameters supplied."""
- query = Query('''
+ """Retrieve all items from the cache based on the parameters
+ supplied."""
+ query = Query(self.retrieve_sql or '''
SELECT *
FROM %s
- ''' % self.tables[0], parameters)
- return (list(x)[:-1] for x in query.execute(self.con))
+ ''' % self.tables[0])
+ if parameters:
+ for k, v in parameters.items():
+ where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k))
+ query.add_where(where, where.count('?') * [v])
+ # group by
+ # FIXME: find a nice way to turn group_by and group_columns into names
+ results = query.execute(self.con)
+ group_by = list(self.group_by + self.group_columns)
+ for column in self.group_columns[::-1]:
+ group_by.pop()
+ results = regroup(results, group_by, column)
+ # strip the mtime from the results
+ return (list(x)[:-1] for x in results)
def __enter__(self):
return self.con.__enter__();
diff --git a/pynslcd/group.py b/pynslcd/group.py
index 2028f1e..10e3423 100644
--- a/pynslcd/group.py
+++ b/pynslcd/group.py
@@ -99,13 +99,17 @@ class Cache(cache.Cache):
ON `group_member_cache`.`group` = `group_cache`.`cn`
'''
- def retrieve(self, parameters):
- query = cache.Query(self.retrieve_sql, parameters)
- # return results returning the members as a set
- q = itertools.groupby(query.execute(self.con),
- key=lambda x: (x['cn'], x['userPassword'], x['gidNumber']))
- for k, v in q:
- yield k + (set(x['memberUid'] for x in v if x['memberUid'] is not None), )
+ retrieve_by = dict(
+ memberUid='''
+ `cn` IN (
+ SELECT `a`.`group`
+ FROM `group_member_cache` `a`
+ WHERE `a`.`memberUid` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (3, ) # memberUid
class GroupRequest(common.Request):
diff --git a/pynslcd/host.py b/pynslcd/host.py
index 91c3fa0..04f5337 100644
--- a/pynslcd/host.py
+++ b/pynslcd/host.py
@@ -34,23 +34,6 @@ class Search(search.LDAPSearch):
required = ('cn', )
-class HostQuery(cache.CnAliasedQuery):
-
- sql = '''
- SELECT `host_cache`.`cn` AS `cn`,
- `host_alias_cache`.`cn` AS `alias`,
- `host_address_cache`.`ipHostNumber` AS `ipHostNumber`
- FROM `host_cache`
- LEFT JOIN `host_alias_cache`
- ON `host_alias_cache`.`host` = `host_cache`.`cn`
- LEFT JOIN `host_address_cache`
- ON `host_address_cache`.`host` = `host_cache`.`cn`
- '''
-
- def __init__(self, parameters):
- super(HostQuery, self).__init__('host', parameters)
-
-
class Cache(cache.Cache):
tables = ('host_cache', 'host_alias_cache', 'host_address_cache')
@@ -73,10 +56,36 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `host_address_idx` ON `host_address_cache`(`host`);
'''
- def retrieve(self, parameters):
- query = HostQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipHostNumber', )):
- yield row['cn'], row['alias'], row['ipHostNumber']
+ retrieve_sql = '''
+ SELECT `host_cache`.`cn` AS `cn`,
+ `host_alias_cache`.`cn` AS `alias`,
+ `host_address_cache`.`ipHostNumber` AS `ipHostNumber`,
+ `host_cache`.`mtime` AS `mtime`
+ FROM `host_cache`
+ LEFT JOIN `host_alias_cache`
+ ON `host_alias_cache`.`host` = `host_cache`.`cn`
+ LEFT JOIN `host_address_cache`
+ ON `host_address_cache`.`host` = `host_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `host_cache`.`cn` = ? OR
+ `host_cache`.`cn` IN (
+ SELECT `by_alias`.`host`
+ FROM `host_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ ipHostNumber='''
+ `host_cache`.`cn` IN (
+ SELECT `by_ipHostNumber`.`host`
+ FROM `host_address_cache` `by_ipHostNumber`
+ WHERE `by_ipHostNumber`.`ipHostNumber` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # alias, ipHostNumber
class HostRequest(common.Request):
diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py
index 1de60bf..d86e38c 100644
--- a/pynslcd/netgroup.py
+++ b/pynslcd/netgroup.py
@@ -63,6 +63,21 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `netgroup_membe_idx` ON `netgroup_member_cache`(`netgroup`);
'''
+ retrieve_sql = '''
+ SELECT `netgroup_cache`.`cn` AS `cn`,
+ `netgroup_triple_cache`.`nisNetgroupTriple` AS `nisNetgroupTriple`,
+ `netgroup_member_cache`.`memberNisNetgroup` AS `memberNisNetgroup`,
+ `netgroup_cache`.`mtime` AS `mtime`
+ FROM `netgroup_cache`
+ LEFT JOIN `netgroup_triple_cache`
+ ON `netgroup_triple_cache`.`netgroup` = `netgroup_cache`.`cn`
+ LEFT JOIN `netgroup_member_cache`
+ ON `netgroup_member_cache`.`netgroup` = `netgroup_cache`.`cn`
+ '''
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # nisNetgroupTriple, memberNisNetgroup
+
class NetgroupRequest(common.Request):
diff --git a/pynslcd/network.py b/pynslcd/network.py
index bf49b4d..01bf6c2 100644
--- a/pynslcd/network.py
+++ b/pynslcd/network.py
@@ -35,23 +35,6 @@ class Search(search.LDAPSearch):
required = ('cn', )
-class NetworkQuery(cache.CnAliasedQuery):
-
- sql = '''
- SELECT `network_cache`.`cn` AS `cn`,
- `network_alias_cache`.`cn` AS `alias`,
- `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`
- FROM `network_cache`
- LEFT JOIN `network_alias_cache`
- ON `network_alias_cache`.`network` = `network_cache`.`cn`
- LEFT JOIN `network_address_cache`
- ON `network_address_cache`.`network` = `network_cache`.`cn`
- '''
-
- def __init__(self, parameters):
- super(NetworkQuery, self).__init__('network', parameters)
-
-
class Cache(cache.Cache):
tables = ('network_cache', 'network_alias_cache', 'network_address_cache')
@@ -74,10 +57,36 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `network_address_idx` ON `network_address_cache`(`network`);
'''
- def retrieve(self, parameters):
- query = NetworkQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipNetworkNumber', )):
- yield row['cn'], row['alias'], row['ipNetworkNumber']
+ retrieve_sql = '''
+ SELECT `network_cache`.`cn` AS `cn`,
+ `network_alias_cache`.`cn` AS `alias`,
+ `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`,
+ `network_cache`.`mtime` AS `mtime`
+ FROM `network_cache`
+ LEFT JOIN `network_alias_cache`
+ ON `network_alias_cache`.`network` = `network_cache`.`cn`
+ LEFT JOIN `network_address_cache`
+ ON `network_address_cache`.`network` = `network_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `network_cache`.`cn` = ? OR
+ `network_cache`.`cn` IN (
+ SELECT `by_alias`.`network`
+ FROM `network_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ ipNetworkNumber='''
+ `network_cache`.`cn` IN (
+ SELECT `by_ipNetworkNumber`.`network`
+ FROM `network_address_cache` `by_ipNetworkNumber`
+ WHERE `by_ipNetworkNumber`.`ipNetworkNumber` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # alias, ipNetworkNumber
class NetworkRequest(common.Request):
diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py
index 122673d..1472c04 100644
--- a/pynslcd/protocol.py
+++ b/pynslcd/protocol.py
@@ -52,10 +52,26 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `protocol_alias_idx` ON `protocol_alias_cache`(`protocol`);
'''
- def retrieve(self, parameters):
- query = cache.CnAliasedQuery('protocol', parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )):
- yield row['cn'], row['alias'], row['ipProtocolNumber']
+ retrieve_sql = '''
+ SELECT `protocol_cache`.`cn` AS `cn`, `protocol_alias_cache`.`cn` AS `alias`,
+ `ipProtocolNumber`, `mtime`
+ FROM `protocol_cache`
+ LEFT JOIN `protocol_alias_cache`
+ ON `protocol_alias_cache`.`protocol` = `protocol_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `protocol_cache`.`cn` = ? OR
+ `protocol_cache`.`cn` IN (
+ SELECT `by_alias`.`protocol`
+ FROM `protocol_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # alias
class ProtocolRequest(common.Request):
diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py
index 98a0ecc..2a241fd 100644
--- a/pynslcd/rpc.py
+++ b/pynslcd/rpc.py
@@ -52,10 +52,26 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `rpc_alias_idx` ON `rpc_alias_cache`(`rpc`);
'''
- def retrieve(self, parameters):
- query = cache.CnAliasedQuery('rpc', parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )):
- yield row['cn'], row['alias'], row['oncRpcNumber']
+ retrieve_sql = '''
+ SELECT `rpc_cache`.`cn` AS `cn`, `rpc_alias_cache`.`cn` AS `alias`,
+ `oncRpcNumber`, `mtime`
+ FROM `rpc_cache`
+ LEFT JOIN `rpc_alias_cache`
+ ON `rpc_alias_cache`.`rpc` = `rpc_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `rpc_cache`.`cn` = ? OR
+ `rpc_cache`.`cn` IN (
+ SELECT `by_alias`.`rpc`
+ FROM `rpc_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # alias
class RpcRequest(common.Request):
diff --git a/pynslcd/service.py b/pynslcd/service.py
index 6f55cc1..c27f485 100644
--- a/pynslcd/service.py
+++ b/pynslcd/service.py
@@ -40,33 +40,6 @@ class Search(search.LDAPSearch):
required = ('cn', 'ipServicePort', 'ipServiceProtocol')
-class ServiceQuery(cache.CnAliasedQuery):
-
- sql = '''
- SELECT `service_cache`.*,
- `service_alias_cache`.`cn` AS `alias`
- FROM `service_cache`
- LEFT JOIN `service_alias_cache`
- ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort`
- AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
- '''
-
- cn_join = '''
- LEFT JOIN `service_alias_cache` `cn_alias`
- ON `cn_alias`.`ipServicePort` = `service_cache`.`ipServicePort`
- AND `cn_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
- '''
-
- def __init__(self, parameters):
- super(ServiceQuery, self).__init__('service', {})
- for k, v in parameters.items():
- if k == 'cn':
- self.add_query(self.cn_join)
- self.add_where('(`service_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)', [v, v])
- else:
- self.add_where('`service_cache`.`%s` = ?' % k, [v])
-
-
class Cache(cache.Cache):
tables = ('service_cache', 'service_alias_cache')
@@ -90,6 +63,34 @@ class Cache(cache.Cache):
CREATE INDEX IF NOT EXISTS `service_alias_idx2` ON `service_alias_cache`(`ipServiceProtocol`);
'''
+ retrieve_sql = '''
+ SELECT `service_cache`.`cn` AS `cn`,
+ `service_alias_cache`.`cn` AS `alias`,
+ `service_cache`.`ipServicePort`,
+ `service_cache`.`ipServiceProtocol`,
+ `mtime`
+ FROM `service_cache`
+ LEFT JOIN `service_alias_cache`
+ ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort`
+ AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `service_cache`.`cn` = ? OR
+ 0 < (
+ SELECT COUNT(*)
+ FROM `service_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?
+ AND `by_alias`.`ipServicePort` = `service_cache`.`ipServicePort`
+ AND `by_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
+ ))
+ ''',
+ )
+
+ group_by = (0, 2, 3) # cn, ipServicePort, ipServiceProtocol
+ group_columns = (1, ) # alias
+
def store(self, name, aliases, port, protocol):
self.con.execute('''
INSERT OR REPLACE INTO `service_cache`
@@ -107,11 +108,6 @@ class Cache(cache.Cache):
(?, ?, ?)
''', ((port, protocol, alias) for alias in aliases))
- def retrieve(self, parameters):
- query = ServiceQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', 'ipServicePort', 'ipServiceProtocol'), ('alias', )):
- yield row['cn'], row['alias'], row['ipServicePort'], row['ipServiceProtocol']
-
class ServiceRequest(common.Request):