[Numpy-svn] r3416 - in trunk/numpy/core: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sun Oct 29 02:38:29 EST 2006


Author: oliphant
Date: 2006-10-29 01:38:25 -0600 (Sun, 29 Oct 2006)
New Revision: 3416

Modified:
   trunk/numpy/core/records.py
   trunk/numpy/core/tests/test_records.py
Log:
Add test for recent fixes to recarray setattr.  Replace sb.ndarray with global variable ndarray

Modified: trunk/numpy/core/records.py
===================================================================
--- trunk/numpy/core/records.py	2006-10-29 06:39:30 UTC (rev 3415)
+++ trunk/numpy/core/records.py	2006-10-29 07:38:25 UTC (rev 3416)
@@ -7,6 +7,8 @@
 import types
 import os
 
+ndarray = sb.ndarray
+
 _byteorderconv = {'b':'>',
                   'l':'<',
                   'n':'=',
@@ -164,7 +166,7 @@
 # If byteorder is given it forces a particular byteorder on all
 #  the fields (and any subfields)
 
-class recarray(sb.ndarray):
+class recarray(ndarray):
     def __new__(subtype, shape, dtype=None, buf=None, offset=0, strides=None,
                 formats=None, names=None, titles=None,
                 byteorder=None, aligned=False):
@@ -175,9 +177,9 @@
             descr = format_parser(formats, names, titles, aligned, byteorder)._descr
 
         if buf is None:
-            self = sb.ndarray.__new__(subtype, shape, (record, descr))
+            self = ndarray.__new__(subtype, shape, (record, descr))
         else:
-            self = sb.ndarray.__new__(subtype, shape, (record, descr),
+            self = ndarray.__new__(subtype, shape, (record, descr),
                                       buffer=buf, offset=offset,
                                       strides=strides)
         return self
@@ -187,7 +189,7 @@
             return object.__getattribute__(self,attr)
         except AttributeError: # attr must be a fieldname
             pass
-        fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
+        fielddict = ndarray.__getattribute__(self,'dtype').fields
         try:
             res = fielddict[attr][:2]
         except (TypeError, KeyError):
@@ -199,7 +201,7 @@
             return obj
         if obj.dtype.char in 'SU':
             return obj.view(chararray)
-        return obj.view(sb.ndarray)
+        return obj.view(ndarray)
 
 # Save the dictionary
 #  If the attr is a field name and not in the saved dictionary
@@ -209,15 +211,22 @@
     def __setattr__(self, attr, val):
         newattr = attr not in self.__dict__
         try:
-            res = object.__setattr__(self, attr, val)
-        except AttributeError:
-            fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
+            ret = object.__setattr__(self, attr, val)
+        except:
+            fielddict = ndarray.__getattribute__(self,'dtype').fields
+            if attr not in fielddict:
+                exctype, value = sys.exc_info()[:2]
+                raise exctype, value
         else:
-            fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
+            fielddict = ndarray.__getattribute__(self,'dtype').fields
             if attr not in fielddict:
-                return res
+                return ret
             if newattr:         # We just added this one
-                object.__delattr__(self, attr)
+                try:            #  or this setattr worked on an internal
+                                #  attribute. 
+                    object.__delattr__(self, attr)
+                except:
+                    return ret
         try:
             res = fielddict[attr][:2]
         except (TypeError,KeyError):
@@ -225,17 +234,17 @@
         return self.setfield(val,*res)
 
     def __getitem__(self, indx):
-        obj = sb.ndarray.__getitem__(self, indx)
-        if (isinstance(obj, sb.ndarray) and obj.dtype.isbuiltin):
-            return obj.view(sb.ndarray)
+        obj = ndarray.__getitem__(self, indx)
+        if (isinstance(obj, ndarray) and obj.dtype.isbuiltin):
+            return obj.view(ndarray)
         return obj
 
     def field(self,attr, val=None):
         if isinstance(attr,int):
-            names = sb.ndarray.__getattribute__(self,'dtype').names
+            names = ndarray.__getattribute__(self,'dtype').names
             attr=names[attr]
 
-        fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
+        fielddict = ndarray.__getattribute__(self,'dtype').fields
 
         res = fielddict[attr][:2]
 
@@ -245,20 +254,20 @@
                 return obj
             if obj.dtype.char in 'SU':
                 return obj.view(chararray)
-            return obj.view(sb.ndarray)
+            return obj.view(ndarray)
         else:
             return self.setfield(val, *res)
 
     def view(self, obj):
         try:
-            if issubclass(obj, sb.ndarray):
-                return sb.ndarray.view(self, obj)
+            if issubclass(obj, ndarray):
+                return ndarray.view(self, obj)
         except TypeError:
             pass
         dtype = sb.dtype(obj)
         if dtype.fields is None:
             return self.__array__().view(dtype)
-        return sb.ndarray.view(self, obj)            
+        return ndarray.view(self, obj)            
     
 def fromarrays(arrayList, dtype=None, shape=None, formats=None,
                names=None, titles=None, aligned=False, byteorder=None):
@@ -288,7 +297,7 @@
         # and determine the formats.
         formats = ''
         for obj in arrayList:
-            if not isinstance(obj, sb.ndarray):
+            if not isinstance(obj, ndarray):
                 raise ValueError, "item in the array list must be an ndarray."
             formats += _typestr[obj.dtype.type]
             if issubclass(obj.dtype.type, nt.flexible):
@@ -536,7 +545,7 @@
     elif isinstance(obj, file):
         return fromfile(obj, dtype=dtype, shape=shape, offset=offset)
 
-    elif isinstance(obj, sb.ndarray):
+    elif isinstance(obj, ndarray):
         if dtype is not None and (obj.dtype != dtype):
             new = obj.view(dtype)
         else:

Modified: trunk/numpy/core/tests/test_records.py
===================================================================
--- trunk/numpy/core/tests/test_records.py	2006-10-29 06:39:30 UTC (rev 3415)
+++ trunk/numpy/core/tests/test_records.py	2006-10-29 07:38:25 UTC (rev 3416)
@@ -70,5 +70,20 @@
         for k in xrange(len(ra)):
             assert ra[k].item() == pa[k].item()
 
+    def check_recarray_conflict_fields(self):
+        ra = rec.array([(1,'abc',2.3),(2,'xyz',4.2),
+                        (3,'wrs',1.3)],
+                       names='field, shape, mean')
+        ra.mean = [1.1,2.2,3.3]
+        assert_array_almost_equal(ra['mean'], [1.1,2.2,3.3])
+        assert type(ra.mean) is type(ra.var)
+        ra.shape = (1,3)
+        assert ra.shape == (1,3)
+        ra.shape = ['A','B','C']
+        assert_array_equal(ra['shape'], [['A','B','C']])
+        ra.field = 5
+        assert_array_equal(ra['field'], [[5,5,5]])
+        assert callable(ra.field)
+
 if __name__ == "__main__":
     NumpyTest().run()




More information about the Numpy-svn mailing list