Skip to content

Commit 1c002fc

Browse files
rokjonkeane
authored andcommitted
ARROW-12499: [C++][Compute] Add ScalarAggregateOptions to Any and All kernels
This is to resolve [ARROW-12499](https://issues.apache.org/jira/browse/ARROW-12499). Closes #10476 from rok/ARROW-12499 Lead-authored-by: Rok <[email protected]> Co-authored-by: Rok Mihevc <[email protected]> Signed-off-by: Jonathan Keane <[email protected]>
1 parent 6db88a9 commit 1c002fc

File tree

9 files changed

+203
-98
lines changed

9 files changed

+203
-98
lines changed

cpp/src/arrow/compute/api_aggregate.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,14 @@ Result<Datum> MinMax(const Datum& value, const ScalarAggregateOptions& options,
155155
return CallFunction("min_max", {value}, &options, ctx);
156156
}
157157

158-
Result<Datum> Any(const Datum& value, ExecContext* ctx) {
159-
return CallFunction("any", {value}, ctx);
158+
Result<Datum> Any(const Datum& value, const ScalarAggregateOptions& options,
159+
ExecContext* ctx) {
160+
return CallFunction("any", {value}, &options, ctx);
160161
}
161162

162-
Result<Datum> All(const Datum& value, ExecContext* ctx) {
163-
return CallFunction("all", {value}, ctx);
163+
Result<Datum> All(const Datum& value, const ScalarAggregateOptions& options,
164+
ExecContext* ctx) {
165+
return CallFunction("all", {value}, &options, ctx);
164166
}
165167

166168
Result<Datum> Mode(const Datum& value, const ModeOptions& options, ExecContext* ctx) {

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,30 +205,44 @@ Result<Datum> MinMax(
205205
/// \brief Test whether any element in a boolean array evaluates to true.
206206
///
207207
/// This function returns true if any of the elements in the array evaluates
208-
/// to true and false otherwise. Null values are skipped.
208+
/// to true and false otherwise. Null values are ignored by default.
209+
/// If null values are taken into account by setting ScalarAggregateOptions
210+
/// parameter skip_nulls = false then Kleene logic is used.
211+
/// See KleeneOr for more details on Kleene logic.
209212
///
210213
/// \param[in] value input datum, expecting a boolean array
214+
/// \param[in] options see ScalarAggregateOptions for more information
211215
/// \param[in] ctx the function execution context, optional
212216
/// \return resulting datum as a BooleanScalar
213217
///
214218
/// \since 3.0.0
215219
/// \note API not yet finalized
216220
ARROW_EXPORT
217-
Result<Datum> Any(const Datum& value, ExecContext* ctx = NULLPTR);
221+
Result<Datum> Any(
222+
const Datum& value,
223+
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
224+
ExecContext* ctx = NULLPTR);
218225

219226
/// \brief Test whether all elements in a boolean array evaluate to true.
220227
///
221228
/// This function returns true if all of the elements in the array evaluate
222-
/// to true and false otherwise. Null values are skipped.
229+
/// to true and false otherwise. Null values are ignored by default.
230+
/// If null values are taken into account by setting ScalarAggregateOptions
231+
/// parameter skip_nulls = false then Kleene logic is used.
232+
/// See KleeneAnd for more details on Kleene logic.
223233
///
224234
/// \param[in] value input datum, expecting a boolean array
235+
/// \param[in] options see ScalarAggregateOptions for more information
225236
/// \param[in] ctx the function execution context, optional
226237
/// \return resulting datum as a BooleanScalar
227238

228239
/// \since 3.0.0
229240
/// \note API not yet finalized
230241
ARROW_EXPORT
231-
Result<Datum> All(const Datum& value, ExecContext* ctx = NULLPTR);
242+
Result<Datum> All(
243+
const Datum& value,
244+
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
245+
ExecContext* ctx = NULLPTR);
232246

233247
/// \brief Calculate the modal (most common) value of a numeric array
234248
///

cpp/src/arrow/compute/kernels/aggregate_basic.cc

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,15 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
142142
// Any implementation
143143

144144
struct BooleanAnyImpl : public ScalarAggregator {
145+
explicit BooleanAnyImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
146+
145147
Status Consume(KernelContext*, const ExecBatch& batch) override {
146148
// short-circuit if seen a True already
147149
if (this->any == true) {
148150
return Status::OK();
149151
}
150-
151152
const auto& data = *batch[0].array();
153+
this->has_nulls = data.GetNullCount() > 0;
152154
arrow::internal::OptionalBinaryBitBlockCounter counter(
153155
data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
154156
int64_t position = 0;
@@ -166,32 +168,48 @@ struct BooleanAnyImpl : public ScalarAggregator {
166168
Status MergeFrom(KernelContext*, KernelState&& src) override {
167169
const auto& other = checked_cast<const BooleanAnyImpl&>(src);
168170
this->any |= other.any;
171+
this->has_nulls |= other.has_nulls;
169172
return Status::OK();
170173
}
171174

172-
Status Finalize(KernelContext*, Datum* out) override {
173-
out->value = std::make_shared<BooleanScalar>(this->any);
175+
Status Finalize(KernelContext* ctx, Datum* out) override {
176+
if (!options.skip_nulls && !this->any && this->has_nulls) {
177+
out->value = std::make_shared<BooleanScalar>();
178+
} else {
179+
out->value = std::make_shared<BooleanScalar>(this->any);
180+
}
174181
return Status::OK();
175182
}
176183

177184
bool any = false;
185+
bool has_nulls = false;
186+
ScalarAggregateOptions options;
178187
};
179188

180189
Result<std::unique_ptr<KernelState>> AnyInit(KernelContext*, const KernelInitArgs& args) {
181-
return ::arrow::internal::make_unique<BooleanAnyImpl>();
190+
const ScalarAggregateOptions options =
191+
static_cast<const ScalarAggregateOptions&>(*args.options);
192+
return ::arrow::internal::make_unique<BooleanAnyImpl>(
193+
static_cast<const ScalarAggregateOptions&>(*args.options));
182194
}
183195

184196
// ----------------------------------------------------------------------
185197
// All implementation
186198

187199
struct BooleanAllImpl : public ScalarAggregator {
200+
explicit BooleanAllImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
201+
188202
Status Consume(KernelContext*, const ExecBatch& batch) override {
189203
// short-circuit if seen a false already
190204
if (this->all == false) {
191205
return Status::OK();
192206
}
193-
207+
// short-circuit if seen a null already
208+
if (!options.skip_nulls && this->has_nulls) {
209+
return Status::OK();
210+
}
194211
const auto& data = *batch[0].array();
212+
this->has_nulls = data.GetNullCount() > 0;
195213
arrow::internal::OptionalBinaryBitBlockCounter counter(
196214
data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
197215
int64_t position = 0;
@@ -210,19 +228,27 @@ struct BooleanAllImpl : public ScalarAggregator {
210228
Status MergeFrom(KernelContext*, KernelState&& src) override {
211229
const auto& other = checked_cast<const BooleanAllImpl&>(src);
212230
this->all &= other.all;
231+
this->has_nulls |= other.has_nulls;
213232
return Status::OK();
214233
}
215234

216235
Status Finalize(KernelContext*, Datum* out) override {
217-
out->value = std::make_shared<BooleanScalar>(this->all);
236+
if (!options.skip_nulls && this->all && this->has_nulls) {
237+
out->value = std::make_shared<BooleanScalar>();
238+
} else {
239+
out->value = std::make_shared<BooleanScalar>(this->all);
240+
}
218241
return Status::OK();
219242
}
220243

221244
bool all = true;
245+
bool has_nulls = false;
246+
ScalarAggregateOptions options;
222247
};
223248

224249
Result<std::unique_ptr<KernelState>> AllInit(KernelContext*, const KernelInitArgs& args) {
225-
return ::arrow::internal::make_unique<BooleanAllImpl>();
250+
return ::arrow::internal::make_unique<BooleanAllImpl>(
251+
static_cast<const ScalarAggregateOptions&>(*args.options));
226252
}
227253

228254
// ----------------------------------------------------------------------
@@ -407,12 +433,22 @@ const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numer
407433
"ScalarAggregateOptions"};
408434

409435
const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
410-
("Null values are ignored."),
411-
{"array"}};
436+
("Null values are ignored by default.\n"
437+
"If null values are taken into account by setting "
438+
"ScalarAggregateOptions parameter skip_nulls = false then "
439+
"Kleene logic is used.\n"
440+
"See KleeneOr for more details on Kleene logic."),
441+
{"array"},
442+
"ScalarAggregateOptions"};
412443

413444
const FunctionDoc all_doc{"Test whether all elements in a boolean array evaluate to true",
414-
("Null values are ignored."),
415-
{"array"}};
445+
("Null values are ignored by default.\n"
446+
"If null values are taken into account by setting "
447+
"ScalarAggregateOptions parameter skip_nulls = false then "
448+
"Kleene logic is used.\n"
449+
"See KleeneAnd for more details on Kleene logic."),
450+
{"array"},
451+
"ScalarAggregateOptions"};
416452

417453
const FunctionDoc index_doc{"Find the index of the first occurrence of a given value",
418454
("The result is always computed as an int64_t, regardless\n"
@@ -496,12 +532,14 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
496532
DCHECK_OK(registry->AddFunction(std::move(func)));
497533

498534
// any
499-
func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc);
535+
func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc,
536+
&default_scalar_aggregate_options);
500537
aggregate::AddBasicAggKernels(aggregate::AnyInit, {boolean()}, boolean(), func.get());
501538
DCHECK_OK(registry->AddFunction(std::move(func)));
502539

503540
// all
504-
func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc);
541+
func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc,
542+
&default_scalar_aggregate_options);
505543
aggregate::AddBasicAggKernels(aggregate::AllInit, {boolean()}, boolean(), func.get());
506544
DCHECK_OK(registry->AddFunction(std::move(func)));
507545

0 commit comments

Comments
 (0)