Skip to content

Commit 78e7573

Browse files
committed
AMD SPIR-V work graphs extensions
1 parent adffd31 commit 78e7573

File tree

59 files changed

+3187
-112
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+3187
-112
lines changed

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,10 @@ bool IsHLSLObjectWithImplicitMemberAccess(clang::QualType type);
488488
bool IsHLSLObjectWithImplicitROMemberAccess(clang::QualType type);
489489
bool IsHLSLRWNodeInputRecordType(clang::QualType type);
490490
bool IsHLSLRONodeInputRecordType(clang::QualType type);
491+
bool IsHLSLDispatchNodeInputRecordType(clang::QualType type);
492+
bool IsHLSLNodeRecordArrayType(clang::QualType type);
491493
bool IsHLSLNodeOutputType(clang::QualType type);
494+
bool IsHLSLEmptyNodeRecordType(clang::QualType type);
492495

493496
DXIL::NodeIOKind GetNodeIOType(clang::QualType type);
494497

@@ -498,6 +501,8 @@ bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT);
498501
bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT);
499502
bool IsHLSLAggregateType(clang::QualType type);
500503
clang::QualType GetHLSLResourceResultType(clang::QualType type);
504+
clang::QualType GetHLSLNodeIOResultType(clang::ASTContext &astContext,
505+
clang::QualType type);
501506
unsigned GetHLSLResourceTemplateUInt(clang::QualType type);
502507
bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context,
503508
clang::QualType type);

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ enum class Extension {
5757
KHR_ray_query,
5858
EXT_shader_image_int64,
5959
KHR_physical_storage_buffer,
60+
AMD_shader_enqueue,
6061
KHR_vulkan_memory_model,
6162
NV_compute_shader_derivatives,
6263
KHR_compute_shader_derivatives,

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,25 @@ class SpirvBuilder {
437437
QualType resultType, NonSemanticDebugPrintfInstructions instId,
438438
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation);
439439

440+
SpirvInstruction *createIsNodePayloadValid(SpirvInstruction *payloadArray,
441+
SpirvInstruction *nodeIndex,
442+
SourceLocation);
443+
444+
SpirvInstruction *createNodePayloadArrayLength(SpirvInstruction *payloadArray,
445+
SourceLocation);
446+
447+
SpirvInstruction *createAllocateNodePayloads(QualType resultType,
448+
spv::Scope allocationScope,
449+
SpirvInstruction *shaderIndex,
450+
SpirvInstruction *recordCount,
451+
SourceLocation);
452+
453+
void createEnqueueOutputNodePayloads(SpirvInstruction *payload,
454+
SourceLocation);
455+
456+
SpirvInstruction *createFinishWritingNodePayload(SpirvInstruction *payload,
457+
SourceLocation);
458+
440459
/// \brief Creates an OpMemoryBarrier or OpControlBarrier instruction with the
441460
/// given flags. If execution scope (exec) is provided, an OpControlBarrier
442461
/// is created; otherwise an OpMemoryBarrier is created.
@@ -766,6 +785,7 @@ class SpirvBuilder {
766785
llvm::ArrayRef<SpirvConstant *> constituents,
767786
bool specConst = false);
768787
SpirvConstant *getConstantNull(QualType);
788+
SpirvConstant *getConstantString(llvm::StringRef str, bool specConst = false);
769789
SpirvUndef *getUndef(QualType);
770790

771791
SpirvString *createString(llvm::StringRef str);

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ struct RuntimeArrayTypeMapInfo {
101101
}
102102
};
103103

