#include "llvm/Pass.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Function.h"

using namespace llvm;

namespace {
struct OurConstantFolding : public FunctionPass {
  std::vector<Instruction *> InstructionsToRemove;

  static char ID;
  OurConstantFolding() : FunctionPass(ID) {}

  void handleBinaryOperator(Instruction &Instr)
  {
    ConstantInt *Lhs, *Rhs;

    if (!(Lhs = dyn_cast<ConstantInt>(Instr.getOperand(0)))) {
      return;
    }

    if (!(Rhs = dyn_cast<ConstantInt>(Instr.getOperand(1)))) {
      return;
    }

    int LhsValue = Lhs->getSExtValue(), RhsValue = Rhs->getSExtValue(), ConstIntValue;

    if (isa<AddOperator>(&Instr)) {
      ConstIntValue = LhsValue + RhsValue;
    }
    else if (isa<MulOperator>(&Instr)) {
      ConstIntValue = LhsValue * RhsValue;
    }
    else if (isa<SDivOperator>(&Instr)) {
      if (RhsValue == 0) {
        errs() << "Division by 0!\n";
        return ;
      }
      ConstIntValue = LhsValue / RhsValue;
    }
    else if (isa<SubOperator>(&Instr)) {
      ConstIntValue = LhsValue - RhsValue;
    }
    else {
      return ;
    }

    Instr.replaceAllUsesWith(ConstantInt::get(Type::getInt32Ty(Instr.getContext()), ConstIntValue));
  }

  void handleCompare(Instruction &Instr)
  {
    ICmpInst *Cmp = dyn_cast<ICmpInst>(&Instr);
    auto Pred = Cmp->getPredicate();

    ConstantInt *Lhs, *Rhs;

    if (!(Lhs = dyn_cast<ConstantInt>(Instr.getOperand(0)))) {
      return;
    }

    if (!(Rhs = dyn_cast<ConstantInt>(Instr.getOperand(1)))) {
      return;
    }

    int LhsValue = Lhs->getSExtValue(), RhsValue = Rhs->getSExtValue();
    bool Value;

    if (Pred == CmpInst::ICMP_EQ) {
      Value = LhsValue == RhsValue;
    }
    else if (Pred == CmpInst::ICMP_NE) {
      Value = LhsValue != RhsValue;
    }
    else if (Pred == CmpInst::ICMP_SGT) {
      Value = LhsValue > RhsValue;
    }
    else if (Pred == CmpInst::ICMP_SLT) {
      Value = LhsValue < RhsValue;
    }
    else if (Pred == CmpInst::ICMP_SGE) {
      Value = LhsValue >= RhsValue;
    }
    else if (Pred == CmpInst::ICMP_SLE) {
      Value = LhsValue <= RhsValue;
    }

    ConstantInt *BooleanValue = ConstantInt::get(Type::getInt1Ty(Instr.getContext()), Value);

    Instr.replaceAllUsesWith(BooleanValue);
  }

  void handleBranch(Instruction &Instr)
  {
    BranchInst *Br = dyn_cast<BranchInst>(&Instr);

    if (Br->isConditional()) {
      ConstantInt *Condition = dyn_cast<ConstantInt>(Br->getCondition());

      if (!Condition) {
        return ;
      }

      if (Condition->getZExtValue() == 1) {
        BranchInst::Create(Br->getSuccessor(0), Br->getParent());
      }
      else {
        BranchInst::Create(Br->getSuccessor(1), Br->getParent());
      }

      InstructionsToRemove.push_back(&Instr);
    }
  }


  void iterateInstructions(Function &F)
  {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (isa<BinaryOperator>(&I)) {
          handleBinaryOperator(I);
        } else if (isa<ICmpInst>(&I)) {
          handleCompare(I);
        } else if (isa<BranchInst>(&I)) {
          handleBranch(I);
        }
      }
    }

    for (Instruction *Instr : InstructionsToRemove) {
      Instr->eraseFromParent();
    }
  }

  bool runOnFunction(Function &F) override {
    iterateInstructions(F);
    return true;
  }
}; // end of struct OurConstantFolding
}  // end of anonymous namespace

char OurConstantFolding::ID = 0;
static RegisterPass<OurConstantFolding> X("our-constant-folding", "OurConstantFolding pass",
                             false /* Only looks at CFG */,
                             false /* Analysis Pass */);