LLVM  8.0.1
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This pass replaces masked memory intrinsics - when unsupported by the target
12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
13 // appropriate mask bit is set.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Twine.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 
42 namespace {
43 
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45  const TargetTransformInfo *TTI = nullptr;
46 
47 public:
48  static char ID; // Pass identification, replacement for typeid
49 
50  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
52  }
53 
54  bool runOnFunction(Function &F) override;
55 
56  StringRef getPassName() const override {
57  return "Scalarize Masked Memory Intrinsics";
58  }
59 
60  void getAnalysisUsage(AnalysisUsage &AU) const override {
62  }
63 
64 private:
65  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 };
68 
69 } // end anonymous namespace
70 
72 
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74  "Scalarize unsupported masked memory intrinsics", false, false)
75 
77  return new ScalarizeMaskedMemIntrin();
78 }
79 
82  if (!C)
83  return false;
84 
85  unsigned NumElts = Mask->getType()->getVectorNumElements();
86  for (unsigned i = 0; i != NumElts; ++i) {
87  Constant *CElt = C->getAggregateElement(i);
88  if (!CElt || !isa<ConstantInt>(CElt))
89  return false;
90  }
91 
92  return true;
93 }
94 
95 // Translate a masked load intrinsic like
96 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97 // <16 x i1> %mask, <16 x i32> %passthru)
98 // to a chain of basic blocks, with loading element one-by-one if
99 // the appropriate mask bit is set
100 //
101 // %1 = bitcast i8* %addr to i32*
102 // %2 = extractelement <16 x i1> %mask, i32 0
103 // br i1 %2, label %cond.load, label %else
104 //
105 // cond.load: ; preds = %0
106 // %3 = getelementptr i32* %1, i32 0
107 // %4 = load i32* %3
108 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
109 // br label %else
110 //
111 // else: ; preds = %0, %cond.load
112 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113 // %6 = extractelement <16 x i1> %mask, i32 1
114 // br i1 %6, label %cond.load1, label %else2
115 //
116 // cond.load1: ; preds = %else
117 // %7 = getelementptr i32* %1, i32 1
118 // %8 = load i32* %7
119 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
120 // br label %else2
121 //
122 // else2: ; preds = %else, %cond.load1
123 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124 // %10 = extractelement <16 x i1> %mask, i32 2
125 // br i1 %10, label %cond.load4, label %else5
126 //
127 static void scalarizeMaskedLoad(CallInst *CI) {
128  Value *Ptr = CI->getArgOperand(0);
129  Value *Alignment = CI->getArgOperand(1);
130  Value *Mask = CI->getArgOperand(2);
131  Value *Src0 = CI->getArgOperand(3);
132 
133  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134  VectorType *VecType = cast<VectorType>(CI->getType());
135 
136  Type *EltTy = VecType->getElementType();
137 
138  IRBuilder<> Builder(CI->getContext());
139  Instruction *InsertPt = CI;
140  BasicBlock *IfBlock = CI->getParent();
141 
142  Builder.SetInsertPoint(InsertPt);
143  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144 
145  // Short-cut if the mask is all-true.
146  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
147  Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
148  CI->replaceAllUsesWith(NewI);
149  CI->eraseFromParent();
150  return;
151  }
152 
153  // Adjust alignment for the scalar instruction.
154  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
155  // Bitcast %addr fron i8* to EltTy*
156  Type *NewPtrType =
157  EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
158  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
159  unsigned VectorWidth = VecType->getNumElements();
160 
161  // The result vector
162  Value *VResult = Src0;
163 
164  if (isConstantIntVector(Mask)) {
165  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
167  continue;
168  Value *Gep =
169  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
170  LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
171  VResult =
172  Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
173  }
174  CI->replaceAllUsesWith(VResult);
175  CI->eraseFromParent();
176  return;
177  }
178 
179  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
180  // Fill the "else" block, created in the previous iteration
181  //
182  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184  // br i1 %mask_1, label %cond.load, label %else
185  //
186 
187  Value *Predicate =
188  Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
189 
190  // Create "cond" block
191  //
192  // %EltAddr = getelementptr i32* %1, i32 0
193  // %Elt = load i32* %EltAddr
194  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
195  //
196  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
197  "cond.load");
198  Builder.SetInsertPoint(InsertPt);
199 
200  Value *Gep =
201  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
202  LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
203  Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
204  Builder.getInt32(Idx));
205 
206  // Create "else" block, fill it in the next iteration
207  BasicBlock *NewIfBlock =
208  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
209  Builder.SetInsertPoint(InsertPt);
210  Instruction *OldBr = IfBlock->getTerminator();
211  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
212  OldBr->eraseFromParent();
213  BasicBlock *PrevIfBlock = IfBlock;
214  IfBlock = NewIfBlock;
215 
216  // Create the phi to join the new and previous value.
217  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
218  Phi->addIncoming(NewVResult, CondBlock);
219  Phi->addIncoming(VResult, PrevIfBlock);
220  VResult = Phi;
221  }
222 
223  CI->replaceAllUsesWith(VResult);
224  CI->eraseFromParent();
225 }
226 
227 // Translate a masked store intrinsic, like
228 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
229 // <16 x i1> %mask)
230 // to a chain of basic blocks, that stores element one-by-one if
231 // the appropriate mask bit is set
232 //
233 // %1 = bitcast i8* %addr to i32*
234 // %2 = extractelement <16 x i1> %mask, i32 0
235 // br i1 %2, label %cond.store, label %else
236 //
237 // cond.store: ; preds = %0
238 // %3 = extractelement <16 x i32> %val, i32 0
239 // %4 = getelementptr i32* %1, i32 0
240 // store i32 %3, i32* %4
241 // br label %else
242 //
243 // else: ; preds = %0, %cond.store
244 // %5 = extractelement <16 x i1> %mask, i32 1
245 // br i1 %5, label %cond.store1, label %else2
246 //
247 // cond.store1: ; preds = %else
248 // %6 = extractelement <16 x i32> %val, i32 1
249 // %7 = getelementptr i32* %1, i32 1
250 // store i32 %6, i32* %7
251 // br label %else2
252 // . . .
253 static void scalarizeMaskedStore(CallInst *CI) {
254  Value *Src = CI->getArgOperand(0);
255  Value *Ptr = CI->getArgOperand(1);
256  Value *Alignment = CI->getArgOperand(2);
257  Value *Mask = CI->getArgOperand(3);
258 
259  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
260  VectorType *VecType = cast<VectorType>(Src->getType());
261 
262  Type *EltTy = VecType->getElementType();
263 
264  IRBuilder<> Builder(CI->getContext());
265  Instruction *InsertPt = CI;
266  BasicBlock *IfBlock = CI->getParent();
267  Builder.SetInsertPoint(InsertPt);
268  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
269 
270  // Short-cut if the mask is all-true.
271  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
272  Builder.CreateAlignedStore(Src, Ptr, AlignVal);
273  CI->eraseFromParent();
274  return;
275  }
276 
277  // Adjust alignment for the scalar instruction.
278  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
279  // Bitcast %addr fron i8* to EltTy*
280  Type *NewPtrType =
281  EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
282  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
283  unsigned VectorWidth = VecType->getNumElements();
284 
285  if (isConstantIntVector(Mask)) {
286  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
287  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
288  continue;
289  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
290  Value *Gep =
291  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
292  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
293  }
294  CI->eraseFromParent();
295  return;
296  }
297 
298  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299  // Fill the "else" block, created in the previous iteration
300  //
301  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
302  // br i1 %mask_1, label %cond.store, label %else
303  //
304  Value *Predicate =
305  Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
306 
307  // Create "cond" block
308  //
309  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
310  // %EltAddr = getelementptr i32* %1, i32 0
311  // %store i32 %OneElt, i32* %EltAddr
312  //
313  BasicBlock *CondBlock =
314  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
315  Builder.SetInsertPoint(InsertPt);
316 
317  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
318  Value *Gep =
319  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
320  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
321 
322  // Create "else" block, fill it in the next iteration
323  BasicBlock *NewIfBlock =
324  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
325  Builder.SetInsertPoint(InsertPt);
326  Instruction *OldBr = IfBlock->getTerminator();
327  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
328  OldBr->eraseFromParent();
329  IfBlock = NewIfBlock;
330  }
331  CI->eraseFromParent();
332 }
333 
334 // Translate a masked gather intrinsic like
335 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
336 // <16 x i1> %Mask, <16 x i32> %Src)
337 // to a chain of basic blocks, with loading element one-by-one if
338 // the appropriate mask bit is set
339 //
340 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
341 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
342 // br i1 %Mask0, label %cond.load, label %else
343 //
344 // cond.load:
345 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
346 // %Load0 = load i32, i32* %Ptr0, align 4
347 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
348 // br label %else
349 //
350 // else:
351 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
352 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
353 // br i1 %Mask1, label %cond.load1, label %else2
354 //
355 // cond.load1:
356 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
357 // %Load1 = load i32, i32* %Ptr1, align 4
358 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
359 // br label %else2
360 // . . .
361 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
362 // ret <16 x i32> %Result
363 static void scalarizeMaskedGather(CallInst *CI) {
364  Value *Ptrs = CI->getArgOperand(0);
365  Value *Alignment = CI->getArgOperand(1);
366  Value *Mask = CI->getArgOperand(2);
367  Value *Src0 = CI->getArgOperand(3);
368 
369  VectorType *VecType = cast<VectorType>(CI->getType());
370 
371  IRBuilder<> Builder(CI->getContext());
372  Instruction *InsertPt = CI;
373  BasicBlock *IfBlock = CI->getParent();
374  Builder.SetInsertPoint(InsertPt);
375  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
376 
377  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
378 
379  // The result vector
380  Value *VResult = Src0;
381  unsigned VectorWidth = VecType->getNumElements();
382 
383  // Shorten the way if the mask is a vector of constants.
384  if (isConstantIntVector(Mask)) {
385  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
386  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
387  continue;
388  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
389  "Ptr" + Twine(Idx));
390  LoadInst *Load =
391  Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
392  VResult = Builder.CreateInsertElement(
393  VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
394  }
395  CI->replaceAllUsesWith(VResult);
396  CI->eraseFromParent();
397  return;
398  }
399 
400  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
401  // Fill the "else" block, created in the previous iteration
402  //
403  // %Mask1 = extractelement <16 x i1> %Mask, i32 1
404  // br i1 %Mask1, label %cond.load, label %else
405  //
406 
407  Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
408  "Mask" + Twine(Idx));
409 
410  // Create "cond" block
411  //
412  // %EltAddr = getelementptr i32* %1, i32 0
413  // %Elt = load i32* %EltAddr
414  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
415  //
416  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
417  Builder.SetInsertPoint(InsertPt);
418 
419  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
420  "Ptr" + Twine(Idx));
421  LoadInst *Load =
422  Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
423  Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
424  Builder.getInt32(Idx),
425  "Res" + Twine(Idx));
426 
427  // Create "else" block, fill it in the next iteration
428  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
429  Builder.SetInsertPoint(InsertPt);
430  Instruction *OldBr = IfBlock->getTerminator();
431  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
432  OldBr->eraseFromParent();
433  BasicBlock *PrevIfBlock = IfBlock;
434  IfBlock = NewIfBlock;
435 
436  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
437  Phi->addIncoming(NewVResult, CondBlock);
438  Phi->addIncoming(VResult, PrevIfBlock);
439  VResult = Phi;
440  }
441 
442  CI->replaceAllUsesWith(VResult);
443  CI->eraseFromParent();
444 }
445 
446 // Translate a masked scatter intrinsic, like
447 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
448 // <16 x i1> %Mask)
449 // to a chain of basic blocks, that stores element one-by-one if
450 // the appropriate mask bit is set.
451 //
452 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
453 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
454 // br i1 %Mask0, label %cond.store, label %else
455 //
456 // cond.store:
457 // %Elt0 = extractelement <16 x i32> %Src, i32 0
458 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
459 // store i32 %Elt0, i32* %Ptr0, align 4
460 // br label %else
461 //
462 // else:
463 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
464 // br i1 %Mask1, label %cond.store1, label %else2
465 //
466 // cond.store1:
467 // %Elt1 = extractelement <16 x i32> %Src, i32 1
468 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
469 // store i32 %Elt1, i32* %Ptr1, align 4
470 // br label %else2
471 // . . .
473  Value *Src = CI->getArgOperand(0);
474  Value *Ptrs = CI->getArgOperand(1);
475  Value *Alignment = CI->getArgOperand(2);
476  Value *Mask = CI->getArgOperand(3);
477 
478  assert(isa<VectorType>(Src->getType()) &&
479  "Unexpected data type in masked scatter intrinsic");
480  assert(isa<VectorType>(Ptrs->getType()) &&
481  isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
482  "Vector of pointers is expected in masked scatter intrinsic");
483 
484  IRBuilder<> Builder(CI->getContext());
485  Instruction *InsertPt = CI;
486  BasicBlock *IfBlock = CI->getParent();
487  Builder.SetInsertPoint(InsertPt);
488  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
489 
490  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
491  unsigned VectorWidth = Src->getType()->getVectorNumElements();
492 
493  // Shorten the way if the mask is a vector of constants.
494  if (isConstantIntVector(Mask)) {
495  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
496  if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
497  continue;
498  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
499  "Elt" + Twine(Idx));
500  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
501  "Ptr" + Twine(Idx));
502  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
503  }
504  CI->eraseFromParent();
505  return;
506  }
507 
508  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
509  // Fill the "else" block, created in the previous iteration
510  //
511  // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
512  // br i1 %Mask1, label %cond.store, label %else
513  //
514  Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
515  "Mask" + Twine(Idx));
516 
517  // Create "cond" block
518  //
519  // %Elt1 = extractelement <16 x i32> %Src, i32 1
520  // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
521  // %store i32 %Elt1, i32* %Ptr1
522  //
523  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
524  Builder.SetInsertPoint(InsertPt);
525 
526  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
527  "Elt" + Twine(Idx));
528  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
529  "Ptr" + Twine(Idx));
530  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
531 
532  // Create "else" block, fill it in the next iteration
533  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
534  Builder.SetInsertPoint(InsertPt);
535  Instruction *OldBr = IfBlock->getTerminator();
536  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
537  OldBr->eraseFromParent();
538  IfBlock = NewIfBlock;
539  }
540  CI->eraseFromParent();
541 }
542 
544  bool EverMadeChange = false;
545 
546  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
547 
548  bool MadeChange = true;
549  while (MadeChange) {
550  MadeChange = false;
551  for (Function::iterator I = F.begin(); I != F.end();) {
552  BasicBlock *BB = &*I++;
553  bool ModifiedDTOnIteration = false;
554  MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
555 
556  // Restart BB iteration if the dominator tree of the Function was changed
557  if (ModifiedDTOnIteration)
558  break;
559  }
560 
561  EverMadeChange |= MadeChange;
562  }
563 
564  return EverMadeChange;
565 }
566 
567 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
568  bool MadeChange = false;
569 
570  BasicBlock::iterator CurInstIterator = BB.begin();
571  while (CurInstIterator != BB.end()) {
572  if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
573  MadeChange |= optimizeCallInst(CI, ModifiedDT);
574  if (ModifiedDT)
575  return true;
576  }
577 
578  return MadeChange;
579 }
580 
581 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
582  bool &ModifiedDT) {
584  if (II) {
585  switch (II->getIntrinsicID()) {
586  default:
587  break;
589  // Scalarize unsupported vector masked load
590  if (!TTI->isLegalMaskedLoad(CI->getType())) {
592  ModifiedDT = true;
593  return true;
594  }
595  return false;
597  if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
599  ModifiedDT = true;
600  return true;
601  }
602  return false;
604  if (!TTI->isLegalMaskedGather(CI->getType())) {
606  ModifiedDT = true;
607  return true;
608  }
609  return false;
611  if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
613  ModifiedDT = true;
614  return true;
615  }
616  return false;
617  }
618  }
619 
620  return false;
621 }
Type * getVectorElementType() const
Definition: Type.h:371
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:68
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static void scalarizeMaskedScatter(CallInst *CI)
This class represents lattice values for constants.
Definition: AllocatorList.h:24
iterator end()
Definition: Function.h:658
This class represents a function call, abstracting a target machine&#39;s calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:705
F(f)
An instruction for reading from memory.
Definition: Instructions.h:168
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:138
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:269
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1135
AnalysisUsage & addRequired()
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:652
static void scalarizeMaskedGather(CallInst *CI)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:743
uint64_t getNumElements() const
Definition: DerivedTypes.h:359
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
#define DEBUG_TYPE
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
iterator begin()
Definition: Function.h:656
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:335
constexpr uint64_t MinAlign(uint64_t A, uint64_t B)
A and B are either alignments or offsets.
Definition: MathExtras.h:610
static bool runOnFunction(Function &F, bool PostInlining)
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important base class in LLVM.
Definition: Constant.h:42
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:285
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:51
Iterator for intrusive lists based on ilist_node.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
iterator end()
Definition: BasicBlock.h:271
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
unsigned getVectorNumElements() const
Definition: DerivedTypes.h:462
Class to represent vector types.
Definition: DerivedTypes.h:393
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:311
static void scalarizeMaskedLoad(CallInst *CI)
#define I(x, y, z)
Definition: MD5.cpp:58
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:408
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:73
std::underlying_type< E >::type Mask()
Get a bitmask with 1s in all places up to the high-order bit of E&#39;s largest value.
Definition: BitmaskEnum.h:81
Type * getElementType() const
Definition: DerivedTypes.h:360
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
This pass exposes codegen information to IR-level passes.
static void scalarizeMaskedStore(CallInst *CI)
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:44
const BasicBlock * getParent() const
Definition: Instruction.h:67