|
62 | 62 | #include "llvm/Support/Casting.h"
|
63 | 63 | #include "llvm/Support/ErrorHandling.h"
|
64 | 64 | #include <cstdint>
|
| 65 | +#include <deque> |
65 | 66 | #include <optional>
|
| 67 | +#include <set> |
66 | 68 |
|
67 | 69 | using namespace cir;
|
68 | 70 | using namespace llvm;
|
@@ -561,13 +563,25 @@ class CIRBrCondOpLowering
|
561 | 563 | mlir::LogicalResult
|
562 | 564 | matchAndRewrite(mlir::cir::BrCondOp brOp, OpAdaptor adaptor,
|
563 | 565 | mlir::ConversionPatternRewriter &rewriter) const override {
|
564 |
| - auto condition = adaptor.getCond(); |
565 |
| - auto i1Condition = rewriter.create<mlir::LLVM::TruncOp>( |
566 |
| - brOp.getLoc(), rewriter.getI1Type(), condition); |
| 566 | + mlir::Value i1Condition; |
| 567 | + |
| 568 | + if (auto defOp = adaptor.getCond().getDefiningOp()) { |
| 569 | + if (auto zext = dyn_cast<mlir::LLVM::ZExtOp>(defOp)) { |
| 570 | + if (zext->use_empty() && |
| 571 | + zext->getOperand(0).getType() == rewriter.getI1Type()) { |
| 572 | + i1Condition = zext->getOperand(0); |
| 573 | + rewriter.eraseOp(zext); |
| 574 | + } |
| 575 | + } |
| 576 | + } |
| 577 | + |
| 578 | + if (!i1Condition) |
| 579 | + i1Condition = rewriter.create<mlir::LLVM::TruncOp>( |
| 580 | + brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond()); |
| 581 | + |
567 | 582 | rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
|
568 |
| - brOp, i1Condition.getResult(), brOp.getDestTrue(), |
569 |
| - adaptor.getDestOperandsTrue(), brOp.getDestFalse(), |
570 |
| - adaptor.getDestOperandsFalse()); |
| 583 | + brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), |
| 584 | + brOp.getDestFalse(), adaptor.getDestOperandsFalse()); |
571 | 585 |
|
572 | 586 | return mlir::success();
|
573 | 587 | }
|
@@ -771,90 +785,6 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
|
771 | 785 | }
|
772 | 786 | };
|
773 | 787 |
|
774 |
| -class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> { |
775 |
| -public: |
776 |
| - using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern; |
777 |
| - |
778 |
| - mlir::LogicalResult |
779 |
| - matchAndRewrite(mlir::cir::IfOp ifOp, OpAdaptor adaptor, |
780 |
| - mlir::ConversionPatternRewriter &rewriter) const override { |
781 |
| - mlir::OpBuilder::InsertionGuard guard(rewriter); |
782 |
| - auto loc = ifOp.getLoc(); |
783 |
| - auto emptyElse = ifOp.getElseRegion().empty(); |
784 |
| - |
785 |
| - auto *currentBlock = rewriter.getInsertionBlock(); |
786 |
| - auto *remainingOpsBlock = |
787 |
| - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
788 |
| - mlir::Block *continueBlock; |
789 |
| - if (ifOp->getResults().size() == 0) |
790 |
| - continueBlock = remainingOpsBlock; |
791 |
| - else |
792 |
| - llvm_unreachable("NYI"); |
793 |
| - |
794 |
| - // Inline then region |
795 |
| - auto *thenBeforeBody = &ifOp.getThenRegion().front(); |
796 |
| - auto *thenAfterBody = &ifOp.getThenRegion().back(); |
797 |
| - rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock); |
798 |
| - |
799 |
| - rewriter.setInsertionPointToEnd(thenAfterBody); |
800 |
| - if (auto thenYieldOp = |
801 |
| - dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) { |
802 |
| - rewriter.replaceOpWithNewOp<mlir::cir::BrOp>( |
803 |
| - thenYieldOp, thenYieldOp.getArgs(), continueBlock); |
804 |
| - } |
805 |
| - |
806 |
| - rewriter.setInsertionPointToEnd(continueBlock); |
807 |
| - |
808 |
| - // Has else region: inline it. |
809 |
| - mlir::Block *elseBeforeBody = nullptr; |
810 |
| - mlir::Block *elseAfterBody = nullptr; |
811 |
| - if (!emptyElse) { |
812 |
| - elseBeforeBody = &ifOp.getElseRegion().front(); |
813 |
| - elseAfterBody = &ifOp.getElseRegion().back(); |
814 |
| - rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody); |
815 |
| - } else { |
816 |
| - elseBeforeBody = elseAfterBody = continueBlock; |
817 |
| - } |
818 |
| - |
819 |
| - rewriter.setInsertionPointToEnd(currentBlock); |
820 |
| - |
821 |
| - // FIXME: CIR always lowers !cir.bool to i8 type. |
822 |
| - // In this reason CIR CodeGen often emits the redundant zext + trunc |
823 |
| - // sequence that prevents lowering of llvm.expect in |
824 |
| - // LowerExpectIntrinsicPass. |
825 |
| - // We should fix that in a more appropriate way. But as a temporary solution |
826 |
| - // just avoid the redundant casts here. |
827 |
| - mlir::Value condition; |
828 |
| - auto zext = |
829 |
| - dyn_cast<mlir::LLVM::ZExtOp>(adaptor.getCondition().getDefiningOp()); |
830 |
| - if (zext && zext->getOperand(0).getType() == rewriter.getI1Type()) { |
831 |
| - condition = zext->getOperand(0); |
832 |
| - if (zext->use_empty()) |
833 |
| - rewriter.eraseOp(zext); |
834 |
| - } else { |
835 |
| - auto trunc = rewriter.create<mlir::LLVM::TruncOp>( |
836 |
| - loc, rewriter.getI1Type(), adaptor.getCondition()); |
837 |
| - condition = trunc.getRes(); |
838 |
| - } |
839 |
| - |
840 |
| - rewriter.create<mlir::LLVM::CondBrOp>(loc, condition, thenBeforeBody, |
841 |
| - elseBeforeBody); |
842 |
| - |
843 |
| - if (!emptyElse) { |
844 |
| - rewriter.setInsertionPointToEnd(elseAfterBody); |
845 |
| - if (auto elseYieldOp = |
846 |
| - dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) { |
847 |
| - rewriter.replaceOpWithNewOp<mlir::cir::BrOp>( |
848 |
| - elseYieldOp, elseYieldOp.getArgs(), continueBlock); |
849 |
| - } |
850 |
| - } |
851 |
| - |
852 |
| - rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
853 |
| - |
854 |
| - return mlir::success(); |
855 |
| - } |
856 |
| -}; |
857 |
| - |
858 | 788 | class CIRScopeOpLowering
|
859 | 789 | : public mlir::OpConversionPattern<mlir::cir::ScopeOp> {
|
860 | 790 | public:
|
@@ -937,9 +867,7 @@ struct ConvertCIRToLLVMPass
|
937 | 867 | }
|
938 | 868 | void runOnOperation() final;
|
939 | 869 |
|
940 |
| - virtual StringRef getArgument() const override { |
941 |
| - return "cir-to-llvm-internal"; |
942 |
| - } |
| 870 | + virtual StringRef getArgument() const override { return "cir-flat-to-llvm"; } |
943 | 871 | };
|
944 | 872 |
|
945 | 873 | class CIRCallLowering : public mlir::OpConversionPattern<mlir::cir::CallOp> {
|
@@ -3081,7 +3009,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
|
3081 | 3009 | CIRLoopOpInterfaceLowering, CIRBrCondOpLowering, CIRPtrStrideOpLowering,
|
3082 | 3010 | CIRCallLowering, CIRUnaryOpLowering, CIRBinOpLowering, CIRShiftOpLowering,
|
3083 | 3011 | CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering,
|
3084 |
| - CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering, CIRIfLowering, |
| 3012 | + CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering, |
3085 | 3013 | CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
|
3086 | 3014 | CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
|
3087 | 3015 | CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering,
|
@@ -3241,6 +3169,64 @@ static void buildCtorDtorList(
|
3241 | 3169 | builder.create<mlir::LLVM::ReturnOp>(loc, result);
|
3242 | 3170 | }
|
3243 | 3171 |
|
| 3172 | +// The unreachable code is not lowered by applyPartialConversion function |
| 3173 | +// since it traverses blocks in the dominance order. At the same time we |
| 3174 | +// do need to lower such code - otherwise verification errors occur. |
| 3175 | +// For instance, the next CIR code: |
| 3176 | +// |
| 3177 | +// cir.func @foo(%arg0: !s32i) -> !s32i { |
| 3178 | +// %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool |
| 3179 | +// cir.if %4 { |
| 3180 | +// %5 = cir.const(#cir.int<1> : !s32i) : !s32i |
| 3181 | +// cir.return %5 : !s32i |
| 3182 | +// } else { |
| 3183 | +// %5 = cir.const(#cir.int<0> : !s32i) : !s32i |
| 3184 | +// cir.return %5 : !s32i |
| 3185 | +// } |
| 3186 | +// cir.return %arg0 : !s32i |
| 3187 | +// } |
| 3188 | +// |
| 3189 | +// contains an unreachable return operation (the last one). After the flattening |
| 3190 | +// pass it will be placed into the unreachable block. And the possible error |
| 3191 | +// after the lowering pass is: error: 'cir.return' op expects parent op to be |
| 3192 | +// one of 'cir.func, cir.scope, cir.if ... The reason that this operation was |
| 3193 | +// not lowered and the new parent is lllvm.func. |
| 3194 | +// |
| 3195 | +// In the future we may want to get rid of this function and use DCE pass or |
| 3196 | +// something similar. But now we need to guarantee the absence of the dialect |
| 3197 | +// verification errors. |
| 3198 | +void collect_unreachable(mlir::Operation *parent, |
| 3199 | + llvm::SmallVector<mlir::Operation *> &ops) { |
| 3200 | + |
| 3201 | + llvm::SmallVector<mlir::Block *> unreachable_blocks; |
| 3202 | + parent->walk([&](mlir::Block *blk) { // check |
| 3203 | + if (blk->hasNoPredecessors() && !blk->isEntryBlock()) |
| 3204 | + unreachable_blocks.push_back(blk); |
| 3205 | + }); |
| 3206 | + |
| 3207 | + std::set<mlir::Block *> visited; |
| 3208 | + for (auto *root : unreachable_blocks) { |
| 3209 | + // We create a work list for each unreachable block. |
| 3210 | + // Thus we traverse operations in some order. |
| 3211 | + std::deque<mlir::Block *> workList; |
| 3212 | + workList.push_back(root); |
| 3213 | + |
| 3214 | + while (!workList.empty()) { |
| 3215 | + auto *blk = workList.back(); |
| 3216 | + workList.pop_back(); |
| 3217 | + if (visited.count(blk)) |
| 3218 | + continue; |
| 3219 | + visited.emplace(blk); |
| 3220 | + |
| 3221 | + for (auto &op : *blk) |
| 3222 | + ops.push_back(&op); |
| 3223 | + |
| 3224 | + for (auto it = blk->succ_begin(); it != blk->succ_end(); ++it) |
| 3225 | + workList.push_back(*it); |
| 3226 | + } |
| 3227 | + } |
| 3228 | +} |
| 3229 | + |
3244 | 3230 | void ConvertCIRToLLVMPass::runOnOperation() {
|
3245 | 3231 | auto module = getOperation();
|
3246 | 3232 | mlir::DataLayout dataLayout(module);
|
@@ -3280,7 +3266,11 @@ void ConvertCIRToLLVMPass::runOnOperation() {
|
3280 | 3266 | getOperation()->removeAttr("cir.sob");
|
3281 | 3267 | getOperation()->removeAttr("cir.lang");
|
3282 | 3268 |
|
3283 |
| - if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
| 3269 | + llvm::SmallVector<mlir::Operation *> ops; |
| 3270 | + ops.push_back(module); |
| 3271 | + collect_unreachable(module, ops); |
| 3272 | + |
| 3273 | + if (failed(applyPartialConversion(ops, target, std::move(patterns)))) |
3284 | 3274 | signalPassFailure();
|
3285 | 3275 |
|
3286 | 3276 | // Emit the llvm.global_ctors array.
|
|
0 commit comments