Skip to content

Commit e8fdb86

Browse files
authored
Set formemset (rust-lang#881)
1 parent 34e85a2 commit e8fdb86

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,11 @@ void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2,
675675
}
676676
}
677677

678+
void EnzymeSetForMemSet(LLVMValueRef inst1) {
679+
Instruction *I1 = cast<Instruction>(unwrap(inst1));
680+
I1->setMetadata("enzyme_formemset", MDNode::get(I1->getContext(), {}));
681+
}
682+
678683
void EnzymeSetMustCache(LLVMValueRef inst1) {
679684
Instruction *I1 = cast<Instruction>(unwrap(inst1));
680685
I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,8 +901,13 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
901901
}
902902
}
903903

904-
if (ZeroInst)
904+
if (ZeroInst) {
905+
if (ZeroInst->getOperand(0) != malloccall) {
906+
scopeInstructions[alloc].push_back(
907+
cast<Instruction>(ZeroInst->getOperand(0)));
908+
}
905909
scopeInstructions[alloc].push_back(ZeroInst);
910+
}
906911
storealloc = allocationBuilder.CreateStore(firstallocation, storeInto);
907912

908913
scopeAllocs[alloc].push_back(malloccall);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,8 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
22142214
toadd = scopeAllocs[found2->second.first][0];
22152215
for (auto u : toadd->users()) {
22162216
if (auto ci = dyn_cast<CastInst>(u)) {
2217+
if (hasMetadata(ci, "enzyme_formemset"))
2218+
continue;
22172219
toadd = ci;
22182220
}
22192221
}

0 commit comments

Comments
 (0)