#include "llvm/IR/Instruction.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopUnrollAnalyzer.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/IRBuilder.h"
#include <algorithm>

using namespace llvm;

namespace {
struct OurLoopUnswitchingPass : public LoopPass {
  std::vector<BasicBlock *> LoopBasicBlocks;
  std::unordered_map<Value *, Value *> VariablesMap;

  static char ID; // Pass identification, replacement for typeid
  OurLoopUnswitchingPass() : LoopPass(ID) {}

  void mapVariables(Loop *L)
  {
    Function *F = L->getHeader()->getParent();
    for (BasicBlock &BB : *F) {
      for (Instruction &I : BB) {
        if (isa<LoadInst>(&I)) {
          VariablesMap[&I] = I.getOperand(0);
        }
      }
    }
  }

  BasicBlock *findCompareBasicBlock(std::vector<BasicBlock *> LoopBasicBlocks)
  {
    for (size_t i = 1; i < LoopBasicBlocks.size(); i++) {
      for (Instruction &I : *LoopBasicBlocks[i]) {
        if (isa<CmpInst>(&I)) {
          return LoopBasicBlocks[i];
        }
      }
    }

    return nullptr;
  }

  bool canUnswitch()
  {
    BasicBlock *CompareBlock = findCompareBasicBlock(LoopBasicBlocks);
    if (CompareBlock == nullptr) {
      return false;
    }

    Value *Var1, *Var2;

    for (Instruction &I : *CompareBlock) {
      if (isa<CmpInst>(&I)) {
        Var1 = VariablesMap[I.getOperand(0)];
        Var2 = VariablesMap[I.getOperand(1)];
        break;
      }
    }

    for (BasicBlock *BB : LoopBasicBlocks) {
      for (Instruction &I : *BB) {
        if (isa<StoreInst>(&I)) {
          if (I.getOperand(1) == Var1 || I.getOperand(1) == Var2) {
            return false;
          }
        }
      }
    }

    return true;
  }

  BasicBlock *copyBlock(BasicBlock *OriginalBlock)
  {
    std::unordered_map<Value *, Value *> Mapping;
    Instruction *Clone;
    IRBuilder<> Builder(OriginalBlock->getContext());
    BasicBlock *NewBasicBlock = BasicBlock::Create(OriginalBlock->getContext(), "",
                                OriginalBlock->getParent(), LoopBasicBlocks.front());

    Builder.SetInsertPoint(NewBasicBlock);
    for (Instruction &I : *OriginalBlock) {
      Clone = I.clone();
      Builder.Insert(Clone);
      Mapping[&I] = Clone;

      for (size_t i = 0; i < Clone->getNumOperands(); i++) {
        if (Mapping.find(Clone->getOperand(i)) != Mapping.end()) {
          Clone->setOperand(i, Mapping[Clone->getOperand(i)]);
        }
      }
    }

    return NewBasicBlock;
  }

  BasicBlock *copyLoop(Loop *L)
  {
    BasicBlock *Exit = L->getExitBlock();
    std::unordered_map<Value *, Value *> Mapping;
    std::unordered_map<BasicBlock *, BasicBlock *> BasicBlocksMapping;
    Instruction *Clone;
    IRBuilder<> Builder(Exit);
    BasicBlock *NewBasicBlock;
    std::vector<BasicBlock *> LoopBasicBlocksCopy;

    for (BasicBlock *BB : LoopBasicBlocks) {
      NewBasicBlock = BasicBlock::Create(BB->getContext(), "", BB->getParent(), Exit);
      LoopBasicBlocksCopy.push_back(NewBasicBlock);
      BasicBlocksMapping[BB] = NewBasicBlock;
    }

    for (BasicBlock *BB : LoopBasicBlocks) {
      Builder.SetInsertPoint(BasicBlocksMapping[BB]);

      for (Instruction &I : *BB) {
        Clone = I.clone();
        Mapping[&I] = Clone;
        Builder.Insert(Clone);

        for (size_t i = 0; i < Clone->getNumOperands(); i++) {
          if (Mapping.find(Clone->getOperand(i)) != Mapping.end()) {
            Clone->setOperand(i, Mapping[Clone->getOperand(i)]);
          }
        }
      }
    }

    for (BasicBlock *Copy : LoopBasicBlocksCopy) {
      for (size_t i = 0; i < Copy->getTerminator()->getNumSuccessors(); i++) {
        if (BasicBlocksMapping.find(Copy->getTerminator()->getSuccessor(i)) != BasicBlocksMapping.end()) {
          Copy->getTerminator()->setSuccessor(i, BasicBlocksMapping[Copy->getTerminator()->getSuccessor(i)]);
        }
      }
    }

    BasicBlock *CompareBasicBlock = findCompareBasicBlock(LoopBasicBlocksCopy);
    LoopBasicBlocksCopy.front()->getTerminator()->setSuccessor(0, CompareBasicBlock->getTerminator()->getSuccessor(1));
    CompareBasicBlock->getTerminator()->getSuccessor(0)->eraseFromParent();
    CompareBasicBlock->eraseFromParent();

    return LoopBasicBlocksCopy.front();
  }

  void unswitch(Loop *L)
  {
    BasicBlock *CompareBlock = findCompareBasicBlock(LoopBasicBlocks);
    BasicBlock *CopyBlock = copyBlock(CompareBlock);
    L->getLoopPreheader()->getTerminator()->setSuccessor(0, CopyBlock);
    CopyBlock->getTerminator()->setSuccessor(0, LoopBasicBlocks.front());

    BasicBlock *LoopCopy = copyLoop(L);
    CopyBlock->getTerminator()->setSuccessor(1, LoopCopy);

    LoopBasicBlocks.front()->getTerminator()->setSuccessor(0, CompareBlock->getTerminator()->getSuccessor(0));
    CompareBlock->getTerminator()->getSuccessor(1)->eraseFromParent();
    CompareBlock->eraseFromParent();
  }

  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    mapVariables(L);
    LoopBasicBlocks = L->getBlocksVector();
    if (canUnswitch()) {
      unswitch(L);
    }
    return true;
  }
}; // end of struct OurLoopUnswitchingPass
}  // end of anonymous namespace

char OurLoopUnswitchingPass::ID = 0;
static RegisterPass<OurLoopUnswitchingPass> X("loop-unswitching", "",
                                              false /* Only looks at CFG */,
                                              false /* Analysis Pass */);