diff --git a/BUILD.bazel b/BUILD.bazel index 022464ce..8e3eb5c9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -83,6 +83,8 @@ cc_library( "@hwy//:matvec", "@hwy//:profiler", "@hwy//:thread_pool", + "//third_party/intel_dnnl:dnnl", + "//third_party/tbb", ], ) diff --git a/gemma/gemma.h b/gemma/gemma.h index ea252817..33f6dec1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -121,7 +121,7 @@ struct RuntimeConfig { const ImageTokens *image_tokens = nullptr; // Whether to use thread spinning to reduce barrier synchronization latency. - bool use_spinning = true; + bool use_spinning = false; // End-of-sequence token. int eos_id = EOS_ID; diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index b2678a36..3c6a11a3 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -31,11 +31,20 @@ // After highway.h #include "compression/compress-inl.h" #include "hwy/contrib/math/math-inl.h" +#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp" +#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_common.hpp" +#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_types.h" + HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +using namespace dnnl; +using tag = memory::format_tag; +using dt = memory::data_type; +using dnnl::primitive_attr; +using dnnl::reorder; // The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of // loads, we reuse the same A row for several B columns, which are also loaded @@ -66,7 +75,9 @@ constexpr size_t kRegRows = kRegCols; // expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B // to bf16 if the native op is available. This will actually demote f32 // activations to bf16. Otherwise, we decompress to f32 and use normal FMA. -using MulT = hwy::If; +// Update the MulT, so Highway matmul always covert inputs to bf16, which is +// matched with the dnnl matmul logic. +using MulT = BF16; // Loads two vectors at a time with element type MulT from a row of transposed // B. Called in a loop over col_ab. No bounds checking because `kRow` is @@ -450,10 +461,10 @@ HWY_INLINE void MatMulTile(const size_t batch_size, const Mat& A, // Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k. // Must not be called concurrently with the same `env`. template -HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, - const Mat& B, const float scale, - const float* HWY_RESTRICT add, MatMulEnv& env, - const Mat& C) { +HWY_NOINLINE void MatMul_hwy(const size_t batch_size, const Mat& A, + const Mat& B, const float scale, + const float* HWY_RESTRICT add, MatMulEnv& env, + const Mat& C) { // PROFILER_ZONE("Matmul"); HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); HWY_DASSERT(A.cols == B.cols); @@ -499,6 +510,148 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, }); } +template +static memory::data_type DnnType(); + +/// Instantiation for float type. Add similar instantiations for other +/// type if needed. +template <> +memory::data_type DnnType() { + return memory::data_type::f32; +} + +template <> +memory::data_type DnnType() { + return memory::data_type::bf16; +} + +template <> +memory::data_type DnnType() { + fprintf(stderr, "DnnType SfpStream is not supported\n"); + return memory::data_type::bf16; +} +template +memory convert_to_bf16(dnnl::engine engine, dnnl::stream engine_stream, + const Mat& source, memory::dims dims, + tag format_tag) { + // Write the input matrix to dnnl memory. + dnnl::memory::desc src_md(dims, DnnType(), format_tag); + auto source_mem = memory(src_md, engine); + source_mem.set_data_handle(const_cast(source.ptr + source.ofs)); + if (std::is_same::value) { + return source_mem; + } + // When the input is float, convert it to BF16. + if (std::is_same::value) { + auto dst_md = memory::desc(source_mem.get_desc().get_dims(), + dnnl::memory::data_type::bf16, format_tag); + dnnl::memory dst_mem(dst_md, engine); + auto reorder_pd = + reorder::primitive_desc(engine, source_mem.get_desc(), engine, dst_md); + auto reorder_prim = reorder(reorder_pd); + reorder_prim.execute(engine_stream, + {{DNNL_ARG_FROM, source_mem}, {DNNL_ARG_TO, dst_mem}}); + return dst_mem; + } + fprintf(stderr, "Unsupported type\n"); + return source_mem; +} + +template +HWY_NOINLINE void MatMul_dnnl(const size_t batch_size, + const Mat& A, + const Mat& B, const float scale, + const float* HWY_RESTRICT add, MatMulEnv& env, + const Mat& C) { + dnnl::engine engine = env.engine; + dnnl::stream engine_stream = env.engine_stream; + + // First stage: process the input data. + // OneDNN mandates that input and output data be managed using the + // dnnl::memory, a practice that can lead to enhanced performance. + // Create memory dims for inputs and outputs. + const memory::dim kRowsAC = batch_size, kColsARowsB = A.cols, + kColsBC = C.cols; + memory::dims src_dims = {kRowsAC, kColsARowsB}; + memory::dims weights_dims = {kColsARowsB, kColsBC}; + memory::dims bias_dims = {1, kColsBC}; + memory::dims dst_dims = {kRowsAC, kColsBC}; + + auto src_format_tag = tag::ab; + // `B` is a transposed matrix. + auto weights_format_tag = tag::ba; + auto dest_format_tag = tag::ab; + + // Create memory descriptors for inputs and outputs. + auto src_md = memory::desc(src_dims, DnnType(), src_format_tag); + auto weights_md = + memory::desc(weights_dims, DnnType(), weights_format_tag); + auto scale_md = memory::desc({{1}, dt::f32, tag::x}); + auto dst_md = memory::desc(dst_dims, dt::f32, dest_format_tag); + + auto src_mem = + convert_to_bf16(engine, engine_stream, A, src_dims, src_format_tag); + auto weights_mem = convert_to_bf16(engine, engine_stream, B, weights_dims, + weights_format_tag); + auto dst_mem = memory(dst_md, engine); + + // Second stage: Create the matmul primitive/operation. + // Define matmul Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, src_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, weights_mem}); + matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + // Apply the scaling factor to the weights, enabling us to apply it + // to the multiplication results before incorporating the bias. + auto scale_mem = memory(scale_md, engine, const_cast(&scale)); + matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_mem}); + dnnl::primitive_attr attr; + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + + // Create primitive descriptor. + matmul::primitive_desc matmul_pd; + + // When there is bias, add it to the matmul_args. + if (kAdd) { + auto bias_md = memory::desc(bias_dims, dt::f32, tag::ab); + auto bias_mem = memory(bias_md, engine); + bias_mem.set_data_handle(const_cast(add)); + matmul_args.insert({DNNL_ARG_BIAS, bias_mem}); + matmul_pd = matmul::primitive_desc(engine, src_md, weights_md, bias_md, + dst_md, attr); + } else { + matmul_pd = + matmul::primitive_desc(engine, src_md, weights_md, dst_md, attr); + } + + // Third stage: Execute the matmul primitive/operation. + auto matmul_prim = matmul(matmul_pd); + matmul_prim.execute(engine_stream, matmul_args); + engine_stream.wait(); + + // Copy the output from dnnl memory to the output matrix. + // Adding padding when the C.stride is more than the C.cols. + auto c_mem_ptr = static_cast(dst_mem.get_data_handle()); + const hn::ScalableTag df; + for (int row = 0; row < batch_size; ++row) { + hn::StoreU(hn::Zero(df), df, C.ptr + row * C.stride); + std::copy(c_mem_ptr + row * C.cols, c_mem_ptr + (row + 1) * C.cols, + C.ptr + row * C.stride); + } +} + +template +HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, + const Mat& B, const float scale, + const float* HWY_RESTRICT add, MatMulEnv& env, + const Mat& C) { + MatMul_dnnl(batch_size, A, B, scale, add, env, C); + + // Enable the hwy matmul and disable the dnnl matmul, when we need to + // benchmark the hwy matmul. + // MatMul_hwy(batch_size, A, B, scale, add, env, C); +} // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index ecc72b1a..19778332 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -18,11 +18,16 @@ #include +#include "tbb/global_control.h" #include "util/allocator.h" // RowVectorBatch #include "util/threading.h" // PerClusterPools #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/per_target.h" +#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp" + +using namespace dnnl; +using namespace tbb; namespace gcpp { @@ -81,8 +86,18 @@ class MatMulEnv { const size_t num_lp = pools.NumLP(); const size_t NF = hwy::VectorBytes() / sizeof(float); buf_ = RowVectorBatch(num_lp, 16 * NF); + setenv("ONEDNN_MAX_CPU_ISA", "AVX512_CORE_AMX", 1); + // Enable verbose logging for dnnl when we need to debug. + // setenv("DNNL_VERBOSE", "2", 2); + tbb::global_control global_limit( + tbb::global_control::max_allowed_parallelism, 128); + // Create execution dnnl::engine. + engine = dnnl::engine(dnnl::engine::kind::cpu, 0); + // Create dnnl::stream. + engine_stream = dnnl::stream(engine); } - + dnnl::stream engine_stream; + dnnl::engine engine; float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); } PerClusterPools& Pools() const { return *pools_; } hwy::ThreadPool& Pool() const { return pools_->Inner(0); } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2b64f275..01488fbc 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/shared.h" #ifndef HWY_DISABLED_TARGETS // Exclude HWY_SCALAR due to 2x bf16 -> f32. #define HWY_DISABLED_TARGETS HWY_SCALAR @@ -145,7 +146,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); // Dot(float,BF16) rounds both to BF16. - using RefType = hwy::If() && IsF32(), float, BF16>; + using RefType = BF16; const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); const double tolerance = 200.0 * norm * epsilon; @@ -233,8 +234,13 @@ void TestMatMul(MatMulEnv& env) { std::unique_ptr> c_slow = GenerateZeroMat(pool); const double start_slow = hwy::platform::Now(); - MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, - kAdd ? add->data() : nullptr, env, c_slow->data()); + // Compare the dnnl matmul results with the hwy matmul results. + MatMul_hwy(kRowsAC, ConstMat(a->data(), kColsARowsB), + ConstMat(b_trans->data(), kColsARowsB), scale, + kAdd ? add->data_scale1() : nullptr, env, + MutableMat(c_slow->data(), kColsBC)); + // MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, + // kAdd ? add->data() : nullptr, c_slow->data()); if (want_bench) { PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC, hwy::platform::Now() - start_slow); @@ -258,9 +264,9 @@ void TestMatMul(MatMulEnv& env) { } void TestAllMatMul() { + tbb::global_control global_limit(tbb::global_control::max_allowed_parallelism, 128); // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. - if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || - HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) { + if (HWY_TARGET != HWY_AVX3 && HWY_TARGET != HWY_AVX3_SPR) { return; } @@ -272,10 +278,10 @@ void TestAllMatMul() { using SFP = SfpStream; // large-scale test: batch_size=128 is better than 64 or 256 for SKX. - TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env); - TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env); + TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, BF16>(env); + TestMatMul<128, 3072, 24576, /*kAdd=*/false, BF16>(env); + TestMatMul<1, 24576, 3072, /*kAdd=*/false, BF16>(env); + TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, BF16>(env); // medium-sized square test - temporarily disabled for faster testing. if constexpr (false) { @@ -292,32 +298,22 @@ void TestAllMatMul() { TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env); TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env); TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env); TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env); TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env); TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env); TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env); } // NOLINTNEXTLINE(google-readability-namespace-comments)