@@ -299,7 +299,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
299
299
ml::GEPOp CreateGep (TypedValue<mlir::RankedTensorType> tensor,
300
300
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
301
301
Type element_type = tensor.getType ().getElementType ();
302
- if (element_type == b.getI4Type ()) {
302
+ if (element_type.isIntOrFloat () &&
303
+ element_type.getIntOrFloatBitWidth () == 4 ) {
303
304
element_type = b.getI8Type ();
304
305
}
305
306
auto ptr = ml::LLVMPointerType::get (b.getContext ());
@@ -328,7 +329,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
328
329
auto linear_index = GetLinearIndex (op.getIndices (), b);
329
330
Type element_type = op.getTensor ().getType ().getElementType ();
330
331
Value is_low_nibble = nullptr ;
331
- if (element_type == rewriter.getI4Type ()) {
332
+ if (element_type.isIntOrFloat () &&
333
+ element_type.getIntOrFloatBitWidth () == 4 ) {
332
334
std::tie (linear_index, is_low_nibble) =
333
335
GetI4IndexAndNibble (linear_index, b);
334
336
}
@@ -342,7 +344,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
342
344
auto high_value = b.create <mlir::arith::ShRUIOp>(
343
345
load, b.create <mlir::arith::ConstantIntOp>(4 , load.getType ()));
344
346
load = b.create <mlir::arith::TruncIOp>(
345
- op. getType (),
347
+ rewriter. getI4Type (),
346
348
b.create <mlir::arith::SelectOp>(is_low_nibble, load, high_value));
347
349
}
348
350
@@ -378,6 +380,7 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
378
380
379
381
auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
380
382
op.getSource ());
383
+ mlir::Type source_element_type = source.getType ().getElementType ();
381
384
382
385
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
383
386
auto linear_index = GetLinearIndex (op.getIndices (), b);
@@ -386,7 +389,9 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
386
389
if (vector_type.getElementType ().isInteger (1 )) {
387
390
vector_type = vector_type.cloneWith (std::nullopt, b.getI8Type ());
388
391
}
389
- if (op.getVectorType ().getElementType ().isInteger (4 )) {
392
+ mlir::Type gep_element_type = vector_type.getElementType ();
393
+ if (gep_element_type.isIntOrFloat () &&
394
+ gep_element_type.getIntOrFloatBitWidth () == 4 ) {
390
395
linear_index = b.create <arith::ShRUIOp>(
391
396
linear_index,
392
397
b.create <arith::ConstantIntOp>(1 , linear_index.getType ()));
@@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
397
402
auto llvm_vector_type = converter.convertType (vector_type);
398
403
auto loaded = b.create <ml::LoadOp>(llvm_vector_type, gep).getResult ();
399
404
400
- if (source. getType (). getElementType () .isInteger (1 )) {
405
+ if (source_element_type .isInteger (1 )) {
401
406
Value zero = b.create <mlir::arith::ConstantOp>(
402
407
mlir::DenseElementsAttr::get (vector_type, b.getI8IntegerAttr (0 )));
403
408
loaded = b.create <arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
404
- } else if (source.getType ().getElementType ().isInteger (4 )) {
409
+ } else if (source_element_type.isIntOrFloat () &&
410
+ source_element_type.getIntOrFloatBitWidth () == 4 ) {
405
411
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
406
412
// elements.
407
413
loaded = PermutePairsInVector (loaded, b);
@@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
430
436
auto scalar_value = op.getScalar ();
431
437
432
438
// For i4 we store 2 values into one byte. This needs special handling here.
433
- if (tensor_dest.getType ().getElementType () == rewriter.getI4Type ()) {
439
+ if (tensor_dest.getType ().getElementType ().isIntOrFloat () &&
440
+ tensor_dest.getType ().getElementType ().getIntOrFloatBitWidth () == 4 ) {
434
441
// We need to use directly op.getDest() as input, otherwise the following
435
442
// rewrite might remove the only user of it.
436
443
tensor_dest = op.getDest ();
@@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
448
455
auto tensor_dest_i8 =
449
456
b.create <UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
450
457
.getResult (0 );
458
+ if (scalar_value.getType () != rewriter.getI4Type ()) {
459
+ scalar_value =
460
+ b.create <arith::BitcastOp>(rewriter.getI4Type (), scalar_value);
461
+ }
451
462
scalar_value = b.create <mlir::arith::ExtUIOp>(ty, scalar_value);
452
463
453
464
// We need AtomicRMWOp because it can happen that different threads try to
@@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
507
518
auto linear_index = GetLinearIndex (op.getIndices (), b);
508
519
509
520
mlir::Value vector_value = op.getVector ();
510
- if (op.getVectorType ().getElementType ().isInteger (1 )) {
521
+ mlir::Type vector_element_type = op.getVectorType ().getElementType ();
522
+ if (vector_element_type.isInteger (1 )) {
511
523
vector_value = b.create <arith::ExtUIOp>(
512
524
op.getVectorType ().cloneWith (std::nullopt, b.getI8Type ()),
513
525
vector_value);
514
526
}
515
- if (op.getVectorType ().getElementType ().isInteger (4 )) {
527
+ if (vector_element_type.isIntOrFloat () &&
528
+ vector_element_type.getIntOrFloatBitWidth () == 4 ) {
516
529
linear_index = b.create <arith::ShRUIOp>(
517
530
linear_index,
518
531
b.create <arith::ConstantIntOp>(1 , linear_index.getType ()));
@@ -575,21 +588,19 @@ ml::GlobalOp CreateGlobalOp(mlir::Attribute value,
575
588
// Needed to support complex element type.
576
589
mlir::LLVMTypeConverter converter (b.getContext ());
577
590
auto llvm_element_type = converter.convertType (element_type);
578
- if (mlir::isa<mlir::IntegerType>(element_type)) {
579
- int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth ();
580
- if (bit_width == 4 ) {
581
- num_elements = CeilOfRatio<int64_t >(num_elements, 2 );
582
- llvm_element_type = b.getI8Type ();
583
- auto unpacked_data =
584
- mlir::cast<mlir::DenseElementsAttr>(value).getRawData ();
585
- std::vector<char > packed_data (num_elements);
586
- absl::Span<char > packed_data_span =
587
- absl::MakeSpan (packed_data.data (), packed_data.size ());
588
- PackIntN (4 , unpacked_data, packed_data_span);
589
- value = mlir::DenseElementsAttr::getFromRawBuffer (
590
- mlir::RankedTensorType::get ({num_elements}, llvm_element_type),
591
- packed_data);
592
- }
591
+ if (element_type.isIntOrFloat () &&
592
+ element_type.getIntOrFloatBitWidth () == 4 ) {
593
+ num_elements = CeilOfRatio<int64_t >(num_elements, 2 );
594
+ llvm_element_type = b.getI8Type ();
595
+ auto unpacked_data =
596
+ mlir::cast<mlir::DenseElementsAttr>(value).getRawData ();
597
+ std::vector<char > packed_data (num_elements);
598
+ absl::Span<char > packed_data_span =
599
+ absl::MakeSpan (packed_data.data (), packed_data.size ());
600
+ PackIntN (4 , unpacked_data, packed_data_span);
601
+ value = mlir::DenseElementsAttr::getFromRawBuffer (
602
+ mlir::RankedTensorType::get ({num_elements}, llvm_element_type),
603
+ packed_data);
593
604
}
594
605
auto array_ty = ml::LLVMArrayType::get (llvm_element_type, num_elements);
595
606
std::string name;
0 commit comments