Skip to content

Commit 4c03e15

Browse files
committed
generic recursive type with importer impl
1 parent 3761ad0 commit 4c03e15

File tree

6 files changed

+235
-76
lines changed

6 files changed

+235
-76
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,70 @@ def LLVM_DIExpressionAttr : LLVM_Attr<"DIExpression", "di_expression"> {
301301
let assemblyFormat = "`<` ( `[` $operations^ `]` ) : (``)? `>`";
302302
}
303303

304+
//===----------------------------------------------------------------------===//
305+
// DIRecursiveTypeAttr
306+
//===----------------------------------------------------------------------===//
307+
308+
def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
309+
/*traits=*/[], "DITypeAttr"> {
310+
let description = [{
311+
This attribute enables recursive DITypes. There are two modes for this
312+
attribute.
313+
314+
1. If `baseType` is present:
315+
- This type is considered a recursive declaration (rec-decl).
316+
- The `baseType` is a self-recursive type identified by the `id` field.
317+
318+
2. If `baseType` is not present:
319+
- This type is considered a recursive self reference (rec-self).
320+
- This DIRecursiveType itself is a placeholder type that should be
321+
conceptually replaced with the closet parent DIRecursiveType with the
322+
same `id` field.
323+
324+
e.g. To represent a linked list struct:
325+
326+
#rec_self = di_recursive_type<self_id = 0>
327+
#ptr = di_derived_type<baseType: #rec_self, ...>
328+
#field = di_derived_type<name = "next", baseType: #ptr, ...>
329+
#struct = di_composite_type<name = "Node", elements: #field, ...>
330+
#rec = di_recursive_type<self_id = 0, baseType: #struct>
331+
332+
#var = di_local_variable<type = #struct_type, ...>
333+
334+
Note that the a rec-self without an outer rec-decl with the same id is
335+
conceptually the same as an "unbound" variable. The context needs to provide
336+
meaning to the rec-self.
337+
338+
This can be avoided by calling the `getUnfoldedBaseType()` method on a
339+
rec-decl, which returns the `baseType` with all matching rec-self instances
340+
replaced with this rec-decl again. This is useful, for example, for fetching
341+
a field out of a recursive struct and maintaining the legality of the field
342+
type.
343+
}];
344+
345+
let parameters = (ins
346+
"DistinctAttr":$id,
347+
OptionalParameter<"DITypeAttr">:$baseType
348+
);
349+
350+
let builders = [
351+
AttrBuilderWithInferredContext<(ins "DistinctAttr":$id), [{
352+
return $_get(id.getContext(), id, nullptr);
353+
}]>
354+
];
355+
356+
let extraClassDeclaration = [{
357+
/// Whether this node represents a self-reference.
358+
bool isRecSelf() { return !getBaseType(); }
359+
360+
/// Get the `baseType` with all instances of the corresponding rec-self
361+
/// replaced with this attribute. This can only be called if `!isRecSelf()`.
362+
DITypeAttr getUnfoldedBaseType();
363+
}];
364+
365+
let assemblyFormat = "`<` struct(params) `>`";
366+
}
367+
304368
//===----------------------------------------------------------------------===//
305369
// DINullTypeAttr
306370
//===----------------------------------------------------------------------===//
@@ -526,14 +590,15 @@ def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
526590
OptionalParameter<"unsigned">:$line,
527591
OptionalParameter<"unsigned">:$scopeLine,
528592
"DISubprogramFlags":$subprogramFlags,
529-
OptionalParameter<"DISubroutineTypeAttr">:$type
593+
OptionalParameter<"DIRecursiveTypeAttrOf<DISubroutineTypeAttr>">:$type
530594
);
531595
let builders = [
532596
AttrBuilderWithInferredContext<(ins
533597
"DistinctAttr":$id, "DICompileUnitAttr":$compileUnit,
534598
"DIScopeAttr":$scope, "StringRef":$name, "StringRef":$linkageName,
535599
"DIFileAttr":$file, "unsigned":$line, "unsigned":$scopeLine,
536-
"DISubprogramFlags":$subprogramFlags, "DISubroutineTypeAttr":$type
600+
"DISubprogramFlags":$subprogramFlags,
601+
"DIRecursiveTypeAttrOf<DISubroutineTypeAttr>":$type
537602
), [{
538603
MLIRContext *ctx = file.getContext();
539604
return $_get(ctx, id, compileUnit, scope, StringAttr::get(ctx, name),

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class DILocalScopeAttr : public DIScopeAttr {
5252
};
5353

5454
/// This class represents a LLVM attribute that describes a debug info type.
55-
class DITypeAttr : public DINodeAttr {
55+
class DITypeAttr : public DIScopeAttr {
5656
public:
57-
using DINodeAttr::DINodeAttr;
57+
using DIScopeAttr::DIScopeAttr;
5858

5959
/// Support LLVM type casting.
6060
static bool classof(Attribute attr);
@@ -74,6 +74,10 @@ class TBAANodeAttr : public Attribute {
7474
}
7575
};
7676

77+
// Forward declare.
78+
template <typename BaseType>
79+
class DIRecursiveTypeAttrOf;
80+
7781
// Inline the LLVM generated Linkage enum and utility.
7882
// This is only necessary to isolate the "enum generated code" from the
7983
// attribute definition itself.
@@ -87,4 +91,31 @@ using linkage::Linkage;
8791
#define GET_ATTRDEF_CLASSES
8892
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
8993

94+
namespace mlir {
95+
namespace LLVM {
96+
/// This class represents either a concrete attr, or a DIRecursiveTypeAttr
97+
/// containing such a concrete attr.
98+
template <typename BaseType>
99+
class DIRecursiveTypeAttrOf : public DITypeAttr {
100+
public:
101+
static_assert(std::is_base_of_v<DITypeAttr, BaseType>);
102+
using DITypeAttr::DITypeAttr;
103+
/// Support LLVM type casting.
104+
static bool classof(Attribute attr) {
105+
if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(attr))
106+
return llvm::isa<BaseType>(rec.getBaseType());
107+
return llvm::isa<BaseType>(attr);
108+
}
109+
110+
DIRecursiveTypeAttrOf(BaseType baseType) : DITypeAttr(baseType) {}
111+
112+
BaseType getUnfoldedBaseType() {
113+
if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(this))
114+
return llvm::cast<BaseType>(rec.getUnfoldedBaseType());
115+
return llvm::cast<BaseType>(this);
116+
}
117+
};
118+
} // namespace LLVM
119+
} // namespace mlir
120+
90121
#endif // MLIR_DIALECT_LLVMIR_LLVMATTRS_H_

mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ bool DINodeAttr::classof(Attribute attr) {
6868
//===----------------------------------------------------------------------===//
6969

7070
bool DIScopeAttr::classof(Attribute attr) {
71-
return llvm::isa<DICompileUnitAttr, DICompositeTypeAttr, DIFileAttr,
72-
DILocalScopeAttr, DIModuleAttr, DINamespaceAttr>(attr);
71+
return llvm::isa<DICompileUnitAttr, DIFileAttr, DILocalScopeAttr,
72+
DIModuleAttr, DINamespaceAttr, DITypeAttr>(attr);
7373
}
7474

7575
//===----------------------------------------------------------------------===//
@@ -86,8 +86,9 @@ bool DILocalScopeAttr::classof(Attribute attr) {
8686
//===----------------------------------------------------------------------===//
8787

8888
bool DITypeAttr::classof(Attribute attr) {
89-
return llvm::isa<DINullTypeAttr, DIBasicTypeAttr, DICompositeTypeAttr,
90-
DIDerivedTypeAttr, DISubroutineTypeAttr>(attr);
89+
return llvm::isa<DIRecursiveTypeAttr, DINullTypeAttr, DIBasicTypeAttr,
90+
DICompositeTypeAttr, DIDerivedTypeAttr,
91+
DISubroutineTypeAttr>(attr);
9192
}
9293

9394
//===----------------------------------------------------------------------===//
@@ -185,6 +186,19 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
185186
});
186187
}
187188

189+
//===----------------------------------------------------------------------===//
190+
// DIRecursiveTypeAttr
191+
//===----------------------------------------------------------------------===//
192+
DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
193+
assert(!isRecSelf() && "cannot get baseType from a rec-self type");
194+
return llvm::cast<DITypeAttr>(getBaseType().replace(
195+
[&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
196+
if (rec.getId() == getId())
197+
return *this;
198+
return std::nullopt;
199+
}));
200+
}
201+
188202
//===----------------------------------------------------------------------===//
189203
// TargetFeaturesAttr
190204
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/DebugImporter.cpp

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) {
5151
std::optional<DIEmissionKind> emissionKind =
5252
symbolizeDIEmissionKind(node->getEmissionKind());
5353
return DICompileUnitAttr::get(
54-
context, DistinctAttr::create(UnitAttr::get(context)),
55-
node->getSourceLanguage(), translate(node->getFile()),
56-
getStringAttrOrNull(node->getRawProducer()), node->isOptimized(),
57-
emissionKind.value());
54+
context, getOrCreateDistinctID(node), node->getSourceLanguage(),
55+
translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()),
56+
node->isOptimized(), emissionKind.value());
5857
}
5958

