Skip to content

[MLIR] Prevent invalid IR from being passed outside of RemoveDeadValues #121079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 162 additions & 60 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,55 @@ using namespace mlir::dataflow;

namespace {

// Set of structures below to be filled with operations and arguments to erase.
// This is done to separate analysis and tree modification phases,
// otherwise analysis is operating on half-deleted tree which is incorrect.

struct CleanupFunction {
FunctionOpInterface funcOp;
BitVector nonLiveArgs;
BitVector nonLiveRets;
};

struct CleanupOperands {
Operation *op;
BitVector nonLiveOperands;
};

struct CleanupResults {
Operation *op;
BitVector nonLiveResults;
};

struct CleanupBlockArgs {
Block *b;
BitVector nonLiveArgs;
};

struct CleanupSuccessorOperands {
BranchOpInterface branch;
unsigned index;
BitVector nonLiveOperands;
};

struct CleanupList {
SmallVector<Operation *> operations;
SmallVector<Value> values;
SmallVector<CleanupFunction> functions;
SmallVector<CleanupOperands> operands;
SmallVector<CleanupResults> results;
SmallVector<CleanupBlockArgs> blocks;
SmallVector<CleanupSuccessorOperands> successorOperands;
};

// Some helper functions...

/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
static bool hasLive(ValueRange values, const DenseSet<Value> &deletionSet,
RunLivenessAnalysis &la) {
for (Value value : values) {
// If there is a null value, it implies that it was dropped during the
// execution of this pass, implying that it was non-live.
if (!value)
if (deletionSet.contains(value))
continue;

const Liveness *liveness = la.getLiveness(value);
Expand All @@ -92,11 +132,13 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {

/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
static BitVector markLives(ValueRange values,
const DenseSet<Value> &deletionSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);

for (auto [index, value] : llvm::enumerate(values)) {
if (!value) {
if (deletionSet.contains(value)) {
lives.reset(index);
continue;
}
Expand All @@ -115,6 +157,21 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
return lives;
}

// DeletionSet is used to track the Values that are scheduled for removal
void updateDeletionSet(DenseSet<Value> &deletionSet, ValueRange range,
const BitVector &nonLive) {
for (auto [index, result] : llvm::enumerate(range)) {
if (!nonLive[index])
continue;
deletionSet.insert(result);
}
}

void updateDeletionSet(DenseSet<Value> &deletionSet, Operation *op,
const BitVector &nonLive) {
updateDeletionSet(deletionSet, op->getResults(), nonLive);
}

/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
/// is 1.
static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
Expand Down Expand Up @@ -174,43 +231,44 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
/// function-like op, a call-like op, a region branch op, a branch op, a region
/// branch terminator op, or return-like.
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
static void cleanSimpleOp(CleanupList &cl, DenseSet<Value> &deletionSet,
Operation *op, RunLivenessAnalysis &la) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), deletionSet, la))
return;

op->dropAllUses();
op->erase();
cl.operations.push_back(op);
updateDeletionSet(deletionSet, op, BitVector(op->getNumResults(), true));
}

/// Clean a function-like op `funcOp`, given the liveness information in `la`
/// and the IR in `module`. Here, cleaning means:
/// (1) Dropping the uses of its unnecessary (non-live) arguments,
/// (2) Erasing these arguments,
/// (3) Erasing their corresponding operands from its callers,
/// (2) Erasing their corresponding operands from its callers,
/// (3) Erasing these arguments,
/// (4) Erasing its unnecessary terminator operands (return values that are
/// non-live across all callers),
/// (5) Dropping the uses of these return values from its callers, AND
/// (6) Erasing these return values
/// iff it is not public or external.
static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
static void cleanFuncOp(CleanupList &cl, DenseSet<Value> &deletionSet,
FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la) {
if (funcOp.isPublic() || funcOp.isExternal())
return;

// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, la);
BitVector nonLiveArgs = markLives(arguments, deletionSet, la);
nonLiveArgs = nonLiveArgs.flip();

// Do (1).
for (auto [index, arg] : llvm::enumerate(arguments))
if (arg && nonLiveArgs[index])
arg.dropAllUses();
if (arg && nonLiveArgs[index]) {
cl.values.push_back(arg);
deletionSet.insert(arg);
}

// Do (2).
funcOp.eraseArguments(nonLiveArgs);

// Do (3).
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
Expand All @@ -222,7 +280,7 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
for (int index : nonLiveArgs.set_bits())
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
callOp->eraseOperands(nonLiveCallOperands);
cl.operands.push_back({callOp, nonLiveCallOperands});
}

// Get the list of unnecessary terminator operands (return values that are
Expand Down Expand Up @@ -253,26 +311,27 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector liveCallRets = markLives(callOp->getResults(), la);
BitVector liveCallRets = markLives(callOp->getResults(), deletionSet, la);
nonLiveRets &= liveCallRets.flip();
}

// Do (4).
// Do (3).
// Note that in the absence of control flow ops forcing the control to go from
// the entry (first) block to the other blocks, the control never reaches any
// block other than the entry block, because every block has a terminator.
for (Block &block : funcOp.getBlocks()) {
Operation *returnOp = block.getTerminator();
if (returnOp && returnOp->getNumOperands() == numReturns)
returnOp->eraseOperands(nonLiveRets);
cl.operands.push_back({returnOp, nonLiveRets});
}
funcOp.eraseResults(nonLiveRets);
cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});

