41 #define DEBUG_TYPE "arm-parallel-dsp" 43 STATISTIC(NumSMLAD ,
"Number of smlad instructions generated");
47 cl::desc(
"Disable the ARM Parallel DSP pass"));
58 using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
70 OpChain(
Instruction *
I, ValueList &vl) : Root(I), AllValues(vl) { }
71 virtual ~OpChain() =
default;
73 void SetMemoryLocations() {
75 for (
auto *V : AllValues) {
76 if (
auto *
I = dyn_cast<Instruction>(V)) {
77 if (
I->mayWriteToMemory())
79 if (
auto *Ld = dyn_cast<LoadInst>(V))
85 unsigned size()
const {
return AllValues.size(); }
92 struct BinOpChain :
public OpChain {
95 bool Exchange =
false;
97 BinOpChain(
Instruction *
I, ValueList &lhs, ValueList &rhs) :
98 OpChain(I, lhs), LHS(lhs), RHS(rhs) {
100 AllValues.push_back(V);
103 bool AreSymmetrical(BinOpChain *Other);
111 OpChainList MACCandidates;
113 PMACPairList PMACPairs;
117 class ARMParallelDSP :
public LoopPass {
126 std::map<LoadInst*, LoadInst*> LoadPairs;
127 std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
129 bool RecordSequentialLoads(
BasicBlock *Header);
132 void CreateParallelMACPairs(
Reduction &R);
166 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
167 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
168 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
169 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
171 auto &TPC = getAnalysis<TargetPassConfig>();
181 LLVM_DEBUG(
dbgs() <<
"The loop header is not the loop latch: not " 182 "running pass ARMParallelDSP\n");
193 if (!ST->allowsUnalignedMem()) {
194 LLVM_DEBUG(
dbgs() <<
"Unaligned memory access not supported: not " 195 "running pass ARMParallelDSP\n");
200 LLVM_DEBUG(
dbgs() <<
"DSP extension not enabled: not running pass " 206 bool Changes =
false;
211 if (!RecordSequentialLoads(Header)) {
216 Changes = MatchSMLAD(F);
227 template<
unsigned MaxBitW
idth>
241 Value *Val, *LHS, *RHS;
243 if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
244 return IsNarrowSequence<MaxBitWidth>(Val, VL);
251 if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
253 cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() <<
"\n");
268 template<
typename MemInst>
271 if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
284 MemInstList &VecMem) {
298 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
302 VecMem.push_back(Ld0);
303 VecMem.push_back(Ld1);
309 bool ARMParallelDSP::RecordSequentialLoads(
BasicBlock *Header) {
311 for (
auto &I : *Header) {
318 std::map<LoadInst*, LoadInst*> BaseLoads;
320 for (
auto *Ld0 : Loads) {
321 for (
auto *Ld1 : Loads) {
325 if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
326 LoadPairs[Ld0] = Ld1;
327 if (BaseLoads.count(Ld0)) {
329 BaseLoads[Ld1] = Base;
330 SequentialLoads[Base].push_back(Ld1);
332 BaseLoads[Ld1] = Ld0;
333 SequentialLoads[Ld0].push_back(Ld1);
338 return LoadPairs.size() > 1;
341 void ARMParallelDSP::CreateParallelMACPairs(
Reduction &R) {
342 OpChainList &Candidates = R.MACCandidates;
343 PMACPairList &PMACPairs = R.PMACPairs;
344 const unsigned Elems = Candidates.size();
349 auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
350 if (!PMul0->AreSymmetrical(PMul1))
357 for (
unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
363 if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
367 <<
"\t Ld0: " << *Ld0 <<
"\n" 368 <<
"\t Ld1: " << *Ld1 <<
"\n" 369 <<
"and operands " << x + 2 <<
":\n" 370 <<
"\t Ld2: " << *Ld2 <<
"\n" 371 <<
"\t Ld3: " << *Ld3 <<
"\n");
373 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
374 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
376 PMACPairs.push_back(std::make_pair(PMul0, PMul1));
378 }
else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
381 PMul1->Exchange =
true;
382 PMACPairs.push_back(std::make_pair(PMul0, PMul1));
385 }
else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
386 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
390 PMul0->Exchange =
true;
392 PMACPairs.push_back(std::make_pair(PMul1, PMul0));
400 for (
unsigned i = 0; i < Elems; ++i) {
401 BinOpChain *PMul0 =
static_cast<BinOpChain*
>(Candidates[i].get());
402 if (Paired.
count(PMul0->Root))
405 for (
unsigned j = 0; j < Elems; ++j) {
409 BinOpChain *PMul1 =
static_cast<BinOpChain*
>(Candidates[j].get());
410 if (Paired.
count(PMul1->Root))
418 assert(PMul0 != PMul1 &&
"expected different chains");
425 if (CanPair(PMul0, PMul1)) {
438 for (
auto &Pair : Reduction.PMACPairs) {
439 BinOpChain *PMul0 = Pair.first;
440 BinOpChain *PMul1 = Pair.second;
442 dbgs() <<
"- "; PMul0->Root->dump();
443 dbgs() <<
"- "; PMul1->Root->dump());
445 auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
446 auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
447 Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
451 if (Acc != Reduction.Phi) {
453 Reduction.AccIntAdd->replaceAllUsesWith(Acc);
460 ReductionList &Reductions) {
462 const bool HasFnNoNaNAttr =
473 const auto *Ty = Phi.getType();
474 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
477 const bool IsReduction =
480 TheLoop, HasFnNoNaNAttr, RecDesc);
488 Reductions.push_back(
Reduction(&Phi, Acc));
492 dbgs() <<
"\nAccumulating integer additions (reductions) found:\n";
493 for (
auto &R : Reductions) {
494 dbgs() <<
"- "; R.Phi->dump();
495 dbgs() <<
"-> "; R.AccIntAdd->dump();
505 "expected mul instruction");
508 if (IsNarrowSequence<16>(MulOp0, LHS) &&
509 IsNarrowSequence<16>(MulOp1, RHS)) {
511 Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
516 OpChainList &Candidates) {
521 std::function<bool(Value*)>
Match =
522 [&Candidates, &Match](
Value *V) ->
bool {
533 case Instruction::Mul: {
536 if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1))
540 case Instruction::SExt:
548 << Candidates.size() <<
" candidates.\n");
554 Instructions &Writes) {
555 for (
auto &I : *Header) {
559 Writes.push_back(&I);
567 Instructions &Writes, OpChainList &MACCandidates) {
569 for (
auto &MAC : MACCandidates) {
580 for (
auto *I : Writes) {
582 assert(MAC->MemLocs.size() >= 2 &&
"expecting at least 2 memlocs");
583 for (
auto &MemLoc : MAC->MemLocs) {
598 for (
auto &
C : Candidates) {
605 C->SetMemoryLocations();
606 ValueList &LHS =
static_cast<BinOpChain*
>(
C.get())->LHS;
607 ValueList &RHS =
static_cast<BinOpChain*
>(
C.get())->RHS;
610 for (
unsigned i = 0, e = LHS.size(); i < e; i += 2) {
611 if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
650 bool ARMParallelDSP::MatchSMLAD(
Function &F) {
653 dbgs() <<
"Header block:\n"; Header->
dump();
654 dbgs() <<
"Loop info:\n\n"; L->
dump());
656 bool Changed =
false;
657 ReductionList Reductions;
660 for (
auto &R : Reductions) {
661 OpChainList MACCandidates;
666 R.MACCandidates = std::move(MACCandidates);
669 for (
auto &M : R.MACCandidates)
677 Instructions Reads, Writes;
680 for (
auto &R : Reductions) {
681 if (
AreAliased(AA, Reads, Writes, R.MACCandidates))
683 CreateParallelMACPairs(R);
684 Changed |= InsertParallelMACs(R);
692 const Type *LoadTy) {
704 <<
"- " << *VecLd0 <<
"\n" 705 <<
"- " << *VecLd1 <<
"\n" 706 <<
"- " << *Acc <<
"\n" 707 <<
"Exchange: " << Exchange <<
"\n");
726 CallInst *Call = Builder.CreateCall(SMLAD, Args);
732 bool BinOpChain::AreSymmetrical(BinOpChain *
Other) {
735 auto CompareValueList = [](
const ValueList &VL0,
736 const ValueList &VL1) {
737 if (VL0.size() != VL1.size()) {
738 LLVM_DEBUG(
dbgs() <<
"Muls are mismatching operand list lengths: " 739 << VL0.size() <<
" != " << VL1.size() <<
"\n");
743 const unsigned Pairs = VL0.size();
744 LLVM_DEBUG(
dbgs() <<
"Number of operand pairs: " << Pairs <<
"\n");
746 for (
unsigned i = 0; i < Pairs; ++i) {
747 const Value *V0 = VL0[i];
748 const Value *V1 = VL1[i];
756 if (!Inst0 || !Inst1)
759 if (Inst0->isSameOperationAs(Inst1)) {
764 const APInt *C0, *C1;
773 return CompareValueList(LHS, Other->LHS) &&
774 CompareValueList(RHS, Other->RHS);
778 return new ARMParallelDSP();
784 "Transform loops to use DSP intrinsics",
false,
false)
The access may reference and may modify the value stored in memory.
Pass interface - Implemented by all 'passes'.
static LoadInst * CreateLoadIns(IRBuilder< NoFolder > &IRB, LoadInst &BaseLoad, const Type *LoadTy)
A parsed version of the target data layout string in and methods for querying it. ...
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
This class represents lattice values for constants.
static cl::opt< bool > DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), cl::desc("Disable the ARM Parallel DSP pass"))
A Module instance is used to store all the information related to an LLVM module. ...
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, unsigned Align, const char *Name)
Provided to resolve 'CreateAlignedLoad(Ptr, Align, "...")' correctly, instead of converting the strin...
static constexpr LocationSize unknown()
void push_back(const T &Elt)
static bool AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, RecurrenceDescriptor &RedDes, DemandedBits *DB=nullptr, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Returns true if Phi is a reduction of type Kind and adds it to the RecurrenceDescriptor.
static bool AreAliased(AliasAnalysis *AA, Instructions &Reads, Instructions &Writes, OpChainList &MACCandidates)
The main scalar evolution driver.
This class represents a function call, abstracting a target machine's calling convention.
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1, const DataLayout &DL, ScalarEvolution &SE)
An immutable pass that tracks lazily created AssumptionCache objects.
bool mayWriteToMemory() const
Return true if this instruction may modify memory.
STATISTIC(NumFunctions, "Total number of functions")
An instruction for reading from memory.
INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp", "Transform loops to use DSP intrinsics", false, false) INITIALIZE_PASS_END(ARMParallelDSP
void dump() const
Support for debugging, callable in GDB: V->dump()
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
LLVMContext & getContext() const
Get the global data context.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
CastClass_match< OpTy, Instruction::Trunc > m_Trunc(const OpTy &Op)
Matches Trunc.
bool isIntegerTy() const
True if this is an instance of IntegerType.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
arm parallel Transform loops to use DSP intrinsics
Target-Independent Code Generator Pass Configuration Options.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
BlockT * getHeader() const
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Value * getOperand(unsigned i) const
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, ScalarEvolution &SE, bool CheckType=true)
Returns true if the memory operations A and B are consecutive.
initializer< Ty > init(const Ty &Val)
Pass * createARMParallelDSPPass()
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt...
LLVM Basic Block Representation.
static bool IsNarrowSequence(Value *V, ValueList &VL)
The instances of the Type class are immutable: once they are created, they are never changed...
static void AliasCandidates(BasicBlock *Header, Instructions &Reads, Instructions &Writes)
void dump() const
Dump the module to stderr (for debugging).
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Represent the analysis usage information of a pass.
match_combine_or< CastClass_match< OpTy, Instruction::ZExt >, CastClass_match< OpTy, Instruction::SExt > > m_ZExtOrSExt(const OpTy &Op)
Value * getPointerOperand()
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
The RecurrenceDescriptor is used to identify recurrences variables in a loop.
Representation for a specific memory location.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
This is the shared class of boolean and integer constants.
auto size(R &&Range, typename std::enable_if< std::is_same< typename std::iterator_traits< decltype(Range.begin())>::iterator_category, std::random_access_iterator_tag >::value, void >::type *=nullptr) -> decltype(std::distance(Range.begin(), Range.end()))
Get the size of a range.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Provides information about what library functions are available for the current target.
Drive the analysis of memory accesses in the loop.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Class for arbitrary precision integers.
InstListType::iterator iterator
Instruction iterators...
loop Loop Strength Reduction
unsigned getAlignment() const
Return the alignment of the access that is being performed.
StringRef getValueAsString() const
Return the attribute's value as a string.
Represents a single loop in the control flow graph.
StringRef getName() const
Return a constant reference to the value's name.
LLVM_NODISCARD ModRefInfo intersectModRef(const ModRefInfo MRI1, const ModRefInfo MRI2)
bool mayReadFromMemory() const
Return true if this instruction may read memory.
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Module * getParent()
Get the module that this global value is contained inside of...
LLVM Value Representation.
static void AddMACCandidate(OpChainList &Candidates, Instruction *Mul, Value *MulOp0, Value *MulOp1)
Attribute getFnAttribute(Attribute::AttrKind Kind) const
Return the attribute for the given attribute kind.
static void MatchParallelMACSequences(Reduction &R, OpChainList &Candidates)
Primary interface to the complete machine description for the target machine.
The legacy pass manager's analysis pass to compute loop information.
bool hasOneUse() const
Return true if there is exactly one user of this value.
Legacy analysis pass which computes a DominatorTree.
LLVM_NODISCARD bool isModOrRefSet(const ModRefInfo MRI)
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object...
ModRefInfo getModRefInfo(const CallBase *Call, const MemoryLocation &Loc)
getModRefInfo (for call sites) - Return information about whether a particular call site modifies or ...
static bool CheckMACMemory(OpChainList &Candidates)
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
PointerType * getType() const
Global values are always pointers.
const BasicBlock * getParent() const
static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header, ReductionList &Reductions)