6059
DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
@@ -64,11 +63,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
6463
assert(element && "expected a non-null element type");
6564
elements.push_back(translate(element));
6665
}
67-
// Drop the elements parameter if a cyclic dependency is detected. We
68-
// currently cannot model these cycles and thus drop the parameter if
69-
// required. A cyclic dependency is detected if one of the element nodes
70-
// translates to a nullptr since the node is already on the translation stack.
71-
// TODO: Support debug metadata with cyclic dependencies.
66+
// Drop the elements parameter if any of the elements are invalid.
7267
if (llvm::is_contained(elements, nullptr))
7368
elements.clear();
7469
DITypeAttr baseType = translate(node->getBaseType());
@@ -84,7 +79,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
8479
}
8580

8681
DIDerivedTypeAttr DebugImporter::translateImpl(llvm::DIDerivedType *node) {
87-
// Return nullptr if the base type is a cyclic dependency.
82+
// Return nullptr if the base type invalid.
8883
DITypeAttr baseType = translate(node->getBaseType());
8984
if (node->getBaseType() && !baseType)
9085
return nullptr;
@@ -179,14 +174,14 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
179174
// Only definitions require a distinct identifier.
180175
mlir::DistinctAttr id;
181176
if (node->isDistinct())
182-
id = DistinctAttr::create(UnitAttr::get(context));
177+
id = getOrCreateDistinctID(node);
183178
std::optional<DISubprogramFlags> subprogramFlags =
184179
symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
185-
// Return nullptr if the scope or type is a cyclic dependency.
186-
DIScopeAttr scope = translate(node->getScope());
180+
// Return nullptr if the scope or type is invalid.
181+
DIScopeAttr scope = cast<DIScopeAttr>(translate(node->getScope()));
187182
if (node->getScope() && !scope)
188183
return nullptr;
189-
DISubroutineTypeAttr type = translate(node->getType());
184+
DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
190185
if (node->getType() && !type)
191186
return nullptr;
192187
return DISubprogramAttr::get(context, id, translate(node->getUnit()), scope,
@@ -229,7 +224,7 @@ DebugImporter::translateImpl(llvm::DISubroutineType *node) {
229224
}
230225
types.push_back(translate(type));
231226
}
232-
// Return nullptr if any of the types is a cyclic dependency.
227+
// Return nullptr if any of the types is invalid.
233228
if (llvm::is_contained(types, nullptr))
234229
return nullptr;
235230
return DISubroutineTypeAttr::get(context, node->getCC(), types);
@@ -247,12 +242,47 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
247242
if (DINodeAttr attr = nodeToAttr.lookup(node))
248243
return attr;
249244

250-
// Return nullptr if a cyclic dependency is detected since the same node is
251-
// being traversed twice. This check avoids infinite recursion if the debug
252-
// metadata contains cycles.
253-
if (!translationStack.insert(node))
254-
return nullptr;
255-
auto guard = llvm::make_scope_exit([&]() { translationStack.pop_back(); });
245+
// If a cyclic dependency is detected since the same node is being traversed
246+
// twice, emit a recursive self type, and mark the duplicate node on the
247+
// translationStack so it can emit a recursive decl type.
248+
auto *typeNode = dyn_cast<llvm::DIType>(node);
249+
if (typeNode) {
250+
auto [iter, inserted] = typeTranslationStack.try_emplace(typeNode, nullptr);
251+
if (!inserted) {
252+
// The original node may have already been assigned a recursive ID from
253+
// a different self-reference. Use that if possible.
254+
DistinctAttr recId = iter->second;
255+
if (!recId) {
256+
recId = DistinctAttr::create(UnitAttr::get(context));
257+
iter->second = recId;
258+
}
259+
unboundRecursiveSelfRefs.back().insert(recId);
260+
return DIRecursiveTypeAttr::get(recId);
261+
}
262+
} else {
263+
bool inserted =
264+
nonTypeTranslationStack.insert({node, typeTranslationStack.size()});
265+
assert(inserted && "recursion is only supported via DITypes");
266+
}
267+
268+
unboundRecursiveSelfRefs.emplace_back();
269+
270+
auto guard = llvm::make_scope_exit([&]() {
271+
if (typeNode)
272+
typeTranslationStack.pop_back();
273+
else
274+
nonTypeTranslationStack.pop_back();
275+
276+
// Copy unboundRecursiveSelfRefs down to the previous level.
277+
if (unboundRecursiveSelfRefs.size() == 1)
278+
assert(unboundRecursiveSelfRefs.back().empty() &&
279+
"internal error: unbound recursive self reference at top level.");
280+
else
281+
unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
282+
unboundRecursiveSelfRefs.back().begin(),
283+
unboundRecursiveSelfRefs.back().end());
284+
unboundRecursiveSelfRefs.pop_back();
285+
});
256286

