On 04/02/2015 11:54 AM, Petr Viktorin wrote:
On 03/31/2015 12:11 PM, Petr Vobornik wrote:
The only different thing is a lack of utf-8 encoded str support(as
input). I don't know how much important the support is.

I don't think that support is too important (assuming IPA doesn't use
it!). However, the behavior with this patch is dangerous.
It allows unicode and ASCII strings, but fails on non-ASCII strings.
That means things will usually work, but as soon as a non-ASCII
component is introduced at the wrong place, you get an error.

Restoring support for utf-8 encoded str looks easy to do; here's a patch
you can squash in. Or did I miss something?

I also had to fix creation of AVAs to support utf-8 encoded str as input for attr and value (separately).


maybe it could be attached to ticket
https://fedorahosted.org/freeipa/ticket/4947
-----
DN code was optimized to be faster if DNs are created from string. This
is the major use case, since most DNs come from LDAP.

With this patch, DN creation is almost 8-10x faster (with 30K-100K DNs).

Second mojor use case - deepcopy in LDAPEntry is about 20x faster - done
by custom __deepcopy__ function.

The major change is that DN is no longer internally composed  of RDNs
and AVAs but it rather keeps the data in open ldap format - the same as
output of str2dn function. Therefore, for immutable DNs, no other
transformations are required on instantiation.

The format is:

DN: [RDN, RDN,...]
RDN: [AVA, AVA,...]
AVA: ['utf-8 encoded str - attr', 'utf-8 encode str -value', FLAG]
FLAG: int

Further indexing of DN object constructs an RDN which is just an
encapsulation of the RDN part of open ldap representation. Indexing of
RDN constructs AVA in the same fashion.

Obtained EditableAVA, EditableRDN from EditableDN shares the respected
lists of the open ldap repr. so that the change of value or attr is
reflected in parent object.


Looks good. A couple of comments:

RDN.to_openldap: _avas always has 3 components, right? I'd prefer
`list(a)` over `[a[0], a[1], a[2]]`.  Similarly for tuple in in __add__
and RDN._avas_from_sequence.

Fixed


DN._rdns_from_value: the error message at the end is wrong, RDN is also
accepted. (And, `type(value)` would be more informative than
`value.__class__.__name__`.)

Fixed


You can optimize __deepcopy__ for immutable DNs even further: just
return self!

Fixed, but kept part for EditableDN


In DN.find & rfind, RDNs are not accepted but the error message says
they are.

messages fixed


You removed the newline at end of file.


line readded
--
Petr Vobornik
From 6289202ca5c5d24c1b07754d19a292e66cfd5df2 Mon Sep 17 00:00:00 2001
From: Petr Vobornik <pvobo...@redhat.com>
Date: Wed, 25 Mar 2015 13:39:43 +0100
Subject: [PATCH] performance: faster DN implementation

DN code was optimized to be faster if DNs are created from string. This is
the major use case, since most DNs come from LDAP.

With this patch, DN creation is almost 8-10x faster (with 30K-100K DNs).

Second mojor use case - deepcopy in LDAPEntry is about 20x faster - done by
custom __deepcopy__ function.

The major change is that DN is no longer internally composed  of RDNs and
AVAs but it rather keeps the data in open ldap format - the same as output
of str2dn function. Therefore, for immutable DNs, no other transformations
are required on instantiation.

The format is:

DN: [RDN, RDN,...]
RDN: [AVA, AVA,...]
AVA: ['utf-8 encoded str - attr', 'utf-8 encode str -value', FLAG]
FLAG: int

Further indexing of DN object constructs an RDN which is just an encapsulation
of the RDN part of open ldap representation. Indexing of RDN constructs AVA in
the same fashion.

Obtained EditableAVA, EditableRDN from EditableDN shares the respected lists
of the open ldap repr. so that the change of value or attr is reflected in
parent object.
---
 ipapython/dn.py                    | 595 ++++++++++++++++++-------------------
 ipatests/test_ipapython/test_dn.py |  17 +-
 2 files changed, 306 insertions(+), 306 deletions(-)

diff --git a/ipapython/dn.py b/ipapython/dn.py
index 834291fbe8696622162efa5193622d74f11f25ca..5b6570770d587937c87380f7ea19e999c3d8867d 100644
--- a/ipapython/dn.py
+++ b/ipapython/dn.py
@@ -497,6 +497,97 @@ def _adjust_indices(start, end, length):
 
     return start, end
 
