@@ -18240,65 +18240,211 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
18240
18240
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18241
18241
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18242
18242
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18243
- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18243
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18244
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18245
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18246
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18247
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18248
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18249
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18250
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18251
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18252
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18253
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18254
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18255
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18256
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18257
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18258
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18259
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18260
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18261
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18262
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18263
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18264
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18265
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18266
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18267
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18268
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18269
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18270
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18271
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18272
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18273
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18274
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18275
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18276
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18277
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18278
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18279
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18280
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18281
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18282
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18283
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18284
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18285
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18286
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18287
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
18244
18288
18245
18289
// These operations perform a matrix multiplication and accumulation of
18246
18290
// the form:
18247
18291
// D = A * B + C
18248
- // The return type always matches the type of matrix C.
18249
- unsigned ArgForMatchingRetType;
18292
+ // We need to specify one type for matrices AB and one for matrices CD.
18293
+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18294
+ // Some intrinsics expect "false" as an extra bool argument.
18295
+ bool AppendExtraBoolArg = false;
18250
18296
unsigned BuiltinWMMAOp;
18251
18297
18252
18298
switch (BuiltinID) {
18253
18299
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
18254
18300
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18255
- ArgForMatchingRetType = 2;
18301
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18302
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18303
+ ArgsForMatchingMatrixTypes = {0, 2};
18256
18304
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
18257
18305
break;
18258
18306
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
18259
18307
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18260
- ArgForMatchingRetType = 2;
18308
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18309
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18310
+ ArgsForMatchingMatrixTypes = {0, 2};
18261
18311
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
18262
18312
break;
18313
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18314
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18315
+ AppendExtraBoolArg = true;
18316
+ LLVM_FALLTHROUGH;
18263
18317
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
18264
18318
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18265
- ArgForMatchingRetType = 2 ;
18319
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18266
18320
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
18267
18321
break;
18322
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18323
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18324
+ AppendExtraBoolArg = true;
18325
+ LLVM_FALLTHROUGH;
18268
18326
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
18269
18327
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18270
- ArgForMatchingRetType = 2 ;
18328
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18271
18329
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
18272
18330
break;
18273
18331
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
18274
18332
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18275
- ArgForMatchingRetType = 2 ;
18333
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18276
18334
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
18277
18335
break;
18278
18336
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
18279
18337
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18280
- ArgForMatchingRetType = 2 ;
18338
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18281
18339
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
18282
18340
break;
18283
18341
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18284
18342
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18285
- ArgForMatchingRetType = 4;
18343
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18344
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18345
+ ArgsForMatchingMatrixTypes = {1, 4};
18286
18346
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
18287
18347
break;
18288
18348
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18289
18349
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18290
- ArgForMatchingRetType = 4;
18350
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18351
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18352
+ ArgsForMatchingMatrixTypes = {1, 4};
18291
18353
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
18292
18354
break;
18355
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18356
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18357
+ ArgsForMatchingMatrixTypes = {0, 2};
18358
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18359
+ break;
18360
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18361
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18362
+ ArgsForMatchingMatrixTypes = {0, 2};
18363
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18364
+ break;
18365
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18366
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18367
+ ArgsForMatchingMatrixTypes = {0, 2};
18368
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18369
+ break;
18370
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18371
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18372
+ ArgsForMatchingMatrixTypes = {0, 2};
18373
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18374
+ break;
18375
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18376
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18377
+ ArgsForMatchingMatrixTypes = {1, 4};
18378
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18379
+ break;
18380
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18381
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18382
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18383
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18384
+ break;
18385
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18386
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18387
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18388
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18389
+ break;
18390
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18391
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18392
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18393
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18394
+ break;
18395
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18396
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18397
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18398
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18399
+ break;
18400
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18401
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18402
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18403
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18404
+ break;
18405
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18406
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18407
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18408
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18409
+ break;
18410
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18411
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18412
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18413
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18414
+ break;
18415
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18416
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18417
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18418
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18419
+ break;
18420
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18421
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18422
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18423
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18424
+ break;
18425
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18426
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18427
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18428
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18429
+ break;
18430
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18431
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18432
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18433
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18434
+ break;
18293
18435
}
18294
18436
18295
18437
SmallVector<Value *, 6> Args;
18296
18438
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
18297
18439
Args.push_back(EmitScalarExpr(E->getArg(i)));
18440
+ if (AppendExtraBoolArg)
18441
+ Args.push_back(Builder.getFalse());
18298
18442
18299
- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18300
- {Args[ArgForMatchingRetType]->getType()});
18443
+ SmallVector<llvm::Type *, 6> ArgTypes;
18444
+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18445
+ ArgTypes.push_back(Args[ArgIdx]->getType());
18301
18446
18447
+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
18302
18448
return Builder.CreateCall(F, Args);
18303
18449
}
18304
18450
0 commit comments