#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/raw_ostream.h"
#include <vector>

using namespace llvm;

namespace {
struct OurCSEPass : public FunctionPass {
  std::unordered_map<Value *, Value *> VariablesMap;
  std::vector<Instruction *> SavedInstructions;

  static char ID;
  OurCSEPass() : FunctionPass(ID) {}

  void mapVariables(Function &F) {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (isa<LoadInst>(&I)) {
          VariablesMap[&I] = I.getOperand(0);
        }
      }
    }
  }

  bool shouldSave(Instruction *I) {
    return isa<CallInst>(I) || isa<AddOperator>(I) || isa<SubOperator>(I) ||
        isa<MulOperator>(I) || isa<SDivOperator>(I);
  }

  bool haveTheSameType(Instruction *I1, Instruction *I2) {
    return I1->getOpcode() == I2->getOpcode();
  }

  bool isConstantInt(Value *Operand) {
    return isa<ConstantInt>(Operand);
  }

  int getValue(Value *Operand) {
    ConstantInt *Const = dyn_cast<ConstantInt>(Operand);
    return Const->getSExtValue();
  }

  bool isCommutative(Instruction *I) {
    return isa<AddOperator>(I) || isa<MulOperator>(I);
  }

  bool haveTheSameOperands(Instruction *I1, Instruction *I2) {
    return areTheSameOperand(I1->getOperand(0), I2->getOperand(0)) &&
           areTheSameOperand(I1->getOperand(1), I2->getOperand(1));
  }

  bool areTheSameOperand(Value *Operand1, Value *Operand2) {
    if (isConstantInt(Operand1) && isConstantInt(Operand2)) {
      return getValue(Operand1) == getValue(Operand2);
    }

    if (isConstantInt(Operand1) || isConstantInt(Operand2)) {
      return false;
    }

    return VariablesMap[Operand1] == VariablesMap[Operand2];
  }

  bool haveTheSameOperandsCommutative(Instruction *I1, Instruction *I2) {
    Value *Operand11 = I1->getOperand(0), *Operand12 = I1->getOperand(1),
          *Operand21 = I2->getOperand(0), *Operand22 = I2->getOperand(1);

    return (areTheSameOperand(Operand11, Operand21) && areTheSameOperand(Operand12, Operand22)) ||
           (areTheSameOperand(Operand11, Operand22) && areTheSameOperand(Operand12, Operand21));
  }

  bool haveTheSameOperandsCall(Instruction *I1, Instruction *I2) {
    if (I1->getNumOperands() != I2->getNumOperands()) {
      return false;
    }

    for (size_t i = 0; i < I1->getNumOperands(); i++) {
      if (!areTheSameOperand(I1->getOperand(i), I2->getOperand(i))) {
        return false;
      }
    }

    return true;
  }

  Instruction *isAlreadySaved(Instruction *I) {
    for (Instruction *SavedInstr : SavedInstructions) {
      if (haveTheSameType(I, SavedInstr)) {
        if (CallInst *Call1 = dyn_cast<CallInst>(I)) {
          CallInst *Call2 = dyn_cast<CallInst>(SavedInstr);

          if (Call1->getCalledFunction() != Call2->getCalledFunction()) {
            return nullptr;
          }

          if (haveTheSameOperandsCall(I, SavedInstr)) {
            return SavedInstr;
          }
        }
        else if (isCommutative(I)) {
          if (haveTheSameOperandsCommutative(I, SavedInstr)) {
            return SavedInstr;
          }
        }
        else if (haveTheSameOperands(I, SavedInstr)) {
          return SavedInstr;
        }
      }
    }

    return nullptr;
  }

  void maybeRemoveFromSaved(Instruction *I) {
    std::vector<Instruction *> InstructionsToRemove;

    Value *StoreOperand = I->getOperand(1);

    for (Instruction *SavedInstr : SavedInstructions) {
      for (size_t i = 0; i < SavedInstr->getNumOperands(); i++) {
        if (VariablesMap[SavedInstr->getOperand(i)] == StoreOperand) {
          InstructionsToRemove.push_back(SavedInstr);
        }
      }
    }

    for (Instruction *InstructionToRemove : InstructionsToRemove) {
      SavedInstructions.erase(std::find(SavedInstructions.begin(), SavedInstructions.end(), InstructionToRemove));
    }
  }

  bool runOnFunction(Function &F) override {
    VariablesMap.clear();
    SavedInstructions.clear();

    mapVariables(F);

    Instruction *SavedInstruction;

    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (shouldSave(&I)) {
          if ((SavedInstruction = isAlreadySaved(&I)) != nullptr) {
            I.replaceAllUsesWith(SavedInstruction);
          }
          else {
            SavedInstructions.push_back(&I);
          }
        }
        else if (isa<StoreInst>(&I)) {
          maybeRemoveFromSaved(&I);
        }
      }
    }
    return true;
  }
}; // end of struct Hello
}  // end of anonymous namespace

char OurCSEPass::ID = 0;
static RegisterPass<OurCSEPass> X("our-cse", "Our simple implementation of CSE",
                             false /* Only looks at CFG */,
                             false /* Analysis Pass */);