+
+def _normalize_ava_input(val):
+    if not isinstance(val, basestring):
+        val = unicode(val).encode('utf-8')
+    elif isinstance(val, unicode):
+        val = val.encode('utf-8')
+    return val
+
+
+def str2rdn(value):
+    try:
+        rdns = str2dn(value.encode('utf-8'))
+    except DECODING_ERROR:
+        raise ValueError("malformed AVA string = \"%s\"" % value)
+    if len(rdns) != 1:
+        raise ValueError("multiple RDN's specified by \"%s\"" % (value))
+    return rdns[0]
+
+
+def get_ava(*args, **kwds):
+    """
+    Get AVA from args in open ldap format(raw). Optimized for construction
+    from openldap format.
+
+    Allowed formats of argument list:
+    1) three args - open ldap format (attr and value have to be utf-8 encoded):
+        a) ['attr', 'value', 0]
+    2) two args:
+        a) ['attr', 'value']
+    3) one arg:
+        a) [('attr', 'value')]
+        b) [['attr', 'value']]
+        c) [AVA(..)]
+        d) ['attr=value']
+    """
+    ava = None
+    l = len(args)
+    if l == 3:  # raw values - constructed FROM RDN
+        if kwds.get('mutable', False):
+            ava = args
+        else:
+            ava = (args[0], args[1], args[2])
+    elif l == 2:  # user defined values
+        ava = [_normalize_ava_input(args[0]), _normalize_ava_input(args[1]), 0]
+    elif l == 1:  # slow mode, tuple, string,
+        arg = args[0]
+        if isinstance(arg, AVA):
+            ava = arg.to_openldap()
+        elif isinstance(arg, (tuple, list)):
+            if len(arg) != 2:
+                raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (arg))
+            ava = [_normalize_ava_input(arg[0]), _normalize_ava_input(arg[1]), 0]
+        elif isinstance(arg, basestring):
+            rdn = str2rdn(arg)
+            if len(rdn) > 1:
+                raise TypeError("multiple AVA's specified by \"%s\"" % (arg))
+            ava = list(rdn[0])
+        else:
+            raise TypeError("with 1 argument, argument must be str, unicode, tuple or list, got %s instead" %
+                            arg.__class__.__name__)
+    else:
+        raise TypeError("invalid number of arguments. 1-3 allowed")
+    return ava
+
+
+def sort_avas(rdn):
+    if len(rdn) <= 1:
+        return
+    rdn.sort(cmp=cmp_avas)
+
+
+def cmp_avas(a, b):
+    r = cmp(a[0].lower(), b[0].lower())
+    if r == 0:
+        r = cmp(a[1].lower(), b[1].lower())
+    return r
+
+
+def cmp_rdns(a, b):
+
+    l = len(a)
+    r = cmp(l, len(b))
+    if r != 0:
+        return r
+
+    for i, ava_a in enumerate(a):
+        r = cmp_avas(ava_a, b[i])
+        if r != 0:
+            return r
+    return 0
+
 class AVA(object):
     '''
     AVA(arg0, ...)
@@ -552,100 +643,51 @@ class AVA(object):
     syntax with proper escaping.
     '''
     is_mutable = False
-    flags = 0
 
     def __init__(self, *args, **kwds):
-        if len(args) == 1:
-            arg = args[0]
-            if isinstance(arg, AVA):
-                ava = (arg.attr, arg.value)
-            elif isinstance(arg, basestring):
-                try:
-                    rdns = str2dn(arg.encode('utf-8'))
-                except DECODING_ERROR:
-                    raise ValueError("malformed AVA string = \"%s\"" % arg)
-                if len(rdns) != 1:
-                    raise ValueError("multiple RDN's specified by \"%s\"" % (arg))
-                rdn = rdns[0]
-                if len(rdn) != 1:
-                    raise ValueError("multiple AVA's specified by \"%s\"" % (arg))
-                ava = rdn[0]
-            elif isinstance(arg, (tuple, list)):
-                ava = arg
-                if len(ava) != 2:
-                    raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (ava))
-            else:
-                raise TypeError("with 1 argument, argument must be str,unicode,tuple or list, got %s instead" % \
-                                arg.__class__.__name__)
-
-            attr  = ava[0]
-            value = ava[1]
-        elif len(args) == 2:
-            attr  = args[0]
-            value = args[1]
-        else:
-            raise TypeError("takes 1 or 2 arguments (%d given)" % (len(args)))
-
-        self._set_attr(attr)
-        self._set_value(value)
+        self._ava = get_ava(*args, **{'mutable': self.is_mutable})
 
     def _get_attr(self):
-        return self._attr_unicode
+        return self._ava[0].decode('utf-8')
 
     def _set_attr(self, new_attr):
-        # Scalars only
-        if isinstance(new_attr, (tuple, list)):
-            raise TypeError("attr must be scalar, got %s" % type(new_attr))
-
         try:
