#include "llvm/IR/Instruction.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopUnrollAnalyzer.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/IRBuilder.h"
#include <algorithm>

using namespace llvm;

namespace {
struct OurLoopUnswitchingPass : public LoopPass {
  std::vector<BasicBlock *> LoopBasicBlocks;
  std::unordered_map<Value *, Value *> VariablesMap;

  static char ID; // Pass identification, replacement for typeid
  OurLoopUnswitchingPass() : LoopPass(ID) {}

  void mapVariables(Loop *L)
  {
    Function *F = L->getHeader()->getParent();
    for (BasicBlock &BB : *F) {
      for (Instruction &I : BB) {
        if (isa<LoadInst>(&I)) {
          VariablesMap[&I] = I.getOperand(0);
        }
      }
    }
  }

  BasicBlock *copyBasicBlock(BasicBlock *BlockToCopy)
  {
    std::unordered_map<Value *, Value *>Mapping;
    Instruction *Clone;
    BasicBlock *NewBasicBlock =
        BasicBlock::Create(BlockToCopy->getContext(), "", BlockToCopy->getParent(), LoopBasicBlocks.front());

    IRBuilder<>Builder(NewBasicBlock->getContext());
    Builder.SetInsertPoint(NewBasicBlock);

    for (Instruction &I : *BlockToCopy) {
      Clone = I.clone();
      Builder.Insert(Clone);
      Mapping[&I] = Clone;

      for (size_t i = 0; i < Clone->getNumOperands(); i++) {
        if (Mapping.find(Clone->getOperand(i)) != Mapping.end()) {
          Clone->setOperand(i, Mapping[Clone->getOperand(i)]);
        }
      }
    }

    return NewBasicBlock;
  }

  BasicBlock *findCompareBasicBlock(std::vector<BasicBlock *>LoopBasicBlocks)
  {
    for (size_t i = 1; i < LoopBasicBlocks.size(); i++) {
      for (Instruction &I : *LoopBasicBlocks[i]) {
        if (isa<CmpInst>(&I)) {
         return LoopBasicBlocks[i];
        }
      }
    }

    return nullptr;
  }

  BasicBlock *CopyLoop(Loop *L)
  {
    std::unordered_map<Value *, Value *>Mapping;
    std::unordered_map<BasicBlock *, BasicBlock *>BlocksMapping;
    Instruction *Clone;
    BasicBlock *NewBlock;
    std::vector<BasicBlock *> LoopBasicBlocksCopy;
    IRBuilder<>Builder(L->getExitBlock()->getParent()->getContext());

    int block = 0;
    for (BasicBlock *BB : LoopBasicBlocks) {
      NewBlock = BasicBlock::Create(BB->getContext(), "", BB->getParent(), L->getExitBlock());
      LoopBasicBlocksCopy.push_back(NewBlock);
      BlocksMapping[BB] = NewBlock;
    }

    for (size_t i = 0; i < LoopBasicBlocks.size(); i++) {
      Builder.SetInsertPoint(LoopBasicBlocksCopy[i]);

      for (Instruction &I : *LoopBasicBlocks[i]) {
        Clone = I.clone();
        Mapping[&I] = Clone;
        Builder.Insert(Clone);

        for (size_t j = 0; j < Clone->getNumOperands(); j++) {
          if (Mapping.find(Clone->getOperand(j)) != Mapping.end()) {
            Clone->setOperand(j, Mapping[Clone->getOperand(j)]);
          }
        }
      }
    }

    for (BasicBlock *Copy : LoopBasicBlocksCopy) {
      for (size_t i = 0; i < Copy->getTerminator()->getNumSuccessors(); i++) {
        if (BlocksMapping.find(Copy->getTerminator()->getSuccessor(i)) != BlocksMapping.end()) {
          Copy->getTerminator()->setSuccessor(i, BlocksMapping[Copy->getTerminator()->getSuccessor(i)]);
        }
      }
    }

    BasicBlock *CompareBasicBlock = findCompareBasicBlock(LoopBasicBlocksCopy);
    LoopBasicBlocksCopy.front()->getTerminator()->setSuccessor(0, CompareBasicBlock->getTerminator()->getSuccessor(1));
    CompareBasicBlock->getTerminator()->getSuccessor(0)->eraseFromParent();
    CompareBasicBlock->eraseFromParent();

    return LoopBasicBlocksCopy.front();
  }

  void unswitchLoop(Loop *L)
  {
    BasicBlock *CompareBasicBlock = findCompareBasicBlock(LoopBasicBlocks);
    BasicBlock *NewBasicBlock = copyBasicBlock(CompareBasicBlock);
    L->getLoopPreheader()->getTerminator()->setSuccessor(0, NewBasicBlock);
    NewBasicBlock->getTerminator()->setSuccessor(0, LoopBasicBlocks.front());
    NewBasicBlock->getTerminator()->setSuccessor(1, L->getExitBlock());

    BasicBlock *CopyHeader = CopyLoop(L);
    NewBasicBlock->getTerminator()->setSuccessor(1, CopyHeader);

    LoopBasicBlocks.front()->getTerminator()->setSuccessor(0, CompareBasicBlock->getTerminator()->getSuccessor(0));
    CompareBasicBlock->getTerminator()->getSuccessor(1)->eraseFromParent();
    CompareBasicBlock->eraseFromParent();
  }

  bool canUnswitch()
  {
    BasicBlock *CompareBasicBlock = findCompareBasicBlock(LoopBasicBlocks);
    Value *Var1, *Var2;

    for (Instruction &I : *CompareBasicBlock) {
      if (isa<CmpInst>(&I)) {
        Var1 = VariablesMap[I.getOperand(0)];
        Var2 = VariablesMap[I.getOperand(1)];
      }
    }


    for (BasicBlock *BB : LoopBasicBlocks) {
      for (Instruction &I : *BB) {
        if (isa<StoreInst>(&I)) {
          if (Var1 != nullptr && I.getOperand(1) == Var1) {
            return false;
          }

          if (Var2 != nullptr && I.getOperand(1) == Var2) {
            return false;
          }
        }
      }
    }

    return true;
  }

  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    mapVariables(L);
    LoopBasicBlocks = L->getBlocksVector();
    if (canUnswitch()) {
      unswitchLoop(L);
    }

    return true;
  }
}; // end of struct OurLoopUnswitchingPass
}  // end of anonymous namespace

char OurLoopUnswitchingPass::ID = 0;
static RegisterPass<OurLoopUnswitchingPass> X("loop-unswitching", "",
                                            false /* Only looks at CFG */,
                                            false /* Analysis Pass */);