[Numpy-svn] r3098 - trunk/numpy/oldnumeric
numpy-svn at scipy.org
numpy-svn at scipy.org
Tue Aug 29 14:29:58 EDT 2006
Author: oliphant
Date: 2006-08-29 13:29:53 -0500 (Tue, 29 Aug 2006)
New Revision: 3098
Added:
trunk/numpy/oldnumeric/fix_default_axis.py
Log:
Add a module/script to fix the default axis issue for code already converted to NumPy
Added: trunk/numpy/oldnumeric/fix_default_axis.py
===================================================================
--- trunk/numpy/oldnumeric/fix_default_axis.py 2006-08-29 17:56:21 UTC (rev 3097)
+++ trunk/numpy/oldnumeric/fix_default_axis.py 2006-08-29 18:29:53 UTC (rev 3098)
@@ -0,0 +1,292 @@
+"""
+This module adds the default axis argument to code which did not specify it
+for the functions where the default was changed in NumPy.
+
+The functions changed are
+
+add -1 ( all second argument)
+======
+nansum
+nanmax
+nanmin
+nanargmax
+nanargmin
+argmax
+argmin
+compress 3
+
+
+add 0
+======
+take 3
+repeat 3
+sum # might cause problems with builtin.
+product
+sometrue
+alltrue
+cumsum
+cumproduct
+average
+ptp
+cumprod
+prod
+std
+mean
+"""
+__all__ = ['convertfile', 'convertall', 'converttree',
+ 'convertfile2','convertall2', 'converttree2']
+
+import sys
+import os
+import re
+import glob
+
+
+_args3 = ['compress', 'take', 'repeat']
+_funcm1 = ['nansum', 'nanmax', 'nanmin', 'nanargmax', 'nanargmin',
+ 'argmax', 'argmin', 'compress']
+_func0 = ['take', 'repeat', 'sum', 'product', 'sometrue', 'alltrue',
+ 'cumsum', 'cumproduct', 'average', 'ptp', 'cumprod', 'prod',
+ 'std', 'mean']
+
+_all = _func0 + _funcm1
+func_re = {}
+
+for name in _all:
+ _astr = r"""%s\s*[(]"""%name
+ func_re[name] = re.compile(_astr)
+
+
+import string
+disallowed = '_' + string.uppercase + string.lowercase + string.digits
+
+def _add_axis(fstr, name, repl):
+ alter = 0
+ if name in _args3:
+ allowed_comma = 1
+ else:
+ allowed_comma = 0
+ newcode = ""
+ last = 0
+ for obj in func_re[name].finditer(fstr):
+ nochange = 0
+ start, end = obj.span()
+ if fstr[start-1] in disallowed:
+ continue
+ if fstr[start-1] == '.' \
+ and fstr[start-6:start-1] != 'numpy' \
+ and fstr[start-2:start-1] != 'N' \
+ and fstr[start-9:start-1] != 'numarray' \
+ and fstr[start-8:start-1] != 'numerix' \
+ and fstr[start-8:start-1] != 'Numeric':
+ continue
+ if fstr[start-1] in ['\t',' ']:
+ k = start-2
+ while fstr[k] in ['\t',' ']:
+ k -= 1
+ if fstr[k-2:k+1] == 'def' or \
+ fstr[k-4:k+1] == 'class':
+ continue
+ k = end
+ stack = 1
+ ncommas = 0
+ N = len(fstr)
+ while stack:
+ if k>=N:
+ nochange =1
+ break
+ if fstr[k] == ')':
+ stack -= 1
+ elif fstr[k] == '(':
+ stack += 1
+ elif stack == 1 and fstr[k] == ',':
+ ncommas += 1
+ if ncommas > allowed_comma:
+ nochange = 1
+ break
+ k += 1
+ if nochange:
+ continue
+ alter += 1
+ newcode = "%s%s,%s)" % (newcode, fstr[last:k-1], repl)
+ last = k
+ if not alter:
+ newcode = fstr
+ else:
+ newcode = "%s%s" % (newcode, fstr[last:])
+ return newcode, alter
+
+def _import_change(fstr, names):
+ # Four possibilities
+ # 1.) import numpy with subsequent use of numpy.<name>
+ # change this to import numpy.oldnumeric as numpy
+ # 2.) import numpy as XXXX with subsequent use of
+ # XXXX.<name> ==> import numpy.oldnumeric as XXXX
+ # 3.) from numpy import *
+ # with subsequent use of one of the names
+ # 4.) from numpy import ..., <name>, ... (could span multiple
+ # lines. ==> remove all names from list and
+ # add from numpy.oldnumeric import <name>
+
+ num = 0
+ # case 1
+ importstr = "import numpy"
+ ind = fstr.find(importstr)
+ if (ind > 0):
+ found = 0
+ for name in names:
+ ind2 = fstr.find("numpy.%s" % name, ind)
+ if (ind2 > 0):
+ found = 1
+ break
+ if found:
+ fstr = "%s%s%s" % (fstr[:ind], "import numpy.oldnumeric as numpy",
+ fstr[ind+len(importstr):])
+ num += 1
+
+ # case 2
+ importre = re.compile("""import numpy as ([A-Za-z0-9_]+)""")
+ modules = importre.findall(fstr)
+ if len(modules) > 0:
+ for module in modules:
+ found = 0
+ for name in names:
+ ind2 = fstr.find("%s.%s" % (module, name))
+ if (ind2 > 0):
+ found = 1
+ break
+ if found:
+ importstr = "import numpy as %s" % module
+ ind = fstr.find(importstr)
+ fstr = "%s%s%s" % (fstr[:ind],
+ "import numpy.oldnumeric as %s" % module,
+ fstr[ind+len(importstr):])
+ num += 1
+
+ # case 3
+ importstr = "from numpy import *"
+ ind = fstr.find(importstr)
+ if (ind > 0):
+ found = 0
+ for name in names:
+ ind2 = fstr.find(name, ind)
+ if (ind2 > 0) and fstr[ind2-1] not in disallowed:
+ found = 1
+ break
+ if found:
+ fstr = "%s%s%s" % (fstr[:ind],
+ "from numpy.oldnumeric import *",
+ fstr[ind+len(importstr):])
+ num += 1
+
+ # case 4
+ ind = 0
+ importstr = "from numpy import"
+ N = len(importstr)
+ while 1:
+ ind = fstr.find(importstr, ind)
+ if (ind < 0):
+ break
+ ind += N
+ ptr = ind+1
+ stack = 1
+ while stack:
+ if fstr[ptr] == '\\':
+ stack += 1
+ elif fstr[ptr] == '\n':
+ stack -= 1
+ ptr += 1
+ substr = fstr[ind:ptr]
+ found = 0
+ substr = substr.replace('\n',' ')
+ substr = substr.replace('\\','')
+ importnames = [x.strip() for x in substr.split(',')]
+ # determine if any of names are in importnames
+ addnames = []
+ for name in names:
+ if name in importnames:
+ importnames.remove(name)
+ addnames.append(name)
+ if len(addnames) > 0:
+ fstr = "%s%s\n%s\n%s" % \
+ (fstr[:ind],
+ "from numpy import %s" % \
+ ", ".join(importnames),
+ "from numpy.oldnumeric import %s" % \
+ ", ".join(addnames),
+ fstr[ptr:])
+ num += 1
+
+ return fstr, num
+
+def add_axis(fstr, import_change=False):
+ total = 0
+ if not import_change:
+ for name in _funcm1:
+ fstr, num = _add_axis(fstr, name, 'axis=-1')
+ total += num
+ for name in _func0:
+ fstr, num = _add_axis(fstr, name, 'axis=0')
+ total += num
+ return fstr, total
+ else:
+ fstr, num = _import_change(fstr, _funcm1+_func0)
+ return fstr, num
+
+
+def makenewfile(name, filestr):
+ fid = file(name, 'w')
+ fid.write(filestr)
+ fid.close()
+
+def getfile(name):
+ fid = file(name)
+ filestr = fid.read()
+ fid.close()
+ return filestr
+
+def copyfile(name, fstr):
+ base, ext = os.path.splitext(name)
+ makenewfile(base+'.orig', fstr)
+ return
+
+def convertfile(filename, import_change=False):
+ """Convert the filename given from using Numeric to using NumPy
+
+ Copies the file to filename.orig and then over-writes the file
+ with the updated code
+ """
+ filestr = getfile(filename)
+ newstr, total = add_axis(filestr, import_change)
+ if total > 0:
+ print "Changing ", filename
+ copyfile(filename, filestr)
+ makenewfile(filename, newstr)
+ sys.stdout.flush()
+
+def fromargs(args):
+ filename = args[1]
+ convertfile(filename)
+
+def convertall(direc=os.path.curdir, import_change=False):
+ """Convert all .py files in the directory given
+
+ For each file, a backup of <usesnumeric>.py is made as
+ <usesnumeric>.py.orig. A new file named <usesnumeric>.py
+ is then written with the updated code.
+ """
+ files = glob.glob(os.path.join(direc,'*.py'))
+ for afile in files:
+ convertfile(afile, import_change)
+
+def _func(arg, dirname, fnames):
+ convertall(dirname, import_change=arg)
+
+def converttree(direc=os.path.curdir, import_change=False):
+ """Convert all .py files in the tree given
+
+ """
+ os.path.walk(direc, _func, import_change)
+
+if __name__ == '__main__':
+ fromargs(sys.argv)
More information about the Numpy-svn
mailing list