-            if isinstance(new_attr, unicode):
-                self._attr_unicode = new_attr
-            elif isinstance(new_attr, str):
-                self._attr_unicode = new_attr.decode('utf-8')
-            else:
-                self._attr_unicode = unicode(new_attr)
+            self._ava[0] = _normalize_ava_input(new_attr)
         except Exception, e:
-            raise ValueError('unable to convert attr "%s" to unicode: %s' % (new_attr, e))
+            raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
 
-    attr  = property(_get_attr)
+    attr = property(_get_attr)
 
     def _get_value(self):
-        return self._value_unicode
+        return self._ava[1].decode('utf-8')
 
     def _set_value(self, new_value):
-        # Scalars only
-        if isinstance(new_value, (tuple, list)):
-            raise TypeError("value must be scalar, got %s" % type(new_value))
-
         try:
-            if isinstance(new_value, unicode):
-                self._value_unicode  = new_value
-            elif isinstance(new_value, str):
-                self._value_unicode  = new_value.decode('utf-8')
-            else:
-                self._value_unicode  = unicode(new_value)
+            self._ava[1] = _normalize_ava_input(new_value)
         except Exception, e:
-            raise ValueError('unable to convert value "%s" to unicode: %s' % (new_value, e))
+            raise ValueError('unable to convert value "%s": %s' % (new_value, e))
 
     value = property(_get_value)
 
-    def _to_openldap(self):
-        return [[(self._attr_unicode.encode('utf-8'), self._value_unicode.encode('utf-8'), self.flags)]]
+    def to_openldap(self):
+        return list(self._ava)
 
     def __str__(self):
-        return dn2str(self._to_openldap())
+        return dn2str([[self.to_openldap()]])
 
     def __repr__(self):
         return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
 
     def __getitem__(self, key):
-        if isinstance(key, basestring):
-            if key == self._attr_unicode:
-                return self._value_unicode
+
+        if key == 0:
+            return self.attr
+        elif key == 1:
+            return self.value
+        elif key == self.attr:
+            return self.value
+        else:
             raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
-        else:
-            raise TypeError("unsupported type for AVA indexing, must be basestring; not %s" % \
-                                (key.__class__.__name__))
 
     def __hash__(self):
         # Hash is computed from AVA's string representation because it's immutable.
@@ -682,8 +724,7 @@ class AVA(object):
             return False
 
         # Perform comparison between objects of same type
-        return self._attr_unicode.lower() == other.attr.lower() and \
-            self._value_unicode.lower() == other.value.lower()
+        return cmp_avas(self._ava, other._ava) == 0
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -694,11 +735,7 @@ class AVA(object):
         if not isinstance(other, AVA):
             raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
 
-        result = cmp(self._attr_unicode.lower(), other.attr.lower())
-        if result != 0:
-            return result
-        result = cmp(self._value_unicode.lower(), other.value.lower())
-        return result
+        return cmp_avas(self._ava, other._ava)
 
 class EditableAVA(AVA):
     '''
@@ -826,113 +863,96 @@ class RDN(object):
     '''
 
     is_mutable = False
-    flags = 0
     AVA_type = AVA
 
     def __init__(self, *args, **kwds):
-        self.avas = self._avas_from_sequence(args)
-        self.avas.sort()
+        self._avas = self._avas_from_sequence(args, kwds.get('raw', False))
 
-    def _ava_from_value(self, value):
-        if isinstance(value, AVA):
-            return self.AVA_type(value.attr, value.value)
-        elif isinstance(value, RDN):
-            avas = []
-            for ava in value.avas:
-                avas.append(self.AVA_type(ava.attr, ava.value))
-            if len(avas) == 1:
-                return avas[0]
-            else:
-                return avas
-        elif isinstance(value, basestring):
-            try:
-                rdns = str2dn(value.encode('utf-8'))
-                if len(rdns) != 1:
-                    raise ValueError("multiple RDN's specified by \"%s\"" % (value))
-                rdn = rdns[0]
-                if len(rdn) == 1:
-                    return self.AVA_type(rdn[0][0], rdn[0][1])
-                else:
-                    avas = []
-                    for ava_tuple in rdn:
-                        avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1]))
-                    return avas
-            except DECODING_ERROR:
-                raise ValueError("malformed RDN string = \"%s\"" % value)
-        elif isinstance(value, (tuple, list)):
-            if len(value) != 2:
-                raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value))
-            return self.AVA_type(value)
-        else:
-            raise TypeError("must be str,unicode,tuple, or AVA, got %s instead" % \
-                            value.__class__.__name__)
-
-
-    def _avas_from_sequence(self, seq):
+    def _avas_from_sequence(self, args, raw=False):
         avas = []
+        sort = 0
+        ava_count = len(args)
 
