Skip to content

Commit 9f2f039

Browse files
committed
Adding ScalarAggregateOptions to Any and All kernels.
1 parent 8f001fc commit 9f2f039

File tree

9 files changed

+187
-100
lines changed

9 files changed

+187
-100
lines changed

cpp/src/arrow/compute/api_aggregate.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ Result<Datum> MinMax(const Datum& value, const ScalarAggregateOptions& options,
4545
return CallFunction("min_max", {value}, &options, ctx);
4646
}
4747

48-
Result<Datum> Any(const Datum& value, ExecContext* ctx) {
49-
return CallFunction("any", {value}, ctx);
48+
Result<Datum> Any(const Datum& value, const ScalarAggregateOptions& options,
49+
ExecContext* ctx) {
50+
return CallFunction("any", {value}, &options, ctx);
5051
}
5152

52-
Result<Datum> All(const Datum& value, ExecContext* ctx) {
53-
return CallFunction("all", {value}, ctx);
53+
Result<Datum> All(const Datum& value, const ScalarAggregateOptions& options,
54+
ExecContext* ctx) {
55+
return CallFunction("all", {value}, &options, ctx);
5456
}
5557

5658
Result<Datum> Mode(const Datum& value, const ModeOptions& options, ExecContext* ctx) {

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,30 +201,38 @@ Result<Datum> MinMax(
201201
/// \brief Test whether any element in a boolean array evaluates to true.
202202
///
203203
/// This function returns true if any of the elements in the array evaluates
204-
/// to true and false otherwise. Null values are skipped.
204+
/// to true and false otherwise. Null values are ignored by default.
205205
///
206206
/// \param[in] value input datum, expecting a boolean array
207+
/// \param[in] options see ScalarAggregateOptions for more information
207208
/// \param[in] ctx the function execution context, optional
208209
/// \return resulting datum as a BooleanScalar
209210
///
210211
/// \since 3.0.0
211212
/// \note API not yet finalized
212213
ARROW_EXPORT
213-
Result<Datum> Any(const Datum& value, ExecContext* ctx = NULLPTR);
214+
Result<Datum> Any(
215+
const Datum& value,
216+
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
217+
ExecContext* ctx = NULLPTR);
214218

215219
/// \brief Test whether all elements in a boolean array evaluate to true.
216220
///
217221
/// This function returns true if all of the elements in the array evaluate
218-
/// to true and false otherwise. Null values are skipped.
222+
/// to true and false otherwise. Null values are ignored by default.
219223
///
220224
/// \param[in] value input datum, expecting a boolean array
225+
/// \param[in] options see ScalarAggregateOptions for more information
221226
/// \param[in] ctx the function execution context, optional
222227
/// \return resulting datum as a BooleanScalar
223228

224229
/// \since 3.0.0
225230
/// \note API not yet finalized
226231
ARROW_EXPORT
227-
Result<Datum> All(const Datum& value, ExecContext* ctx = NULLPTR);
232+
Result<Datum> All(
233+
const Datum& value,
234+
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
235+
ExecContext* ctx = NULLPTR);
228236

229237
/// \brief Calculate the modal (most common) value of a numeric array
230238
///

cpp/src/arrow/compute/kernels/aggregate_basic.cc

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,19 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
142142
// Any implementation
143143

144144
struct BooleanAnyImpl : public ScalarAggregator {
145+
explicit BooleanAnyImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
146+
145147
Status Consume(KernelContext*, const ExecBatch& batch) override {
146148
// short-circuit if seen a True already
147-
if (this->any == true) {
149+
if (options.skip_nulls && this->any == true) {
150+
return Status::OK();
151+
}
152+
// short-circuit if seen a null already
153+
if (!options.skip_nulls && this->has_nulls) {
148154
return Status::OK();
149155
}
150-
151156
const auto& data = *batch[0].array();
157+
this->has_nulls = data.GetNullCount() > 0;
152158
arrow::internal::OptionalBinaryBitBlockCounter counter(
153159
data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
154160
int64_t position = 0;
@@ -166,32 +172,48 @@ struct BooleanAnyImpl : public ScalarAggregator {
166172
Status MergeFrom(KernelContext*, KernelState&& src) override {
167173
const auto& other = checked_cast<const BooleanAnyImpl&>(src);
168174
this->any |= other.any;
175+
this->has_nulls |= other.has_nulls;
169176
return Status::OK();
170177
}
171178

172-
Status Finalize(KernelContext*, Datum* out) override {
173-
out->value = std::make_shared<BooleanScalar>(this->any);
179+
Status Finalize(KernelContext* ctx, Datum* out) override {
180+
if (!options.skip_nulls && this->has_nulls) {
181+
out->value = std::make_shared<BooleanScalar>();
182+
} else {
183+
out->value = std::make_shared<BooleanScalar>(this->any);
184+
}
174185
return Status::OK();
175186
}
176187

177188
bool any = false;
189+
bool has_nulls = false;
190+
ScalarAggregateOptions options;
178191
};
179192

180193
Result<std::unique_ptr<KernelState>> AnyInit(KernelContext*, const KernelInitArgs& args) {
181-
return ::arrow::internal::make_unique<BooleanAnyImpl>();
194+
const ScalarAggregateOptions options =
195+
static_cast<const ScalarAggregateOptions&>(*args.options);
196+
return ::arrow::internal::make_unique<BooleanAnyImpl>(
197+
static_cast<const ScalarAggregateOptions&>(*args.options));
182198
}
183199

184200
// ----------------------------------------------------------------------
185201
// All implementation
186202

187203
struct BooleanAllImpl : public ScalarAggregator {
204+
explicit BooleanAllImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
205+
188206
Status Consume(KernelContext*, const ExecBatch& batch) override {
189207
// short-circuit if seen a false already
190-
if (this->all == false) {
208+
if (options.skip_nulls && this->all == false) {
209+
return Status::OK();
210+
}
211+
// short-circuit if seen a null already
212+
if (!options.skip_nulls && this->has_nulls) {
191213
return Status::OK();
192214
}
193-
194215
const auto& data = *batch[0].array();
216+
this->has_nulls = data.GetNullCount() > 0;
195217
arrow::internal::OptionalBinaryBitBlockCounter counter(
196218
data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
197219
int64_t position = 0;
@@ -210,19 +232,27 @@ struct BooleanAllImpl : public ScalarAggregator {
210232
Status MergeFrom(KernelContext*, KernelState&& src) override {
211233
const auto& other = checked_cast<const BooleanAllImpl&>(src);
212234
this->all &= other.all;
235+
this->has_nulls |= other.has_nulls;
213236
return Status::OK();
214237
}
215238

216239
Status Finalize(KernelContext*, Datum* out) override {
217-
out->value = std::make_shared<BooleanScalar>(this->all);
240+
if (!options.skip_nulls && this->has_nulls) {
241+
out->value = std::make_shared<BooleanScalar>();
242+
} else {
243+
out->value = std::make_shared<BooleanScalar>(this->all);
244+
}
218245
return Status::OK();
219246
}
220247

221248
bool all = true;
249+
bool has_nulls = false;
250+
ScalarAggregateOptions options;
222251
};
223252

224253
Result<std::unique_ptr<KernelState>> AllInit(KernelContext*, const KernelInitArgs& args) {
225-
return ::arrow::internal::make_unique<BooleanAllImpl>();
254+
return ::arrow::internal::make_unique<BooleanAllImpl>(
255+
static_cast<const ScalarAggregateOptions&>(*args.options));
226256
}
227257

228258
// ----------------------------------------------------------------------
@@ -408,12 +438,16 @@ const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numer
408438
"ScalarAggregateOptions"};
409439

410440
const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
411-
("Null values are ignored."),
412-
{"array"}};
441+
("Null values are ignored by default.\n"
442+
"This can be changed through ScalarAggregateOptions."),
443+
{"array"},
444+
"ScalarAggregateOptions"};
413445

414446
const FunctionDoc all_doc{"Test whether all elements in a boolean array evaluate to true",
415-
("Null values are ignored."),
416-
{"array"}};
447+
("Null values are ignored by default.\n"
448+
"This can be changed through ScalarAggregateOptions."),
449+
{"array"},
450+
"ScalarAggregateOptions"};
417451

418452
const FunctionDoc index_doc{"Find the index of the first occurrence of a given value",
419453
("The result is always computed as an int64_t, regardless\n"
@@ -497,12 +531,14 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
497531
DCHECK_OK(registry->AddFunction(std::move(func)));
498532

499533
// any
500-
func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc);
534+
func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc,
535+
&default_scalar_aggregate_options);
501536
aggregate::AddBasicAggKernels(aggregate::AnyInit, {boolean()}, boolean(), func.get());
502537
DCHECK_OK(registry->AddFunction(std::move(func)));
503538

504539
// all
505-
func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc);
540+
func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc,
541+
&default_scalar_aggregate_options);
506542
aggregate::AddBasicAggKernels(aggregate::AllInit, {boolean()}, boolean(), func.get());
507543
DCHECK_OK(registry->AddFunction(std::move(func)));
508544

0 commit comments

Comments
 (0)