LLVM API Documentation

ScalarEvolution.cpp

Go to the documentation of this file.
00001 //===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file is distributed under the University of Illinois Open Source
00006 // License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 //
00010 // This file contains the implementation of the scalar evolution analysis
00011 // engine, which is used primarily to analyze expressions involving induction
00012 // variables in loops.
00013 //
00014 // There are several aspects to this library.  First is the representation of
00015 // scalar expressions, which are represented as subclasses of the SCEV class.
00016 // These classes are used to represent certain types of subexpressions that we
00017 // can handle.  These classes are reference counted, managed by the SCEVHandle
00018 // class.  We only create one SCEV of a particular shape, so pointer-comparisons
00019 // for equality are legal.
00020 //
00021 // One important aspect of the SCEV objects is that they are never cyclic, even
00022 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
00023 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
00024 // recurrence) then we represent it directly as a recurrence node, otherwise we
00025 // represent it as a SCEVUnknown node.
00026 //
00027 // In addition to being able to represent expressions of various types, we also
00028 // have folders that are used to build the *canonical* representation for a
00029 // particular expression.  These folders are capable of using a variety of
00030 // rewrite rules to simplify the expressions.
00031 //
00032 // Once the folders are defined, we can implement the more interesting
00033 // higher-level code, such as the code that recognizes PHI nodes of various
00034 // types, computes the execution count of a loop, etc.
00035 //
00036 // TODO: We should use these routines and value representations to implement
00037 // dependence analysis!
00038 //
00039 //===----------------------------------------------------------------------===//
00040 //
00041 // There are several good references for the techniques used in this analysis.
00042 //
00043 //  Chains of recurrences -- a method to expedite the evaluation
00044 //  of closed-form functions
00045 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
00046 //
00047 //  On computational properties of chains of recurrences
00048 //  Eugene V. Zima
00049 //
00050 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
00051 //  Robert A. van Engelen
00052 //
00053 //  Efficient Symbolic Analysis for Optimizing Compilers
00054 //  Robert A. van Engelen
00055 //
00056 //  Using the chains of recurrences algebra for data dependence testing and
00057 //  induction variable substitution
00058 //  MS Thesis, Johnie Birch
00059 //
00060 //===----------------------------------------------------------------------===//
00061 
00062 #define DEBUG_TYPE "scalar-evolution"
00063 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
00064 #include "llvm/Constants.h"
00065 #include "llvm/DerivedTypes.h"
00066 #include "llvm/GlobalVariable.h"
00067 #include "llvm/Instructions.h"
00068 #include "llvm/Analysis/ConstantFolding.h"
00069 #include "llvm/Analysis/LoopInfo.h"
00070 #include "llvm/Assembly/Writer.h"
00071 #include "llvm/Transforms/Scalar.h"
00072 #include "llvm/Support/CFG.h"
00073 #include "llvm/Support/CommandLine.h"
00074 #include "llvm/Support/Compiler.h"
00075 #include "llvm/Support/ConstantRange.h"
00076 #include "llvm/Support/InstIterator.h"
00077 #include "llvm/Support/ManagedStatic.h"
00078 #include "llvm/Support/MathExtras.h"
00079 #include "llvm/Support/Streams.h"
00080 #include "llvm/ADT/Statistic.h"
00081 #include <ostream>
00082 #include <algorithm>
00083 #include <cmath>
00084 using namespace llvm;
00085 
00086 STATISTIC(NumArrayLenItCounts,
00087           "Number of trip counts computed with array length");
00088 STATISTIC(NumTripCountsComputed,
00089           "Number of loops with predictable loop counts");
00090 STATISTIC(NumTripCountsNotComputed,
00091           "Number of loops without predictable loop counts");
00092 STATISTIC(NumBruteForceTripCountsComputed,
00093           "Number of loops with trip counts computed by force");
00094 
00095 static cl::opt<unsigned>
00096 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
00097                         cl::desc("Maximum number of iterations SCEV will "
00098                                  "symbolically execute a constant derived loop"),
00099                         cl::init(100));
00100 
00101 static RegisterPass<ScalarEvolution>
00102 R("scalar-evolution", "Scalar Evolution Analysis", false, true);
00103 char ScalarEvolution::ID = 0;
00104 
00105 //===----------------------------------------------------------------------===//
00106 //                           SCEV class definitions
00107 //===----------------------------------------------------------------------===//
00108 
00109 //===----------------------------------------------------------------------===//
00110 // Implementation of the SCEV class.
00111 //
00112 SCEV::~SCEV() {}
00113 void SCEV::dump() const {
00114   print(cerr);
00115   cerr << '\n';
00116 }
00117 
00118 uint32_t SCEV::getBitWidth() const {
00119   if (const IntegerType* ITy = dyn_cast<IntegerType>(getType()))
00120     return ITy->getBitWidth();
00121   return 0;
00122 }
00123 
00124 bool SCEV::isZero() const {
00125   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
00126     return SC->getValue()->isZero();
00127   return false;
00128 }
00129 
00130 
00131 SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
00132 
00133 bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
00134   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
00135   return false;
00136 }
00137 
00138 const Type *SCEVCouldNotCompute::getType() const {
00139   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
00140   return 0;
00141 }
00142 
00143 bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
00144   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
00145   return false;
00146 }
00147 
00148 SCEVHandle SCEVCouldNotCompute::
00149 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
00150                                   const SCEVHandle &Conc,
00151                                   ScalarEvolution &SE) const {
00152   return this;
00153 }
00154 
00155 void SCEVCouldNotCompute::print(std::ostream &OS) const {
00156   OS << "***COULDNOTCOMPUTE***";
00157 }
00158 
00159 bool SCEVCouldNotCompute::classof(const SCEV *S) {
00160   return S->getSCEVType() == scCouldNotCompute;
00161 }
00162 
00163 
00164 // SCEVConstants - Only allow the creation of one SCEVConstant for any
00165 // particular value.  Don't use a SCEVHandle here, or else the object will
00166 // never be deleted!
00167 static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
00168 
00169 
00170 SCEVConstant::~SCEVConstant() {
00171   SCEVConstants->erase(V);
00172 }
00173 
00174 SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
00175   SCEVConstant *&R = (*SCEVConstants)[V];
00176   if (R == 0) R = new SCEVConstant(V);
00177   return R;
00178 }
00179 
00180 SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
00181   return getConstant(ConstantInt::get(Val));
00182 }
00183 
00184 const Type *SCEVConstant::getType() const { return V->getType(); }
00185 
00186 void SCEVConstant::print(std::ostream &OS) const {
00187   WriteAsOperand(OS, V, false);
00188 }
00189 
00190 // SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
00191 // particular input.  Don't use a SCEVHandle here, or else the object will
00192 // never be deleted!
00193 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>, 
00194                      SCEVTruncateExpr*> > SCEVTruncates;
00195 
00196 SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
00197   : SCEV(scTruncate), Op(op), Ty(ty) {
00198   assert(Op->getType()->isInteger() && Ty->isInteger() &&
00199          "Cannot truncate non-integer value!");
00200   assert(Op->getType()->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()
00201          && "This is not a truncating conversion!");
00202 }
00203 
00204 SCEVTruncateExpr::~SCEVTruncateExpr() {
00205   SCEVTruncates->erase(std::make_pair(Op, Ty));
00206 }
00207 
00208 void SCEVTruncateExpr::print(std::ostream &OS) const {
00209   OS << "(truncate " << *Op << " to " << *Ty << ")";
00210 }
00211 
00212 // SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
00213 // particular input.  Don't use a SCEVHandle here, or else the object will never
00214 // be deleted!
00215 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
00216                      SCEVZeroExtendExpr*> > SCEVZeroExtends;
00217 
00218 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
00219   : SCEV(scZeroExtend), Op(op), Ty(ty) {
00220   assert(Op->getType()->isInteger() && Ty->isInteger() &&
00221          "Cannot zero extend non-integer value!");
00222   assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
00223          && "This is not an extending conversion!");
00224 }
00225 
00226 SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
00227   SCEVZeroExtends->erase(std::make_pair(Op, Ty));
00228 }
00229 
00230 void SCEVZeroExtendExpr::print(std::ostream &OS) const {
00231   OS << "(zeroextend " << *Op << " to " << *Ty << ")";
00232 }
00233 
00234 // SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
00235 // particular input.  Don't use a SCEVHandle here, or else the object will never
00236 // be deleted!
00237 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
00238                      SCEVSignExtendExpr*> > SCEVSignExtends;
00239 
00240 SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
00241   : SCEV(scSignExtend), Op(op), Ty(ty) {
00242   assert(Op->getType()->isInteger() && Ty->isInteger() &&
00243          "Cannot sign extend non-integer value!");
00244   assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
00245          && "This is not an extending conversion!");
00246 }
00247 
00248 SCEVSignExtendExpr::~SCEVSignExtendExpr() {
00249   SCEVSignExtends->erase(std::make_pair(Op, Ty));
00250 }
00251 
00252 void SCEVSignExtendExpr::print(std::ostream &OS) const {
00253   OS << "(signextend " << *Op << " to " << *Ty << ")";
00254 }
00255 
00256 // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
00257 // particular input.  Don't use a SCEVHandle here, or else the object will never
00258 // be deleted!
00259 static ManagedStatic<std::map<std::pair<unsigned, std::vector<SCEV*> >,
00260                      SCEVCommutativeExpr*> > SCEVCommExprs;
00261 
00262 SCEVCommutativeExpr::~SCEVCommutativeExpr() {
00263   SCEVCommExprs->erase(std::make_pair(getSCEVType(),
00264                                       std::vector<SCEV*>(Operands.begin(),
00265                                                          Operands.end())));
00266 }
00267 
00268 void SCEVCommutativeExpr::print(std::ostream &OS) const {
00269   assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
00270   const char *OpStr = getOperationStr();
00271   OS << "(" << *Operands[0];
00272   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
00273     OS << OpStr << *Operands[i];
00274   OS << ")";
00275 }
00276 
00277 SCEVHandle SCEVCommutativeExpr::
00278 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
00279                                   const SCEVHandle &Conc,
00280                                   ScalarEvolution &SE) const {
00281   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
00282     SCEVHandle H =
00283       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
00284     if (H != getOperand(i)) {
00285       std::vector<SCEVHandle> NewOps;
00286       NewOps.reserve(getNumOperands());
00287       for (unsigned j = 0; j != i; ++j)
00288         NewOps.push_back(getOperand(j));
00289       NewOps.push_back(H);
00290       for (++i; i != e; ++i)
00291         NewOps.push_back(getOperand(i)->
00292                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
00293 
00294       if (isa<SCEVAddExpr>(this))
00295         return SE.getAddExpr(NewOps);
00296       else if (isa<SCEVMulExpr>(this))
00297         return SE.getMulExpr(NewOps);
00298       else if (isa<SCEVSMaxExpr>(this))
00299         return SE.getSMaxExpr(NewOps);
00300       else if (isa<SCEVUMaxExpr>(this))
00301         return SE.getUMaxExpr(NewOps);
00302       else
00303         assert(0 && "Unknown commutative expr!");
00304     }
00305   }
00306   return this;
00307 }
00308 
00309 
00310 // SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
00311 // input.  Don't use a SCEVHandle here, or else the object will never be
00312 // deleted!
00313 static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, 
00314                      SCEVUDivExpr*> > SCEVUDivs;
00315 
00316 SCEVUDivExpr::~SCEVUDivExpr() {
00317   SCEVUDivs->erase(std::make_pair(LHS, RHS));
00318 }
00319 
00320 void SCEVUDivExpr::print(std::ostream &OS) const {
00321   OS << "(" << *LHS << " /u " << *RHS << ")";
00322 }
00323 
00324 const Type *SCEVUDivExpr::getType() const {
00325   return LHS->getType();
00326 }
00327 
00328 
00329 // SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular
00330 // input.  Don't use a SCEVHandle here, or else the object will never be
00331 // deleted!
00332 static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, 
00333                      SCEVSDivExpr*> > SCEVSDivs;
00334 
00335 SCEVSDivExpr::~SCEVSDivExpr() {
00336   SCEVSDivs->erase(std::make_pair(LHS, RHS));
00337 }
00338 
00339 void SCEVSDivExpr::print(std::ostream &OS) const {
00340   OS << "(" << *LHS << " /s " << *RHS << ")";
00341 }
00342 
00343 const Type *SCEVSDivExpr::getType() const {
00344   return LHS->getType();
00345 }
00346 
00347 
00348 // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
00349 // particular input.  Don't use a SCEVHandle here, or else the object will never
00350 // be deleted!
00351 static ManagedStatic<std::map<std::pair<const Loop *, std::vector<SCEV*> >,
00352                      SCEVAddRecExpr*> > SCEVAddRecExprs;
00353 
00354 SCEVAddRecExpr::~SCEVAddRecExpr() {
00355   SCEVAddRecExprs->erase(std::make_pair(L,
00356                                         std::vector<SCEV*>(Operands.begin(),
00357                                                            Operands.end())));
00358 }
00359 
00360 SCEVHandle SCEVAddRecExpr::
00361 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
00362                                   const SCEVHandle &Conc,
00363                                   ScalarEvolution &SE) const {
00364   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
00365     SCEVHandle H =
00366       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
00367     if (H != getOperand(i)) {
00368       std::vector<SCEVHandle> NewOps;
00369       NewOps.reserve(getNumOperands());
00370       for (unsigned j = 0; j != i; ++j)
00371         NewOps.push_back(getOperand(j));
00372       NewOps.push_back(H);
00373       for (++i; i != e; ++i)
00374         NewOps.push_back(getOperand(i)->
00375                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
00376 
00377       return SE.getAddRecExpr(NewOps, L);
00378     }
00379   }
00380   return this;
00381 }
00382 
00383 
00384 bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
00385   // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
00386   // contain L and if the start is invariant.
00387   return !QueryLoop->contains(L->getHeader()) &&
00388          getOperand(0)->isLoopInvariant(QueryLoop);
00389 }
00390 
00391 
00392 void SCEVAddRecExpr::print(std::ostream &OS) const {
00393   OS << "{" << *Operands[0];
00394   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
00395     OS << ",+," << *Operands[i];
00396   OS << "}<" << L->getHeader()->getName() + ">";
00397 }
00398 
00399 // SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
00400 // value.  Don't use a SCEVHandle here, or else the object will never be
00401 // deleted!
00402 static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
00403 
00404 SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
00405 
00406 bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
00407   // All non-instruction values are loop invariant.  All instructions are loop
00408   // invariant if they are not contained in the specified loop.
00409   if (Instruction *I = dyn_cast<Instruction>(V))
00410     return !L->contains(I->getParent());
00411   return true;
00412 }
00413 
00414 const Type *SCEVUnknown::getType() const {
00415   return V->getType();
00416 }
00417 
00418 void SCEVUnknown::print(std::ostream &OS) const {
00419   WriteAsOperand(OS, V, false);
00420 }
00421 
00422 //===----------------------------------------------------------------------===//
00423 //                               SCEV Utilities
00424 //===----------------------------------------------------------------------===//
00425 
00426 namespace {
00427   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
00428   /// than the complexity of the RHS.  This comparator is used to canonicalize
00429   /// expressions.
00430   struct VISIBILITY_HIDDEN SCEVComplexityCompare {
00431     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
00432       return LHS->getSCEVType() < RHS->getSCEVType();
00433     }
00434   };
00435 }
00436 
00437 /// GroupByComplexity - Given a list of SCEV objects, order them by their
00438 /// complexity, and group objects of the same complexity together by value.
00439 /// When this routine is finished, we know that any duplicates in the vector are
00440 /// consecutive and that complexity is monotonically increasing.
00441 ///
00442 /// Note that we go take special precautions to ensure that we get determinstic
00443 /// results from this routine.  In other words, we don't want the results of
00444 /// this to depend on where the addresses of various SCEV objects happened to
00445 /// land in memory.
00446 ///
00447 static void GroupByComplexity(std::vector<SCEVHandle> &Ops) {
00448   if (Ops.size() < 2) return;  // Noop
00449   if (Ops.size() == 2) {
00450     // This is the common case, which also happens to be trivially simple.
00451     // Special case it.
00452     if (SCEVComplexityCompare()(Ops[1], Ops[0]))
00453       std::swap(Ops[0], Ops[1]);
00454     return;
00455   }
00456 
00457   // Do the rough sort by complexity.
00458   std::sort(Ops.begin(), Ops.end(), SCEVComplexityCompare());
00459 
00460   // Now that we are sorted by complexity, group elements of the same
00461   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
00462   // be extremely short in practice.  Note that we take this approach because we
00463   // do not want to depend on the addresses of the objects we are grouping.
00464   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
00465     SCEV *S = Ops[i];
00466     unsigned Complexity = S->getSCEVType();
00467 
00468     // If there are any objects of the same complexity and same value as this
00469     // one, group them.
00470     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
00471       if (Ops[j] == S) { // Found a duplicate.
00472         // Move it to immediately after i'th element.
00473         std::swap(Ops[i+1], Ops[j]);
00474         ++i;   // no need to rescan it.
00475         if (i == e-2) return;  // Done!
00476       }
00477     }
00478   }
00479 }
00480 
00481 
00482 
00483 //===----------------------------------------------------------------------===//
00484 //                      Simple SCEV method implementations
00485 //===----------------------------------------------------------------------===//
00486 
00487 /// getIntegerSCEV - Given an integer or FP type, create a constant for the
00488 /// specified signed integer value and return a SCEV for the constant.
00489 SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
00490   Constant *C;
00491   if (Val == 0)
00492     C = Constant::getNullValue(Ty);
00493   else if (Ty->isFloatingPoint())
00494     C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : 
00495                                 APFloat::IEEEdouble, Val));
00496   else 
00497     C = ConstantInt::get(Ty, Val);
00498   return getUnknown(C);
00499 }
00500 
00501 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
00502 ///
00503 SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
00504   if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
00505     return getUnknown(ConstantExpr::getNeg(VC->getValue()));
00506 
00507   return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType())));
00508 }
00509 
00510 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
00511 SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
00512   if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
00513     return getUnknown(ConstantExpr::getNot(VC->getValue()));
00514 
00515   SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType()));
00516   return getMinusSCEV(AllOnes, V);
00517 }
00518 
00519 /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
00520 ///
00521 SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
00522                                          const SCEVHandle &RHS) {
00523   // X - Y --> X + -Y
00524   return getAddExpr(LHS, getNegativeSCEV(RHS));
00525 }
00526 
00527 
00528 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
00529 // Assume, K > 0.
00530 static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
00531                                       ScalarEvolution &SE,
00532                                       const IntegerType* ResultTy) {
00533   // Handle the simplest case efficiently.
00534   if (K == 1)
00535     return SE.getTruncateOrZeroExtend(It, ResultTy);
00536 
00537   // We are using the following formula for BC(It, K):
00538   //
00539   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
00540   //
00541   // Suppose, W is the bitwidth of the return value.  We must be prepared for
00542   // overflow.  Hence, we must assure that the result of our computation is
00543   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
00544   // safe in modular arithmetic.
00545   //
00546   // However, this code doesn't use exactly that formula; the formula it uses
00547   // is something like the following, where T is the number of factors of 2 in 
00548   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
00549   // exponentiation:
00550   //
00551   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
00552   //
00553   // This formula is trivially equivalent to the previous formula.  However,
00554   // this formula can be implemented much more efficiently.  The trick is that
00555   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
00556   // arithmetic.  To do exact division in modular arithmetic, all we have
00557   // to do is multiply by the inverse.  Therefore, this step can be done at
00558   // width W.
00559   // 
00560   // The next issue is how to safely do the division by 2^T.  The way this
00561   // is done is by doing the multiplication step at a width of at least W + T
00562   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
00563   // when we perform the division by 2^T (which is equivalent to a right shift
00564   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
00565   // truncated out after the division by 2^T.
00566   //
00567   // In comparison to just directly using the first formula, this technique
00568   // is much more efficient; using the first formula requires W * K bits,
00569   // but this formula less than W + K bits. Also, the first formula requires
00570   // a division step, whereas this formula only requires multiplies and shifts.
00571   //
00572   // It doesn't matter whether the subtraction step is done in the calculation
00573   // width or the input iteration count's width; if the subtraction overflows,
00574   // the result must be zero anyway.  We prefer here to do it in the width of
00575   // the induction variable because it helps a lot for certain cases; CodeGen
00576   // isn't smart enough to ignore the overflow, which leads to much less
00577   // efficient code if the width of the subtraction is wider than the native
00578   // register width.
00579   //
00580   // (It's possible to not widen at all by pulling out factors of 2 before
00581   // the multiplication; for example, K=2 can be calculated as
00582   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
00583   // extra arithmetic, so it's not an obvious win, and it gets
00584   // much more complicated for K > 3.)
00585 
00586   // Protection from insane SCEVs; this bound is conservative,
00587   // but it probably doesn't matter.
00588   if (K > 1000)
00589     return new SCEVCouldNotCompute();
00590 
00591   unsigned W = ResultTy->getBitWidth();
00592 
00593   // Calculate K! / 2^T and T; we divide out the factors of two before
00594   // multiplying for calculating K! / 2^T to avoid overflow.
00595   // Other overflow doesn't matter because we only care about the bottom
00596   // W bits of the result.
00597   APInt OddFactorial(W, 1);
00598   unsigned T = 1;
00599   for (unsigned i = 3; i <= K; ++i) {
00600     APInt Mult(W, i);
00601     unsigned TwoFactors = Mult.countTrailingZeros();
00602     T += TwoFactors;
00603     Mult = Mult.lshr(TwoFactors);
00604     OddFactorial *= Mult;
00605   }
00606 
00607   // We need at least W + T bits for the multiplication step
00608   // FIXME: A temporary hack; we round up the bitwidths
00609   // to the nearest power of 2 to be nice to the code generator.
00610   unsigned CalculationBits = 1U << Log2_32_Ceil(W + T);
00611   // FIXME: Temporary hack to avoid generating integers that are too wide.
00612   // Although, it's not completely clear how to determine how much
00613   // widening is safe; for example, on X86, we can't really widen
00614   // beyond 64 because we need to be able to do multiplication
00615   // that's CalculationBits wide, but on X86-64, we can safely widen up to
00616   // 128 bits.
00617   if (CalculationBits > 64)
00618     return new SCEVCouldNotCompute();
00619 
00620   // Calcuate 2^T, at width T+W.
00621   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
00622 
00623   // Calculate the multiplicative inverse of K! / 2^T;
00624   // this multiplication factor will perform the exact division by
00625   // K! / 2^T.
00626   APInt Mod = APInt::getSignedMinValue(W+1);
00627   APInt MultiplyFactor = OddFactorial.zext(W+1);
00628   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
00629   MultiplyFactor = MultiplyFactor.trunc(W);
00630 
00631   // Calculate the product, at width T+W
00632   const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
00633   SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
00634   for (unsigned i = 1; i != K; ++i) {
00635     SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
00636     Dividend = SE.getMulExpr(Dividend,
00637                              SE.getTruncateOrZeroExtend(S, CalculationTy));
00638   }
00639 
00640   // Divide by 2^T
00641   SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
00642 
00643   // Truncate the result, and divide by K! / 2^T.
00644 
00645   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
00646                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
00647 }
00648 
00649 /// evaluateAtIteration - Return the value of this chain of recurrences at
00650 /// the specified iteration number.  We can evaluate this recurrence by
00651 /// multiplying each element in the chain by the binomial coefficient
00652 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
00653 ///
00654 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
00655 ///
00656 /// where BC(It, k) stands for binomial coefficient.
00657 ///
00658 SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
00659                                                ScalarEvolution &SE) const {
00660   SCEVHandle Result = getStart();
00661   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
00662     // The computation is correct in the face of overflow provided that the
00663     // multiplication is performed _after_ the evaluation of the binomial
00664     // coefficient.
00665     SCEVHandle Coeff = BinomialCoefficient(It, i, SE,
00666                                            cast<IntegerType>(getType()));
00667     if (isa<SCEVCouldNotCompute>(Coeff))
00668       return Coeff;
00669 
00670     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
00671   }
00672   return Result;
00673 }
00674 
00675 //===----------------------------------------------------------------------===//
00676 //                    SCEV Expression folder implementations
00677 //===----------------------------------------------------------------------===//
00678 
00679 SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, const Type *Ty) {
00680   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
00681     return getUnknown(
00682         ConstantExpr::getTrunc(SC->getValue(), Ty));
00683 
00684   // If the input value is a chrec scev made out of constants, truncate
00685   // all of the constants.
00686   if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
00687     std::vector<SCEVHandle> Operands;
00688     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
00689       // FIXME: This should allow truncation of other expression types!
00690       if (isa<SCEVConstant>(AddRec->getOperand(i)))
00691         Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
00692       else
00693         break;
00694     if (Operands.size() == AddRec->getNumOperands())
00695       return getAddRecExpr(Operands, AddRec->getLoop());
00696   }
00697 
00698   SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
00699   if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
00700   return Result;
00701 }
00702 
00703 SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, const Type *Ty) {
00704   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
00705     return getUnknown(
00706         ConstantExpr::getZExt(SC->getValue(), Ty));
00707 
00708   // FIXME: If the input value is a chrec scev, and we can prove that the value
00709   // did not overflow the old, smaller, value, we can zero extend all of the
00710   // operands (often constants).  This would allow analysis of something like
00711   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
00712 
00713   SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
00714   if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
00715   return Result;
00716 }
00717 
00718 SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *Ty) {
00719   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
00720     return getUnknown(
00721         ConstantExpr::getSExt(SC->getValue(), Ty));
00722 
00723   // FIXME: If the input value is a chrec scev, and we can prove that the value
00724   // did not overflow the old, smaller, value, we can sign extend all of the
00725   // operands (often constants).  This would allow analysis of something like
00726   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
00727 
00728   SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
00729   if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
00730   return Result;
00731 }
00732 
00733 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
00734 /// of the input value to the specified type.  If the type must be
00735 /// extended, it is zero extended.
00736 SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
00737                                                     const Type *Ty) {
00738   const Type *SrcTy = V->getType();
00739   assert(SrcTy->isInteger() && Ty->isInteger() &&
00740          "Cannot truncate or zero extend with non-integer arguments!");
00741   if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits())
00742     return V;  // No conversion
00743   if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits())
00744     return getTruncateExpr(V, Ty);
00745   return getZeroExtendExpr(V, Ty);
00746 }
00747 
00748 // get - Get a canonical add expression, or something simpler if possible.
00749 SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
00750   assert(!Ops.empty() && "Cannot get empty add!");
00751   if (Ops.size() == 1) return Ops[0];
00752 
00753   // Sort by complexity, this groups all similar expression types together.
00754   GroupByComplexity(Ops);
00755 
00756   // If there are any constants, fold them together.
00757   unsigned Idx = 0;
00758   if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
00759     ++Idx;
00760     assert(Idx < Ops.size());
00761     while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
00762       // We found two constants, fold them together!
00763       ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + 
00764                                            RHSC->getValue()->getValue());
00765       Ops[0] = getConstant(Fold);
00766       Ops.erase(Ops.begin()+1);  // Erase the folded element
00767       if (Ops.size() == 1) return Ops[0];
00768       LHSC = cast<SCEVConstant>(Ops[0]);
00769     }
00770 
00771     // If we are left with a constant zero being added, strip it off.
00772     if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
00773       Ops.erase(Ops.begin());
00774       --Idx;
00775     }
00776   }
00777 
00778   if (Ops.size() == 1) return Ops[0];
00779 
00780   // Okay, check to see if the same value occurs in the operand list twice.  If
00781   // so, merge them together into an multiply expression.  Since we sorted the
00782   // list, these values are required to be adjacent.
00783   const Type *Ty = Ops[0]->getType();
00784   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
00785     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
00786       // Found a match, merge the two values into a multiply, and add any
00787       // remaining values to the result.
00788       SCEVHandle Two = getIntegerSCEV(2, Ty);
00789       SCEVHandle Mul = getMulExpr(Ops[i], Two);
00790       if (Ops.size() == 2)
00791         return Mul;
00792       Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
00793       Ops.push_back(Mul);
00794       return getAddExpr(Ops);
00795     }
00796 
00797   // Now we know the first non-constant operand.  Skip past any cast SCEVs.
00798   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
00799     ++Idx;
00800 
00801   // If there are add operands they would be next.
00802   if (Idx < Ops.size()) {
00803     bool DeletedAdd = false;
00804     while (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
00805       // If we have an add, expand the add operands onto the end of the operands
00806       // list.
00807       Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
00808       Ops.erase(Ops.begin()+Idx);
00809       DeletedAdd = true;
00810     }
00811 
00812     // If we deleted at least one add, we added operands to the end of the list,
00813     // and they are not necessarily sorted.  Recurse to resort and resimplify
00814     // any operands we just aquired.
00815     if (DeletedAdd)
00816       return getAddExpr(Ops);
00817   }
00818 
00819   // Skip over the add expression until we get to a multiply.
00820   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
00821     ++Idx;
00822 
00823   // If we are adding something to a multiply expression, make sure the
00824   // something is not already an operand of the multiply.  If so, merge it into
00825   // the multiply.
00826   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
00827     SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
00828     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
00829       SCEV *MulOpSCEV = Mul->getOperand(MulOp);
00830       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
00831         if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
00832           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
00833           SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
00834           if (Mul->getNumOperands() != 2) {
00835             // If the multiply has more than two operands, we must get the
00836             // Y*Z term.
00837             std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
00838             MulOps.erase(MulOps.begin()+MulOp);
00839             InnerMul = getMulExpr(MulOps);
00840           }
00841           SCEVHandle One = getIntegerSCEV(1, Ty);
00842           SCEVHandle AddOne = getAddExpr(InnerMul, One);
00843           SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
00844           if (Ops.size() == 2) return OuterMul;
00845           if (AddOp < Idx) {
00846             Ops.erase(Ops.begin()+AddOp);
00847             Ops.erase(Ops.begin()+Idx-1);
00848           } else {
00849             Ops.erase(Ops.begin()+Idx);
00850             Ops.erase(Ops.begin()+AddOp-1);
00851           }
00852           Ops.push_back(OuterMul);
00853           return getAddExpr(Ops);
00854         }
00855 
00856       // Check this multiply against other multiplies being added together.
00857       for (unsigned OtherMulIdx = Idx+1;
00858            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
00859            ++OtherMulIdx) {
00860         SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
00861         // If MulOp occurs in OtherMul, we can fold the two multiplies
00862         // together.
00863         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
00864              OMulOp != e; ++OMulOp)
00865           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
00866             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
00867             SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
00868             if (Mul->getNumOperands() != 2) {
00869               std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
00870               MulOps.erase(MulOps.begin()+MulOp);
00871               InnerMul1 = getMulExpr(MulOps);
00872             }
00873             SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
00874             if (OtherMul->getNumOperands() != 2) {
00875               std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
00876                                              OtherMul->op_end());
00877               MulOps.erase(MulOps.begin()+OMulOp);
00878               InnerMul2 = getMulExpr(MulOps);
00879             }
00880             SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
00881             SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
00882             if (Ops.size() == 2) return OuterMul;
00883             Ops.erase(Ops.begin()+Idx);
00884             Ops.erase(Ops.begin()+OtherMulIdx-1);
00885             Ops.push_back(OuterMul);
00886             return getAddExpr(Ops);
00887           }
00888       }
00889     }
00890   }
00891 
00892   // If there are any add recurrences in the operands list, see if any other
00893   // added values are loop invariant.  If so, we can fold them into the
00894   // recurrence.
00895   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
00896     ++Idx;
00897 
00898   // Scan over all recurrences, trying to fold loop invariants into them.
00899   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
00900     // Scan all of the other operands to this add and add them to the vector if
00901     // they are loop invariant w.r.t. the recurrence.
00902     std::vector<SCEVHandle> LIOps;
00903     SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
00904     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
00905       if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
00906         LIOps.push_back(Ops[i]);
00907         Ops.erase(Ops.begin()+i);
00908         --i; --e;
00909       }
00910 
00911     // If we found some loop invariants, fold them into the recurrence.
00912     if (!LIOps.empty()) {
00913       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
00914       LIOps.push_back(AddRec->getStart());
00915 
00916       std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
00917       AddRecOps[0] = getAddExpr(LIOps);
00918 
00919       SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->