Skip to content

Commit 25c8c9f

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:cpu:xnn] Add a very basic single-threaded XnnDotThunk
- extract dot_lib library with a code shared between dot thunk implementations - add xnn_status conversion to xnn_interop PiperOrigin-RevId: 707427096
1 parent 7b99a6c commit 25c8c9f

13 files changed

+699
-121
lines changed

xla/backends/cpu/runtime/BUILD

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,37 @@ cc_library(
678678
],
679679
)
680680

681+
cc_library(
682+
name = "dot_lib",
683+
srcs = ["dot_lib.cc"],
684+
hdrs = ["dot_lib.h"],
685+
deps = [
686+
":thunk",
687+
"//xla:shape_util",
688+
"//xla:types",
689+
"//xla:util",
690+
"//xla:xla_data_proto_cc",
691+
"//xla/runtime:buffer_use",
692+
"//xla/service:buffer_assignment",
693+
"//xla/stream_executor:device_memory",
694+
"//xla/tsl/concurrency:async_value",
695+
"//xla/tsl/framework/contraction:eigen_contraction_kernel",
696+
"@com_google_absl//absl/algorithm:container",
697+
"@com_google_absl//absl/base:core_headers",
698+
"@com_google_absl//absl/container:inlined_vector",
699+
"@com_google_absl//absl/functional:any_invocable",
700+
"@com_google_absl//absl/memory",
701+
"@com_google_absl//absl/status:statusor",
702+
"@com_google_absl//absl/strings",
703+
"@com_google_absl//absl/strings:str_format",
704+
"@com_google_absl//absl/types:span",
705+
"@eigen_archive//:eigen3",
706+
"@tsl//tsl/platform:logging",
707+
"@tsl//tsl/platform:statusor",
708+
"@tsl//tsl/profiler/lib:traceme",
709+
],
710+
)
711+
681712
cc_library(
682713
name = "dot_thunk",
683714
srcs = [
@@ -691,6 +722,7 @@ cc_library(
691722
],
692723
hdrs = ["dot_thunk.h"],
693724
deps = [
725+
":dot_lib",
694726
":thunk",
695727
"//xla:shape_util",
696728
"//xla:types",

xla/backends/cpu/runtime/dot_lib.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/cpu/runtime/dot_lib.h"
17+
18+
#include <cstdint>
19+
#include <functional>
20+
#include <numeric>
21+
#include <utility>
22+
#include <vector>
23+
24+
#include "absl/algorithm/container.h"
25+
#include "absl/container/inlined_vector.h"
26+
#include "absl/status/statusor.h"
27+
#include "absl/strings/str_join.h"
28+
#include "xla/layout_util.h"
29+
#include "xla/runtime/buffer_use.h"
30+
#include "xla/shape.h"
31+
#include "xla/shape_util.h"
32+
#include "xla/util.h"
33+
34+
namespace xla::cpu {
35+
36+
absl::InlinedVector<BufferUse, 4> DotBufferUses(const DotSlices& slices) {
37+
return {BufferUse::Read(slices.lhs_buffer),
38+
BufferUse::Read(slices.rhs_buffer),
39+
BufferUse::Write(slices.out_buffer)};
40+
}
41+
42+
absl::StatusOr<DotShape> GetDotShape(DotDimensionNumbers dot_dimensions,
43+
const Shape& lhs_shape,
44+
const Shape& rhs_shape,
45+
const Shape& out_shape) {
46+
// All shapes must be in dim0-major layout.
47+
if (!LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) ||
48+
!LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) ||
49+
!LayoutUtil::IsMonotonicWithDim0Major(out_shape.layout())) {
50+
return InvalidArgument(
51+
"DotThunk requires all operands and outputs to be in "
52+
"dim0-major layout: lhs_shape=[%s], rhs_shape=[%s], out_shape=[%s]",
53+
lhs_shape.ToString(true), rhs_shape.ToString(true),
54+
out_shape.ToString(true));
55+
}
56+
57+
// Batch dimensions must be contiguous and start at 0.
58+
std::vector<int64_t> batch_dims(dot_dimensions.lhs_batch_dimensions().size());
59+
absl::c_iota(batch_dims, 0);
60+
61+
if (!absl::c_equal(dot_dimensions.lhs_batch_dimensions(), batch_dims) ||
62+
!absl::c_equal(dot_dimensions.rhs_batch_dimensions(), batch_dims)) {
63+
return InvalidArgument(
64+
"Batch dimensions must be contiguous and start at 0: "
65+
"lhs_batch_dims=[%s], rhs_batch_dims=[%s]",
66+
absl::StrJoin(dot_dimensions.lhs_batch_dimensions(), ","),
67+
absl::StrJoin(dot_dimensions.rhs_batch_dimensions(), ","));
68+
}
69+
70+
int64_t num_batch_dims = batch_dims.size();
71+
int64_t batch_size =
72+
std::accumulate(out_shape.dimensions().begin(),
73+
out_shape.dimensions().begin() + num_batch_dims, 1LL,
74+
std::multiplies<int64_t>());
75+
76+
Shape lhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, lhs_shape);
77+
Shape rhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, rhs_shape);
78+
Shape out_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, out_shape);
79+
80+
// Check that matmul shapes are rank 2 or less and can be represented as
81+
// Eigen 2D contraction.
82+
if (lhs_matmul_shape.rank() > 2 || rhs_matmul_shape.rank() > 2 ||
83+
out_matmul_shape.rank() > 2) {
84+
return InvalidArgument(
85+
"MatMul shape must be rank 2 or less: lhs=%s, rhs=%s, out=%s",
86+
lhs_matmul_shape.ToString(true), rhs_matmul_shape.ToString(true),
87+
out_matmul_shape.ToString(true));
88+
}
89+
90+
return DotShape{
91+
batch_size,
92+
std::move(lhs_matmul_shape),
93+
std::move(rhs_matmul_shape),
94+
std::move(out_matmul_shape),
95+
};
96+
}
97+
98+
} // namespace xla::cpu

