Skip to content

Commit e32c45f

Browse files
authored
Rollup merge of rust-lang#55448 - Mokosha:SortAtIndex, r=bluss
Add 'partition_at_index/_by/_by_key' for slices. This is an analog to C++'s std::nth_element (a.k.a. quickselect). Corresponds to tracking bug rust-lang#55300.
2 parents 4c2be9c + 40faae8 commit e32c45f

File tree

4 files changed

+353
-0
lines changed

4 files changed

+353
-0
lines changed

src/libcore/slice/mod.rs

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,153 @@ impl<T> [T] {
15821582
sort::quicksort(self, |a, b| f(a).lt(&f(b)));
15831583
}
15841584

1585+
/// Reorder the slice such that the element at `index` is at its final sorted position.
1586+
///
1587+
/// This reordering has the additional property that any value at position `i < index` will be
1588+
/// less than or equal to any value at a position `j > index`. Additionally, this reordering is
1589+
/// unstable (i.e. any number of equal elements may end up at position `index`), in-place
1590+
/// (i.e. does not allocate), and `O(n)` worst-case. This function is also/ known as "kth
1591+
/// element" in other libraries. It returns a triplet of the following values: all elements less
1592+
/// than the one at the given index, the value at the given index, and all elements greater than
1593+
/// the one at the given index.
1594+
///
1595+
/// # Current implementation
1596+
///
1597+
/// The current algorithm is based on the quickselect portion of the same quicksort algorithm
1598+
/// used for [`sort_unstable`].
1599+
///
1600+
/// [`sort_unstable`]: #method.sort_unstable
1601+
///
1602+
/// # Panics
1603+
///
1604+
/// Panics when `index >= len()`, meaning it always panics on empty slices.
1605+
///
1606+
/// # Examples
1607+
///
1608+
/// ```
1609+
/// #![feature(slice_partition_at_index)]
1610+
///
1611+
/// let mut v = [-5i32, 4, 1, -3, 2];
1612+
///
1613+
/// // Find the median
1614+
/// v.partition_at_index(2);
1615+
///
1616+
/// // We are only guaranteed the slice will be one of the following, based on the way we sort
1617+
/// // about the specified index.
1618+
/// assert!(v == [-3, -5, 1, 2, 4] ||
1619+
/// v == [-5, -3, 1, 2, 4] ||
1620+
/// v == [-3, -5, 1, 4, 2] ||
1621+
/// v == [-5, -3, 1, 4, 2]);
1622+
/// ```
1623+
#[unstable(feature = "slice_partition_at_index", issue = "55300")]
1624+
#[inline]
1625+
pub fn partition_at_index(&mut self, index: usize) -> (&mut [T], &mut T, &mut [T])
1626+
where T: Ord
1627+
{
1628+
let mut f = |a: &T, b: &T| a.lt(b);
1629+
sort::partition_at_index(self, index, &mut f)
1630+
}
1631+
1632+
/// Reorder the slice with a comparator function such that the element at `index` is at its
1633+
/// final sorted position.
1634+
///
1635+
/// This reordering has the additional property that any value at position `i < index` will be
1636+
/// less than or equal to any value at a position `j > index` using the comparator function.
1637+
/// Additionally, this reordering is unstable (i.e. any number of equal elements may end up at
1638+
/// position `index`), in-place (i.e. does not allocate), and `O(n)` worst-case. This function
1639+
/// is also known as "kth element" in other libraries. It returns a triplet of the following
1640+
/// values: all elements less than the one at the given index, the value at the given index,
1641+
/// and all elements greater than the one at the given index, using the provided comparator
1642+
/// function.
1643+
///
1644+
/// # Current implementation
1645+
///
1646+
/// The current algorithm is based on the quickselect portion of the same quicksort algorithm
1647+
/// used for [`sort_unstable`].
1648+
///
1649+
/// [`sort_unstable`]: #method.sort_unstable
1650+
///
1651+
/// # Panics
1652+
///
1653+
/// Panics when `index >= len()`, meaning it always panics on empty slices.
1654+
///
1655+
/// # Examples
1656+
///
1657+
/// ```
1658+
/// #![feature(slice_partition_at_index)]
1659+
///
1660+
/// let mut v = [-5i32, 4, 1, -3, 2];
1661+
///
1662+
/// // Find the median as if the slice were sorted in descending order.
1663+
/// v.partition_at_index_by(2, |a, b| b.cmp(a));
1664+
///
1665+
/// // We are only guaranteed the slice will be one of the following, based on the way we sort
1666+
/// // about the specified index.
1667+
/// assert!(v == [2, 4, 1, -5, -3] ||
1668+
/// v == [2, 4, 1, -3, -5] ||
1669+
/// v == [4, 2, 1, -5, -3] ||
1670+
/// v == [4, 2, 1, -3, -5]);
1671+
/// ```
1672+
#[unstable(feature = "slice_partition_at_index", issue = "55300")]
1673+
#[inline]
1674+
pub fn partition_at_index_by<F>(&mut self, index: usize, mut compare: F)
1675+
-> (&mut [T], &mut T, &mut [T])
1676+
where F: FnMut(&T, &T) -> Ordering
1677+
{
1678+
let mut f = |a: &T, b: &T| compare(a, b) == Less;
1679+
sort::partition_at_index(self, index, &mut f)
1680+
}
1681+
1682+
/// Reorder the slice with a key extraction function such that the element at `index` is at its
1683+
/// final sorted position.
1684+
///
1685+
/// This reordering has the additional property that any value at position `i < index` will be
1686+
/// less than or equal to any value at a position `j > index` using the key extraction function.
1687+
/// Additionally, this reordering is unstable (i.e. any number of equal elements may end up at
1688+
/// position `index`), in-place (i.e. does not allocate), and `O(n)` worst-case. This function
1689+
/// is also known as "kth element" in other libraries. It returns a triplet of the following
1690+
/// values: all elements less than the one at the given index, the value at the given index, and
1691+
/// all elements greater than the one at the given index, using the provided key extraction
1692+
/// function.
1693+
///
1694+
/// # Current implementation
1695+
///
1696+
/// The current algorithm is based on the quickselect portion of the same quicksort algorithm
1697+
/// used for [`sort_unstable`].
1698+
///
1699+
/// [`sort_unstable`]: #method.sort_unstable
1700+
///
1701+
/// # Panics
1702+
///
1703+
/// Panics when `index >= len()`, meaning it always panics on empty slices.
1704+
///
1705+
/// # Examples
1706+
///
1707+
/// ```
1708+
/// #![feature(slice_partition_at_index)]
1709+
///
1710+
/// let mut v = [-5i32, 4, 1, -3, 2];
1711+
///
1712+
/// // Return the median as if the array were sorted according to absolute value.
1713+
/// v.partition_at_index_by_key(2, |a| a.abs());
1714+
///
1715+
/// // We are only guaranteed the slice will be one of the following, based on the way we sort
1716+
/// // about the specified index.
1717+
/// assert!(v == [1, 2, -3, 4, -5] ||
1718+
/// v == [1, 2, -3, -5, 4] ||
1719+
/// v == [2, 1, -3, 4, -5] ||
1720+
/// v == [2, 1, -3, -5, 4]);
1721+
/// ```
1722+
#[unstable(feature = "slice_partition_at_index", issue = "55300")]
1723+
#[inline]
1724+
pub fn partition_at_index_by_key<K, F>(&mut self, index: usize, mut f: F)
1725+
-> (&mut [T], &mut T, &mut [T])
1726+
where F: FnMut(&T) -> K, K: Ord
1727+
{
1728+
let mut g = |a: &T, b: &T| f(a).lt(&f(b));
1729+
sort::partition_at_index(self, index, &mut g)
1730+
}
1731+
15851732
/// Moves all consecutive repeated elements to the end of the slice according to the
15861733
/// [`PartialEq`] trait implementation.
15871734
///

