#include "llvm/Pass.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Instruction.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/IRBuilder.h"
#include <vector>

using namespace llvm;

namespace {
struct SimpleStrengthReductionPass : public ModulePass {
  std::unordered_map<Value *, Value *> VariablesMap;
  Value *PHIValueTrue;
  Value *PHIValueFalse;
  std::vector<Instruction *> InstructionsToCopy;

  static char ID;
  SimpleStrengthReductionPass() : ModulePass(ID) {}
  std::vector<Instruction *> InstructionsToRemove;

  bool isConstant(Value *Operand) {
    return isa<ConstantInt>(Operand);
  }

  int getConstValue(Value *Value) {
    ConstantInt *ConstInt = dyn_cast<ConstantInt>(Value);
    return ConstInt->getSExtValue();
  }

  bool isPowerOfTwo(int value) {
    return (value & std::numeric_limits<int>::max()) == value;
  }

  int powerOfTwo(int value) {
    int mask = 1, power = 0;

    while (1) {
      if ((value & mask) != 0) {
        return power;
      }
      power++;
      mask <<= 1;
    }
  }

  BasicBlock *createTrueBasicBlock(BasicBlock &Current, Value *Variable) {
    Function *F = Current.getParent();
    IRBuilder<> Builder(F->getContext());
    BasicBlock *TrueBasicBlock = BasicBlock::Create(F->getContext(), "", F, nullptr);

    Builder.SetInsertPoint(TrueBasicBlock);
    Builder.CreateLoad(Type::getInt32Ty(F->getContext()), Variable);
    Builder.CreateRet(ConstantInt::get(Type::getInt32Ty(F->getContext()), 0));

    return TrueBasicBlock;
  }

  BasicBlock *createFalseBasicBlock(BasicBlock &Current, Value *Variable) {
    Function *F = Current.getParent();
    IRBuilder<> Builder(F->getContext());
    BasicBlock *FalseBasicBlock = BasicBlock::Create(F->getContext(), "", F, nullptr);

    Builder.SetInsertPoint(FalseBasicBlock);
    Instruction *Tmp = (Instruction *) Builder.CreateLoad(Type::getInt32Ty(F->getContext()), Variable);
    Builder.CreateSub(ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), Tmp);
    Builder.CreateRet(ConstantInt::get(Type::getInt32Ty(F->getContext()), 0));

    return FalseBasicBlock;
  }

  BasicBlock *createUpdateBasicBlock(BasicBlock &Current, Value ) {
    Function *F = Current.getParent();
    IRBuilder<> Builder(F->getContext());
    BasicBlock *FalseBasicBlock = BasicBlock::Create(F->getContext(), "", F, nullptr);

    Builder.SetInsertPoint(FalseBasicBlock);
    Instruction *Tmp = (Instruction *) Builder.CreateLoad(Type::getInt32Ty(F->getContext()), Variable);
    Builder.CreateSub(ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), Tmp);
    Builder.CreateRet(ConstantInt::get(Type::getInt32Ty(F->getContext()), 0));

    return FalseBasicBlock;
  }


  bool runOnModule(Module &M) override {
    for (Function &F : M) {
      for (BasicBlock &BB : F) {
        for (Instruction &Instr : BB) {
          IRBuilder<> Builder(Instr.getContext());
          if (BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(&Instr)) {
            if (MulOperator *Mul = dyn_cast<MulOperator>(BinaryOp)) {
              (void) Mul;
              Value *FirstOperand = Instr.getOperand(0);
              Value *SecondOperand = Instr.getOperand(1);
              if (isConstant(FirstOperand) && isPowerOfTwo(getConstValue(FirstOperand))) {
                Instruction *ShiftLeft = (Instruction *) Builder.CreateShl(SecondOperand, ConstantInt::get(Type::getInt32Ty(Instr.getContext()), powerOfTwo(getConstValue(FirstOperand))));
                ShiftLeft->insertAfter(&Instr);
                Instr.replaceAllUsesWith(ShiftLeft);
                InstructionsToRemove.push_back(&Instr);
              } else if (isConstant(SecondOperand) && isPowerOfTwo(getConstValue(SecondOperand))){
                Instruction *ShiftLeft = (Instruction *) Builder.CreateShl(FirstOperand, ConstantInt::get(Type::getInt32Ty(Instr.getContext()), powerOfTwo(getConstValue(SecondOperand))));
                ShiftLeft->insertAfter(&Instr);
                Instr.replaceAllUsesWith(ShiftLeft);
                InstructionsToRemove.push_back(&Instr);
              }
            }
            else if (std::string(Instr.getOpcodeName()) == "srem") {
              if (isConstant(Instr.getOperand(1)) && getConstValue(Instr.getOperand(1)) == 2) {
                Instruction *AndInstr = (Instruction *) Builder.CreateAnd(Instr.getOperand(0), ConstantInt::get(Type::getInt32Ty(Instr.getContext()), 2));
                AndInstr->insertAfter(&Instr);
                Instr.replaceAllUsesWith(AndInstr);
                InstructionsToRemove.push_back(&Instr);
              }
            }
          }
          else if (CallInst *CallInstr = dyn_cast<CallInst>(&Instr)) {
            if (CallInstr->getCalledFunction()->getName().str() == "abs") {
              IRBuilder<> Builder(Instr.getContext());
              Value *Variable = CallInstr->getOperand(0);
              Value *Pointer = Instr.getPrevNonDebugInstruction(true)->getOperand(0);
              Instruction *Cmp = (Instruction *) Builder.CreateICmpSGE(Variable, ConstantInt::get(Type::getInt32Ty(Instr.getContext()), 0));
              Cmp->insertAfter(&Instr);

              BasicBlock *TrueBasicBlock = createTrueBasicBlock(BB, Pointer);
              BasicBlock *FalseBasicBlock = createFalseBasicBlock(BB, Pointer);

              BasicBlock
              BranchInst *Br = Builder.CreateCondBr(Cmp, TrueBasicBlock, FalseBasicBlock);
              Br->insertAfter(Cmp);
            }
          }
        }
      }
    }

    for (Instruction *Instr : InstructionsToRemove) {
      Instr->eraseFromParent();
    }

    return true;
  }
};
}

char SimpleStrengthReductionPass::ID = 0;
static RegisterPass<SimpleStrengthReductionPass> X("simple-strength-reduction", "Simple strength reduction pass",
                             false /* Only looks at CFG */,
                             false /* Analysis Pass */);