@@ -580,6 +580,8 @@ def to_datetime(self, dayfirst=False):
580
580
return DatetimeIndex (self .values )
581
581
582
582
def _assert_can_do_setop (self , other ):
583
+ if not com .is_list_like (other ):
584
+ raise TypeError ('Input must be Index or array-like' )
583
585
return True
584
586
585
587
@property
@@ -1364,16 +1366,14 @@ def union(self, other):
1364
1366
-------
1365
1367
union : Index
1366
1368
"""
1367
- if not hasattr (other , '__iter__' ):
1368
- raise TypeError ( 'Input must be iterable.' )
1369
+ self . _assert_can_do_setop (other )
1370
+ other = _ensure_index ( other )
1369
1371
1370
1372
if len (other ) == 0 or self .equals (other ):
1371
1373
return self
1372
1374
1373
1375
if len (self ) == 0 :
1374
- return _ensure_index (other )
1375
-
1376
- self ._assert_can_do_setop (other )
1376
+ return other
1377
1377
1378
1378
if not is_dtype_equal (self .dtype ,other .dtype ):
1379
1379
this = self .astype ('O' )
@@ -1439,11 +1439,7 @@ def intersection(self, other):
1439
1439
-------
1440
1440
intersection : Index
1441
1441
"""
1442
- if not hasattr (other , '__iter__' ):
1443
- raise TypeError ('Input must be iterable!' )
1444
-
1445
1442
self ._assert_can_do_setop (other )
1446
-
1447
1443
other = _ensure_index (other )
1448
1444
1449
1445
if self .equals (other ):
@@ -1492,9 +1488,7 @@ def difference(self, other):
1492
1488
1493
1489
>>> index.difference(index2)
1494
1490
"""
1495
-
1496
- if not hasattr (other , '__iter__' ):
1497
- raise TypeError ('Input must be iterable!' )
1491
+ self ._assert_can_do_setop (other )
1498
1492
1499
1493
if self .equals (other ):
1500
1494
return Index ([], name = self .name )
@@ -1517,7 +1511,7 @@ def sym_diff(self, other, result_name=None):
1517
1511
Parameters
1518
1512
----------
1519
1513
1520
- other : array-like
1514
+ other : Index or array-like
1521
1515
result_name : str
1522
1516
1523
1517
Returns
@@ -1545,9 +1539,7 @@ def sym_diff(self, other, result_name=None):
1545
1539
>>> idx1 ^ idx2
1546
1540
Int64Index([1, 5], dtype='int64')
1547
1541
"""
1548
- if not hasattr (other , '__iter__' ):
1549
- raise TypeError ('Input must be iterable!' )
1550
-
1542
+ self ._assert_can_do_setop (other )
1551
1543
if not isinstance (other , Index ):
1552
1544
other = Index (other )
1553
1545
result_name = result_name or self .name
@@ -5460,12 +5452,11 @@ def union(self, other):
5460
5452
>>> index.union(index2)
5461
5453
"""
5462
5454
self ._assert_can_do_setop (other )
5455
+ other , result_names = self ._convert_can_do_setop (other )
5463
5456
5464
5457
if len (other ) == 0 or self .equals (other ):
5465
5458
return self
5466
5459
5467
- result_names = self .names if self .names == other .names else None
5468
-
5469
5460
uniq_tuples = lib .fast_unique_multiple ([self .values , other .values ])
5470
5461
return MultiIndex .from_arrays (lzip (* uniq_tuples ), sortorder = 0 ,
5471
5462
names = result_names )
@@ -5483,12 +5474,11 @@ def intersection(self, other):
5483
5474
Index
5484
5475
"""
5485
5476
self ._assert_can_do_setop (other )
5477
+ other , result_names = self ._convert_can_do_setop (other )
5486
5478
5487
5479
if self .equals (other ):
5488
5480
return self
5489
5481
5490
- result_names = self .names if self .names == other .names else None
5491
-
5492
5482
self_tuples = self .values
5493
5483
other_tuples = other .values
5494
5484
uniq_tuples = sorted (set (self_tuples ) & set (other_tuples ))
@@ -5509,18 +5499,10 @@ def difference(self, other):
5509
5499
diff : MultiIndex
5510
5500
"""
5511
5501
self ._assert_can_do_setop (other )
5502
+ other , result_names = self ._convert_can_do_setop (other )
5512
5503
5513
- if not isinstance (other , MultiIndex ):
5514
- if len (other ) == 0 :
5504
+ if len (other ) == 0 :
5515
5505
return self
5516
- try :
5517
- other = MultiIndex .from_tuples (other )
5518
- except :
5519
- raise TypeError ('other must be a MultiIndex or a list of'
5520
- ' tuples' )
5521
- result_names = self .names
5522
- else :
5523
- result_names = self .names if self .names == other .names else None
5524
5506
5525
5507
if self .equals (other ):
5526
5508
return MultiIndex (levels = [[]] * self .nlevels ,
@@ -5537,15 +5519,29 @@ def difference(self, other):
5537
5519
return MultiIndex .from_tuples (difference , sortorder = 0 ,
5538
5520
names = result_names )
5539
5521
5540
- def _assert_can_do_setop (self , other ):
5541
- pass
5542
-
5543
5522
def astype (self , dtype ):
5544
5523
if not is_object_dtype (np .dtype (dtype )):
5545
5524
raise TypeError ('Setting %s dtype to anything other than object '
5546
5525
'is not supported' % self .__class__ )
5547
5526
return self ._shallow_copy ()
5548
5527
5528
+ def _convert_can_do_setop (self , other ):
5529
+ if not isinstance (other , MultiIndex ):
5530
+ if len (other ) == 0 :
5531
+ other = MultiIndex (levels = [[]] * self .nlevels ,
5532
+ labels = [[]] * self .nlevels ,
5533
+ verify_integrity = False )
5534
+ else :
5535
+ msg = 'other must be a MultiIndex or a list of tuples'
5536
+ try :
5537
+ other = MultiIndex .from_tuples (other )
5538
+ except :
5539
+ raise TypeError (msg )
5540
+ result_names = self .names
5541
+ else :
5542
+ result_names = self .names if self .names == other .names else None
5543
+ return other , result_names
5544
+
5549
5545
def insert (self , loc , item ):
5550
5546
"""
5551
5547
Make new MultiIndex inserting new item at location
0 commit comments