110 Candidate() =
default;
113 : CandidateKind(CT),
Base(B),
Index(Idx), Stride(S),
Ins(I) {}
115 Kind CandidateKind = Invalid;
124 Value *Stride =
nullptr;
144 Candidate *Basis =
nullptr;
161 bool doInitialization(
Module &M)
override {
171 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
179 bool isSimplestForm(
const Candidate &
C);
186 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
190 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
193 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
197 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
206 Value *S, uint64_t ElementSize,
216 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
221 void factorArrayIndex(
Value *ArrayIdx,
const SCEV *
Base, uint64_t ElementSize,
228 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
230 bool &BumpWithUglyGEP);
236 std::list<Candidate> Candidates;
241 std::vector<Instruction *> UnlinkedInstructions;
249 "Straight line strength reduction",
false,
false)
257 return new StraightLineStrengthReduce();
260 bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
261 const Candidate &
C) {
262 return (Basis.Ins != C.Ins &&
265 Basis.Ins->getType() == C.Ins->getType() &&
267 DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) &&
269 Basis.Base == C.Base && Basis.Stride == C.Stride &&
270 Basis.CandidateKind == C.CandidateKind);
291 bool StraightLineStrengthReduce::isFoldable(
const Candidate &C,
303 unsigned NumNonZeroIndices = 0;
306 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
309 return NumNonZeroIndices <= 1;
312 bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &C) {
315 return C.Index->isOne() || C.Index->isMinusOne();
317 if (C.CandidateKind == Candidate::Mul) {
319 return C.Index->isZero();
323 return ((C.Index->isOne() || C.Index->isMinusOne()) &&
336 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
339 Candidate
C(CT, B, Idx, S, I);
353 if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) {
355 unsigned NumIterations = 0;
357 static const unsigned MaxNumIterations = 50;
358 for (
auto Basis = Candidates.rbegin();
359 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
360 ++Basis, ++NumIterations) {
361 if (isBasisFor(*Basis, C)) {
369 Candidates.push_back(C);
372 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
376 allocateCandidatesAndFindBasisForAdd(I);
378 case Instruction::Mul:
379 allocateCandidatesAndFindBasisForMul(I);
381 case Instruction::GetElementPtr:
382 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
387 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
390 if (!isa<IntegerType>(I->
getType()))
395 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
397 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
400 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
406 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
411 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
415 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), One, RHS,
432 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
439 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
445 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
449 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
454 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
458 if (!isa<IntegerType>(I->
getType()))
463 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
466 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
470 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
479 IntPtrTy, Idx->
getSExtValue() * (int64_t)ElementSize,
true);
480 allocateCandidatesAndFindBasis(
Candidate::GEP, B, ScaledIdx, S, I);
483 void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
485 uint64_t ElementSize,
488 allocateCandidatesAndFindBasisForGEP(
490 ArrayIdx, ElementSize,
GEP);
491 Value *LHS =
nullptr;
507 allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP);
514 allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP);
518 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
533 const SCEV *OrigIndexExpr = IndexExprs[I - 1];
534 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->
getType());
538 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs);
545 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
550 Value *TruncatedArrayIdx =
nullptr;
556 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
559 IndexExprs[I - 1] = OrigIndexExpr;
571 Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
575 bool &BumpWithUglyGEP) {
576 APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
578 APInt IndexOffset = Idx - BasisIdx;
580 BumpWithUglyGEP =
false;
585 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
591 BumpWithUglyGEP =
true;
596 if (IndexOffset == 1)
610 return Builder.
CreateShl(ExtendedStride, Exponent);
612 if ((-IndexOffset).isPowerOf2()) {
619 return Builder.
CreateMul(ExtendedStride, Delta);
622 void StraightLineStrengthReduce::rewriteCandidateWithBasis(
623 const Candidate &C,
const Candidate &Basis) {
624 assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
625 C.Stride == Basis.Stride);
628 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
634 if (!C.Ins->getParent())
638 bool BumpWithUglyGEP;
639 Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP);
640 Value *Reduced =
nullptr;
641 switch (C.CandidateKind) {
643 case Candidate::Mul: {
648 Reduced = Builder.
CreateSub(Basis.Ins, NegBump);
662 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
669 bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
670 if (BumpWithUglyGEP) {
672 unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
688 Reduced = Builder.
CreateGEP(
nullptr, Basis.Ins, Bump);
696 C.Ins->replaceAllUsesWith(Reduced);
699 C.Ins->removeFromParent();
700 UnlinkedInstructions.push_back(C.Ins);
707 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
708 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
709 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
713 for (
auto &I : *(Node->getBlock()))
714 allocateCandidatesAndFindBasis(&I);
718 while (!Candidates.empty()) {
719 const Candidate &C = Candidates.back();
720 if (C.Basis !=
nullptr) {
721 rewriteCandidateWithBasis(C, *C.Basis);
723 Candidates.pop_back();
727 for (
auto *UnlinkedInst : UnlinkedInstructions) {
728 for (
unsigned I = 0,
E = UnlinkedInst->getNumOperands(); I !=
E; ++
I) {
729 Value *
Op = UnlinkedInst->getOperand(I);
730 UnlinkedInst->setOperand(I,
nullptr);
733 UnlinkedInst->deleteValue();
735 bool Ret = !UnlinkedInstructions.empty();
736 UnlinkedInstructions.clear();
Value * CreateInBoundsGEP(Value *Ptr, ArrayRef< Value *> IdxList, const Twine &Name="")
FunctionPass * createStraightLineStrengthReducePass()
A parsed version of the target data layout string in and methods for querying it. ...
Value * getPointerOperand()
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
APInt sext(unsigned width) const
Sign extend to a new width.
GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2)
This class represents lattice values for constants.
A Module instance is used to store all the information related to an LLVM module. ...
void push_back(const T &Elt)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
The main scalar evolution driver.
static void unifyBitWidth(APInt &A, APInt &B)
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
LLVMContext & getContext() const
All values hold a context through their type.
unsigned getPointerSizeInBits(unsigned AS=0) const
Layout pointer size, in bits FIXME: The defaults need to be removed once all of the backends/clients ...
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
bool isVectorTy() const
True if this is an instance of VectorType.
unsigned getBitWidth() const
getBitWidth - Return the bitwidth of this constant.
unsigned getBitWidth() const
Return the number of bits in the APInt.
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Type * getSourceElementType() const
This file implements a class to represent arbitrary precision integral constant values and operations...
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
void initializeStraightLineStrengthReducePass(PassRegistry &)
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Type * getType() const
All values are typed, get the type of this value.
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
const APInt & getValue() const
Return the constant as an APInt value reference.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) INITIALIZE_PASS_END(StraightLineStrengthReduce
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void takeName(Value *V)
Transfer the name from V to this value.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
Value * getOperand(unsigned i) const
unsigned getAddressSpace() const
Returns the address space of this instruction's pointer type.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs ...
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space...
static bool runOnFunction(Function &F, bool PostInlining)
bool isAllOnesValue() const
Determine if all bits are set.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
The instances of the Type class are immutable: once they are created, they are never changed...
BinaryOp_match< LHS, RHS, Instruction::Or > m_Or(const LHS &L, const RHS &R)
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This is an important base class in LLVM.
Straight line strength reduction
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static const unsigned UnknownAddressSpace
Represent the analysis usage information of a pass.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
FunctionPass class - This class is used to implement most global optimizations.
Class to represent integer types.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is a trivially dead instruction, delete it.
Type * getIndexedType() const
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
CastClass_match< OpTy, Instruction::SExt > m_SExt(const OpTy &Op)
Matches SExt.
Value * CreateGEP(Value *Ptr, ArrayRef< Value *> IdxList, const Twine &Name="")
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
unsigned getNumOperands() const
This is the shared class of boolean and integer constants.
Type * getType() const
Return the LLVM type of this SCEV expression.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
unsigned logBase2() const
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
Class for arbitrary precision integers.
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
uint64_t getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
This class represents an analyzed expression in the program.
unsigned getIntegerBitWidth() const
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if LHS and RHS have no common bits set.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
iterator_range< df_iterator< T > > depth_first(const T &G)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
Legacy analysis pass which computes a DominatorTree.
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
gep_type_iterator gep_type_begin(const User *GEP)