src/libcore/slice/sort.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,92 @@ pub fn quicksort<T, F>(v: &mut [T], mut is_less: F)
691691

692692
recurse(v, &mut is_less, None, limit);
693693
}
694+
695+
fn partition_at_index_loop<'a, T, F>( mut v: &'a mut [T], mut index: usize, is_less: &mut F
696+
, mut pred: Option<&'a T>) where F: FnMut(&T, &T) -> bool
697+
{
698+
loop {
699+
// For slices of up to this length it's probably faster to simply sort them.
700+
const MAX_INSERTION: usize = 10;
701+
if v.len() <= MAX_INSERTION {
702+
insertion_sort(v, is_less);
703+
return;
704+
}
705+
706+
// Choose a pivot
707+
let (pivot, _) = choose_pivot(v, is_less);
708+
709+
// If the chosen pivot is equal to the predecessor, then it's the smallest element in the
710+
// slice. Partition the slice into elements equal to and elements greater than the pivot.
711+
// This case is usually hit when the slice contains many duplicate elements.
712+
if let Some(p) = pred {
713+
if !is_less(p, &v[pivot]) {
714+
let mid = partition_equal(v, pivot, is_less);
715+
716+
// If we've passed our index, then we're good.
717+
if mid > index {
718+
return;
719+
}
720+
721+
// Otherwise, continue sorting elements greater than the pivot.
722+
v = &mut v[mid..];
723+
index = index - mid;
724+
pred = None;
725+
continue;
726+
}
727+
}
728+
729+
let (mid, _) = partition(v, pivot, is_less);
730+
731+
// Split the slice into `left`, `pivot`, and `right`.
732+
let (left, right) = {v}.split_at_mut(mid);
733+
let (pivot, right) = right.split_at_mut(1);
734+
let pivot = &pivot[0];
735+
736+
if mid < index {
737+
v = right;
738+
index = index - mid - 1;
739+
pred = Some(pivot);
740+
} else if mid > index {
741+
v = left;
742+
} else {
743+
// If mid == index, then we're done, since partition() guaranteed that all elements
744+
// after mid are greater than or equal to mid.
745+
return;
746+
}
747+
}
748+
}
749+
750+
pub fn partition_at_index<T, F>(v: &mut [T], index: usize, mut is_less: F)
751+
-> (&mut [T], &mut T, &mut [T]) where F: FnMut(&T, &T) -> bool
752+
{
753+
use cmp::Ordering::Less;
754+
use cmp::Ordering::Greater;
755+
756+
if index >= v.len() {
757+
panic!("partition_at_index index {} greater than length of slice {}", index, v.len());
758+
}
759+
760+
if mem::size_of::<T>() == 0 {
761+
// Sorting has no meaningful behavior on zero-sized types. Do nothing.
762+
} else if index == v.len() - 1 {
763+
// Find max element and place it in the last position of the array. We're free to use
764+
// `unwrap()` here because we know v must not be empty.
765+
let (max_index, _) = v.iter().enumerate().max_by(
766+
|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater }).unwrap();
767+
v.swap(max_index, index);
768+
} else if index == 0 {
769+
// Find min element and place it in the first position of the array. We're free to use
770+
// `unwrap()` here because we know v must not be empty.
771+
let (min_index, _) = v.iter().enumerate().min_by(
772+
|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater }).unwrap();
773+
v.swap(min_index, index);
774+
} else {
775+
partition_at_index_loop(v, index, &mut is_less, None);
776+
}
777+
778+
let (left, right) = v.split_at_mut(index);
779+
let (pivot, right) = right.split_at_mut(1);
780+
let pivot = &mut pivot[0];
781+
(left, pivot, right)
782+
}

