199 #define DEBUG_TYPE "loop-predication" 201 STATISTIC(TotalConsidered,
"Number of guards considered");
202 STATISTIC(TotalWidened,
"Number of checks widened");
204 using namespace llvm;
222 cl::desc(
"scale factor for the latch probability. Value should be greater " 223 "than 1. Lower values are ignored"));
226 class LoopPredication {
235 : Pred(Pred), IV(IV), Limit(Limit) {}
238 dbgs() <<
"LoopICmp Pred = " << Pred <<
", IV = " << *IV
239 <<
", Limit = " << *Limit <<
"\n";
251 bool isSupportedStep(
const SCEV* Step);
261 bool CanExpand(
const SCEV* S);
282 bool isLoopProfitableToPredicate();
296 bool isSafeToTruncateWideIVType(
Type *RangeCheckType);
303 : SE(SE), BPI(BPI){};
304 bool runOnLoop(
Loop *L);
307 class LoopPredicationLegacyPass :
public LoopPass {
310 LoopPredicationLegacyPass() :
LoopPass(ID) {
322 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
324 getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
325 LoopPredication LP(SE, &BPI);
326 return LP.runOnLoop(L);
334 "Loop predication",
false,
false)
341 return new LoopPredicationLegacyPass();
349 Function *
F = L.getHeader()->getParent();
351 LoopPredication LP(&AR.
SE, BPI);
352 if (!LP.runOnLoop(&L))
362 if (isa<SCEVCouldNotCompute>(LHSS))
365 if (isa<SCEVCouldNotCompute>(RHSS))
379 return LoopICmp(Pred, AR, RHSS);
389 assert(Ty == RHS->
getType() &&
"expandCheck operands have different types?");
400 LoopPredication::generateLoopLatchCheck(
Type *RangeCheckType) {
402 auto *LatchType = LatchCheck.IV->getType();
403 if (RangeCheckType == LatchType)
408 if (!isSafeToTruncateWideIVType(RangeCheckType))
412 LoopICmp NewLatchCheck;
413 NewLatchCheck.Pred = LatchCheck.Pred;
416 if (!NewLatchCheck.IV)
418 NewLatchCheck.Limit = SE->
getTruncateExpr(LatchCheck.Limit, RangeCheckType);
420 <<
"can be represented as range check type:" 421 << *RangeCheckType <<
"\n");
422 LLVM_DEBUG(
dbgs() <<
"LatchCheck.IV: " << *NewLatchCheck.IV <<
"\n");
423 LLVM_DEBUG(
dbgs() <<
"LatchCheck.Limit: " << *NewLatchCheck.Limit <<
"\n");
424 return NewLatchCheck;
427 bool LoopPredication::isSupportedStep(
const SCEV* Step) {
431 bool LoopPredication::CanExpand(
const SCEV* S) {
436 LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
438 auto *Ty = RangeCheck.IV->getType();
445 const SCEV *GuardStart = RangeCheck.IV->getStart();
446 const SCEV *GuardLimit = RangeCheck.Limit;
447 const SCEV *LatchStart = LatchCheck.IV->getStart();
448 const SCEV *LatchLimit = LatchCheck.Limit;
454 if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
455 !CanExpand(LatchLimit) || !CanExpand(RHS)) {
459 auto LimitCheckPred =
468 expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt);
469 auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
470 GuardStart, GuardLimit, InsertAt);
471 return Builder.
CreateAnd(FirstIterationCheck, LimitCheck);
475 LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
477 auto *Ty = RangeCheck.IV->getType();
478 const SCEV *GuardStart = RangeCheck.IV->getStart();
479 const SCEV *GuardLimit = RangeCheck.Limit;
480 const SCEV *LatchLimit = LatchCheck.Limit;
481 if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
482 !CanExpand(LatchLimit)) {
488 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
489 if (RangeCheck.IV != PostDecLatchCheckIV) {
491 << *PostDecLatchCheckIV
492 <<
" and RangeCheckIV: " << *RangeCheck.IV <<
"\n");
501 auto LimitCheckPred =
504 GuardStart, GuardLimit, InsertAt);
505 auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
506 SE->
getOne(Ty), InsertAt);
507 return Builder.
CreateAnd(FirstIterationCheck, LimitCheck);
523 auto RangeCheck = parseLoopICmp(ICI);
525 LLVM_DEBUG(
dbgs() <<
"Failed to parse the loop latch condition!\n");
532 << RangeCheck->Pred <<
")!\n");
535 auto *RangeCheckIV = RangeCheck->IV;
536 if (!RangeCheckIV->isAffine()) {
540 auto *Step = RangeCheckIV->getStepRecurrence(*SE);
543 if (!isSupportedStep(Step)) {
544 LLVM_DEBUG(
dbgs() <<
"Range check and latch have IVs different steps!\n");
547 auto *Ty = RangeCheckIV->getType();
548 auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty);
549 if (!CurrLatchCheckOpt) {
551 "corresponding to range type: " 556 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
560 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
561 "Range and latch steps should be of same type!");
562 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
563 LLVM_DEBUG(
dbgs() <<
"Range and latch have different step values!\n");
568 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
572 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
577 bool LoopPredication::widenGuardConditions(
IntrinsicInst *Guard,
596 unsigned NumWidened = 0;
598 Value *Condition = Worklist.pop_back_val();
599 if (!Visited.
insert(Condition).second)
605 Worklist.push_back(LHS);
606 Worklist.push_back(RHS);
610 if (
ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
611 if (
auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
612 Checks.
push_back(NewRangeCheck.getValue());
620 }
while (Worklist.size() != 0);
625 TotalWidened += NumWidened;
629 Value *LastCheck =
nullptr;
630 for (
auto *
Check : Checks)
661 "One of the latch's destinations must be the header");
665 auto Result = parseLoopICmp(Pred, LHS, RHS);
667 LLVM_DEBUG(
dbgs() <<
"Failed to parse the loop latch condition!\n");
673 if (!Result->IV->isAffine()) {
678 auto *Step = Result->IV->getStepRecurrence(*SE);
679 if (!isSupportedStep(Step)) {
680 LLVM_DEBUG(
dbgs() <<
"Unsupported loop stride(" << *Step <<
")!\n");
695 if (IsUnsupportedPredicate(Step, Result->Pred)) {
696 LLVM_DEBUG(
dbgs() <<
"Unsupported loop latch predicate(" << Result->Pred
704 bool LoopPredication::isSafeToTruncateWideIVType(
Type *RangeCheckType) {
709 "Expected latch check IV type to be larger than range check operand " 715 if (!Limit || !Start)
729 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
730 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
733 bool LoopPredication::isLoopProfitableToPredicate() {
741 if (ExitEdges.
size() == 1)
750 assert(LatchBlock &&
"Should have a single latch at this point!");
751 auto *LatchTerm = LatchBlock->getTerminator();
752 assert(LatchTerm->getNumSuccessors() == 2 &&
753 "expected to be an exiting block with 2 succs!");
754 unsigned LatchBrExitIdx =
755 LatchTerm->getSuccessor(0) == L->
getHeader() ? 1 : 0;
762 if (ScaleFactor < 1) {
765 <<
"Ignored user setting for loop-predication-latch-probability-scale: " 770 const auto LatchProbabilityThreshold =
771 LatchExitProbability * ScaleFactor;
773 for (
const auto &ExitEdge : ExitEdges) {
778 if (ExitingBlockProbability > LatchProbabilityThreshold)
787 bool LoopPredication::runOnLoop(
Loop *
Loop) {
798 if (!GuardDecl || GuardDecl->use_empty())
801 DL = &M->getDataLayout();
807 auto LatchCheckOpt = parseLoopLatchICmp();
810 LatchCheck = *LatchCheckOpt;
815 if (!isLoopProfitableToPredicate()) {
822 for (
const auto BB : L->
blocks())
832 bool Changed =
false;
833 for (
auto *Guard : Guards)
834 Changed |= widenGuardConditions(Guard, Expander);
Pass interface - Implemented by all 'passes'.
static bool Check(DecodeStatus &Out, DecodeStatus In)
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
A parsed version of the target data layout string in and methods for querying it. ...
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
STATISTIC(TotalConsidered, "Number of guards considered")
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
This class represents lattice values for constants.
A Module instance is used to store all the information related to an LLVM module. ...
void push_back(const T &Elt)
The main scalar evolution driver.
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isMonotonicPredicate(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing)
Return true if, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) INITIALIZE_PASS_END(LoopPredicationLegacyPass
void dump() const
Support for debugging, callable in GDB: V->dump()
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
#define INITIALIZE_PASS_DEPENDENCY(depName)
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE, etc.
const Loop * getLoop() const
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
BlockT * getHeader() const
Analysis pass which computes BranchProbabilityInfo.
This node represents a polynomial recurrence on the trip count of the specified loop.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block...
Legacy analysis pass which computes BranchProbabilityInfo.
Value * getOperand(unsigned i) const
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS. ...
initializer< Ty > init(const Ty &Val)
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
A set of analyses that are preserved following a run of a transformation pass.
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM Basic Block Representation.
The instances of the Type class are immutable: once they are created, they are never changed...
ConstantInt * getTrue()
Get the constant value for i1 true.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV *> &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
brc_match< Cond_t > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
Represent the analysis usage information of a pass.
Pass * createLoopPredicationPass()
This instruction compares its operands according to the predicate given to the constructor.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
BranchProbability getEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors) const
Get an edge's probability, relative to other out-edges of the Src.
void getExitEdges(SmallVectorImpl< Edge > &ExitEdges) const
Return all pairs of (inside_block,outside_block).
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty)
An analysis over an "inner" IR unit that provides access to an analysis manager over a "outer" IR uni...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
bool isGuard(const User *U)
Returns true iff U has semantics of a guard.
void setOperand(unsigned i, Value *Val)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
This class uses information about analyze scalars to rewrite expressions in canonical form...
iterator insert(iterator I, T &&Elt)
uint64_t getTypeSizeInBits(Type *Ty) const
Size examples:
Predicate getPredicate() const
Return the predicate for this instruction.
Analysis providing branch probability information.
This class represents an analyzed expression in the program.
LLVM_NODISCARD bool empty() const
unsigned greater or equal
Represents a single loop in the control flow graph.
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass's AnalysisUsage.
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
bool isOne() const
Return true if the expression is a constant one.
LLVM Value Representation.
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
A container for analyses that lazily runs them and caches their results.
bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE)
Return true if the given expression is safe to expand in the sense that all materialized values are s...
iterator_range< block_iterator > blocks() const
A wrapper class for inspecting calls to intrinsic functions.
This class represents a constant integer value.
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)