104+
// Provides DenseMapInfo for NodePayloadArrayType so we can create a DenseSet of
105+
// node payload array types.
106+
struct NodePayloadArrayTypeMapInfo {
107+
static inline NodePayloadArrayType *getEmptyKey() { return nullptr; }
108+
static inline NodePayloadArrayType *getTombstoneKey() { return nullptr; }
109+
static unsigned getHashValue(const NodePayloadArrayType *Val) {
110+
return llvm::hash_combine(Val->getElementType(), Val->getNodeDecl());
111+
}
112+
static bool isEqual(const NodePayloadArrayType *LHS,
113+
const NodePayloadArrayType *RHS) {
114+
// Either both are null, or both should have the same underlying type.
115+
return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
116+
}
117+
};
118+
104119
// Provides DenseMapInfo for ImageType so we can create a DenseSet of
105120
// image types.
106121
struct ImageTypeMapInfo {
@@ -273,6 +288,9 @@ class SpirvContext {
273288
const RuntimeArrayType *
274289
getRuntimeArrayType(const SpirvType *elemType,
275290
llvm::Optional<uint32_t> arrayStride);
291+
const NodePayloadArrayType *
292+
getNodePayloadArrayType(const SpirvType *elemType,
293+
const ParmVarDecl *nodeDecl);
276294

277295
const StructType *getStructType(
278296
llvm::ArrayRef<StructType::FieldInfo> fields, llvm::StringRef name,
@@ -349,6 +367,7 @@ class SpirvContext {
349367
bool isDS() const { return curShaderModelKind == ShaderModelKind::Domain; }
350368
bool isCS() const { return curShaderModelKind == ShaderModelKind::Compute; }
351369
bool isLib() const { return curShaderModelKind == ShaderModelKind::Library; }
370+
bool isNode() const { return curShaderModelKind == ShaderModelKind::Node; }
352371
bool isRay() const {
353372
return curShaderModelKind >= ShaderModelKind::RayGeneration &&
354373
curShaderModelKind <= ShaderModelKind::Callable;
@@ -440,6 +459,31 @@ class SpirvContext {
440459
instructionsWithLoweredType.end();
441460
}
442461

462+
void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
463+
auto iter = dispatchGridIndices.find(decl);
464+
if (iter == dispatchGridIndices.end()) {
465+
dispatchGridIndices[decl] = index;
466+
}
467+
}
468+
469+
llvm::Optional<unsigned> getDispatchGridIndex(const RecordDecl *decl) {
470+
auto iter = dispatchGridIndices.find(decl);
471+
if (iter != dispatchGridIndices.end()) {
472+
return iter->second;
473+
}
474+
return llvm::None;
475+
}
476+
477+
void registerNodeDeclPayloadType(const NodePayloadArrayType *type,
478+
const ParmVarDecl *decl) {
479+
nodeDecls[decl] = type;
480+
}
481+
482+
const NodePayloadArrayType *getNodeDeclPayloadType(const ParmVarDecl *decl) {
483+
auto iter = nodeDecls.find(decl);
484+
return iter == nodeDecls.end() ? nullptr : iter->second;
485+
}
486+
443487
private:
444488
/// \brief The allocator used to create SPIR-V entity objects.
445489
///
@@ -484,6 +528,8 @@ class SpirvContext {
484528
llvm::DenseSet<const ArrayType *, ArrayTypeMapInfo> arrayTypes;
485529
llvm::DenseSet<const RuntimeArrayType *, RuntimeArrayTypeMapInfo>
486530
runtimeArrayTypes;
531+
llvm::DenseSet<const NodePayloadArrayType *, NodePayloadArrayTypeMapInfo>
532+
nodePayloadArrayTypes;
487533
llvm::SmallVector<const StructType *, 8> structTypes;
488534
llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
489535
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
@@ -510,6 +556,9 @@ class SpirvContext {
510556
llvm::StringMap<RichDebugInfo> debugInfo;
511557
SpirvDebugInstruction *currentLexicalScope;
512558

559+
// Mapping from graphics node input record types to member decoration maps.
560+
llvm::MapVector<const RecordDecl *, unsigned> dispatchGridIndices;
561+
513562
// Mapping from SPIR-V type to debug type instruction.
514563
// The purpose is not to generate several DebugType* instructions for the same
515564
// type if the type is used for several variables.
@@ -541,6 +590,10 @@ class SpirvContext {
541590

542591
// Set of instructions that already have lowered SPIR-V types.
543592
llvm::DenseSet<const SpirvInstruction *> instructionsWithLoweredType;
593+
594+
// Mapping from shader entry function parameter declaration to node payload
595+
// array type.
596+
llvm::MapVector<const ParmVarDecl *, const NodePayloadArrayType *> nodeDecls;
544597
};
545598

546599
} // end namespace spirv

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class SpirvInstruction {
7070
IK_ConstantInteger,
7171
IK_ConstantFloat,
7272
IK_ConstantComposite,
73+
IK_ConstantString,
7374
IK_ConstantNull,
7475

7576
// Pointer <-> uint conversions.
@@ -168,6 +169,13 @@ class SpirvInstruction {
168169
IK_DebugTypeMember,
169170
IK_DebugTypeTemplate,
170171
IK_DebugTypeTemplateParameter,
172+
173+
// For workgraph instructions
174+
IK_IsNodePayloadValid,
175+
IK_NodePayloadArrayLength,
176+
IK_AllocateNodePayloads,
177+
IK_EnqueueNodePayloads,
178+
IK_FinishWritingNodePayload,
171179
};
172180

173181
// All instruction classes should include a releaseMemory method.
@@ -443,9 +451,13 @@ class SpirvExecutionMode : public SpirvExecutionModeBase {
443451

444452
bool invokeVisitor(Visitor *v) override;
445453

454+
SpirvFunction *getEntryPoint() const { return entryPoint; }
455+
spv::ExecutionMode getExecutionMode() const { return execMode; }
446456
llvm::ArrayRef<uint32_t> getParams() const { return params; }
447457

448458
private:
459+
SpirvFunction *entryPoint;
460+
spv::ExecutionMode execMode;
449461
llvm::SmallVector<uint32_t, 4> params;
450462
};
451463

@@ -1059,6 +1071,119 @@ class SpirvBarrier : public SpirvInstruction {
10591071
llvm::Optional<spv::Scope> executionScope;
10601072
};
10611073

1074+
/// \brief OpIsNodePayloadValidAMDX instruction
1075+
class SpirvIsNodePayloadValid : public SpirvInstruction {
1076+
public:
1077+
SpirvIsNodePayloadValid(QualType resultType, SourceLocation loc,
1078+
SpirvInstruction *payloadArray,
1079+
SpirvInstruction *nodeIndex);
1080+
1081+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIsNodePayloadValid)
1082+
1083+
// For LLVM-style RTTI
1084+
static bool classof(const SpirvInstruction *inst) {
1085+
return inst->getKind() == IK_IsNodePayloadValid;
1086+
}
1087+
1088+
bool invokeVisitor(Visitor *v) override;
1089+
1090+
SpirvInstruction *getPayloadArray() { return payloadArray; }
1091+
SpirvInstruction *getNodeIndex() { return nodeIndex; }
1092+
1093+
private:
1094+
SpirvInstruction *payloadArray;
1095+
SpirvInstruction *nodeIndex;
1096+
};
1097+
1098+
/// \brief OpNodePayloadArrayLengthAMDX instruction
1099+
class SpirvNodePayloadArrayLength : public SpirvInstruction {
1100+
public:
1101+
SpirvNodePayloadArrayLength(QualType resultType, SourceLocation loc,
1102+
SpirvInstruction *payloadArray);
1103+
1104+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNodePayloadArrayLength)
1105+
1106+
// For LLVM-style RTTI
1107+
static bool classof(const SpirvInstruction *inst) {
1108+
return inst->getKind() == IK_NodePayloadArrayLength;
1109+
}
1110+
1111+
bool invokeVisitor(Visitor *v) override;
1112+
1113+
SpirvInstruction *getPayloadArray() { return payloadArray; }
1114+
1115+
private:
1116+
SpirvInstruction *payloadArray;
1117+
};
1118+
1119+
/// \brief OpAllocateNodePayloadsAMDX instruction
1120+
class SpirvAllocateNodePayloads : public SpirvInstruction {
1121+
public:
1122+
SpirvAllocateNodePayloads(QualType resultType, SourceLocation loc,
1123+
spv::Scope allocationScope,
1124+
SpirvInstruction *shaderIndex,
1125+
SpirvInstruction *recordCount);
1126+
1127+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAllocateNodePayloads)
1128+
1129+
// For LLVM-style RTTI
1130+
static bool classof(const SpirvInstruction *inst) {
1131+
return inst->getKind() == IK_AllocateNodePayloads;
1132+
}
1133+
1134+
bool invokeVisitor(Visitor *v) override;
1135+
1136+
spv::Scope getAllocationScope() { return allocationScope; }
1137+
SpirvInstruction *getShaderIndex() { return shaderIndex; }
1138+
SpirvInstruction *getRecordCount() { return recordCount; }
1139+
1140+
private:
1141+
spv::Scope allocationScope;
1142+
SpirvInstruction *shaderIndex;
1143+
SpirvInstruction *recordCount;
1144+
};
1145+
1146+
/// \brief OpReleaseOutputNodePayloadAMDX instruction
1147+
class SpirvEnqueueNodePayloads : public SpirvInstruction {
1148+
public:
1149+
SpirvEnqueueNodePayloads(SourceLocation loc, SpirvInstruction *payload);
1150+
1151+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEnqueueNodePayloads)
1152+
1153+
// For LLVM-style RTTI
1154+
static bool classof(const SpirvInstruction *inst) {
1155+
return inst->getKind() == IK_EnqueueNodePayloads;
1156+
}
1157+
1158+
bool invokeVisitor(Visitor *v) override;
1159+
1160+
SpirvInstruction *getPayload() { return payload; }
1161+
1162+
private:
1163+
SpirvInstruction *payload;
1164+
};
1165+
1166+
/// \brief OpFinishWritingNodePayloadAMDX instruction
1167+
class SpirvFinishWritingNodePayload : public SpirvInstruction {
1168+
public:
1169+
SpirvFinishWritingNodePayload(QualType resultType, SourceLocation loc,
1170+
SpirvInstruction *payload);
1171+
1172+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFinishWritingNodePayload)
1173+
1174+
// For LLVM-style RTTI
1175+
static bool classof(const SpirvInstruction *inst) {
1176+
return inst->getKind() == IK_FinishWritingNodePayload;
1177+
}
1178+
1179+
bool invokeVisitor(Visitor *v) override;
1180+
1181+
SpirvInstruction *getPayload() { return payload; }
1182+
1183+
private:
1184+
SpirvInstruction *payload;
1185+
};
1186+
10621187
/// \brief Represents SPIR-V binary operation instructions.
10631188
///
10641189
/// This class includes:
@@ -1355,6 +1480,27 @@ class SpirvConstantNull : public SpirvConstant {
13551480
bool operator==(const SpirvConstantNull &that) const;
13561481
};
13571482

1483+
class SpirvConstantString : public SpirvConstant {
1484+
public:
1485+
SpirvConstantString(llvm::StringRef stringLiteral, bool isSpecConst = false);
1486+
1487+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantString)
1488+
1489+
// For LLVM-style RTTI
1490+
static bool classof(const SpirvInstruction *inst) {
1491+
return inst->getKind() == IK_ConstantString;
1492+
}
1493+
1494+
bool invokeVisitor(Visitor *v) override;
1495+
1496+
bool operator==(const SpirvConstantString &that) const;
1497+
1498+
llvm::StringRef getString() const { return str; }
1499+
1500+
private:
1501+
std::string str;
1502+
};
1503+
13581504
class SpirvConvertPtrToU : public SpirvInstruction {
13591505
public:
13601506
SpirvConvertPtrToU(SpirvInstruction *ptr, QualType type,

tools/clang/include/clang/SPIRV/SpirvType.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class SpirvType {
5454
TK_SampledImage,
5555
TK_Array,
5656
TK_RuntimeArray,
57+
TK_NodePayloadArrayAMD,
5758
TK_Struct,
5859
TK_Pointer,
5960
TK_ForwardPointer,
@@ -294,6 +295,26 @@ class RuntimeArrayType : public SpirvType {
294295
llvm::Optional<uint32_t> stride;
295296
};
296297

298+
class NodePayloadArrayType : public SpirvType {
299+
public:
300+
NodePayloadArrayType(const SpirvType *elemType, const ParmVarDecl *decl)
301+
: SpirvType(TK_NodePayloadArrayAMD), elementType(elemType),
302+
nodeDecl(decl) {}
303+
304+
static bool classof(const SpirvType *t) {
305+
return t->getKind() == TK_NodePayloadArrayAMD;
306+
}
307+
308+
bool operator==(const NodePayloadArrayType &that) const;
309+
310+
const SpirvType *getElementType() const { return elementType; }
311+
const ParmVarDecl *getNodeDecl() const { return nodeDecl; }
312+
313+
private:
314+
const SpirvType *elementType;
315+
const ParmVarDecl *nodeDecl;
316+
};
317+
297318
// The StructType is the lowered type that best represents what a structure type
298319
// is in SPIR-V. Contains all necessary information for properly emitting a
299320
// SPIR-V structure type.
@@ -630,6 +651,8 @@ bool SpirvType::isOrContainsType(const SpirvType *type) {
630651
return isOrContainsType<T, Bitwidth>(pointerType->getPointeeType());
631652
if (const auto *raType = dyn_cast<RuntimeArrayType>(type))
632653
return isOrContainsType<T, Bitwidth>(raType->getElementType());
654+
if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type))
655+
return isOrContainsType<T, Bitwidth>(npaType->getElementType());
633656
if (const auto *imgType = dyn_cast<ImageType>(type))
634657
return isOrContainsType<T, Bitwidth>(imgType->getSampledType());
635658
if (const auto *sampledImageType = dyn_cast<SampledImageType>(type))

0 commit comments

Comments
 (0)