src/libcore/tests/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#![feature(refcell_replace_swap)]
2323
#![feature(slice_patterns)]
2424
#![feature(sort_internals)]
25+
#![feature(slice_partition_at_index)]
2526
#![feature(specialization)]
2627
#![feature(step_trait)]
2728
#![feature(str_internals)]

src/libcore/tests/slice.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,122 @@ fn sort_unstable() {
10791079
assert!(v == [0xDEADBEEF]);
10801080
}
10811081

1082+
#[test]
1083+
fn partition_at_index() {
1084+
use core::cmp::Ordering::{Equal, Greater, Less};
1085+
use rand::rngs::SmallRng;
1086+
use rand::seq::SliceRandom;
1087+
use rand::{FromEntropy, Rng};
1088+
1089+
let mut rng = SmallRng::from_entropy();
1090+
1091+
for len in (2..21).chain(500..501) {
1092+
let mut orig = vec![0; len];
1093+
1094+
for &modulus in &[5, 10, 1000] {
1095+
for _ in 0..10 {
1096+
for i in 0..len {
1097+
orig[i] = rng.gen::<i32>() % modulus;
1098+
}
1099+
1100+
let v_sorted = {
1101+
let mut v = orig.clone();
1102+
v.sort();
1103+
v
1104+
};
1105+
1106+
// Sort in default order.
1107+
for pivot in 0..len {
1108+
let mut v = orig.clone();
1109+
v.partition_at_index(pivot);
1110+
1111+
assert_eq!(v_sorted[pivot], v[pivot]);
1112+
for i in 0..pivot {
1113+
for j in pivot..len {
1114+
assert!(v[i] <= v[j]);
1115+
}
1116+
}
1117+
}
1118+
1119+
// Sort in ascending order.
1120+
for pivot in 0..len {
1121+
let mut v = orig.clone();
1122+
let (left, pivot, right) = v.partition_at_index_by(pivot, |a, b| a.cmp(b));
1123+
1124+
assert_eq!(left.len() + right.len(), len - 1);
1125+
1126+
for l in left {
1127+
assert!(l <= pivot);
1128+
for r in right.iter_mut() {
1129+
assert!(l <= r);
1130+
assert!(pivot <= r);
1131+
}
1132+
}
1133+
}
1134+
1135+
// Sort in descending order.
1136+
let sort_descending_comparator = |a: &i32, b: &i32| b.cmp(a);
1137+
let v_sorted_descending = {
1138+
let mut v = orig.clone();
1139+
v.sort_by(sort_descending_comparator);
1140+
v
1141+
};
1142+
1143+
for pivot in 0..len {
1144+
let mut v = orig.clone();
1145+
v.partition_at_index_by(pivot, sort_descending_comparator);
1146+
1147+
assert_eq!(v_sorted_descending[pivot], v[pivot]);
1148+
for i in 0..pivot {
1149+
for j in pivot..len {
1150+
assert!(v[j] <= v[i]);
1151+
}
1152+
}
1153+
}
1154+
}
1155+
}
1156+
}
1157+
1158+
// Sort at index using a completely random comparison function.
1159+
// This will reorder the elements *somehow*, but won't panic.
1160+
let mut v = [0; 500];
1161+
for i in 0..v.len() {
1162+
v[i] = i as i32;
1163+
}
1164+
1165+
for pivot in 0..v.len() {
1166+
v.partition_at_index_by(pivot, |_, _| *[Less, Equal, Greater].choose(&mut rng).unwrap());
1167+
v.sort();
1168+
for i in 0..v.len() {
1169+
assert_eq!(v[i], i as i32);
1170+
}
1171+
}
1172+
1173+
// Should not panic.
1174+
[(); 10].partition_at_index(0);
1175+
[(); 10].partition_at_index(5);
1176+
[(); 10].partition_at_index(9);
1177+
[(); 100].partition_at_index(0);
1178+
[(); 100].partition_at_index(50);
1179+
[(); 100].partition_at_index(99);
1180+
1181+
let mut v = [0xDEADBEEFu64];
1182+
v.partition_at_index(0);
1183+
assert!(v == [0xDEADBEEF]);
1184+
}
1185+
1186+
#[test]
1187+
#[should_panic(expected = "index 0 greater than length of slice")]
1188+
fn partition_at_index_zero_length() {
1189+
[0i32; 0].partition_at_index(0);
1190+
}
1191+
1192+
#[test]
1193+
#[should_panic(expected = "index 20 greater than length of slice")]
1194+
fn partition_at_index_past_length() {
1195+
[0i32; 10].partition_at_index(20);
1196+
}
1197+
10821198
pub mod memchr {
10831199
use core::slice::memchr::{memchr, memrchr};
10841200

0 commit comments

Comments
 (0)