LLVM API Documentation

Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

LowerSwitch.cpp

Go to the documentation of this file.
00001 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file is distributed under the University of Illinois Open Source
00006 // License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 //
00010 // The LowerSwitch transformation rewrites switch instructions with a sequence
00011 // of branches, which allows targets to get away with not implementing the
00012 // switch instruction until it is convenient.
00013 //
00014 //===----------------------------------------------------------------------===//
00015 
00016 #include "llvm/Transforms/Scalar.h"
00017 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
00018 #include "llvm/Constants.h"
00019 #include "llvm/Function.h"
00020 #include "llvm/Instructions.h"
00021 #include "llvm/Pass.h"
00022 #include "llvm/ADT/STLExtras.h"
00023 #include "llvm/Support/Debug.h"
00024 #include "llvm/Support/Compiler.h"
00025 #include "llvm/Support/raw_ostream.h"
00026 #include <algorithm>
00027 using namespace llvm;
00028 
00029 namespace {
00030   /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
00031   /// instructions.  Note that this cannot be a BasicBlock pass because it
00032   /// modifies the CFG!
00033   class VISIBILITY_HIDDEN LowerSwitch : public FunctionPass {
00034   public:
00035     static char ID; // Pass identification, replacement for typeid
00036     LowerSwitch() : FunctionPass((intptr_t) &ID) {} 
00037 
00038     virtual bool runOnFunction(Function &F);
00039     
00040     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
00041       // This is a cluster of orthogonal Transforms
00042       AU.addPreserved<UnifyFunctionExitNodes>();
00043       AU.addPreservedID(PromoteMemoryToRegisterID);
00044       AU.addPreservedID(LowerInvokePassID);
00045       AU.addPreservedID(LowerAllocationsID);
00046     }
00047 
00048     struct CaseRange {
00049       Constant* Low;
00050       Constant* High;
00051       BasicBlock* BB;
00052 
00053       CaseRange() : Low(0), High(0), BB(0) { }
00054       CaseRange(Constant* low, Constant* high, BasicBlock* bb) :
00055         Low(low), High(high), BB(bb) { }
00056     };
00057 
00058     typedef std::vector<CaseRange>           CaseVector;
00059     typedef std::vector<CaseRange>::iterator CaseItr;
00060   private:
00061     void processSwitchInst(SwitchInst *SI);
00062 
00063     BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val,
00064                               BasicBlock* OrigBlock, BasicBlock* Default);
00065     BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val,
00066                              BasicBlock* OrigBlock, BasicBlock* Default);
00067     unsigned Clusterify(CaseVector& Cases, SwitchInst *SI);
00068   };
00069 
00070   /// The comparison function for sorting the switch case values in the vector.
00071   /// WARNING: Case ranges should be disjoint!
00072   struct CaseCmp {
00073     bool operator () (const LowerSwitch::CaseRange& C1,
00074                       const LowerSwitch::CaseRange& C2) {
00075 
00076       const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
00077       const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
00078       return CI1->getValue().slt(CI2->getValue());
00079     }
00080   };
00081 }
00082 
00083 char LowerSwitch::ID = 0;
00084 static RegisterPass<LowerSwitch>
00085 X("lowerswitch", "Lower SwitchInst's to branches");
00086 
00087 // Publically exposed interface to pass...
00088 const PassInfo *const llvm::LowerSwitchID = &X;
00089 // createLowerSwitchPass - Interface to this file...
00090 FunctionPass *llvm::createLowerSwitchPass() {
00091   return new LowerSwitch();
00092 }
00093 
00094 bool LowerSwitch::runOnFunction(Function &F) {
00095   bool Changed = false;
00096 
00097   for (Function::iterator I = F.begin(), E = F.end(); I != E; ) {
00098     BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks
00099 
00100     if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) {
00101       Changed = true;
00102       processSwitchInst(SI);
00103     }
00104   }
00105 
00106   return Changed;
00107 }
00108 
00109 // operator<< - Used for debugging purposes.
00110 //
00111 static std::ostream& operator<<(std::ostream &O,
00112                                 const LowerSwitch::CaseVector &C) {
00113   O << "[";
00114 
00115   for (LowerSwitch::CaseVector::const_iterator B = C.begin(),
00116          E = C.end(); B != E; ) {
00117     O << *B->Low << " -" << *B->High;
00118     if (++B != E) O << ", ";
00119   }
00120 
00121   return O << "]";
00122 }
00123 
00124 static OStream& operator<<(OStream &O, const LowerSwitch::CaseVector &C) {
00125   if (O.stream()) *O.stream() << C;
00126   return O;
00127 }
00128 
00129 // switchConvert - Convert the switch statement into a binary lookup of
00130 // the case values. The function recursively builds this tree.
00131 //
00132 BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End,
00133                                        Value* Val, BasicBlock* OrigBlock,
00134                                        BasicBlock* Default)
00135 {
00136   unsigned Size = End - Begin;
00137 
00138   if (Size == 1)
00139     return newLeafBlock(*Begin, Val, OrigBlock, Default);
00140 
00141   unsigned Mid = Size / 2;
00142   std::vector<CaseRange> LHS(Begin, Begin + Mid);
00143   DOUT << "LHS: " << LHS << "\n";
00144   std::vector<CaseRange> RHS(Begin + Mid, End);
00145   DOUT << "RHS: " << RHS << "\n";
00146 
00147   CaseRange& Pivot = *(Begin + Mid);
00148   DEBUG(errs() << "Pivot ==> " 
00149                << cast<ConstantInt>(Pivot.Low)->getValue() << " -"
00150                << cast<ConstantInt>(Pivot.High)->getValue() << "\n";
00151         errs().flush());
00152 
00153   BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val,
00154                                       OrigBlock, Default);
00155   BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val,
00156                                       OrigBlock, Default);
00157 
00158   // Create a new node that checks if the value is < pivot. Go to the
00159   // left branch if it is and right branch if not.
00160   Function* F = OrigBlock->getParent();
00161   BasicBlock* NewNode = BasicBlock::Create("NodeBlock");
00162   Function::iterator FI = OrigBlock;
00163   F->getBasicBlockList().insert(++FI, NewNode);
00164 
00165   ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot");
00166   NewNode->getInstList().push_back(Comp);
00167   BranchInst::Create(LBranch, RBranch, Comp, NewNode);
00168   return NewNode;
00169 }
00170 
00171 // newLeafBlock - Create a new leaf block for the binary lookup tree. It
00172 // checks if the switch's value == the case's value. If not, then it
00173 // jumps to the default branch. At this point in the tree, the value
00174 // can't be another valid case value, so the jump to the "default" branch
00175 // is warranted.
00176 //
00177 BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val,
00178                                       BasicBlock* OrigBlock,
00179                                       BasicBlock* Default)
00180 {
00181   Function* F = OrigBlock->getParent();
00182   BasicBlock* NewLeaf = BasicBlock::Create("LeafBlock");
00183   Function::iterator FI = OrigBlock;
00184   F->getBasicBlockList().insert(++FI, NewLeaf);
00185 
00186   // Emit comparison
00187   ICmpInst* Comp = NULL;
00188   if (Leaf.Low == Leaf.High) {
00189     // Make the seteq instruction...
00190     Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, Leaf.Low,
00191                         "SwitchLeaf", NewLeaf);
00192   } else {
00193     // Make range comparison
00194     if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) {
00195       // Val >= Min && Val <= Hi --> Val <= Hi
00196       Comp = new ICmpInst(ICmpInst::ICMP_SLE, Val, Leaf.High,
00197                           "SwitchLeaf", NewLeaf);
00198     } else if (cast<ConstantInt>(Leaf.Low)->isZero()) {
00199       // Val >= 0 && Val <= Hi --> Val <=u Hi
00200       Comp = new ICmpInst(ICmpInst::ICMP_ULE, Val, Leaf.High,
00201                           "SwitchLeaf", NewLeaf);      
00202     } else {
00203       // Emit V-Lo <=u Hi-Lo
00204       Constant* NegLo = ConstantExpr::getNeg(Leaf.Low);
00205       Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo,
00206                                                    Val->getName()+".off",
00207                                                    NewLeaf);
00208       Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
00209       Comp = new ICmpInst(ICmpInst::ICMP_ULE, Add, UpperBound,
00210                           "SwitchLeaf", NewLeaf);
00211     }
00212   }
00213 
00214   // Make the conditional branch...
00215   BasicBlock* Succ = Leaf.BB;
00216   BranchInst::Create(Succ, Default, Comp, NewLeaf);
00217 
00218   // If there were any PHI nodes in this successor, rewrite one entry
00219   // from OrigBlock to come from NewLeaf.
00220   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
00221     PHINode* PN = cast<PHINode>(I);
00222     // Remove all but one incoming entries from the cluster
00223     uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() -
00224                      cast<ConstantInt>(Leaf.Low)->getSExtValue();    
00225     for (uint64_t j = 0; j < Range; ++j) {
00226       PN->removeIncomingValue(OrigBlock);
00227     }
00228     
00229     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
00230     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
00231     PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
00232   }
00233 
00234   return NewLeaf;
00235 }
00236 
00237 // Clusterify - Transform simple list of Cases into list of CaseRange's
00238 unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
00239   unsigned numCmps = 0;
00240 
00241   // Start with "simple" cases
00242   for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
00243     Cases.push_back(CaseRange(SI->getSuccessorValue(i),
00244                               SI->getSuccessorValue(i),
00245                               SI->getSuccessor(i)));
00246   std::sort(Cases.begin(), Cases.end(), CaseCmp());
00247 
00248   // Merge case into clusters
00249   if (Cases.size()>=2)
00250     for (CaseItr I=Cases.begin(), J=next(Cases.begin()); J!=Cases.end(); ) {
00251       int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue();
00252       int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue();
00253       BasicBlock* nextBB = J->BB;
00254       BasicBlock* currentBB = I->BB;
00255 
00256       // If the two neighboring cases go to the same destination, merge them
00257       // into a single case.
00258       if ((nextValue-currentValue==1) && (currentBB == nextBB)) {
00259         I->High = J->High;
00260         J = Cases.erase(J);
00261       } else {
00262         I = J++;
00263       }
00264     }
00265 
00266   for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) {
00267     if (I->Low != I->High)
00268       // A range counts double, since it requires two compares.
00269       ++numCmps;
00270   }
00271 
00272   return numCmps;
00273 }
00274 
00275 // processSwitchInst - Replace the specified switch instruction with a sequence
00276 // of chained if-then insts in a balanced binary search.
00277 //
00278 void LowerSwitch::processSwitchInst(SwitchInst *SI) {
00279   BasicBlock *CurBlock = SI->getParent();
00280   BasicBlock *OrigBlock = CurBlock;
00281   Function *F = CurBlock->getParent();
00282   Value *Val = SI->getOperand(0);  // The value we are switching on...
00283   BasicBlock* Default = SI->getDefaultDest();
00284 
00285   // If there is only the default destination, don't bother with the code below.
00286   if (SI->getNumOperands() == 2) {
00287     BranchInst::Create(SI->getDefaultDest(), CurBlock);
00288     CurBlock->getInstList().erase(SI);
00289     return;
00290   }
00291 
00292   // Create a new, empty default block so that the new hierarchy of
00293   // if-then statements go to this and the PHI nodes are happy.
00294   BasicBlock* NewDefault = BasicBlock::Create("NewDefault");
00295   F->getBasicBlockList().insert(Default, NewDefault);
00296 
00297   BranchInst::Create(Default, NewDefault);
00298 
00299   // If there is an entry in any PHI nodes for the default edge, make sure
00300   // to update them as well.
00301   for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
00302     PHINode *PN = cast<PHINode>(I);
00303     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
00304     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
00305     PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
00306   }
00307 
00308   // Prepare cases vector.
00309   CaseVector Cases;
00310   unsigned numCmps = Clusterify(Cases, SI);
00311 
00312   DOUT << "Clusterify finished. Total clusters: " << Cases.size()
00313        << ". Total compares: " << numCmps << "\n";
00314   DOUT << "Cases: " << Cases << "\n";
00315   
00316   BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val,
00317                                           OrigBlock, NewDefault);
00318 
00319   // Branch to our shiny new if-then stuff...
00320   BranchInst::Create(SwitchBlock, OrigBlock);
00321 
00322   // We are now done with the switch instruction, delete it.
00323   CurBlock->getInstList().erase(SI);
00324 }



This web site is hosted by the Computer Science Department at the University of Illinois at Urbana-Champaign.