LLVM  8.0.1
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file This pass does attempts to make use of reqd_work_group_size metadata
11 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
12 /// get_local_size-like functions.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "AMDGPU.h"
17 #include "AMDGPUTargetMachine.h"
19 #include "llvm/CodeGen/Passes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Pass.h"
26 
27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28 
29 using namespace llvm;
30 
31 namespace {
32 
33 // Field offsets in hsa_kernel_dispatch_packet_t.
35  WORKGROUP_SIZE_X = 4,
36  WORKGROUP_SIZE_Y = 6,
37  WORKGROUP_SIZE_Z = 8,
38 
39  GRID_SIZE_X = 12,
40  GRID_SIZE_Y = 16,
41  GRID_SIZE_Z = 20
42 };
43 
44 class AMDGPULowerKernelAttributes : public ModulePass {
45  Module *Mod = nullptr;
46 
47 public:
48  static char ID;
49 
50  AMDGPULowerKernelAttributes() : ModulePass(ID) {}
51 
52  bool processUse(CallInst *CI);
53 
54  bool doInitialization(Module &M) override;
55  bool runOnModule(Module &M) override;
56 
57  StringRef getPassName() const override {
58  return "AMDGPU Kernel Attributes";
59  }
60 
61  void getAnalysisUsage(AnalysisUsage &AU) const override {
62  AU.setPreservesAll();
63  }
64 };
65 
66 } // end anonymous namespace
67 
68 bool AMDGPULowerKernelAttributes::doInitialization(Module &M) {
69  Mod = &M;
70  return false;
71 }
72 
73 bool AMDGPULowerKernelAttributes::processUse(CallInst *CI) {
74  Function *F = CI->getParent()->getParent();
75 
76  auto MD = F->getMetadata("reqd_work_group_size");
77  const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
78 
79  const bool HasUniformWorkGroupSize =
80  F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true";
81 
82  if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
83  return false;
84 
85  Value *WorkGroupSizeX = nullptr;
86  Value *WorkGroupSizeY = nullptr;
87  Value *WorkGroupSizeZ = nullptr;
88 
89  Value *GridSizeX = nullptr;
90  Value *GridSizeY = nullptr;
91  Value *GridSizeZ = nullptr;
92 
93  const DataLayout &DL = Mod->getDataLayout();
94 
95  // We expect to see several GEP users, casted to the appropriate type and
96  // loaded.
97  for (User *U : CI->users()) {
98  if (!U->hasOneUse())
99  continue;
100 
101  int64_t Offset = 0;
102  if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
103  continue;
104 
105  auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
106  if (!BCI || !BCI->hasOneUse())
107  continue;
108 
109  auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
110  if (!Load || !Load->isSimple())
111  continue;
112 
113  unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
114 
115  // TODO: Handle merged loads.
116  switch (Offset) {
117  case WORKGROUP_SIZE_X:
118  if (LoadSize == 2)
119  WorkGroupSizeX = Load;
120  break;
121  case WORKGROUP_SIZE_Y:
122  if (LoadSize == 2)
123  WorkGroupSizeY = Load;
124  break;
125  case WORKGROUP_SIZE_Z:
126  if (LoadSize == 2)
127  WorkGroupSizeZ = Load;
128  break;
129  case GRID_SIZE_X:
130  if (LoadSize == 4)
131  GridSizeX = Load;
132  break;
133  case GRID_SIZE_Y:
134  if (LoadSize == 4)
135  GridSizeY = Load;
136  break;
137  case GRID_SIZE_Z:
138  if (LoadSize == 4)
139  GridSizeZ = Load;
140  break;
141  default:
142  break;
143  }
144  }
145 
146  // Pattern match the code used to handle partial workgroup dispatches in the
147  // library implementation of get_local_size, so the entire function can be
148  // constant folded with a known group size.
149  //
150  // uint r = grid_size - group_id * group_size;
151  // get_local_size = (r < group_size) ? r : group_size;
152  //
153  // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
154  // the grid_size is required to be a multiple of group_size). In this case:
155  //
156  // grid_size - (group_id * group_size) < group_size
157  // ->
158  // grid_size < group_size + (group_id * group_size)
159  //
160  // (grid_size / group_size) < 1 + group_id
161  //
162  // grid_size / group_size is at least 1, so we can conclude the select
163  // condition is false (except for group_id == 0, where the select result is
164  // the same).
165 
166  bool MadeChange = false;
167  Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
168  Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
169 
170  for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
171  Value *GroupSize = WorkGroupSizes[I];
172  Value *GridSize = GridSizes[I];
173  if (!GroupSize || !GridSize)
174  continue;
175 
176  for (User *U : GroupSize->users()) {
177  auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
178  if (!ZextGroupSize)
179  continue;
180 
181  for (User *ZextUser : ZextGroupSize->users()) {
182  auto *SI = dyn_cast<SelectInst>(ZextUser);
183  if (!SI)
184  continue;
185 
186  using namespace llvm::PatternMatch;
187  auto GroupIDIntrin = I == 0 ?
188  m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
189  (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
190  m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
191 
192  auto SubExpr = m_Sub(m_Specific(GridSize),
193  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
194 
195  ICmpInst::Predicate Pred;
196  if (match(SI,
197  m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
198  SubExpr,
199  m_Specific(ZextGroupSize))) &&
200  Pred == ICmpInst::ICMP_ULT) {
201  if (HasReqdWorkGroupSize) {
202  ConstantInt *KnownSize
203  = mdconst::extract<ConstantInt>(MD->getOperand(I));
204  SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
205  SI->getType(),
206  false));
207  } else {
208  SI->replaceAllUsesWith(ZextGroupSize);
209  }
210 
211  MadeChange = true;
212  }
213  }
214  }
215  }
216 
217  if (!HasReqdWorkGroupSize)
218  return MadeChange;
219 
220  // Eliminate any other loads we can from the dispatch packet.
221  for (int I = 0; I < 3; ++I) {
222  Value *GroupSize = WorkGroupSizes[I];
223  if (!GroupSize)
224  continue;
225 
226  ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
227  GroupSize->replaceAllUsesWith(
229  GroupSize->getType(),
230  false));
231  MadeChange = true;
232  }
233 
234  return MadeChange;
235 }
236 
237 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
238 // TargetPassConfig for subtarget.
239 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
240  StringRef DispatchPtrName
242 
243  Function *DispatchPtr = Mod->getFunction(DispatchPtrName);
244  if (!DispatchPtr) // Dispatch ptr not used.
245  return false;
246 
247  bool MadeChange = false;
248 
249  SmallPtrSet<Instruction *, 4> HandledUses;
250  for (auto *U : DispatchPtr->users()) {
251  CallInst *CI = cast<CallInst>(U);
252  if (HandledUses.insert(CI).second) {
253  if (processUse(CI))
254  MadeChange = true;
255  }
256  }
257 
258  return MadeChange;
259 }
260 
261 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
262  "AMDGPU IR optimizations", false, false)
263 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations",
264  false, false)
265 
266 char AMDGPULowerKernelAttributes::ID = 0;
267 
269  return new AMDGPULowerKernelAttributes();
270 }
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
Definition: PatternMatch.h:654
This class represents lattice values for constants.
Definition: AllocatorList.h:24
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
AMDGPU IR optimizations
This class represents zero extension of integer types.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:701
This class represents a function call, abstracting a target machine&#39;s calling convention.
unsigned less than
Definition: InstrTypes.h:671
ModulePass * createAMDGPULowerKernelAttributesPass()
F(f)
An instruction for reading from memory.
Definition: Instructions.h:168
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:48
static Constant * getIntegerCast(Constant *C, Type *Ty, bool isSigned)
Create a ZExt, Bitcast or Trunc for integer -> integer casts.
Definition: Constants.cpp:1613
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:626
This class represents the LLVM &#39;select&#39; instruction.
MDNode * getMetadata(unsigned KindID) const
Get the current metadata attachments for the given kind, if any.
Definition: Metadata.cpp:1444
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset...
#define DEBUG_TYPE
This class represents a no-op cast from one type to another.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:502
Represent the analysis usage information of a pass.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:646
The AMDGPU TargetMachine interface definition for hw codgen targets.
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
This is the shared class of boolean and integer constants.
Definition: Constants.h:84
The access may modify the value stored in memory.
void setPreservesAll()
Set by analyses that do not transform their input at all.
iterator_range< user_iterator > users()
Definition: Value.h:400
StringRef getValueAsString() const
Return the attribute&#39;s value as a string.
Definition: Attributes.cpp:195
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:107
#define I(x, y, z)
Definition: MD5.cpp:58
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
LLVM Value Representation.
Definition: Value.h:73
uint64_t getTypeStoreSize(Type *Ty) const
Returns the maximum number of bytes that may be overwritten by storing the specified type...
Definition: DataLayout.h:419
Attribute getFnAttribute(Attribute::AttrKind Kind) const
Return the attribute for the given attribute kind.
Definition: Function.h:331
INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations", false, false) INITIALIZE_PASS_END(AMDGPULowerKernelAttributes
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1075
Statically lint checks LLVM IR
Definition: Lint.cpp:193
const BasicBlock * getParent() const
Definition: Instruction.h:67
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)