diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-11-28 00:35:04 -0800 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-11-28 00:59:58 -0800 |
commit | 73995123775920369251fd401f0fd72d1ad7808e (patch) | |
tree | 63a74580ef8713d51ada87bc42cbd25fef7b3948 | |
parent | f0f8d6e412c62643ef702c1a46352f0ef267a1a1 (diff) | |
download | numpy-73995123775920369251fd401f0fd72d1ad7808e.tar.gz |
ENH: Improve alignment of datetime64 arrays containing NaT
Makes them consistent with timedelta
Fixes #10102
-rw-r--r-- | numpy/core/arrayprint.py | 65 | ||||
-rw-r--r-- | numpy/core/tests/test_arrayprint.py | 24 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 2 |
3 files changed, 62 insertions, 29 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 8399a47b2..460661df7 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -326,7 +326,7 @@ def _get_formatdict(data, **opt): ComplexFloatingFormat(data, prec, fmode, supp, sign, legacy=legacy), 'longcomplexfloat': lambda: ComplexFloatingFormat(data, prec, fmode, supp, sign, legacy=legacy), - 'datetime': lambda: DatetimeFormat(data), + 'datetime': lambda: DatetimeFormat(data, legacy=legacy), 'timedelta': lambda: TimedeltaFormat(data), 'object': lambda: _object_format, 'void': lambda: str_format, @@ -1051,8 +1051,35 @@ class LongComplexFormat(ComplexFloatingFormat): super(LongComplexFormat, self).__init__(*args, **kwargs) -class DatetimeFormat(object): - def __init__(self, x, unit=None, timezone=None, casting='same_kind'): +class _TimelikeFormat(object): + def __init__(self, data): + non_nat = data[~isnat(data)] + if len(non_nat) > 0: + # Max str length of non-NaT elements + max_str_len = max(len(self._format_non_nat(np.max(non_nat))), + len(self._format_non_nat(np.min(non_nat)))) + else: + max_str_len = 0 + if len(non_nat) < data.size: + # data contains a NaT + max_str_len = max(max_str_len, 5) + self._format = '%{}s'.format(max_str_len) + self._nat = "'NaT'".rjust(max_str_len) + + def _format_non_nat(self, x): + # override in subclass + raise NotImplementedError + + def __call__(self, x): + if isnat(x): + return self._nat + else: + return self._format % self._format_non_nat(x) + + +class DatetimeFormat(_TimelikeFormat): + def __init__(self, x, unit=None, timezone=None, casting='same_kind', + legacy=False): # Get the unit from the dtype if unit is None: if x.dtype.kind == 'M': @@ -1065,34 +1092,26 @@ class DatetimeFormat(object): self.timezone = timezone self.unit = unit self.casting = casting + self.legacy = legacy + + # must be called after the above are configured + super(DatetimeFormat, self).__init__(x) def __call__(self, x): + if self.legacy == '1.13': + return self._format_non_nat(x) + return super(DatetimeFormat, self).__call__(x) + + def _format_non_nat(self, x): return "'%s'" % datetime_as_string(x, unit=self.unit, timezone=self.timezone, casting=self.casting) -class TimedeltaFormat(object): - def __init__(self, data): - non_nat = data[~isnat(data)] - if len(non_nat) > 0: - # Max str length of non-NaT elements - max_str_len = max(len(str(np.max(non_nat).astype('i8'))), - len(str(np.min(non_nat).astype('i8')))) - else: - max_str_len = 0 - if len(non_nat) < data.size: - # data contains a NaT - max_str_len = max(max_str_len, 5) - self.format = '%' + str(max_str_len) + 'd' - self._nat = "'NaT'".rjust(max_str_len) - - def __call__(self, x): - if isnat(x): - return self._nat - else: - return self.format % x.astype('i8') +class TimedeltaFormat(_TimelikeFormat): + def _format_non_nat(self, x): + return str(x.astype('i8')) class SubArrayFormat(object): diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py index 32c96221d..9719e8668 100644 --- a/numpy/core/tests/test_arrayprint.py +++ b/numpy/core/tests/test_arrayprint.py @@ -169,15 +169,29 @@ class TestArray2String(object): assert_equal(np.array2string(x), "[('Sarah', [8., 7.]) ('John', [6., 7.])]") - # for issue #5692 - A = np.zeros(shape=10, dtype=[("A", "M8[s]")]) - A[5:].fill(np.datetime64('NaT')) + np.set_printoptions(legacy='1.13') + try: + # for issue #5692 + A = np.zeros(shape=10, dtype=[("A", "M8[s]")]) + A[5:].fill(np.datetime64('NaT')) + assert_equal( + np.array2string(A), + textwrap.dedent("""\ + [('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) + ('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) ('NaT',) ('NaT',) + ('NaT',) ('NaT',) ('NaT',)]""") + ) + finally: + np.set_printoptions(legacy=False) + + # same again, but with non-legacy behavior assert_equal( np.array2string(A), textwrap.dedent("""\ [('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) - ('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) ('NaT',) ('NaT',) - ('NaT',) ('NaT',) ('NaT',)]""") + ('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) ( 'NaT',) + ( 'NaT',) ( 'NaT',) ( 'NaT',) + ( 'NaT',)]""") ) # and again, with timedeltas diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index dc84a039c..638994aee 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -565,7 +565,7 @@ class TestDateTime(object): # Check that one NaT doesn't corrupt subsequent entries a = np.array(['2010', 'NaT', '2030']).astype('M') - assert_equal(str(a), "['2010' 'NaT' '2030']") + assert_equal(str(a), "['2010' 'NaT' '2030']") def test_timedelta_array_str(self): a = np.array([-1, 0, 100], dtype='m') |