27 #define DEBUG_TYPE "expandmemcmp" 29 STATISTIC(NumMemCmpCalls,
"Number of memcmp calls");
30 STATISTIC(NumMemCmpNotConstant,
"Number of memcmp calls without constant size");
32 "Number of memcmp calls with size greater than max size");
33 STATISTIC(NumMemCmpInlined,
"Number of inlined memcmp calls");
37 cl::desc(
"The number of loads per basic block for inline expansion of " 38 "memcmp that is only being compared against zero."));
45 class MemCmpExpansion {
51 ResultBlock() =
default;
58 uint64_t NumLoadsNonOneByte;
59 const uint64_t NumLoadsPerBlockForZeroCmp;
60 std::vector<BasicBlock *> LoadCmpBlocks;
63 const bool IsUsedForZeroCmp;
70 LoadEntry(
unsigned LoadSize, uint64_t
Offset)
71 : LoadSize(LoadSize), Offset(Offset) {
80 LoadEntryVector LoadSequence;
82 void createLoadCmpBlocks();
83 void createResultBlock();
84 void setupResultBlockPHINodes();
85 void setupEndBlockPHINodes();
86 Value *getCompareLoadPairs(
unsigned BlockIndex,
unsigned &LoadIndex);
87 void emitLoadCompareBlock(
unsigned BlockIndex);
88 void emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
90 void emitLoadCompareByteBlock(
unsigned BlockIndex,
unsigned OffsetBytes);
91 void emitMemCmpResultBlock();
92 Value *getMemCmpExpansionZeroCase();
93 Value *getMemCmpEqZeroOneBlock();
94 Value *getMemCmpOneBlock();
96 uint64_t OffsetBytes);
98 static LoadEntryVector
100 unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);
101 static LoadEntryVector
102 computeOverlappingLoadSequence(uint64_t Size,
unsigned MaxLoadSize,
103 unsigned MaxNumLoads,
104 unsigned &NumLoadsNonOneByte);
107 MemCmpExpansion(
CallInst *CI, uint64_t Size,
109 unsigned MaxNumLoads,
const bool IsUsedForZeroCmp,
110 unsigned MaxLoadsPerBlockForZeroCmp,
const DataLayout &TheDataLayout);
112 unsigned getNumBlocks();
113 uint64_t getNumLoads()
const {
return LoadSequence.size(); }
115 Value *getMemCmpExpansion();
120 const unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte) {
121 NumLoadsNonOneByte = 0;
122 LoadEntryVector LoadSequence;
124 while (Size && !LoadSizes.
empty()) {
125 const unsigned LoadSize = LoadSizes.
front();
126 const uint64_t NumLoadsForThisSize = Size / LoadSize;
127 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
134 if (NumLoadsForThisSize > 0) {
135 for (uint64_t
I = 0;
I < NumLoadsForThisSize; ++
I) {
136 LoadSequence.push_back({LoadSize, Offset});
140 ++NumLoadsNonOneByte;
141 Size = Size % LoadSize;
149 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
150 const unsigned MaxLoadSize,
151 const unsigned MaxNumLoads,
152 unsigned &NumLoadsNonOneByte) {
154 if (Size < 2 || MaxLoadSize < 2)
159 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
160 assert(NumNonOverlappingLoads &&
"there must be at least one load");
163 Size = Size - NumNonOverlappingLoads * MaxLoadSize;
170 if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
174 LoadEntryVector LoadSequence;
176 for (uint64_t
I = 0;
I < NumNonOverlappingLoads; ++
I) {
177 LoadSequence.push_back({MaxLoadSize, Offset});
178 Offset += MaxLoadSize;
182 assert(Size > 0 && Size < MaxLoadSize &&
"broken invariant");
183 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize -
Size)});
184 NumLoadsNonOneByte = 1;
196 MemCmpExpansion::MemCmpExpansion(
199 const unsigned MaxNumLoads,
const bool IsUsedForZeroCmp,
200 const unsigned MaxLoadsPerBlockForZeroCmp,
const DataLayout &TheDataLayout)
204 NumLoadsNonOneByte(0),
205 NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp),
206 IsUsedForZeroCmp(IsUsedForZeroCmp),
209 assert(Size > 0 &&
"zero blocks");
215 assert(!LoadSizes.
empty() &&
"cannot load Size bytes");
216 MaxLoadSize = LoadSizes.
front();
218 unsigned GreedyNumLoadsNonOneByte = 0;
219 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads,
220 GreedyNumLoadsNonOneByte);
221 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
222 assert(LoadSequence.size() <= MaxNumLoads &&
"broken invariant");
226 (LoadSequence.empty() || LoadSequence.size() > 2)) {
227 unsigned OverlappingNumLoadsNonOneByte = 0;
228 auto OverlappingLoads = computeOverlappingLoadSequence(
229 Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte);
230 if (!OverlappingLoads.empty() &&
231 (LoadSequence.empty() ||
232 OverlappingLoads.size() < LoadSequence.size())) {
233 LoadSequence = OverlappingLoads;
234 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
237 assert(LoadSequence.size() <= MaxNumLoads &&
"broken invariant");
240 unsigned MemCmpExpansion::getNumBlocks() {
241 if (IsUsedForZeroCmp)
242 return getNumLoads() / NumLoadsPerBlockForZeroCmp +
243 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
244 return getNumLoads();
247 void MemCmpExpansion::createLoadCmpBlocks() {
248 for (
unsigned i = 0; i < getNumBlocks(); i++) {
251 LoadCmpBlocks.push_back(BB);
255 void MemCmpExpansion::createResultBlock() {
256 ResBlock.BB = BasicBlock::Create(CI->
getContext(),
"res_block",
264 uint64_t OffsetBytes) {
265 if (OffsetBytes > 0) {
266 auto *ByteType = Type::getInt8Ty(CI->
getContext());
268 ByteType, Builder.
CreateBitCast(Source, ByteType->getPointerTo()),
269 ConstantInt::get(ByteType, OffsetBytes));
278 void MemCmpExpansion::emitLoadCompareByteBlock(
unsigned BlockIndex,
279 unsigned OffsetBytes) {
283 getPtrToElementAtOffset(CI->
getArgOperand(0), LoadSizeType, OffsetBytes);
285 getPtrToElementAtOffset(CI->
getArgOperand(1), LoadSizeType, OffsetBytes);
294 PhiRes->
addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
296 if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
300 ConstantInt::get(Diff->
getType(), 0));
302 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
306 BranchInst *CmpBr = BranchInst::Create(EndBlock);
314 Value *MemCmpExpansion::getCompareLoadPairs(
unsigned BlockIndex,
315 unsigned &LoadIndex) {
316 assert(LoadIndex < getNumLoads() &&
317 "getCompareLoadPairs() called with no remaining loads");
318 std::vector<Value *> XorList, OrList;
321 const unsigned NumLoads =
322 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
325 if (LoadCmpBlocks.empty())
330 Value *Cmp =
nullptr;
335 NumLoads == 1 ? nullptr
336 : IntegerType::get(CI->
getContext(), MaxLoadSize * 8);
337 for (
unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
338 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
341 IntegerType::get(CI->
getContext(), CurLoadEntry.LoadSize * 8);
344 CurLoadEntry.Offset);
346 CurLoadEntry.Offset);
349 Value *LoadSrc1 =
nullptr;
350 if (
auto *Source1C = dyn_cast<Constant>(Source1))
353 LoadSrc1 = Builder.
CreateLoad(LoadSizeType, Source1);
355 Value *LoadSrc2 =
nullptr;
356 if (
auto *Source2C = dyn_cast<Constant>(Source2))
359 LoadSrc2 = Builder.
CreateLoad(LoadSizeType, Source2);
362 if (LoadSizeType != MaxLoadType) {
363 LoadSrc1 = Builder.
CreateZExt(LoadSrc1, MaxLoadType);
364 LoadSrc2 = Builder.
CreateZExt(LoadSrc2, MaxLoadType);
368 Diff = Builder.
CreateXor(LoadSrc1, LoadSrc2);
370 XorList.push_back(Diff);
377 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
378 std::vector<Value *> OutList;
379 for (
unsigned i = 0; i < InList.size() - 1; i = i + 2) {
381 OutList.push_back(Or);
383 if (InList.size() % 2 != 0)
384 OutList.push_back(InList.back());
390 OrList = pairWiseOr(XorList);
393 while (OrList.size() != 1) {
394 OrList = pairWiseOr(OrList);
402 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
403 unsigned &LoadIndex) {
404 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
406 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
408 : LoadCmpBlocks[BlockIndex + 1];
411 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
417 if (BlockIndex == LoadCmpBlocks.size() - 1) {
419 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
432 void MemCmpExpansion::emitLoadCompareBlock(
unsigned BlockIndex) {
434 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
436 if (CurLoadEntry.LoadSize == 1) {
437 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
442 IntegerType::get(CI->
getContext(), CurLoadEntry.LoadSize * 8);
443 Type *MaxLoadType = IntegerType::get(CI->
getContext(), MaxLoadSize * 8);
444 assert(CurLoadEntry.LoadSize <= MaxLoadSize &&
"Unexpected load type");
449 CurLoadEntry.Offset);
451 CurLoadEntry.Offset);
460 LoadSrc1 = Builder.
CreateCall(Bswap, LoadSrc1);
461 LoadSrc2 = Builder.
CreateCall(Bswap, LoadSrc2);
464 if (LoadSizeType != MaxLoadType) {
465 LoadSrc1 = Builder.
CreateZExt(LoadSrc1, MaxLoadType);
466 LoadSrc2 = Builder.
CreateZExt(LoadSrc2, MaxLoadType);
471 if (!IsUsedForZeroCmp) {
472 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
473 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
476 Value *Cmp = Builder.
CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
477 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
479 : LoadCmpBlocks[BlockIndex + 1];
482 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
488 if (BlockIndex == LoadCmpBlocks.size() - 1) {
490 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
497 void MemCmpExpansion::emitMemCmpResultBlock() {
500 if (IsUsedForZeroCmp) {
505 BranchInst *NewBr = BranchInst::Create(EndBlock);
519 BranchInst *NewBr = BranchInst::Create(EndBlock);
524 void MemCmpExpansion::setupResultBlockPHINodes() {
525 Type *MaxLoadType = IntegerType::get(CI->
getContext(), MaxLoadSize * 8);
529 Builder.
CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src1");
531 Builder.
CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src2");
534 void MemCmpExpansion::setupEndBlockPHINodes() {
539 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
540 unsigned LoadIndex = 0;
543 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
544 emitLoadCompareBlockMultipleLoads(
I, LoadIndex);
547 emitMemCmpResultBlock();
554 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
555 unsigned LoadIndex = 0;
556 Value *Cmp = getCompareLoadPairs(0, LoadIndex);
557 assert(LoadIndex == getNumLoads() &&
"some entries were not consumed");
563 Value *MemCmpExpansion::getMemCmpOneBlock() {
564 Type *LoadSizeType = IntegerType::get(CI->
getContext(), Size * 8);
569 if (Source1->getType() != LoadSizeType)
571 if (Source2->
getType() != LoadSizeType)
581 LoadSrc1 = Builder.
CreateCall(Bswap, LoadSrc1);
582 LoadSrc2 = Builder.
CreateCall(Bswap, LoadSrc2);
590 return Builder.
CreateSub(LoadSrc1, LoadSrc2);
603 return Builder.
CreateSub(ZextUGT, ZextULT);
608 Value *MemCmpExpansion::getMemCmpExpansion() {
610 if (getNumBlocks() != 1) {
613 setupEndBlockPHINodes();
620 if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
623 createLoadCmpBlocks();
632 if (IsUsedForZeroCmp)
633 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
634 : getMemCmpExpansionZeroCase();
636 if (getNumBlocks() == 1)
637 return getMemCmpOneBlock();
639 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
640 emitLoadCompareBlock(
I);
643 emitMemCmpResultBlock();
731 NumMemCmpNotConstant++;
743 if (!Options)
return false;
745 const unsigned MaxNumLoads =
752 MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
753 IsUsedForZeroCmp, NumLoadsPerBlock, *DL);
756 if (Expansion.getNumLoads() == 0) {
757 NumMemCmpGreaterThanMax++;
763 Value *Res = Expansion.getMemCmpExpansion();
785 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
790 TPC->getTM<
TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
793 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
795 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
796 auto PA =
runImpl(F, TLI, TTI, TL);
797 return !PA.areAllPreserved();
816 bool ExpandMemCmpPass::runOnBlock(
827 Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
839 bool MadeChanges =
false;
840 for (
auto BBIt = F.
begin(); BBIt != F.
end();) {
841 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
857 "Expand memcmp() to load/stores",
false,
false)
864 return new ExpandMemCmpPass();
const T & front() const
front - Get the first element.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A parsed version of the target data layout string in and methods for querying it. ...
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT)
This is the entry point for all transforms.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
bool isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI)
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
This class represents lattice values for constants.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateXor(Value *LHS, Value *RHS, const Twine &Name="")
INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", "Expand memcmp() to load/stores", false, false) INITIALIZE_PASS_END(ExpandMemCmpPass
This class represents a function call, abstracting a target machine's calling convention.
FunctionPass * createExpandMemCmpPass()
LLVMContext & getContext() const
All values hold a context through their type.
STATISTIC(NumFunctions, "Total number of functions")
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Value * getArgOperand(unsigned i) const
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
void initializeExpandMemCmpPassPass(PassRegistry &)
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Type * getType() const
All values are typed, get the type of this value.
Value * CreateICmpUGT(Value *LHS, Value *RHS, const Twine &Name="")
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
bool isLittleEndian() const
Layout endianness...
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="")
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block...
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="")
initializer< Ty > init(const Ty &Val)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
A set of analyses that are preserved following a run of a transformation pass.
LLVM Basic Block Representation.
The instances of the Type class are immutable: once they are created, they are never changed...
Conditional or Unconditional Branch instruction.
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
const Instruction & front() const
Represent the analysis usage information of a pass.
bool optForSize() const
Optimize this function for size (-Os) or minimum size (-Oz).
FunctionPass class - This class is used to implement most global optimizations.
Class to represent integer types.
Expand memcmp() to load/stores"
const Function * getFunction() const
Return the function this instruction belongs to.
Constant * ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, const DataLayout &DL)
ConstantFoldLoadFromConstPtr - Return the value that a load from C would produce if it is constant an...
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateGEP(Value *Ptr, ArrayRef< Value *> IdxList, const Twine &Name="")
Iterator for intrusive lists based on ilist_node.
This is the shared class of boolean and integer constants.
static cl::opt< unsigned > MemCmpEqZeroNumLoadsPerBlock("memcmp-num-loads-per-block", cl::Hidden, cl::init(1), cl::desc("The number of loads per basic block for inline expansion of " "memcmp that is only being compared against zero."))
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Provides information about what library functions are available for the current target.
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
unsigned getMaxExpandSizeMemcmp(bool OptSize) const
Get maximum # of load operations permitted for memcmp.
bool getLibFunc(StringRef funcName, LibFunc &F) const
Searches for a particular function name.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
ArrayRef< T > drop_front(size_t N=1) const
Drop the first N elements of the array.
Establish a view to a call site for examination.
const Function * getParent() const
Return the enclosing method, or null if none.
bool optForMinSize() const
Optimize this function for minimum size (-Oz).
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value *> Args=None, const Twine &Name="", MDNode *FPMathTag=nullptr)
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
virtual unsigned getMemcmpEqZeroLoadsPerBlock() const
For memcmp expansion when the memcmp result is only compared equal or not-equal to 0...
Module * getParent()
Get the module that this global value is contained inside of...
LLVM Value Representation.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass...
Primary interface to the complete machine description for the target machine.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
bool empty() const
empty - Check if the array is empty.
This file describes how to lower LLVM code to machine code.
const BasicBlock * getParent() const