-        for item in seq:
-            ava = self._ava_from_value(item)
-            if isinstance(ava, list):
-                avas.extend(ava)
-            else:
-                avas.append(ava)
+        if raw:  # fast raw mode
+            try:
+                if self.is_mutable:
+                    avas = args
+                else:
+                    for arg in args:
+                        avas.append((arg[0], arg[1], arg[2]))
+            except KeyError as e:
+                raise TypeError('all AVA values in RAW mode must be in open ldap format')
+        elif ava_count == 1 and isinstance(args[0], basestring):
+            avas = str2rdn(args[0])
+            sort = 1
+        elif ava_count == 1 and isinstance(args[0], RDN):
+            avas = args[0].to_openldap()
+        elif ava_count > 0:
+            sort = 1
+            for arg in args:
+                avas.append(get_ava(arg))
+        if sort:
+            sort_avas(avas)
         return avas
 
-    def _to_openldap(self):
-        return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in self.avas]]
+    def to_openldap(self):
+        return [list(a) for a in self._avas]
 
     def __str__(self):
-        return dn2str(self._to_openldap())
+        return dn2str([self.to_openldap()])
 
     def __repr__(self):
         return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
 
+    def _get_ava(self, ava):
+        return self.AVA_type(*ava)
+
     def _next(self):
-        for ava in self.avas:
-            yield ava
+        for ava in self._avas:
+            yield self._get_ava(ava)
 
     def __iter__(self):
         return self._next()
 
     def __len__(self):
-        return len(self.avas)
+        return len(self._avas)
 
     def __getitem__(self, key):
-        if isinstance(key, (int, long, slice)):
-            return self.avas[key]
+        if isinstance(key, (int, long)):
+            return self._get_ava(self._avas[key])
+        if isinstance(key, slice):
+            return [self._get_ava(ava) for ava in self._avas[key]]
         elif isinstance(key, basestring):
-            for ava in self.avas:
-                if key == ava.attr:
-                    return ava.value
+            for ava in self._avas:
+                if key == ava[0].decode('utf-8'):
+                    return ava[1].decode('utf-8')
             raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
         else:
             raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
                                 (key.__class__.__name__))
 
     def _get_attr(self):
-        if len(self.avas) == 0:
+        if len(self._avas) == 0:
             raise IndexError("No AVA's in this RDN")
-        return self.avas[0].attr
+        return self._avas[0][0].decode('utf-8')
 
     def _set_attr(self, new_attr):
-        if len(self.avas) == 0:
+        if len(self._avas) == 0:
             raise IndexError("No AVA's in this RDN")
 
-        self.avas[0].attr = new_attr
+        self._avas[0][0] = unicode(new_attr).encode('utf-8')
 
     attr  = property(_get_attr)
 
     def _get_value(self):
-        if len(self.avas) == 0:
+        if len(self._avas) == 0:
             raise IndexError("No AVA's in this RDN")
-        return self.avas[0].value
+        return self._avas[0][1].decode('utf-8')
 
     def _set_value(self, new_value):
-        if len(self.avas) == 0:
+        if len(self._avas) == 0:
             raise IndexError("No AVA's in this RDN")
-
-        self.avas[0].value = new_value
+        self._avas[0][1] = unicode(new_value).encode('utf-8')
 
     value = property(_get_value)
 
@@ -959,7 +979,7 @@ class RDN(object):
             return False
 
         # Perform comparison between objects of same type
-        return self.avas == other.avas
+        return cmp_rdns(self._avas, other._avas) == 0
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -968,32 +988,23 @@ class RDN(object):
         if not isinstance(other, RDN):
             raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
 
-        result = cmp(len(self), len(other))
-        if result != 0:
-            return result
-        i = 0
-        while i < len(self):
-            result = cmp(self[i], other[i])
-            if result != 0:
-                return result
-            i += 1
-        return 0
+        return cmp_rdns(self._avas, other._avas)
 
     def __add__(self, other):
         result = self.__class__(self)
         if isinstance(other, RDN):
-            for ava in other.avas:
-                result.avas.append(self.AVA_type(ava.attr, ava.value))
+            for ava in other._avas:
+                result._avas.append((ava[0], ava[1], ava[2]))
         elif isinstance(other, AVA):
-            result.avas.append(self.AVA_type(other.attr, other.value))
+            result._avas.append(other.to_openldap())
         elif isinstance(other, basestring):
             rdn = self.__class__(other)
-            for ava in rdn.avas:
-                result.avas.append(self.AVA_type(ava.attr, ava.value))
+            for ava in rdn._avas:
+                result._avas.append((ava[0], ava[1], ava[2]))
         else:
             raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__))
 
-        result.avas.sort()
+        sort_avas(result._avas)
         return result
 
 class EditableRDN(RDN):