257287
// Convert the debug metadata if possible.
258288
auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -289,7 +319,21 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
289319
return nullptr;
290320
};
291321
if (DINodeAttr attr = translateNode(node)) {
292-
nodeToAttr.insert({node, attr});
322+
// If this node was marked as recursive, wrap with a recursive type.
323+
if (typeNode) {
324+
if (DistinctAttr id = typeTranslationStack.lookup(typeNode)) {
325+
DITypeAttr typeAttr = cast<DITypeAttr>(attr);
326+
attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
327+
328+
// Remove the unbound recursive attr.
329+
AttrTypeReplacer replacer;
330+
unboundRecursiveSelfRefs.back().erase(id);
331+
}
332+
}
333+
334+
// Only cache fully self-contained nodes.
335+
if (unboundRecursiveSelfRefs.back().empty())
336+
nodeToAttr.try_emplace(node, attr);
293337
return attr;
294338
}
295339
return nullptr;
@@ -346,3 +390,10 @@ StringAttr DebugImporter::getStringAttrOrNull(llvm::MDString *stringNode) {
346390
return StringAttr();
347391
return StringAttr::get(context, stringNode->getString());
348392
}
393+
394+
DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
395+
DistinctAttr &id = nodeToDistinctAttr[node];
396+
if (!id)
397+
id = DistinctAttr::create(UnitAttr::get(context));
398+
return id;
399+
}