// Do (5) and (6).
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
dropUsesAndEraseResults(callOp, nonLiveRets);
cl.results.push_back({callOp, nonLiveRets});
updateDeletionSet(deletionSet, callOp, nonLiveRets);
}
}

Expand All @@ -297,18 +356,19 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
/// It is important to note that values in this op flow from operands and
/// terminator operands (successor operands) to arguments and results (successor
/// inputs).
static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void cleanRegionBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la) {
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), la);
liveResults = markLives(regionBranchOp->getResults(), deletionSet, la);
};

// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
for (Region &region : regionBranchOp->getRegions()) {
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, la);
BitVector regionLiveArgs = markLives(arguments, deletionSet, la);
liveArgs[&region] = regionLiveArgs;
}
};
Expand Down Expand Up @@ -497,9 +557,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// It could never be live because of this op but its liveness could have been
// attributed to something else.
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), la)) {
regionBranchOp->dropAllUses();
regionBranchOp->erase();
!hasLive(regionBranchOp->getResults(), deletionSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}

Expand Down Expand Up @@ -538,29 +597,29 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
terminatorOperandsToKeep);

// Do (1).
regionBranchOp->eraseOperands(operandsToKeep.flip());
cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});

// Do (2.a) and (2.b).
for (Region &region : regionBranchOp->getRegions()) {
assert(!region.empty() && "expected a non-empty region in an op "
"implementing `RegionBranchOpInterface`");
for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
if (argsToKeep[&region][index])
continue;
if (arg)
arg.dropAllUses();
}
region.front().eraseArguments(argsToKeep[&region].flip());
BitVector argsToRemove = argsToKeep[&region].flip();
cl.blocks.push_back({&region.front(), argsToRemove});
updateDeletionSet(deletionSet, region.front().getArguments(), argsToRemove);
}

// Do (2.c).
for (Region &region : regionBranchOp->getRegions()) {
Operation *terminator = region.front().getTerminator();
terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
cl.operands.push_back(
{terminator, terminatorOperandsToKeep[terminator].flip()});
}

// Do (3) and (4).
dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
BitVector resultsToRemove = resultsToKeep.flip();
updateDeletionSet(deletionSet, regionBranchOp.getOperation(),
resultsToRemove);
cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
}

// 1. Iterate over each successor block of the given BranchOpInterface
Expand All @@ -572,7 +631,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// c. Mark each operand as live or dead based on the analysis.
// 3. Remove dead operands from the branch operation and arguments accordingly

static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
static void cleanBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
BranchOpInterface branchOp, RunLivenessAnalysis &la) {
unsigned numSuccessors = branchOp->getNumSuccessors();

// Do (1)
Expand All @@ -588,22 +648,60 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
operandValues.push_back(successorOperands[operandIdx]);
}

BitVector successorLiveOperands = markLives(operandValues, la);

// Do (3)
for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
if (!successorLiveOperands[argIdx]) {
if (successorBlock->getNumArguments() < successorOperands.size()) {
// if block was cleaned through a different code path
// we only need to remove operands from the invokation
successorOperands.erase(argIdx);
continue;
}
BitVector successorNonLive =
markLives(operandValues, deletionSet, la).flip();
updateDeletionSet(deletionSet, successorBlock->getArguments(),
successorNonLive);
cl.blocks.push_back({successorBlock, successorNonLive});
cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
}
}

void cleanup(CleanupList &cl) {
for (auto &op : cl.operations) {
op->dropAllUses();
op->erase();
}

for (auto &v : cl.values) {
v.dropAllUses();
}

successorBlock->getArgument(argIdx).dropAllUses();
successorOperands.erase(argIdx);
successorBlock->eraseArgument(argIdx);
}
for (auto &f : cl.functions) {
f.funcOp.eraseArguments(f.nonLiveArgs);
f.funcOp.eraseResults(f.nonLiveRets);
}

for (auto &o : cl.operands) {
o.op->eraseOperands(o.nonLiveOperands);
}

for (auto &r : cl.results) {
dropUsesAndEraseResults(r.op, r.nonLiveResults);
}

for (auto &b : cl.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
Comment on lines +709 to +714
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is important to iterate from the end? Also, why not:

Suggested change
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
for (auto [idx, nonLiveArg] : llvm::enumerate(b.nonLiveArgs)) {
if (!nonLiveArg)
continue;
b.b->getArgument(idx).dropAllUses();
b.b->eraseArgument(idx);
}

Or use llvm::foreach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the collection of arguments is changed while iterating, if we iterate from start, it would be invalid because each deletion is invalidating all successor indexes, that is why iteration is reversed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding a note.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, comment added, thank you!

}
for (auto &op : cl.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.index);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
successorOperands.erase(i);
}
}
}
Expand All @@ -616,24 +714,28 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
DenseSet<Value> deletionSet;
CleanupList cl;

module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
cleanFuncOp(funcOp, module, la);
cleanFuncOp(cl, deletionSet, funcOp, module, la);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
cleanRegionBranchOp(regionBranchOp, la);
cleanRegionBranchOp(cl, deletionSet, regionBranchOp, la);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
cleanBranchOp(branchOp, la);
cleanBranchOp(cl, deletionSet, branchOp, la);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
} else {
cleanSimpleOp(op, la);
cleanSimpleOp(cl, deletionSet, op, la);
}
});

cleanup(cl);
}

std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
Expand Down
Loading
Loading