@@ -1016,24 +1027,22 @@ class EditableRDN(RDN):
     AVA_type = EditableAVA
 
     def __setitem__(self, key, value):
+
         if isinstance(key, (int, long)):
-            new_ava = self._ava_from_value(value)
-            if isinstance(new_ava, list):
-                raise TypeError("cannot assign multiple AVA's to single entry")
-            self.avas[key] = new_ava
+            self._avas[key] = get_ava(value)
         elif isinstance(key, slice):
             avas = self._avas_from_sequence(value)
-            self.avas[key] = avas
+            self._avas[key] = avas
         elif isinstance(key, basestring):
-            new_ava = self._ava_from_value(value)
-            if isinstance(new_ava, list):
+            if isinstance(value, list):
                 raise TypeError("cannot assign multiple AVA's to single entry")
+            new_ava = get_ava(value)
             found = False
             i = 0
-            while i < len(self.avas):
-                if key == self.avas[i].attr:
+            while i < len(self._avas):
+                if key == self._avas[i][0].decode('utf-8'):
                     found = True
-                    self.avas[i] = new_ava
+                    self._avas[i] = new_ava
                     break
                 i += 1
             if not found:
@@ -1041,7 +1050,7 @@ class EditableRDN(RDN):
         else:
             raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
                                 (key.__class__.__name__))
-        self.avas.sort()
+        sort_avas(self._avas)
 
     attr  = property(RDN._get_attr, RDN._set_attr)
     value = property(RDN._get_value, RDN._set_value)
@@ -1051,18 +1060,15 @@ class EditableRDN(RDN):
         # If __iadd__ is not available Python will emulate += by
         # replacing the lhs object with the result of __add__ (if available).
         if isinstance(other, RDN):
-            for ava in other.avas:
-                self.avas.append(self.AVA_type(ava.attr, ava.value))
+            self._avas.extend(other.to_openldap())
         elif isinstance(other, AVA):
-            self.avas.append(self.AVA_type(other.attr, other.value))
+            self._avas.append(other.to_openldap())
         elif isinstance(other, basestring):
-            rdn = self.__class__(other)
-            for ava in rdn.avas:
-                self.avas.append(self.AVA_type(ava.attr, ava.value))
+            self._avas.extend(self._avas_from_sequence([other]))
         else:
             raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__))
 
-        self.avas.sort()
+        sort_avas(self._avas)
         return self
 
 class DN(object):