mlir/lib/Target/LLVMIR/DebugImporter.h

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,15 @@ class DebugImporter {
5252
/// Infers the metadata type and translates it to MLIR.
5353
template <typename DINodeT>
5454
auto translate(DINodeT *node) {
55-
// Infer the MLIR type from the LLVM metadata type.
56-
using MLIRTypeT = decltype(translateImpl(node));
55+
// Infer the result MLIR type from the LLVM metadata type.
56+
// If the result is a DIType, it can also be wrapped in a recursive type,
57+
// so the result is wrapped into a DIRecursiveTypeAttrOf.
58+
// Otherwise, the exact result type is used.
59+
constexpr bool isDIType = std::is_base_of_v<llvm::DIType, DINodeT>;
60+
using RawMLIRTypeT = decltype(translateImpl(node));
61+
using MLIRTypeT =
62+
std::conditional_t<isDIType, DIRecursiveTypeAttrOf<RawMLIRTypeT>,
63+
RawMLIRTypeT>;
5764
return cast_or_null<MLIRTypeT>(
5865
translate(static_cast<llvm::DINode *>(node)));
5966
}
@@ -82,12 +89,27 @@ class DebugImporter {
8289
/// null attribute otherwise.
8390
StringAttr getStringAttrOrNull(llvm::MDString *stringNode);
8491

92+
DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
93+
8594
/// A mapping between LLVM debug metadata and the corresponding attribute.
8695
DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
87-
88-
/// A stack that stores the metadata nodes that are being traversed. The stack
89-
/// is used to detect cyclic dependencies during the metadata translation.
90-
SetVector<llvm::DINode *> translationStack;
96+
/// A mapping between LLVM debug metadata and the distinct ID attr for DI
97+
/// nodes that require distinction.
98+
DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
99+
100+
/// A stack that stores the metadata type nodes that are being traversed. The
101+
/// stack is used to detect cyclic dependencies during the metadata
102+
/// translation. Nodes are pushed with a null value. If it is ever seen twice,
103+
/// it is given a DistinctAttr, indicating that it is a recursive node and
104+
/// should take on that DistinctAttr as ID.
105+
llvm::MapVector<llvm::DIType *, DistinctAttr> typeTranslationStack;
106+
/// All the unbound recursive self references in the translation stack.
107+
SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
108+
/// A stack that stores the non-type metadata nodes that are being traversed.
109+
/// Each node is associated with the size of the `typeTranslationStack` at the
110+
/// time of push. This is used to identify a recursion purely in the non-type
111+
/// metadata nodes, which is not supported yet.
112+
SetVector<std::pair<llvm::DINode *, unsigned>> nonTypeTranslationStack;
91113

92114
MLIRContext *context;
93115
ModuleOp mlirModule;

0 commit comments

Comments
 (0)