40 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 57 return "Scalarize Masked Memory Intrinsics";
65 bool optimizeBlock(
BasicBlock &BB,
bool &ModifiedDT);
66 bool optimizeCallInst(
CallInst *CI,
bool &ModifiedDT);
74 "Scalarize unsupported masked memory intrinsics",
false,
false)
77 return new ScalarizeMaskedMemIntrin();
86 for (
unsigned i = 0; i != NumElts; ++i) {
88 if (!CElt || !isa<ConstantInt>(CElt))
133 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
142 Builder.SetInsertPoint(InsertPt);
143 Builder.SetCurrentDebugLocation(CI->
getDebugLoc());
146 if (isa<Constant>(Mask) && cast<Constant>(
Mask)->isAllOnesValue()) {
147 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
154 AlignVal =
MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
158 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
162 Value *VResult = Src0;
165 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
169 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
170 LoadInst *
Load = Builder.CreateAlignedLoad(Gep, AlignVal);
172 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
179 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
188 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
198 Builder.SetInsertPoint(InsertPt);
201 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
202 LoadInst *
Load = Builder.CreateAlignedLoad(Gep, AlignVal);
203 Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
204 Builder.getInt32(Idx));
209 Builder.SetInsertPoint(InsertPt);
214 IfBlock = NewIfBlock;
217 PHINode *Phi = Builder.CreatePHI(VecType, 2,
"res.phi.else");
259 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
267 Builder.SetInsertPoint(InsertPt);
268 Builder.SetCurrentDebugLocation(CI->
getDebugLoc());
271 if (isa<Constant>(Mask) && cast<Constant>(
Mask)->isAllOnesValue()) {
272 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
278 AlignVal =
MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
282 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
286 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
287 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
289 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
291 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
292 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
298 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
305 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
315 Builder.SetInsertPoint(InsertPt);
317 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
319 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
320 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
325 Builder.SetInsertPoint(InsertPt);
329 IfBlock = NewIfBlock;
374 Builder.SetInsertPoint(InsertPt);
375 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
377 Builder.SetCurrentDebugLocation(CI->
getDebugLoc());
380 Value *VResult = Src0;
385 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
386 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
388 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
391 Builder.CreateAlignedLoad(Ptr, AlignVal,
"Load" +
Twine(Idx));
392 VResult = Builder.CreateInsertElement(
393 VResult, Load, Builder.getInt32(Idx),
"Res" +
Twine(Idx));
400 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
407 Value *
Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
408 "Mask" +
Twine(Idx));
417 Builder.SetInsertPoint(InsertPt);
419 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
422 Builder.CreateAlignedLoad(Ptr, AlignVal,
"Load" +
Twine(Idx));
423 Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
424 Builder.getInt32(Idx),
429 Builder.SetInsertPoint(InsertPt);
434 IfBlock = NewIfBlock;
436 PHINode *Phi = Builder.CreatePHI(VecType, 2,
"res.phi.else");
479 "Unexpected data type in masked scatter intrinsic");
482 "Vector of pointers is expected in masked scatter intrinsic");
487 Builder.SetInsertPoint(InsertPt);
488 Builder.SetCurrentDebugLocation(CI->
getDebugLoc());
490 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
495 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
496 if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
498 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
500 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
502 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
508 for (
unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
514 Value *
Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
515 "Mask" +
Twine(Idx));
524 Builder.SetInsertPoint(InsertPt);
526 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
528 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
530 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
534 Builder.SetInsertPoint(InsertPt);
538 IfBlock = NewIfBlock;
544 bool EverMadeChange =
false;
546 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
548 bool MadeChange =
true;
553 bool ModifiedDTOnIteration =
false;
554 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
557 if (ModifiedDTOnIteration)
561 EverMadeChange |= MadeChange;
564 return EverMadeChange;
567 bool ScalarizeMaskedMemIntrin::optimizeBlock(
BasicBlock &BB,
bool &ModifiedDT) {
568 bool MadeChange =
false;
571 while (CurInstIterator != BB.
end()) {
572 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
573 MadeChange |= optimizeCallInst(CI, ModifiedDT);
581 bool ScalarizeMaskedMemIntrin::optimizeCallInst(
CallInst *CI,
590 if (!TTI->isLegalMaskedLoad(CI->
getType())) {
604 if (!TTI->isLegalMaskedGather(CI->
getType())) {
Type * getVectorElementType() const
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static void scalarizeMaskedScatter(CallInst *CI)
This class represents lattice values for constants.
This class represents a function call, abstracting a target machine's calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
An instruction for reading from memory.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
AnalysisUsage & addRequired()
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
static void scalarizeMaskedGather(CallInst *CI)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
uint64_t getNumElements() const
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
constexpr uint64_t MinAlign(uint64_t A, uint64_t B)
A and B are either alignments or offsets.
static bool runOnFunction(Function &F, bool PostInlining)
LLVM Basic Block Representation.
The instances of the Type class are immutable: once they are created, they are never changed...
This is an important base class in LLVM.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Iterator for intrusive lists based on ilist_node.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
unsigned getVectorNumElements() const
Class to represent vector types.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
static void scalarizeMaskedLoad(CallInst *CI)
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
std::underlying_type< E >::type Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Type * getElementType() const
StringRef - Represent a constant reference to a string, i.e.
static void scalarizeMaskedStore(CallInst *CI)
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
const BasicBlock * getParent() const