xla/backends/cpu/runtime/dot_lib.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_
17+
#define XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_
18+
19+
#include <cstdint>
20+
21+
#include "absl/container/inlined_vector.h"
22+
#include "absl/status/statusor.h"
23+
#include "xla/runtime/buffer_use.h"
24+
#include "xla/service/buffer_assignment.h"
25+
#include "xla/shape.h"
26+
27+
namespace xla::cpu {
28+
29+
// Allocation slices of the dot operation.
30+
struct DotSlices {
31+
BufferAllocation::Slice lhs_buffer;
32+
Shape lhs_shape;
33+
34+
BufferAllocation::Slice rhs_buffer;
35+
Shape rhs_shape;
36+
37+
BufferAllocation::Slice out_buffer;
38+
Shape out_shape;
39+
};
40+
41+
// Shape of the batched dot operation supported by the XLA:CPU runtime.
42+
struct DotShape {
43+
// Product of batch dimensions.
44+
int64_t batch_size;
45+
46+
// Shapes of the non-batch matrix-multiplication for the dot operation
47+
Shape lhs_matmul_shape;
48+
Shape rhs_matmul_shape;
49+
Shape out_matmul_shape;
50+
};
51+
52+
// Returns buffer uses of the dot operation.
53+
absl::InlinedVector<BufferUse, 4> DotBufferUses(const DotSlices& slices);
54+
55+
// Verifies dot dimensions and shapes and returns the shape of the dot operation
56+
// in a form that is convenient for the runtime implementation.
57+
absl::StatusOr<DotShape> GetDotShape(DotDimensionNumbers dot_dimensions,
58+
const Shape& lhs_shape,
59+
const Shape& rhs_shape,
60+
const Shape& out_shape);
61+
62+
} // namespace xla::cpu
63+
64+
#endif // XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_

0 commit comments

Comments
 (0)