@@ -1213,72 +1219,74 @@ class DN(object):
     '''
 
     is_mutable = False
-    flags = 0
     AVA_type = AVA
     RDN_type = RDN
 
     def __init__(self, *args, **kwds):
         self.rdns = self._rdns_from_sequence(args)
 
-    def _rdn_from_value(self, value):
-        if isinstance(value, RDN):
-            return self.RDN_type(value)
+    def _copy_rdns(self, rdns=None):
+        if not rdns:
+            rdns = self.rdns
+        return [[list(a) for a in rdn] for rdn in rdns]
+
+    def _rdns_from_value(self, value):
+        if isinstance(value, basestring):
+            try:
+                if isinstance(value, unicode):
+                    value = value.encode('utf-8')
+                rdns = str2dn(value)
+                if self.is_mutable:
+                    self._copy_rdns(rdns)  # AVAs to be list instead of tuple
+            except DECODING_ERROR:
+                raise ValueError("malformed RDN string = \"%s\"" % value)
+            for rdn in rdns:
+                sort_avas(rdn)
         elif isinstance(value, DN):
-            rdns = []
-            for rdn in value.rdns:
-                rdns.append(self.RDN_type(rdn))
-            if len(rdns) == 1:
-                return rdns[0]
-            else:
-                return rdns
-        elif isinstance(value, basestring):
-            rdns = []
-            try:
-                dn_list = str2dn(value.encode('utf-8'))
-                for rdn_list in dn_list:
-                    avas = []
-                    for ava_tuple in rdn_list:
-                        avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1]))
-                    rdn = self.RDN_type(*avas)
-                    rdns.append(rdn)
-            except DECODING_ERROR:
-                raise ValueError("malformed RDN string = \"%s\"" % value)
-            if len(rdns) == 1:
-                return rdns[0]
-            else:
-                return rdns
-        elif isinstance(value, (tuple, list)):
-            if len(value) != 2:
-                raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value))
-            rdn = self.RDN_type(value)
-            return rdn
+            rdns = value._copy_rdns()
+        elif isinstance(value, (tuple, list, AVA)):
+            ava = get_ava(value)
+            rdns = [[ava]]
+        elif isinstance(value, RDN):
+            rdns = [value.to_openldap()]
         else:
-            raise TypeError("must be str,unicode,tuple, or RDN, got %s instead" % \
-                            value.__class__.__name__)
+            raise TypeError("must be str, unicode, tuple, or RDN or DN, got %s instead" %
+                            type(value))
+        return rdns
 
     def _rdns_from_sequence(self, seq):
         rdns = []
 
         for item in seq:
-            rdn = self._rdn_from_value(item)
-            if isinstance(rdn, list):
-                rdns.extend(rdn)
-            else:
-                rdns.append(rdn)
+            rdn = self._rdns_from_value(item)
+            rdns.extend(rdn)
         return rdns
 
-    def _to_openldap(self):
-        return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in rdn] for rdn in self.rdns]
+    def __deepcopy__(self, memo):
+        if self.is_mutable:
+            cls = self.__class__
+            clone = cls.__new__(cls)
+            clone.rdns = self._copy_rdns()
+            return clone
+        return self
+
+    def _get_rdn(self, rdn):
+        return self.RDN_type(*rdn, **{'raw': True})
 
     def __str__(self):
-        return dn2str(self._to_openldap())
+        try:
+            return dn2str(self.rdns)
+        except Exception, e:
+            print len(self.rdns)
+            print self.rdns
+            raise
 
     def __repr__(self):
         return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
 
     def _next(self):
         for rdn in self.rdns:
-            yield rdn
+            yield self._get_rdn(rdn)
 
     def __iter__(self):
         return self._next()
@@ -1287,12 +1295,20 @@ class DN(object):
         return len(self.rdns)
 
     def __getitem__(self, key):
-        if isinstance(key, (int, long, slice)):
-            return self.rdns[key]
+        if isinstance(key, (int, long)):
+            return self._get_rdn(self.rdns[key])
+        if isinstance(key, slice):
+            cls = self.__class__
+            new_dn = cls.__new__(cls)
+            new_dn.rdns = self.rdns[key]
+            if self.is_mutable:
+                new_dn.rdns = self._copy_rdns(new_dn.rdns)
+            return new_dn
         elif isinstance(key, basestring):
             for rdn in self.rdns:
-                if key == rdn.attr:
-                    return rdn.value
+                for ava in rdn:
+                    if key == ava[0].decode('utf-8'):
+                        return ava[1].decode('utf-8')
             raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
         else:
             raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
@@ -1305,11 +1321,16 @@ class DN(object):
         # hash value between two objects which compare as equal but
         # differ in case must yield the same hash value.
 
-        return hash(str(self).lower())
+        str_dn = ';,'.join([
+            '++'.join(
+                ['=='.join((atype, avalue or '')) for atype,avalue,dummy in rdn]
+            ) for rdn in self.rdns
+        ])
+        return hash(str_dn.lower())
 
     def __eq__(self, other):
-        # Try coercing string to DN, if successful compare to coerced object
-        if isinstance(other, basestring):
+        # Try coercing to DN, if successful compare to coerced object
+        if isinstance(other, (basestring, RDN, AVA)):
             try:
                 other_dn = DN(other)
                 return self.__eq__(other_dn)
@@ -1321,7 +1342,7 @@ class DN(object):
             return False
 
         # Perform comparison between objects of same type
-        return self.rdns == other.rdns
+        return self.__cmp__(other) == 0
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -1336,31 +1357,21 @@ class DN(object):
         return self._cmp_sequence(other, 0, len(self))
 
     def _cmp_sequence(self, pattern, self_start, pat_len):
+
         self_idx = self_start
+        self_len = len(self)
         pat_idx = 0
+        #  and self_idx < self_len
         while pat_idx < pat_len:
-            result = cmp(self[self_idx], pattern[pat_idx])
-            if result != 0:
-                return result
+            r = cmp_rdns(self.rdns[self_idx], pattern.rdns[pat_idx])
+            if r != 0:
+                return r
             self_idx += 1
             pat_idx += 1
         return 0
 
     def __add__(self, other):
-        result = self.__class__(self)
-        if isinstance(other, DN):
-            for rdn in other.rdns:
-                result.rdns.append(self.RDN_type(rdn))
-        elif isinstance(other, RDN):
-            result.rdns.append(self.RDN_type(other))
-        elif isinstance(other, basestring):
-            dn = self.__class__(other)
-            for rdn in dn.rdns:
-                result.rdns.append(rdn)
-        else:
-            raise TypeError("expected DN, RDN or basestring but got %s" % (other.__class__.__name__))
-
-        return result
+        return self.__class__(self, other)
 
     # The implementation of startswith, endswith, tailmatch, adjust_indices
     # was based on the Python's stringobject.c implementation
@@ -1402,10 +1413,10 @@ class DN(object):
         arguments. Returns 0 if not found and 1 if found.
         '''
 
+        if isinstance(pattern, RDN):
+            pattern = DN(pattern)
         if isinstance(pattern, DN):
             pat_len = len(pattern)
