@@ -328,6 +328,54 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
328
328
fn apply_attrs_callsite ( & self , bx : & mut Builder < ' _ , ' ll , ' tcx > , callsite : & ' ll Value ) ;
329
329
}
330
330
331
+ fn equate_ty < ' ll > ( cx : & CodegenCx < ' ll , ' _ > , rust_ty : & ' ll Type , llvm_ty : & ' ll Type ) -> bool {
332
+ if rust_ty == llvm_ty {
333
+ return true ;
334
+ }
335
+ match cx. type_kind ( llvm_ty) {
336
+ TypeKind :: X86_AMX => {
337
+ // we will insert casts from/to x86amx in callsite, so this is fine
338
+ if cx. type_kind ( rust_ty) == TypeKind :: Vector {
339
+ let element_count = cx. vector_length ( rust_ty) ;
340
+ let element_ty = cx. element_type ( rust_ty) ;
341
+ let element_size_bits = match cx. type_kind ( element_ty) {
342
+ TypeKind :: Half => 16 ,
343
+ TypeKind :: Float => 32 ,
344
+ TypeKind :: Double => 64 ,
345
+ TypeKind :: FP128 => 128 ,
346
+ TypeKind :: Integer => cx. int_width ( element_ty) ,
347
+ TypeKind :: Pointer => cx. int_width ( cx. isize_ty ) ,
348
+ _ => bug ! (
349
+ "Vector element type `{element_ty:?}` not one of integer, float or pointer"
350
+ ) ,
351
+ } ;
352
+ element_size_bits * element_count as u64 == 8192
353
+ } else {
354
+ false
355
+ }
356
+ }
357
+ TypeKind :: BFloat => rust_ty == cx. type_i16 ( ) ,
358
+ TypeKind :: Vector => {
359
+ let element_count = cx. vector_length ( llvm_ty) ;
360
+ let element_ty = cx. element_type ( llvm_ty) ;
361
+
362
+ if element_ty == cx. type_bf16 ( ) {
363
+ rust_ty == cx. type_vector ( cx. type_i16 ( ) , element_count as u64 )
364
+ } else {
365
+ false
366
+ }
367
+ }
368
+ _ => false ,
369
+ }
370
+ }
371
+
372
+ macro_rules! error_exit {
373
+ ( $( $t: tt) * ) => {
374
+ eprintln!( $( $t) * ) ;
375
+ :: std:: process:: exit( 101 )
376
+ } ;
377
+ }
378
+
331
379
impl < ' ll , ' tcx > FnAbiLlvmExt < ' ll , ' tcx > for FnAbi < ' tcx , Ty < ' tcx > > {
332
380
fn llvm_return_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type {
333
381
match & self . ret . mode {
@@ -405,78 +453,58 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
405
453
let actual_return_ty = self . llvm_return_type ( cx) ;
406
454
let actual_argument_tys = self . llvm_argument_types ( cx) ;
407
455
408
- if name. starts_with ( b"llvm." )
409
- && let Some ( ( intrinsic, type_params) ) = cx. parse_intrinsic_name ( name)
410
- {
411
- let fn_ty = cx. intrinsic_type ( intrinsic, & type_params) ;
456
+ let is_llvm_intrinsic = name. starts_with ( b"llvm." ) ;
412
457
458
+ if is_llvm_intrinsic && let Some ( ( fn_ty, _) ) = cx. get_intrinsic_from_name ( name) {
413
459
let expected_return_ty = cx. get_return_type ( fn_ty) ;
414
460
let expected_argument_tys = cx. func_params_types ( fn_ty) ;
415
461
416
- let equate_ty = |rust_ty, llvm_ty| {
417
- if rust_ty == llvm_ty {
418
- return true ;
419
- }
420
- match cx. type_kind ( llvm_ty) {
421
- TypeKind :: X86_AMX => {
422
- // we will insert casts from/to x86amx in callsite, so this is fine
423
- if cx. type_kind ( rust_ty) == TypeKind :: Vector {
424
- let element_count = cx. vector_length ( rust_ty) ;
425
- let element_ty = cx. element_type ( rust_ty) ;
426
- let element_size_bits = match cx. type_kind ( element_ty) {
427
- TypeKind :: Half => 16 ,
428
- TypeKind :: Float => 32 ,
429
- TypeKind :: Double => 64 ,
430
- TypeKind :: FP128 => 128 ,
431
- TypeKind :: Integer => cx. int_width ( element_ty) ,
432
- TypeKind :: Pointer => cx. int_width ( cx. isize_ty ) ,
433
- _ => bug ! (
434
- "Vector element type `{element_ty:?}` not one of integer, float or pointer"
435
- ) ,
436
- } ;
437
- element_size_bits * element_count as u64 == 8192
438
- } else {
439
- false
440
- }
441
- }
442
- TypeKind :: BFloat => rust_ty == cx. type_i16 ( ) ,
443
- TypeKind :: Vector => {
444
- let element_count = cx. vector_length ( llvm_ty) ;
445
- let element_ty = cx. element_type ( llvm_ty) ;
446
-
447
- if element_ty == cx. type_bf16 ( ) {
448
- rust_ty == cx. type_vector ( cx. type_i16 ( ) , element_count as u64 )
449
- } else {
450
- false
451
- }
452
- }
453
- _ => false ,
454
- }
455
- } ;
456
-
457
462
if actual_argument_tys. len ( ) != expected_argument_tys. len ( ) {
458
- todo ! ( "A very friendly error msg" )
463
+ error_exit ! (
464
+ "Intrinsic signature mismatch: expected {} arguments for `{}`, found {} arguments" ,
465
+ expected_argument_tys. len( ) ,
466
+ str :: from_utf8( name) . unwrap( ) ,
467
+ actual_argument_tys. len( ) ,
468
+ ) ;
459
469
}
460
470
461
- if !equate_ty ( actual_return_ty, expected_return_ty) {
462
- todo ! ( "A very friendly error msg" )
471
+ if !equate_ty ( cx, actual_return_ty, expected_return_ty) {
472
+ error_exit ! (
473
+ "Intrinsic signature mismatch: expected {expected_return_ty:?} as return type for `{}`, found {actual_return_ty:?}" ,
474
+ str :: from_utf8( name) . unwrap( )
475
+ ) ;
463
476
}
464
- for ( actual_argument_ty, expected_argument_ty) in
465
- zip ( actual_argument_tys, expected_argument_tys)
477
+ for ( idx , ( actual_argument_ty, expected_argument_ty) ) in
478
+ zip ( actual_argument_tys, expected_argument_tys) . enumerate ( )
466
479
{
467
- if !equate_ty ( actual_argument_ty, expected_argument_ty) {
468
- todo ! ( "A very friendly error msg" )
480
+ if !equate_ty ( cx, actual_argument_ty, expected_argument_ty) {
481
+ error_exit ! (
482
+ "Intrinsic signature mismatch: expected {expected_argument_ty:?} as argument {idx} for `{}`, found {actual_argument_ty:?}" ,
483
+ str :: from_utf8( name) . unwrap( )
484
+ ) ;
469
485
}
470
486
}
471
487
472
- fn_ty
488
+ // todo: check validity of the intrinsic via `getIntrinsicSignature`
489
+ return fn_ty;
490
+ }
491
+
492
+ let llfn = if self . c_variadic {
493
+ cx. type_variadic_func ( & actual_argument_tys, actual_return_ty)
473
494
} else {
474
- if self . c_variadic {
475
- cx. type_variadic_func ( & actual_argument_tys, actual_return_ty)
476
- } else {
477
- cx. type_func ( & actual_argument_tys, actual_return_ty)
478
- }
495
+ cx. type_func ( & actual_argument_tys, actual_return_ty)
496
+ } ;
497
+
498
+ if is_llvm_intrinsic {
499
+ // either the intrinsic is invalid or it needs to be upgraded
500
+ tracing:: error!(
501
+ "Using invalid or upgradeable intrinsic `{}`" ,
502
+ str :: from_utf8( name) . unwrap( )
503
+ ) ;
504
+ // todo: check if it's upgradeable, otherwise error
479
505
}
506
+
507
+ llfn
480
508
}
481
509
482
510
fn ptr_to_llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type {
0 commit comments