Skip to content

[CodeGen] Use 128bits for LaneBitmask. #111157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/include/llvm/CodeGen/RDFLiveness.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ namespace std {

template <> struct hash<llvm::rdf::detail::NodeRef> {
std::size_t operator()(llvm::rdf::detail::NodeRef R) const {
return std::hash<llvm::rdf::NodeId>{}(R.first) ^
std::hash<llvm::LaneBitmask::Type>{}(R.second.getAsInteger());
return llvm::hash_value<llvm::rdf::NodeId>(R.first) ^
llvm::hash_value(R.second.getAsPair());
}
};

Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/CodeGen/RDFRegisters.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ struct RegisterRef {
}

size_t hash() const {
return std::hash<RegisterId>{}(Reg) ^
std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
return llvm::hash_value<RegisterId>(Reg) ^
llvm::hash_value(Mask.getAsPair());
}

static constexpr bool isRegId(unsigned Id) {
Expand Down
146 changes: 97 additions & 49 deletions llvm/include/llvm/MC/LaneBitmask.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,72 +29,120 @@
#ifndef LLVM_MC_LANEBITMASK_H
#define LLVM_MC_LANEBITMASK_H

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Printable.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>

namespace llvm {

struct LaneBitmask {
// When changing the underlying type, change the format string as well.
using Type = uint64_t;
enum : unsigned { BitWidth = 8*sizeof(Type) };
constexpr static const char *const FormatStr = "%016llX";
struct LaneBitmask {
static constexpr unsigned int BitWidth = 128;

constexpr LaneBitmask() = default;
explicit constexpr LaneBitmask(Type V) : Mask(V) {}

constexpr bool operator== (LaneBitmask M) const { return Mask == M.Mask; }
constexpr bool operator!= (LaneBitmask M) const { return Mask != M.Mask; }
constexpr bool operator< (LaneBitmask M) const { return Mask < M.Mask; }
constexpr bool none() const { return Mask == 0; }
constexpr bool any() const { return Mask != 0; }
constexpr bool all() const { return ~Mask == 0; }

constexpr LaneBitmask operator~() const {
return LaneBitmask(~Mask);
}
constexpr LaneBitmask operator|(LaneBitmask M) const {
return LaneBitmask(Mask | M.Mask);
}
constexpr LaneBitmask operator&(LaneBitmask M) const {
return LaneBitmask(Mask & M.Mask);
}
LaneBitmask &operator|=(LaneBitmask M) {
Mask |= M.Mask;
return *this;
}
LaneBitmask &operator&=(LaneBitmask M) {
Mask &= M.Mask;
return *this;
explicit LaneBitmask(APInt V) {
switch (V.getBitWidth()) {
case BitWidth:
Mask[0] = V.getRawData()[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should use APInt::extractBitsAsZExtValue() instead exposing APInts internals.

Mask[1] = V.getRawData()[1];
break;
default:
llvm_unreachable("Unsupported bitwidth");
}
}
constexpr explicit LaneBitmask(uint64_t Lo = 0, uint64_t Hi = 0) : Mask{Lo, Hi} {}

constexpr Type getAsInteger() const { return Mask; }
constexpr bool operator==(LaneBitmask M) const {
return Mask[0] == M.Mask[0] && Mask[1] == M.Mask[1];
}
constexpr bool operator!=(LaneBitmask M) const {
return Mask[0] != M.Mask[0] || Mask[1] != M.Mask[1];
}
constexpr bool operator<(LaneBitmask M) const {
return Mask[1] < M.Mask[1] || (Mask[1] == M.Mask[1] && Mask[0] < M.Mask[0]);
}
constexpr bool none() const { return Mask[0] == 0 && Mask[1] == 0; }
constexpr bool any() const { return Mask[0] != 0 || Mask[1] != 0; }
constexpr bool all() const { return ~Mask[0] == 0 && ~Mask[1] == 0; }

unsigned getNumLanes() const { return llvm::popcount(Mask); }
unsigned getHighestLane() const {
return Log2_64(Mask);
}
constexpr LaneBitmask operator~() const { return LaneBitmask(~Mask[0], ~Mask[1]); }
constexpr LaneBitmask operator|(LaneBitmask M) const {
return LaneBitmask(Mask[0] | M.Mask[0], Mask[1] | M.Mask[1]);
}
constexpr LaneBitmask operator&(LaneBitmask M) const {
return LaneBitmask(Mask[0] & M.Mask[0], Mask[1] & M.Mask[1]);
}
LaneBitmask &operator|=(LaneBitmask M) {
Mask[0] |= M.Mask[0];
Mask[1] |= M.Mask[1];
return *this;
}
LaneBitmask &operator&=(LaneBitmask M) {
Mask[0] &= M.Mask[0];
Mask[1] &= M.Mask[1];
return *this;
}

static constexpr LaneBitmask getNone() { return LaneBitmask(0); }
static constexpr LaneBitmask getAll() { return ~LaneBitmask(0); }
static constexpr LaneBitmask getLane(unsigned Lane) {
return LaneBitmask(Type(1) << Lane);
}
APInt getAsAPInt() const { return APInt(BitWidth, {Mask[0], Mask[1]}); }
constexpr std::pair<uint64_t, uint64_t> getAsPair() const { return {Mask[0], Mask[1]}; }

private:
Type Mask = 0;
};
unsigned getNumLanes() const {
return Mask[1] ? llvm::popcount(Mask[1]) + llvm::popcount(Mask[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope you could rely on modern host machines having a fast popcount instruction so you don't need to special-case Mask[1] here, but I'm not sure.

: llvm::popcount(Mask[0]);
}
unsigned getHighestLane() const {
return Mask[1] ? Log2_64(Mask[1]) + 64 : Log2_64(Mask[0]);
}

/// Create Printable object to print LaneBitmasks on a \ref raw_ostream.
inline Printable PrintLaneMask(LaneBitmask LaneMask) {
return Printable([LaneMask](raw_ostream &OS) {
OS << format(LaneBitmask::FormatStr, LaneMask.getAsInteger());
});
static constexpr LaneBitmask getNone() { return LaneBitmask(0, 0); }
static constexpr LaneBitmask getAll() { return ~LaneBitmask(0, 0); }
static constexpr LaneBitmask getLane(unsigned Lane) {
return Lane >= 64 ? LaneBitmask(0, 1ULL << (Lane % 64))
: LaneBitmask(1ULL << Lane, 0);
}

private:
uint64_t Mask[2];
};

/// Create Printable object to print LaneBitmasks on a \ref raw_ostream.
/// If \p FormatAsCLiterals is true, it will print the bitmask as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the option? Please can we just print it as a single hex literal with lots of digits?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to avoid having to update many tests that rely on a pretty printed format. Changing it would require changing various tests (which I'm happy to do if that's preferred).

/// a hexadecimal C literal with zero padding, or a list of such C literals if
/// the value cannot be represented in 64 bits.
/// For example (FormatAsCliterals == true)
/// bitmask '1' => "0x0000000000000001"
/// bitmask '1 << 64' => "0x0000000000000000,0x0000000000000001"
/// (FormatAsCLiterals == false)
/// bitmask '1' => "00000000000000000000000000000001"
/// bitmask '1 << 64' => "00000000000000010000000000000000"
inline Printable PrintLaneMask(LaneBitmask LaneMask,
bool FormatAsCLiterals = false) {
return Printable([LaneMask, FormatAsCLiterals](raw_ostream &OS) {
SmallString<64> Buffer;
APInt V = LaneMask.getAsAPInt();
while (true) {
unsigned Bitwidth = FormatAsCLiterals ? 64 : LaneBitmask::BitWidth;
APInt VToPrint = V.trunc(Bitwidth);

Buffer.clear();
VToPrint.toString(Buffer, 16, /*Signed=*/false,
/*formatAsCLiteral=*/false);
unsigned NumZeroesToPad =
(VToPrint.countLeadingZeros() / 4) - VToPrint.isZero();
OS << (FormatAsCLiterals ? "0x" : "") << std::string(NumZeroesToPad, '0')
<< Buffer.str();
V = V.lshr(Bitwidth);
if (V.getActiveBits())
OS << ",";
else
break;
}
});
}

} // end namespace llvm

#endif // LLVM_MC_LANEBITMASK_H
45 changes: 34 additions & 11 deletions llvm/lib/CodeGen/MIRParser/MIParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,17 +870,40 @@ bool MIParser::parseBasicBlockLiveins(MachineBasicBlock &MBB) {
lex();
LaneBitmask Mask = LaneBitmask::getAll();
if (consumeIfPresent(MIToken::colon)) {
// Parse lane mask.
if (Token.isNot(MIToken::IntegerLiteral) &&
Token.isNot(MIToken::HexLiteral))
return error("expected a lane mask");
static_assert(sizeof(LaneBitmask::Type) == sizeof(uint64_t),
"Use correct get-function for lane mask");
LaneBitmask::Type V;
if (getUint64(V))
return error("invalid lane mask value");
Mask = LaneBitmask(V);
lex();
if (consumeIfPresent(MIToken::lparen)) {
// We need to parse a list of literals
SmallVector<uint64_t, 2> Literals;
while (true) {
if (Token.isNot(MIToken::HexLiteral))
return error("expected a lane mask");
APInt V;
getHexUint(V);
Literals.push_back(V.getZExtValue());
// Lex past literal
lex();
if (Token.is(MIToken::rparen))
break;
else if (Token.isNot(MIToken::comma))
return error("expected a comma");
// Lex past comma
lex();
}
// Lex past rparen
lex();
Mask = LaneBitmask(APInt(LaneBitmask::BitWidth, Literals));
} else {
// Parse lane mask.
APInt V;
if (Token.is(MIToken::IntegerLiteral)) {
uint64_t UV;
if (getUint64(UV))
return error("invalid lane mask value");
V = APInt(LaneBitmask::BitWidth, UV);
} else if (getHexUint(V))
return error("expected a lane mask");
Mask = LaneBitmask(APInt(LaneBitmask::BitWidth, V.getZExtValue()));
lex();
}
}
MBB.addLiveIn(Reg, Mask);
} while (consumeIfPresent(MIToken::comma));
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/CodeGen/MIRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,14 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
OS << ", ";
First = false;
OS << printReg(LI.PhysReg, &TRI);
if (!LI.LaneMask.all())
OS << ":0x" << PrintLaneMask(LI.LaneMask);
if (!LI.LaneMask.all()) {
OS << ":";
if (LI.LaneMask.getAsAPInt().getActiveBits() <= 64)
OS << PrintLaneMask(LI.LaneMask, /*FormatAsCLiterals=*/true);
else
OS << '(' << PrintLaneMask(LI.LaneMask, /*FormatAsCLiterals=*/true)
<< ')';
}
}
OS << "\n";
HasLineAttributes = true;
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/RDFRegisters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,11 @@ raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
if (P.Mask.none())
return OS << ":*none*";

LaneBitmask::Type Val = P.Mask.getAsInteger();
if ((Val & 0xffff) == Val)
return OS << ':' << format("%04llX", Val);
if ((Val & 0xffffffff) == Val)
return OS << ':' << format("%08llX", Val);
APInt Val = P.Mask.getAsAPInt();
if (Val.getActiveBits() <= 16)
return OS << ':' << format("%04llX", Val.getZExtValue());
if (Val.getActiveBits() <= 32)
return OS << ':' << format("%08llX", Val.getZExtValue());
return OS << ':' << PrintLaneMask(P.Mask);
}

Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/AMDGPU/SIRegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ class SIRegisterInfo final : public AMDGPUGenRegisterInfo {
static unsigned getNumCoveredRegs(LaneBitmask LM) {
// The assumption is that every lo16 subreg is an even bit and every hi16
// is an adjacent odd bit or vice versa.
uint64_t Mask = LM.getAsInteger();
APInt MaskV = LM.getAsAPInt();
assert(MaskV.getActiveBits() <= 64 &&
"uint64_t is insufficient to represent lane bitmask operation");
uint64_t Mask = MaskV.getZExtValue();
uint64_t Even = Mask & 0xAAAAAAAAAAAAAAAAULL;
Mask = (Even >> 1) | Mask;
uint64_t Odd = Mask & 0x5555555555555555ULL;
Expand Down
18 changes: 18 additions & 0 deletions llvm/test/CodeGen/AArch64/lanebitmask.mir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
# RUN: llc -o - %s -mtriple=aarch64 -stop-before=greedy | FileCheck %s
---
name: test_parse_lanebitmask
tracksRegLiveness: true
liveins:
- { reg: '$h0' }
- { reg: '$s1' }
body: |
bb.0:
liveins: $h0:0x0000000000000001, $s1:(0x0000000000000001,0x0000000000000000)
; CHECK-LABEL: name: test_parse_lanebitmask
; CHECK: liveins: $h0:0x0000000000000001, $s1:0x0000000000000001, $h0, $s1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: RET_ReallyLR
RET_ReallyLR
...

2 changes: 1 addition & 1 deletion llvm/unittests/CodeGen/MFCommon.inc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BogusRegisterInfo : public TargetRegisterInfo {
public:
BogusRegisterInfo()
: TargetRegisterInfo(nullptr, BogusRegisterClasses, BogusRegisterClasses,
nullptr, nullptr, nullptr, LaneBitmask(~0u), nullptr,
nullptr, nullptr, nullptr, LaneBitmask::getAll(), nullptr,
nullptr) {
InitMCRegisterInfo(nullptr, 0, 0, 0, nullptr, 0, nullptr, 0, nullptr,
nullptr, nullptr, nullptr, nullptr, 0, nullptr);
Expand Down
2 changes: 1 addition & 1 deletion llvm/unittests/MC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ add_llvm_unittest(MCTests
Disassembler.cpp
DwarfLineTables.cpp
DwarfLineTableHeaders.cpp
LaneBitmaskTest.cpp
MCInstPrinter.cpp
StringTableBuilderTest.cpp
TargetRegistry.cpp
MCDisassemblerTest.cpp
)

69 changes: 69 additions & 0 deletions llvm/unittests/MC/LaneBitmaskTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===------------------ LaneBitmaskTest.cpp -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gtest/gtest.h"
#include "llvm/MC/LaneBitmask.h"
#include "llvm/Support/raw_ostream.h"
#include <string>

using namespace llvm;

TEST(LaneBitmaskTest, Basic) {
EXPECT_EQ(LaneBitmask::getAll(), ~LaneBitmask::getNone());
EXPECT_EQ(LaneBitmask::getNone(), ~LaneBitmask::getAll());
EXPECT_EQ(LaneBitmask::getLane(0) | LaneBitmask::getLane(1), LaneBitmask(3));
EXPECT_EQ(LaneBitmask(3) & LaneBitmask::getLane(1), LaneBitmask::getLane(1));

EXPECT_EQ(LaneBitmask(APInt(128, 42)).getAsAPInt(), APInt(128, 42));
EXPECT_EQ(LaneBitmask(3).getNumLanes(), 2);
EXPECT_EQ(LaneBitmask::getLane(0).getHighestLane(), 0);
EXPECT_EQ(LaneBitmask::getLane(64).getHighestLane(), 64);
EXPECT_EQ(LaneBitmask::getLane(127).getHighestLane(), 127);

EXPECT_LT(LaneBitmask::getLane(64), LaneBitmask::getLane(65));
EXPECT_LT(LaneBitmask::getLane(63), LaneBitmask::getLane(64));
EXPECT_LT(LaneBitmask::getLane(62), LaneBitmask::getLane(63));
EXPECT_LT(LaneBitmask::getLane(64), LaneBitmask::getLane(64) | LaneBitmask::getLane(0));

LaneBitmask X(1);
X |= LaneBitmask(2);
EXPECT_EQ(X, LaneBitmask(3));

LaneBitmask Y(3);
Y &= LaneBitmask(1);
EXPECT_EQ(Y, LaneBitmask(1));
}

TEST(LaneBitmaskTest, Print) {
std::string buffer;
raw_string_ostream OS(buffer);

buffer = "";
OS << PrintLaneMask(LaneBitmask::getAll(), /*FormatAsCLiterals=*/true);
EXPECT_STREQ(OS.str().data(), "0xFFFFFFFFFFFFFFFF,0xFFFFFFFFFFFFFFFF");

buffer = "";
OS << PrintLaneMask(LaneBitmask::getAll(), /*FormatAsCLiterals=*/false);
EXPECT_STREQ(OS.str().data(), "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");

buffer = "";
OS << PrintLaneMask(LaneBitmask::getLane(0), /*FormatAsCLiterals=*/true);
EXPECT_STREQ(OS.str().data(), "0x0000000000000001");

buffer = "";
OS << PrintLaneMask(LaneBitmask::getLane(63), /*FormatAsCLiterals=*/true);
EXPECT_STREQ(OS.str().data(), "0x8000000000000000");

buffer = "";
OS << PrintLaneMask(LaneBitmask::getNone(), /*FormatAsCLiterals=*/true);
EXPECT_STREQ(OS.str().data(), "0x0000000000000000");

buffer = "";
OS << PrintLaneMask(LaneBitmask::getLane(64), /*FormatAsCLiterals=*/true);
EXPECT_STREQ(OS.str().data(), "0x0000000000000000,0x0000000000000001");
}
Loading
Loading