-        elif isinstance(pattern, RDN):
-            pat_len = 1
         else:
             raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
 
@@ -1423,16 +1434,16 @@ class DN(object):
             if end-pat_len >= start:
                 start = end - pat_len
 
-        if isinstance(pattern, DN):
-            if end-start >= pat_len:
-                return not self._cmp_sequence(pattern, start, pat_len)
-            return 0
-        else:
-            return self.rdns[start] == pattern
+        if end-start >= pat_len:
+            return not self._cmp_sequence(pattern, start, pat_len)
+        return 0
+
 
     def __contains__(self, other):
         'Return the outcome of the test other in self. Note the reversed operands.'
 
+        if isinstance(other, RDN):
+            other = DN(other)
         if isinstance(other, DN):
             other_len = len(other)
             end = len(self) - other_len
@@ -1443,16 +1454,13 @@ class DN(object):
                     return True
                 i += 1
             return False
-
-        elif isinstance(other, RDN):
-            return other in self.rdns
         else:
             raise TypeError("expected DN or RDN but got %s" % (other.__class__.__name__))
 
 
     def find(self, pattern, start=None, end=None):
         '''
-        Return the lowest index in the DN where pattern DN (or RDN) is found,
+        Return the lowest index in the DN where pattern DN is found,
         such that pattern is contained in the range [start, end]. Optional
         arguments start and end are interpreted as in slice notation. Return
         -1 if pattern is not found.
@@ -1460,10 +1468,8 @@ class DN(object):
 
         if isinstance(pattern, DN):
             pat_len = len(pattern)
-        elif isinstance(pattern, RDN):
-            pat_len = 1
         else:
-            raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
+            raise TypeError("expected DN but got %s" % (pattern.__class__.__name__))
 
         self_len = len(self)
 
@@ -1476,19 +1482,14 @@ class DN(object):
 
         i = start
         stop = max(start, end - pat_len)
-        if isinstance(pattern, DN):
-            while i <= stop:
-                result = self._cmp_sequence(pattern, i, pat_len)
-                if result == 0:
-                    return i
-                i += 1
-            return -1
-        else:
-            while i <= stop:
-                if self.rdns[i] == pattern:
-                    return i
-                i += 1
-            return -1
+
+        while i <= stop:
+            result = self._cmp_sequence(pattern, i, pat_len)
+            if result == 0:
+                return i
+            i += 1
+        return -1
+
 
     def index(self, pattern, start=None, end=None):
         '''
@@ -1502,7 +1503,7 @@ class DN(object):
 
     def rfind(self, pattern, start=None, end=None):
         '''
-        Return the highest index in the DN where pattern DN (or RDN) is found,
+        Return the highest index in the DN where pattern DN is found,
         such that pattern is contained in the range [start, end]. Optional
         arguments start and end are interpreted as in slice notation. Return
         -1 if pattern is not found.
@@ -1510,10 +1511,8 @@ class DN(object):
 
         if isinstance(pattern, DN):
             pat_len = len(pattern)
-        elif isinstance(pattern, RDN):
-            pat_len = 1
         else:
-            raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
+            raise TypeError("expected DN but got %s" % (pattern.__class__.__name__))
 
         self_len = len(self)
 
@@ -1526,19 +1525,13 @@ class DN(object):
 
         i = max(start, min(end, self_len - pat_len))
         stop = start
-        if isinstance(pattern, DN):
-            while i >= stop:
-                result = self._cmp_sequence(pattern, i, pat_len)
-                if result == 0:
-                    return i
-                i -= 1
-            return -1
-        else:
-            while i >= stop:
-                if self.rdns[i] == pattern:
-                    return i
-                i -= 1
-            return -1
+
+        while i >= stop:
+            result = self._cmp_sequence(pattern, i, pat_len)
+            if result == 0:
+                return i
+            i -= 1
+        return -1
 
     def rindex(self, pattern, start=None, end=None):
         '''
@@ -1573,23 +1566,23 @@ class EditableDN(DN):
 
     def __setitem__(self, key, value):
         if isinstance(key, (int, long)):
-            new_rdn = self._rdn_from_value(value)
-            if isinstance(new_rdn, list):
+            new_rdns = self._rdns_from_value(value)
+            if len(new_rdns) > 1:
                 raise TypeError("cannot assign multiple RDN's to single entry")
-            self.rdns[key] = new_rdn
+            self.rdns[key] = new_rdns[0]
         elif isinstance(key, slice):
             rdns = self._rdns_from_sequence(value)
             self.rdns[key] = rdns
         elif isinstance(key, basestring):
