[pypy-commit] pypy py3k: port datetime arg handling cleanups from 72e79a8305c7 and 43e61ecb2e40

bdkearns noreply at buildbot.pypy.org
Tue Mar 5 05:15:06 CET 2013


Author: Brian Kearns <bdkearns at gmail.com>
Branch: py3k
Changeset: r62049:5956d289d64f
Date: 2013-03-04 23:04 -0500
http://bitbucket.org/pypy/pypy/changeset/5956d289d64f/

Log:	port datetime arg handling cleanups from 72e79a8305c7 and
	43e61ecb2e40

diff --git a/lib-python/3/datetime.py b/lib-python/3/datetime.py
--- a/lib-python/3/datetime.py
+++ b/lib-python/3/datetime.py
@@ -267,9 +267,23 @@
                          " -timedelta(hours=24) and timedelta(hours=24)"
                          % (name, offset))
 
+def _check_int_field(value):
+    if isinstance(value, int):
+        return value
+    if not isinstance(value, float):
+        try:
+            value = value.__int__()
+        except AttributeError:
+            pass
+        else:
+            if isinstance(value, int):
+                return value
+    raise TypeError('an integer is required')
+
 def _check_date_fields(year, month, day):
-    if not isinstance(year, int):
-        raise TypeError('int expected')
+    year = _check_int_field(year)
+    month = _check_int_field(month)
+    day = _check_int_field(day)
     if not MINYEAR <= year <= MAXYEAR:
         raise ValueError('year must be in %d..%d' % (MINYEAR, MAXYEAR), year)
     if not 1 <= month <= 12:
@@ -277,10 +291,13 @@
     dim = _days_in_month(year, month)
     if not 1 <= day <= dim:
         raise ValueError('day must be in 1..%d' % dim, day)
+    return year, month, day
 
 def _check_time_fields(hour, minute, second, microsecond):
-    if not isinstance(hour, int):
-        raise TypeError('int expected')
+    hour = _check_int_field(hour)
+    minute = _check_int_field(minute)
+    second = _check_int_field(second)
+    microsecond = _check_int_field(microsecond)
     if not 0 <= hour <= 23:
         raise ValueError('hour must be in 0..23', hour)
     if not 0 <= minute <= 59:
@@ -289,6 +306,7 @@
         raise ValueError('second must be in 0..59', second)
     if not 0 <= microsecond <= 999999:
         raise ValueError('microsecond must be in 0..999999', microsecond)
+    return hour, minute, second, microsecond
 
 def _check_tzinfo_arg(tz):
     if tz is not None and not isinstance(tz, tzinfo):
@@ -674,7 +692,7 @@
             self = object.__new__(cls)
             self.__setstate(year)
             return self
-        _check_date_fields(year, month, day)
+        year, month, day = _check_date_fields(year, month, day)
         self = object.__new__(cls)
         self._year = year
         self._month = month
@@ -797,7 +815,7 @@
             month = self._month
         if day is None:
             day = self._day
-        _check_date_fields(year, month, day)
+        year, month, day = _check_date_fields(year, month, day)
         return date(year, month, day)
 
     # Comparisons of date objects with other.
@@ -1033,8 +1051,8 @@
             self = object.__new__(cls)
             self.__setstate(hour, minute or None)
             return self
+        hour, minute, second, microsecond = _check_time_fields(hour, minute, second, microsecond)
         _check_tzinfo_arg(tzinfo)
-        _check_time_fields(hour, minute, second, microsecond)
         self = object.__new__(cls)
         self._hour = hour
         self._minute = minute
@@ -1263,7 +1281,7 @@
             microsecond = self.microsecond
         if tzinfo is True:
             tzinfo = self.tzinfo
-        _check_time_fields(hour, minute, second, microsecond)
+        hour, minute, second, microsecond = _check_time_fields(hour, minute, second, microsecond)
         _check_tzinfo_arg(tzinfo)
         return time(hour, minute, second, microsecond, tzinfo)
 
@@ -1320,8 +1338,8 @@
             self = date.__new__(cls, year[:4])
             self.__setstate(year, month)
             return self
-        _check_date_fields(year, month, day)
-        _check_time_fields(hour, minute, second, microsecond)
+        year, month, day = _check_date_fields(year, month, day)
+        hour, minute, second, microsecond = _check_time_fields(hour, minute, second, microsecond)
         _check_tzinfo_arg(tzinfo)
         self = object.__new__(cls)
         self._year = year
@@ -1487,8 +1505,8 @@
             microsecond = self.microsecond
         if tzinfo is True:
             tzinfo = self.tzinfo
-        _check_date_fields(year, month, day)
-        _check_time_fields(hour, minute, second, microsecond)
+        year, month, day = _check_date_fields(year, month, day)
+        hour, minute, second, microsecond = _check_time_fields(hour, minute, second, microsecond)
         _check_tzinfo_arg(tzinfo)
         return datetime(year, month, day, hour, minute, second,
                           microsecond, tzinfo)
diff --git a/lib-python/3/test/datetimetester.py b/lib-python/3/test/datetimetester.py
--- a/lib-python/3/test/datetimetester.py
+++ b/lib-python/3/test/datetimetester.py
@@ -3688,6 +3688,50 @@
         a = datetime.timedelta()
         with self.assertRaises(AttributeError): a.abc = 1
 
+    def test_check_arg_types():
+        import decimal
+        class Number:
+            def __init__(self, value):
+                self.value = value
+            def __int__(self):
+                return self.value
+        i10 = 10
+        d10 = decimal.Decimal(10)
+        d11 = decimal.Decimal('10.9')
+        c10 = Number(10)
+        assert datetime.datetime(i10, i10, i10, i10, i10, i10, i10) == \
+               datetime.datetime(d10, d10, d10, d10, d10, d10, d10) == \
+               datetime.datetime(d11, d11, d11, d11, d11, d11, d11) == \
+               datetime.datetime(c10, c10, c10, c10, c10, c10, c10)
+
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, '10')
+
+        f10 = Number(10.9)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, f10)
+
+        class Float(float):
+            pass
+        s10 = Float(10.9)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, s10)
+
+        with self.assertRaises(TypeError):
+            datetime.datetime(10., 10, 10)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10., 10)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, 10.)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, 10, 10.)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, 10, 10, 10.)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, 10, 10, 10, 10.)
+        with self.assertRaises(TypeError):
+            datetime.datetime(10, 10, 10, 10, 10, 10, 10.)
+
 def test_main():
     support.run_unittest(__name__)
 


More information about the pypy-commit mailing list