Skip to content

Commit 81a7b64

Browse files
authored
[MLIR][LLVM] Recursion importer handle repeated self-references (#87295)
Followup to this discussion: #80251 (comment). The previous debug importer was correct but inefficient. For cases with mutual recursion that contain more than one back-edge, each back-edge would result in a new translated instance. This is because the previous implementation never caches any translated result with unbounded self-references. This means all translation inside a recursive context is performed from scratch, which will incur repeated run-time cost as well as repeated attribute sub-trees in the translated IR (differing only in their `recId`s). This PR refactors the importer to handle caching inside a recursive context. - In the presence of unbound self-refs, the translation result is cached in a separate cache that keeps track of the set of dependent unbound self-refs. - A dependent cache entry is valid only when all the unbound self-refs are in scope. Whenever a cached entry goes out of scope, it will be removed the next time it is looked up.
1 parent 9fd2e2c commit 81a7b64

File tree

5 files changed

+398
-79
lines changed

5 files changed

+398
-79
lines changed

mlir/lib/Target/LLVMIR/DebugImporter.cpp

Lines changed: 131 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/Location.h"
1414
#include "llvm/ADT/STLExtras.h"
1515
#include "llvm/ADT/ScopeExit.h"
16+
#include "llvm/ADT/SetOperations.h"
1617
#include "llvm/ADT/TypeSwitch.h"
1718
#include "llvm/BinaryFormat/Dwarf.h"
1819
#include "llvm/IR/Constants.h"
@@ -25,6 +26,10 @@ using namespace mlir;
2526
using namespace mlir::LLVM;
2627
using namespace mlir::LLVM::detail;
2728

29+
DebugImporter::DebugImporter(ModuleOp mlirModule)
30+
: recursionPruner(mlirModule.getContext()),
31+
context(mlirModule.getContext()), mlirModule(mlirModule) {}
32+
2833
Location DebugImporter::translateFuncLocation(llvm::Function *func) {
2934
llvm::DISubprogram *subprogram = func->getSubprogram();
3035
if (!subprogram)
@@ -246,42 +251,13 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
246251
if (DINodeAttr attr = nodeToAttr.lookup(node))
247252
return attr;
248253

249-
// If the node type is capable of being recursive, check if it's seen before.
250-
auto recSelfCtor = getRecSelfConstructor(node);
251-
if (recSelfCtor) {
252-
// If a cyclic dependency is detected since the same node is being traversed
253-
// twice, emit a recursive self type, and mark the duplicate node on the
254-
// translationStack so it can emit a recursive decl type.
255-
auto [iter, inserted] = translationStack.try_emplace(node, nullptr);
256-
if (!inserted) {
257-
// The original node may have already been assigned a recursive ID from
258-
// a different self-reference. Use that if possible.
259-
DistinctAttr recId = iter->second;
260-
if (!recId) {
261-
recId = DistinctAttr::create(UnitAttr::get(context));
262-
iter->second = recId;
263-
}
264-
unboundRecursiveSelfRefs.back().insert(recId);
265-
return cast<DINodeAttr>(recSelfCtor(recId));
266-
}
267-
}
268-
269-
unboundRecursiveSelfRefs.emplace_back();
270-
271-
auto guard = llvm::make_scope_exit([&]() {
272-
if (recSelfCtor)
273-
translationStack.pop_back();
254+
// Register with the recursive translator. If it can be handled without
255+
// recursing into it, return the result immediately.
256+
if (DINodeAttr attr = recursionPruner.pruneOrPushTranslationStack(node))
257+
return attr;
274258

275-
// Copy unboundRecursiveSelfRefs down to the previous level.
276-
if (unboundRecursiveSelfRefs.size() == 1)
277-
assert(unboundRecursiveSelfRefs.back().empty() &&
278-
"internal error: unbound recursive self reference at top level.");
279-
else
280-
unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
281-
unboundRecursiveSelfRefs.back().begin(),
282-
unboundRecursiveSelfRefs.back().end());
283-
unboundRecursiveSelfRefs.pop_back();
284-
});
259+
auto guard = llvm::make_scope_exit(
260+
[&]() { recursionPruner.popTranslationStack(node); });
285261

286262
// Convert the debug metadata if possible.
287263
auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -318,22 +294,130 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
318294
return nullptr;
319295
};
320296
if (DINodeAttr attr = translateNode(node)) {
321-
// If this node was marked as recursive, set its recId.
322-
if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(attr)) {
323-
if (DistinctAttr recId = translationStack.lookup(node)) {
324-
attr = cast<DINodeAttr>(recType.withRecId(recId));
325-
// Remove the unbound recursive ID from the set of unbound self
326-
// references in the translation stack.
327-
unboundRecursiveSelfRefs.back().erase(recId);
297+
auto [result, isSelfContained] =
298+
recursionPruner.finalizeTranslation(node, attr);
299+
// Only cache fully self-contained nodes.
300+
if (isSelfContained)
301+
nodeToAttr.try_emplace(node, result);
302+
return result;
303+
}
304+
return nullptr;
305+
}
306+
307+
//===----------------------------------------------------------------------===//
308+
// RecursionPruner
309+
//===----------------------------------------------------------------------===//
310+
311+
/// Get the `getRecSelf` constructor for the translated type of `node` if its
312+
/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
313+
static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
314+
getRecSelfConstructor(llvm::DINode *node) {
315+
using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
316+
return TypeSwitch<llvm::DINode *, CtorType>(node)
317+
.Case([&](llvm::DICompositeType *) {
318+
return CtorType(DICompositeTypeAttr::getRecSelf);
319+
})
320+
.Default(CtorType());
321+
}
322+
323+
DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
324+
llvm::DINode *node) {
325+
// If the node type is capable of being recursive, check if it's seen
326+
// before.
327+
auto recSelfCtor = getRecSelfConstructor(node);
328+
if (recSelfCtor) {
329+
// If a cyclic dependency is detected since the same node is being
330+
// traversed twice, emit a recursive self type, and mark the duplicate
331+
// node on the translationStack so it can emit a recursive decl type.
332+
auto [iter, inserted] = translationStack.try_emplace(node);
333+
if (!inserted) {
334+
// The original node may have already been assigned a recursive ID from
335+
// a different self-reference. Use that if possible.
336+
DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
337+
if (!recSelf) {
338+
DistinctAttr recId = nodeToRecId.lookup(node);
339+
if (!recId) {
340+
recId = DistinctAttr::create(UnitAttr::get(context));
341+
nodeToRecId[node] = recId;
342+
}
343+
recSelf = recSelfCtor(recId);
344+
iter->second.recSelf = recSelf;
328345
}
346+
// Inject the self-ref into the previous layer.
347+
translationStack.back().second.unboundSelfRefs.insert(recSelf);
348+
return cast<DINodeAttr>(recSelf);
329349
}
350+
}
330351

331-
// Only cache fully self-contained nodes.
332-
if (unboundRecursiveSelfRefs.back().empty())
333-
nodeToAttr.try_emplace(node, attr);
334-
return attr;
352+
return lookup(node);
353+
}
354+
355+
std::pair<DINodeAttr, bool>
356+
DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
357+
DINodeAttr result) {
358+
// If `node` is not a potentially recursive type, it will not be on the
359+
// translation stack. Nothing to set in this case.
360+
if (translationStack.empty())
361+
return {result, true};
362+
if (translationStack.back().first != node)
363+
return {result, translationStack.back().second.unboundSelfRefs.empty()};
364+
365+
TranslationState &state = translationStack.back().second;
366+
367+
// If this node is actually recursive, set the recId onto `result`.
368+
if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
369+
auto recType = cast<DIRecursiveTypeAttrInterface>(result);
370+
result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
371+
// Remove this recSelf from the set of unbound selfRefs.
372+
state.unboundSelfRefs.erase(recSelf);
335373
}
336-
return nullptr;
374+
375+
// Insert the result into our internal cache if it's not self-contained.
376+
if (!state.unboundSelfRefs.empty()) {
377+
auto [_, inserted] = dependentCache.try_emplace(
378+
node, DependentTranslation{result, state.unboundSelfRefs});
379+
assert(inserted && "invalid state: caching the same DINode twice");
380+
return {result, false};
381+
}
382+
return {result, true};
383+
}
384+
385+
void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
386+
// If `node` is not a potentially recursive type, it will not be on the
387+
// translation stack. Nothing to handle in this case.
388+
if (translationStack.empty() || translationStack.back().first != node)
389+
return;
390+
391+
// At the end of the stack, all unbound self-refs must be resolved already,
392+
// and the entire cache should be accounted for.
393+
TranslationState &currLayerState = translationStack.back().second;
394+
if (translationStack.size() == 1) {
395+
assert(currLayerState.unboundSelfRefs.empty() &&
396+
"internal error: unbound recursive self reference at top level.");
397+
translationStack.pop_back();
398+
return;
399+
}
400+
401+
// Copy unboundSelfRefs down to the previous level.
402+
TranslationState &nextLayerState = (++translationStack.rbegin())->second;
403+
nextLayerState.unboundSelfRefs.insert(currLayerState.unboundSelfRefs.begin(),
404+
currLayerState.unboundSelfRefs.end());
405+
translationStack.pop_back();
406+
}
407+
408+
DINodeAttr DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
409+
auto cacheIter = dependentCache.find(node);
410+
if (cacheIter == dependentCache.end())
411+
return {};
412+
413+
DependentTranslation &entry = cacheIter->second;
414+
if (llvm::set_is_subset(entry.unboundSelfRefs,
415+
translationStack.back().second.unboundSelfRefs))
416+
return entry.attr;
417+
418+
// Stale cache entry.
419+
dependentCache.erase(cacheIter);
420+
return {};
337421
}
338422

