|
| 1 | +#include "X86.h" |
| 2 | +#include "X86InstrInfo.h" |
| 3 | +#include "llvm/CodeGen/MachineBasicBlock.h" |
| 4 | +#include "llvm/CodeGen/MachineFunctionPass.h" |
| 5 | +#include "llvm/CodeGen/MachineInstr.h" |
| 6 | +#include "llvm/CodeGen/MachineInstrBuilder.h" |
| 7 | +#include "llvm/CodeGen/Register.h" |
| 8 | +#include <utility> |
| 9 | +#include <vector> |
| 10 | + |
| 11 | +using namespace llvm; |
| 12 | + |
| 13 | +namespace { |
| 14 | +class X86SadikovPassFMA : public MachineFunctionPass { |
| 15 | +public: |
| 16 | + static char ID; |
| 17 | + |
| 18 | + X86SadikovPassFMA() : MachineFunctionPass(ID) {} |
| 19 | + |
| 20 | + bool runOnMachineFunction(MachineFunction &MF) override { |
| 21 | + bool instructions_changed = false; |
| 22 | + |
| 23 | + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); |
| 24 | + |
| 25 | + std::vector<std::pair<MachineInstr *, MachineInstr *>> pairs; |
| 26 | + |
| 27 | + for (MachineBasicBlock &BB : MF) { |
| 28 | + for (auto MI = BB.begin(); MI != BB.end(); ++MI) { |
| 29 | + MachineInstr *mulInstr = &(*MI); |
| 30 | + if (mulInstr->getOpcode() == X86::MULPDrr || |
| 31 | + mulInstr->getOpcode() == X86::MULPDrm) { |
| 32 | + MachineInstr *addInstr = nullptr; |
| 33 | + for (auto MI_2 = std::next(MI); MI_2 != BB.end(); ++MI_2) { |
| 34 | + if (MI_2->readsRegister(mulInstr->getOperand(0).getReg())) { |
| 35 | + if ((MI_2->getOpcode() == X86::ADDPDrr || |
| 36 | + MI_2->getOpcode() == X86::ADDPDrm) && |
| 37 | + addInstr == nullptr && |
| 38 | + mulInstr->getOperand(0).getReg() == |
| 39 | + MI_2->getOperand(1).getReg()) { |
| 40 | + addInstr = &(*MI_2); |
| 41 | + } else { |
| 42 | + addInstr = nullptr; |
| 43 | + break; |
| 44 | + } |
| 45 | + } |
| 46 | + } |
| 47 | + if (addInstr) { |
| 48 | + if (addInstr->getOperand(1).getReg() != |
| 49 | + addInstr->getOperand(2).getReg()) { |
| 50 | + pairs.emplace_back(mulInstr, addInstr); |
| 51 | + } |
| 52 | + } |
| 53 | + } |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + for (auto [mulInstr, addInstr] : pairs) { |
| 58 | + MIMetadata MIMD(*mulInstr); |
| 59 | + MachineBasicBlock &BB = *mulInstr->getParent(); |
| 60 | + |
| 61 | + BuildMI(BB, mulInstr, MIMD, TII->get(X86::VFMADD213PDr), |
| 62 | + addInstr->getOperand(0).getReg()) |
| 63 | + .addReg(mulInstr->getOperand(1).getReg()) |
| 64 | + .addReg(mulInstr->getOperand(2).getReg()) |
| 65 | + .addReg(addInstr->getOperand(2).getReg()); |
| 66 | + |
| 67 | + mulInstr->eraseFromParent(); |
| 68 | + addInstr->eraseFromParent(); |
| 69 | + |
| 70 | + instructions_changed = true; |
| 71 | + } |
| 72 | + |
| 73 | + return instructions_changed; |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +char X86SadikovPassFMA::ID = 0; |
| 78 | + |
| 79 | +} // namespace |
| 80 | + |
| 81 | +static RegisterPass<X86SadikovPassFMA> X("x86-sadikov-pass-fma", |
| 82 | + "X86 Sadikov Pass FMA", false, false); |
0 commit comments