From 8c480f86c83b45d8922ae2b7353874ed499f3df1 Mon Sep 17 00:00:00 2001 From: Arthur de Jong Date: Mon, 12 Dec 2011 22:59:00 +0000 Subject: define the search separately from the request git-svn-id: http://arthurdejong.org/svn/nss-pam-ldapd/nss-pam-ldapd@1571 ef36b2f9-881f-0410-afb5-c4e39611909c --- pynslcd/alias.py | 5 ++- pynslcd/common.py | 126 +++++++++++++++++++++++++++++++++------------------- pynslcd/ether.py | 14 +++--- pynslcd/group.py | 39 ++++++++++------ pynslcd/host.py | 5 ++- pynslcd/netgroup.py | 5 ++- pynslcd/network.py | 5 ++- pynslcd/passwd.py | 39 +++++----------- pynslcd/protocol.py | 5 ++- pynslcd/rpc.py | 5 ++- pynslcd/service.py | 5 ++- pynslcd/shadow.py | 5 ++- 12 files changed, 153 insertions(+), 105 deletions(-) (limited to 'pynslcd') diff --git a/pynslcd/alias.py b/pynslcd/alias.py index fe7bd7b..a19382e 100644 --- a/pynslcd/alias.py +++ b/pynslcd/alias.py @@ -28,12 +28,15 @@ attmap = common.Attributes(cn='cn', rfc822MailMember='rfc822MailMember') filter = '(objectClass=nisMailAlias)' -class AliasRequest(common.Request): +class Search(common.Search): case_insensitive = ('cn', ) limit_attributes = ('cn', ) required = ('cn', 'rfc822MailMember') + +class AliasRequest(common.Request): + def write(self, dn, attributes, parameters): # get values names = attributes['cn'] diff --git a/pynslcd/common.py b/pynslcd/common.py index 359b220..a44f708 100644 --- a/pynslcd/common.py +++ b/pynslcd/common.py @@ -54,12 +54,10 @@ def validate_name(name): raise ValueError('%r: invalid user name' % name) -class Request(object): +class Search(object): """ - Request handler class. Subclasses are expected to handle actual requests - and should implement the following members: - - action - the NSLCD_ACTION_* action that should trigger this handler + Class that performs a search. Subclasses are expected to define the actual + searches and should implement the following members: case_sensitive - check that these attributes are present in the response if they were in the request @@ -69,10 +67,7 @@ class Request(object): required - attributes that are required canonical_first - search the DN for these attributes and ensure that they are listed first in the attribute values - read_parameters() - a function that reads the request parameters of the - request stream mk_filter() (optional) - function that returns the LDAP search filter - write() - function that writes a single LDAP entry to the result stream The module that contains the Request class can also contain the following definitions: @@ -84,38 +79,61 @@ class Request(object): """ - def __init__(self, fp, conn, calleruid): - self.fp = fp - self.conn = conn - self.calleruid = calleruid + canonical_first = [] + required = [] + case_sensitive = [] + case_insensitive = [] + limit_attributes = [] + + def __init__(self, conn, base=None, scope=None, filter=None, attributes=None, + parameters=None): # load information from module that defines the class + self.conn = conn module = sys.modules[self.__module__] self.attmap = getattr(module, 'attmap', None) - self.filter = getattr(module, 'filter', None) - self.bases = getattr(module, 'bases', cfg.bases) - self.scope = getattr(module, 'scope', cfg.scope) + self.filter = filter or getattr(module, 'filter', None) + self.parameters = parameters or {} + if base: + self.bases = [base] + else: + self.bases = getattr(module, 'bases', cfg.bases) + self.scope = scope or getattr(module, 'scope', cfg.scope) + self.attributes = attributes or self.attmap.attributes() + + def __iter__(self): + return self() - def read_parameters(self, fp): - """This method should read the parameters from ths stream and - store them in self.""" - pass + def __call__(self): + # get search results + filter = self.mk_filter() + for base in self.bases: + # do the LDAP search + try: + for entry in self.conn.search_s(base, self.scope, filter, self.attributes): + if entry[0]: + entry = self.handle_entry(entry[0], entry[1]) + if entry: + yield entry + except ldap.NO_SUCH_OBJECT: + # FIXME: log message + pass - def mk_filter(self, parameters): + def mk_filter(self): """Return the active search filter (based on the read parameters).""" - if parameters: + if self.parameters: return '(&%s(%s))' % (self.filter, ')('.join('%s=%s' % (self.attmap[attribute], ldap.filter.escape_filter_chars(str(value))) - for attribute, value in parameters.items())) + for attribute, value in self.parameters.items())) return self.filter - def handle_entry(self, dn, attributes, parameters): + def handle_entry(self, dn, attributes): """Handle an entry with the specified attributes, filtering it with the request parameters where needed.""" # translate the attributes using the attribute mapping attributes = self.attmap.translate(attributes) # make sure value from DN is first value - for attr in getattr(self, 'canonical_first', []): + for attr in self.canonical_first: primary_value = get_rdn_value(dn, self.attmap[attr]) if primary_value: values = attributes[attr] @@ -123,44 +141,60 @@ class Request(object): values.remove(primary_value) attributes[attr] = [primary_value] + values # check that these attributes have at least one value - for attr in getattr(self, 'required', []): + for attr in self.required: if not attributes.get(attr, None): print '%s: attribute %s not found' % (dn, self.attmap[attr]) return # check that requested attribute is present (case sensitive) - for attr in getattr(self, 'case_sensitive', []): - value = parameters.get(attr, None) + for attr in self.case_sensitive: + value = self.parameters.get(attr, None) if value and str(value) not in attributes[attr]: print '%s: attribute %s does not contain %r value' % (dn, self.attmap[attr], value) return # not found, skip entry # check that requested attribute is present (case insensitive) - for attr in getattr(self, 'case_insensitive', []): - value = parameters.get(attr, None) + for attr in self.case_insensitive: + value = self.parameters.get(attr, None) if value and str(value).lower() not in (x.lower() for x in attributes[attr]): print '%s: attribute %s does not contain %r value' % (dn, self.attmap[attr], value) return # not found, skip entry # limit attribute values to requested value - for attr in getattr(self, 'limit_attributes', []): - if attr in parameters: - attributes[attr] = [parameters[attr]] - # write the result entry - self.write(dn, attributes, parameters) + for attr in self.limit_attributes: + if attr in self.parameters: + attributes[attr] = [self.parameters[attr]] + # return the entry + return dn, attributes + + +class Request(object): + """ + Request handler class. Subclasses are expected to handle actual requests + and should implement the following members: + + action - the NSLCD_ACTION_* action that should trigger this handler + + read_parameters() - a function that reads the request parameters of the + request stream + write() - function that writes a single LDAP entry to the result stream + + """ + + def __init__(self, fp, conn, calleruid): + self.fp = fp + self.conn = conn + self.calleruid = calleruid + module = sys.modules[self.__module__] + self.search = getattr(module, 'Search', None) + + def read_parameters(self, fp): + """This method should read the parameters from ths stream and + store them in self.""" + pass def handle_request(self, parameters): """This method handles the request based on the parameters read with read_parameters().""" - # get search results - for base in self.bases: - # do the LDAP search - try: - res = self.conn.search_s(base, self.scope, self.mk_filter(parameters), - self.attmap.attributes()) - for entry in res: - if entry[0]: - self.handle_entry(entry[0], entry[1], parameters) - except ldap.NO_SUCH_OBJECT: - # FIXME: log message - pass + for dn, attributes in self.search(conn=self.conn, parameters=parameters): + self.write(dn, attributes, parameters) # write the final result code self.fp.write_int32(constants.NSLCD_RESULT_END) diff --git a/pynslcd/ether.py b/pynslcd/ether.py index 05dea72..b26dcc7 100644 --- a/pynslcd/ether.py +++ b/pynslcd/ether.py @@ -38,23 +38,19 @@ attmap = common.Attributes(cn='cn', macAddress='macAddress') filter = '(objectClass=ieee802Device)' -class EtherRequest(common.Request): +class Search(common.Search): case_insensitive = ('cn', ) limit_attributes = ('cn', 'macAddress') required = ('cn', 'macAddress') + +class EtherRequest(common.Request): + def write(self, dn, attributes, parameters): - # get names + # get values names = attributes['cn'] - # get addresses and convert to binary form addresses = [ether_aton(x) for x in attributes['macAddress']] - if 'macAddress' in parameters: - address = ether_aton(parameters['macAddress']) - if address not in addresses: - print 'value %r for attribute %s not found in %s' % (parameters['macAddress'], attmap['macAddress'], dn) - return - addresses = ( address, ) # write results for name in names: for ether in addresses: diff --git a/pynslcd/group.py b/pynslcd/group.py index aacc44e..00a39eb 100644 --- a/pynslcd/group.py +++ b/pynslcd/group.py @@ -40,10 +40,33 @@ attmap = common.Attributes(cn='cn', filter = '(|(objectClass=posixGroup)(objectClass=groupOfNames))' -class GroupRequest(common.Request): +class Search(common.Search): case_sensitive = ('cn', ) limit_attributes = ('cn', 'gidNumber') + + def __init__(self, *args, **kwargs): + super(Search, self).__init__(*args, **kwargs) + if attmap['member'] and 'memberUid' in self.parameters: + # set up our own attributes that leave out membership attributes + self.attmap = common.Attributes(self.attmap) + del self.attmap['memberUid'] + del self.attmap['member'] + + def mk_filter(self): + # we still need a custom mk_filter because this is an | query + if attmap['member'] and 'memberUid' in self.parameters: + memberuid = self.parameters['memberUid'] + dn = uid2dn(self.conn, memberuid) + if dn: + return '(&%s(|(%s=%s)(%s=%s)))' % ( self.filter, + attmap['memberUid'], ldap.filter.escape_filter_chars(memberuid), + attmap['member'], ldap.filter.escape_filter_chars(dn) ) + return super(Search, self).mk_filter() + + +class GroupRequest(common.Request): + wantmembers = True def write(self, dn, attributes, parameters): @@ -96,6 +119,8 @@ class GroupByGidRequest(GroupRequest): return dict(gidNumber=fp.read_gid_t()) + + class GroupByMemberRequest(GroupRequest): action = constants.NSLCD_ACTION_GROUP_BYMEMBER @@ -113,18 +138,6 @@ class GroupByMemberRequest(GroupRequest): common.validate_name(memberuid) return dict(memberUid=memberuid) - def mk_filter(self, parameters): - # we still need a custom mk_filter because this is an | query - memberuid = parameters['memberUid'] - if attmap['member']: - dn = uid2dn(self.conn, memberuid) - if dn: - return '(&%s(|(%s=%s)(%s=%s)))' % ( self.filter, - attmap['memberUid'], ldap.filter.escape_filter_chars(memberuid), - attmap['member'], ldap.filter.escape_filter_chars(dn) ) - return '(&%s(%s=%s))' % ( self.filter, - attmap['memberUid'], ldap.filter.escape_filter_chars(memberuid) ) - class GroupAllRequest(GroupRequest): diff --git a/pynslcd/host.py b/pynslcd/host.py index c2f8074..e81d8dc 100644 --- a/pynslcd/host.py +++ b/pynslcd/host.py @@ -28,11 +28,14 @@ attmap = common.Attributes(cn='cn', ipHostNumber='ipHostNumber') filter = '(objectClass=ipHost)' -class HostRequest(common.Request): +class Search(common.Search): canonical_first = ('cn', ) required = ('cn', ) + +class HostRequest(common.Request): + def write(self, dn, attributes, parameters): # get values hostnames = attributes['cn'] diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py index 1c74d2d..2b3a45f 100644 --- a/pynslcd/netgroup.py +++ b/pynslcd/netgroup.py @@ -34,11 +34,14 @@ attmap = common.Attributes(cn='cn', filter = '(objectClass=nisNetgroup)' -class NetgroupRequest(common.Request): +class Search(common.Search): case_sensitive = ('cn', ) required = ('cn', ) + +class NetgroupRequest(common.Request): + def write(self, dn, attributes, parameters): # write the netgroup triples for triple in attributes['nisNetgroupTriple']: diff --git a/pynslcd/network.py b/pynslcd/network.py index e5149c4..2887a61 100644 --- a/pynslcd/network.py +++ b/pynslcd/network.py @@ -29,11 +29,14 @@ attmap = common.Attributes(cn='cn', filter = '(objectClass=ipNetwork)' -class NetworkRequest(common.Request): +class Search(common.Search): canonical_first = ('cn', ) required = ('cn', ) + +class NetworkRequest(common.Request): + def write(self, dn, attributes, parameters): # get values networknames = attributes['cn'] diff --git a/pynslcd/passwd.py b/pynslcd/passwd.py index f575fac..f35be8b 100644 --- a/pynslcd/passwd.py +++ b/pynslcd/passwd.py @@ -37,13 +37,16 @@ filter = '(objectClass=posixAccount)' bases = ( 'ou=people,dc=test,dc=tld', ) -class PasswdRequest(common.Request): +class Search(common.Search): case_sensitive = ('uid', 'uidNumber', ) limit_attributes = ('uid', 'uidNumber', ) required = ('uid', 'uidNumber', 'gidNumber', 'gecos', 'homeDirectory', 'loginShell') + +class PasswdRequest(common.Request): + def write(self, dn, attributes, parameters): # get values names = attributes['uid'] @@ -95,32 +98,12 @@ class PasswdAllRequest(PasswdRequest): action = constants.NSLCD_ACTION_PASSWD_ALL -# FIXME: have something in common that does this -def do_search(conn, flt=None, base=None): - mybases = ( base, ) if base else bases - flt = flt or filter - import cfg - # perform a search for each search base - for base in mybases: - # do the LDAP search - try: - scope = locals().get('scope', cfg.scope) - res = conn.search_s(base, scope, flt, [attmap['uid']]) - for entry in res: - if entry[0]: - yield entry - except ldap.NO_SUCH_OBJECT: - # FIXME: log message - pass - def uid2entry(conn, uid): """Look up the user by uid and return the LDAP entry or None if the user was not found.""" - myfilter = '(&%s(%s=%s))' % ( filter, - attmap['uid'], ldap.filter.escape_filter_chars(uid) ) - for dn, attributes in do_search(conn, myfilter): - if uid in attributes[attmap['uid']]: - return dn, attributes + for dn, attributes in Search(conn, parameters=dict(uid=uid)): + return dn, attributes + def uid2dn(conn, uid): """Look up the user by uid and return the DN or None if the user was @@ -131,11 +114,9 @@ def uid2dn(conn, uid): # FIXME: use cache of dn2uid and try to use DN to get uid attribute + def dn2uid(conn, dn): """Look up the user by dn and return a uid or None if the user was not found.""" - try: - for dn, attributes in do_search(conn, base=dn): - return attributes[attmap['uid']][0] - except ldap.NO_SUCH_OBJECT: - return None + for dn, attributes in Search(conn, base=dn): + return attributes['uid'][0] diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py index d7587ce..0f358cb 100644 --- a/pynslcd/protocol.py +++ b/pynslcd/protocol.py @@ -28,12 +28,15 @@ attmap = common.Attributes(cn='cn', ipProtocolNumber='ipProtocolNumber') filter = '(objectClass=ipProtocol)' -class ProtocolRequest(common.Request): +class Search(common.Search): case_sensitive = ('cn', ) canonical_first = ('cn', ) required = ('cn', 'ipProtocolNumber') + +class ProtocolRequest(common.Request): + def write(self, dn, attributes, parameters): # get values names = attributes['cn'] diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py index 2a7e434..2c7aa85 100644 --- a/pynslcd/rpc.py +++ b/pynslcd/rpc.py @@ -28,12 +28,15 @@ attmap = common.Attributes(cn='cn', oncRpcNumber='oncRpcNumber') filter = '(objectClass=oncRpc)' -class RpcRequest(common.Request): +class Search(common.Search): case_sensitive = ('cn', ) canonical_first = ('cn', ) required = ('cn', 'oncRpcNumber') + +class RpcRequest(common.Request): + def write(self, dn, attributes, parameters): # get values names = attributes['cn'] diff --git a/pynslcd/service.py b/pynslcd/service.py index f0cb9b6..6923236 100644 --- a/pynslcd/service.py +++ b/pynslcd/service.py @@ -31,13 +31,16 @@ attmap = common.Attributes(cn='cn', filter = '(objectClass=ipService)' -class ServiceRequest(common.Request): +class Search(common.Search): case_sensitive = ('cn', 'ipServiceProtocol') limit_attributes = ('ipServiceProtocol', ) canonical_first = ('cn', ) required = ('cn', 'ipServicePort', 'ipServiceProtocol') + +class ServiceRequest(common.Request): + def write(self, dn, attributes, parameters): # get values names = attributes['cn'] diff --git a/pynslcd/shadow.py b/pynslcd/shadow.py index 73b8fea..e8e5f52 100644 --- a/pynslcd/shadow.py +++ b/pynslcd/shadow.py @@ -37,12 +37,15 @@ filter = '(objectClass=shadowAccount)' bases = ( 'ou=people,dc=test,dc=tld', ) -class ShadowRequest(common.Request): +class Search(common.Search): case_sensitive = ('uid', ) limit_attributes = ('uid', ) required = ('uid', ) + +class ShadowRequest(common.Request): + def write(self, dn, attributes, parameters): # get name and check against requested name names = attributes['uid'] -- cgit v1.2.3