339423
//===----------------------------------------------------------------------===//
@@ -394,13 +478,3 @@ DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
394478
id = DistinctAttr::create(UnitAttr::get(context));
395479
return id;
396480
}
397-
398-
function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
399-
DebugImporter::getRecSelfConstructor(llvm::DINode *node) {
400-
using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
401-
return TypeSwitch<llvm::DINode *, CtorType>(node)
402-
.Case([&](llvm::DICompositeType *concreteNode) {
403-
return CtorType(DICompositeTypeAttr::getRecSelf);
404-
})
405-
.Default(CtorType());
406-
}

mlir/lib/Target/LLVMIR/DebugImporter.h

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ namespace detail {
2929

3030
class DebugImporter {
3131
public:
32-
DebugImporter(ModuleOp mlirModule)
33-
: context(mlirModule.getContext()), mlirModule(mlirModule) {}
32+
DebugImporter(ModuleOp mlirModule);
3433

3534
/// Translates the given LLVM debug location to an MLIR location.
3635
Location translateLoc(llvm::DILocation *loc);
@@ -86,24 +85,102 @@ class DebugImporter {
8685
/// for it, or create a new one if not.
8786
DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
8887

89-
/// Get the `getRecSelf` constructor for the translated type of `node` if its
90-
/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
91-
function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
92-
getRecSelfConstructor(llvm::DINode *node);
93-
9488
/// A mapping between LLVM debug metadata and the corresponding attribute.
9589
DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
9690
/// A mapping between distinct LLVM debug metadata nodes and the corresponding
9791
/// distinct id attribute.
9892
DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
9993

100-
/// A stack that stores the metadata nodes that are being traversed. The stack
101-
/// is used to detect cyclic dependencies during the metadata translation.
102-
/// A node is pushed with a null value. If it is ever seen twice, it is given
103-
/// a recursive id attribute, indicating that it is a recursive node.
104-
llvm::MapVector<llvm::DINode *, DistinctAttr> translationStack;
105-
/// All the unbound recursive self references in the translation stack.
106-
SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
94+
/// Translation helper for recursive DINodes.
95+
/// Works alongside a stack-based DINode translator (the "main translator")
96+
/// for gracefully handling DINodes that are recursive.
97+
///
98+
/// Usage:
99+
/// - Before translating a node, call `pruneOrPushTranslationStack` to see if
100+
/// the pruner can preempt this translation. If this is a node that the
101+
/// pruner already knows how to handle, it will return the translated
102+
/// DINodeAttr.
103+
/// - After a node is successfully translated by the main translator, call
104+
/// `finalizeTranslation` to save the translated result with the pruner, and
105+
/// give it a chance to further modify the result.
106+
/// - Regardless of success or failure by the main translator, always call
107+
/// `popTranslationStack` at the end of translating a node. This is
108+
/// necessary to keep the internal book-keeping in sync.
109+
///
110+
/// This helper maintains an internal cache so that no recursive type will
111+
/// be translated more than once by the main translator.
112+
/// This internal cache is different from the cache maintained by the main
113+
/// translator because it may store nodes that are not self-contained (i.e.
114+
/// contain unbounded recursive self-references).
115+
class RecursionPruner {
116+
public:
117+
RecursionPruner(MLIRContext *context) : context(context) {}
118+
119+
/// If this node is a recursive instance that was previously seen, returns a
120+
/// self-reference. If this node was previously cached, returns the cached
121+
/// result. Otherwise, returns null attr, and a translation stack frame is
122+
/// created for this node. Expects `finalizeTranslation` &
123+
/// `popTranslationStack` to be called on this node later.
124+
DINodeAttr pruneOrPushTranslationStack(llvm::DINode *node);
125+
126+
/// Register the translated result of `node`. Returns the finalized result
127+
/// (with recId if recursive) and whether the result is self-contained
128+
/// (i.e. contains no unbound self-refs).
129+
std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
130+
DINodeAttr result);
131+
132+
/// Pop off a frame from the translation stack after a node is done being
133+
/// translated.
134+
void popTranslationStack(llvm::DINode *node);
135+
136+
private:
137+
/// Returns the cached result (if exists) or null.
138+
/// The cache entry will be removed if not all of its dependent self-refs
139+
/// exists.
140+
DINodeAttr lookup(llvm::DINode *node);
141+
142+
MLIRContext *context;
143+
144+
/// A cached translation that contains the translated attribute as well
145+
/// as any unbound self-references that it depends on.
146+
struct DependentTranslation {
147+
/// The translated attr. May contain unbound self-references for other
148+
/// recursive attrs.
149+
DINodeAttr attr;
150+
/// The set of unbound self-refs that this cached entry refers to. All
151+
/// these self-refs must exist for the cached entry to be valid.
152+
DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
153+
};
154+
/// A mapping between LLVM debug metadata and the corresponding attribute.
155+
/// Only contains those with unboundSelfRefs. Fully self-contained attrs
156+
/// will be cached by the outer main translator.
157+
DenseMap<llvm::DINode *, DependentTranslation> dependentCache;
158+
159+
/// Each potentially recursive node will have a TranslationState pushed onto
160+
/// the `translationStack` to keep track of whether this node is actually
161+
/// recursive (i.e. has self-references inside), and other book-keeping.
162+
struct TranslationState {
163+
/// The rec-self if this node is indeed a recursive node (i.e. another
164+
/// instance of itself is seen while translating it). Null if this node
165+
/// has not been seen again deeper in the translation stack.
166+
DIRecursiveTypeAttrInterface recSelf;
167+
/// All the unbound recursive self references in this layer of the
168+
/// translation stack.
169+
DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
170+
};
171+
/// A stack that stores the metadata nodes that are being traversed. The
172+
/// stack is used to handle cyclic dependencies during metadata translation.
173+
/// Each node is pushed with an empty TranslationState. If it is ever seen
174+
/// later when the stack is deeper, the node is recursive, and its
175+
/// TranslationState is assigned a recSelf.
176+
llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
177+
178+
/// A mapping between DINodes that are recursive, and their assigned recId.
179+
/// This is kept so that repeated occurrences of the same node can reuse the
180+
/// same ID and be deduplicated.
181+
DenseMap<llvm::DINode *, DistinctAttr> nodeToRecId;
182+
};
183+
RecursionPruner recursionPruner;
107184

108185
MLIRContext *context;
109186
ModuleOp mlirModule;

mlir/lib/Target/LLVMIR/DebugTranslation.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,15 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
216216
llvm::DIType *
217217
DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) {
218218
DistinctAttr recursiveId = attr.getRecId();
219-
if (attr.isRecSelf()) {
220-
auto *iter = recursiveTypeMap.find(recursiveId);
221-
assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");
219+
if (auto *iter = recursiveTypeMap.find(recursiveId);
220+
iter != recursiveTypeMap.end()) {
222221
return iter->second;
222+
} else {
223+
assert(!attr.isRecSelf() && "unbound DI recursive self type");
223224
}
224225

225226
auto setRecursivePlaceholder = [&](llvm::DIType *placeholder) {
226-
[[maybe_unused]] auto [iter, inserted] =
227-
recursiveTypeMap.try_emplace(recursiveId, placeholder);
228-
(void)iter;
229-
(void)inserted;
230-
assert(inserted && "illegal reuse of recursive id");
227+
recursiveTypeMap.try_emplace(recursiveId, placeholder);
231228
};
232229

233230
llvm::DIType *result =

0 commit comments

Comments
 (0)