Skip to content

Commit a0e90ad

Browse files
authored
Fix regression from last PR (#223)
1 parent a7c50e8 commit a0e90ad

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

text_extensions_for_pandas/array/test_token_span.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,19 @@ def test_as_frame(self):
366366
)
367367
self.assertEqual(len(df), len(arr))
368368

369+
def test_multi_doc(self):
370+
arr1 = self._make_spans()
371+
372+
text2 = "Hello world."
373+
tokens2 = SpanArray(text2, [0, 6], [5, 11])
374+
arr2 = TokenSpanArray(tokens2, [0, 0], [1, 2])
375+
376+
series = pd.concat([pd.Series(arr1), pd.Series(arr2)])
377+
self.assertFalse(series.array.is_single_document)
378+
self.assertEqual(2, len(series.array.split_by_document()))
379+
self._assertArrayEquals(arr1, series.array.split_by_document()[0])
380+
self._assertArrayEquals(arr2, series.array.split_by_document()[1])
381+
369382

370383
@pytest.mark.skipif(LooseVersion(pa.__version__) < LooseVersion("2.0.0"),
371384
reason="Nested dictionaries only supported in Arrow >= 2.0.0")

text_extensions_for_pandas/array/token_span.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -983,20 +983,20 @@ def is_single_document(self) -> bool:
983983
# More than one tokenization and at least one span. Check whether
984984
# every span has the same text.
985985

986-
# Find the first text ID that is not NA
987-
first_text_id = None
988-
for b, t in zip(self._begins, self._text_ids):
986+
# Find the first span that is not NA
987+
first_target_text = None
988+
for b, t in zip(self._begin_tokens, self.target_text):
989989
if b != Span.NULL_OFFSET_VALUE:
990-
first_text_id = t
990+
first_target_text = t
991991
break
992-
if first_text_id is None:
992+
if first_target_text is None:
993993
# Special case: All NAs --> Zero documents
994994
return True
995995
return not np.any(np.logical_and(
996996
# Row is not null...
997-
np.not_equal(self._begins, Span.NULL_OFFSET_VALUE),
997+
np.not_equal(self._begin_tokens, Span.NULL_OFFSET_VALUE),
998998
# ...and is over a different text than the first row's text ID
999-
np.not_equal(self._text_ids, first_text_id)))
999+
np.not_equal(self.target_text, first_target_text)))
10001000

10011001
def split_by_document(self) -> List["SpanArray"]:
10021002
"""

0 commit comments

Comments
 (0)