-            new_rdn = self._rdn_from_value(value)
-            if isinstance(new_rdn, list):
+            new_rdns = self._rdns_from_value(value)
+            if len(new_rdns) > 1:
                 raise TypeError("cannot assign multiple values to single entry")
             found = False
             i = 0
             while i < len(self.rdns):
-                if key == self.rdns[i].attr:
+                if key == self.rdns[i][0][0].decode('utf-8'):
                     found = True
-                    self.rdns[i] = new_rdn
+                    self.rdns[i] = new_rdns[0]
                     break
                 i += 1
             if not found:
@@ -1602,10 +1595,9 @@ class EditableDN(DN):
         # If __iadd__ is not available Python will emulate += by
         # replacing the lhs object with the result of __add__ (if available).
         if isinstance(other, DN):
-            for rdn in other.rdns:
-                self.rdns.append(self.RDN_type(rdn))
+            self.rdns.extend(other._copy_rdns())
         elif isinstance(other, RDN):
-            self.rdns.append(self.RDN_type(other))
+            self.rdns.append(other.to_openldap())
         elif isinstance(other, basestring):
             dn = self.__class__(other)
             self.__iadd__(dn)
@@ -1627,7 +1619,11 @@ class EditableDN(DN):
         for slice indices.
         '''
 
-        self.rdns.insert(i, self._rdn_from_value(x))
+        rdns = self._rdns_from_value(x)
+        if len(rdns) > 1:
+            raise TypeError("cannot assign multiple RDN's to single entry")
+
+        self.rdns.insert(i, rdns[0])
 
     def replace(self, old, new, count=sys.maxsize):
         '''
@@ -1656,3 +1652,4 @@ class EditableDN(DN):
             start = index + pat_len
 
         return n_replaced
+
diff --git a/ipatests/test_ipapython/test_dn.py b/ipatests/test_ipapython/test_dn.py
index 60802b70c988dd864de04e63be56fb4111213c85..77e35eb4795453eb386a36e98d4e5be8f4a75ba9 100644
--- a/ipatests/test_ipapython/test_dn.py
+++ b/ipatests/test_ipapython/test_dn.py
@@ -129,9 +129,9 @@ class TestAVA(unittest.TestCase):
             with self.assertRaises(TypeError):
                 AVA_class()
 
-            # Create with more than 2 args should fail
+            # Create with more than 3 args should fail
             with self.assertRaises(TypeError):
-                AVA_class(self.attr1, self.value1, self.attr1)
+                AVA_class(self.attr1, self.value1, self.attr1, self.attr1)
 
             # Create with 1 arg which is not string should fail
             with self.assertRaises(TypeError):
@@ -164,11 +164,14 @@ class TestAVA(unittest.TestCase):
 
             self.assertEqual(ava1[self.attr1], self.value1)
 
+            self.assertEqual(ava1[0], self.attr1)
+            self.assertEqual(ava1[1], self.value1)
+
             with self.assertRaises(KeyError):
                 ava1['foo']
 
-            with self.assertRaises(TypeError):
-                ava1[0]
+            with self.assertRaises(KeyError):
+                ava1[3]
 
     def test_properties(self):
         for AVA_class in (AVA, EditableAVA):
@@ -1413,7 +1416,7 @@ class TestDN(unittest.TestCase):
             dn3 = DN_class(self.dn3)
 
             self.assertEqual(len(dn1), 1)
-            self.assertEqual(dn1[:], [self.rdn1])
+            self.assertEqual(dn1[:], self.rdn1)
             for i, ava in enumerate(dn1):
                 if i == 0:
                     self.assertEqual(ava, self.rdn1)
@@ -1421,7 +1424,7 @@ class TestDN(unittest.TestCase):
                     self.fail("got iteration index %d, but len=%d" % (i, len(self.rdn1)))
 
             self.assertEqual(len(dn2), 1)
-            self.assertEqual(dn2[:], [self.rdn2])
+            self.assertEqual(dn2[:], self.rdn2)
             for i, ava in enumerate(dn2):
                 if i == 0:
                     self.assertEqual(ava, self.rdn2)
@@ -1429,7 +1432,7 @@ class TestDN(unittest.TestCase):
                     self.fail("got iteration index %d, but len=%d" % (i, len(self.rdn2)))
 
             self.assertEqual(len(dn3), 2)
-            self.assertEqual(dn3[:], [self.rdn1, self.rdn2])
+            self.assertEqual(dn3[:], DN_class(self.rdn1, self.rdn2))
             for i, ava in enumerate(dn3):
                 if i == 0:
                     self.assertEqual(ava, self.rdn1)
-- 
2.1.0

-- 
Manage your subscription for the Freeipa-devel mailing list:
https://www.redhat.com/mailman/listinfo/freeipa-devel
Contribute to FreeIPA: http://www.freeipa.org/page/Contribute/Code

Reply via email to