LLVM  8.0.1
R600OpenCLImageTypeLoweringPass.cpp
Go to the documentation of this file.
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //
26 //===----------------------------------------------------------------------===//
27 
28 #include "AMDGPU.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/IR/Argument.h"
33 #include "llvm/IR/DerivedTypes.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Metadata.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Casting.h"
48 #include <cassert>
49 #include <cstddef>
50 #include <cstdint>
51 #include <tuple>
52 
53 using namespace llvm;
54 
55 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
56 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
57 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
59  "llvm.OpenCL.sampler.get.resource.id";
60 
61 static StringRef ImageSizeArgMDType = "__llvm_image_size";
62 static StringRef ImageFormatArgMDType = "__llvm_image_format";
63 
64 static StringRef KernelsMDNodeName = "opencl.kernels";
66  "kernel_arg_addr_space",
67  "kernel_arg_access_qual",
68  "kernel_arg_type",
69  "kernel_arg_base_type",
70  "kernel_arg_type_qual"};
71 static const unsigned NumKernelArgMDNodes = 5;
72 
73 namespace {
74 
75 using MDVector = SmallVector<Metadata *, 8>;
76 struct KernelArgMD {
77  MDVector ArgVector[NumKernelArgMDNodes];
78 };
79 
80 } // end anonymous namespace
81 
82 static inline bool
83 IsImageType(StringRef TypeString) {
84  return TypeString == "image2d_t" || TypeString == "image3d_t";
85 }
86 
87 static inline bool
88 IsSamplerType(StringRef TypeString) {
89  return TypeString == "sampler_t";
90 }
91 
92 static Function *
94  if (!Node)
95  return nullptr;
96 
97  size_t NumOps = Node->getNumOperands();
98  if (NumOps != NumKernelArgMDNodes + 1)
99  return nullptr;
100 
101  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
102  if (!F)
103  return nullptr;
104 
105  // Sanity checks.
106  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
107  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
108  MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
109  if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
110  return nullptr;
111  if (!ArgNode->getOperand(0))
112  return nullptr;
113 
114  // FIXME: It should be possible to do image lowering when some metadata
115  // args missing or not in the expected order.
116  MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
117  if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
118  return nullptr;
119  }
120 
121  return F;
122 }
123 
124 static StringRef
125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
126  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
127  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
128 }
129 
130 static StringRef
131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
132  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
133  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
134 }
135 
136 static MDVector
137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
138  MDVector Res;
139  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
140  MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
141  Res.push_back(Node->getOperand(OpIdx));
142  }
143  return Res;
144 }
145 
146 static void
147 PushArgMD(KernelArgMD &MD, const MDVector &V) {
148  assert(V.size() == NumKernelArgMDNodes);
149  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
150  MD.ArgVector[i].push_back(V[i]);
151  }
152 }
153 
154 namespace {
155 
156 class R600OpenCLImageTypeLoweringPass : public ModulePass {
157  static char ID;
158 
160  Type *Int32Type;
161  Type *ImageSizeType;
162  Type *ImageFormatType;
163  SmallVector<Instruction *, 4> InstsToErase;
164 
165  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
166  Argument &ImageSizeArg,
167  Argument &ImageFormatArg) {
168  bool Modified = false;
169 
170  for (auto &Use : ImageArg.uses()) {
171  auto Inst = dyn_cast<CallInst>(Use.getUser());
172  if (!Inst) {
173  continue;
174  }
175 
176  Function *F = Inst->getCalledFunction();
177  if (!F)
178  continue;
179 
180  Value *Replacement = nullptr;
181  StringRef Name = F->getName();
182  if (Name.startswith(GetImageResourceIDFunc)) {
183  Replacement = ConstantInt::get(Int32Type, ResourceID);
184  } else if (Name.startswith(GetImageSizeFunc)) {
185  Replacement = &ImageSizeArg;
186  } else if (Name.startswith(GetImageFormatFunc)) {
187  Replacement = &ImageFormatArg;
188  } else {
189  continue;
190  }
191 
192  Inst->replaceAllUsesWith(Replacement);
193  InstsToErase.push_back(Inst);
194  Modified = true;
195  }
196 
197  return Modified;
198  }
199 
200  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
201  bool Modified = false;
202 
203  for (const auto &Use : SamplerArg.uses()) {
204  auto Inst = dyn_cast<CallInst>(Use.getUser());
205  if (!Inst) {
206  continue;
207  }
208 
209  Function *F = Inst->getCalledFunction();
210  if (!F)
211  continue;
212 
213  Value *Replacement = nullptr;
214  StringRef Name = F->getName();
215  if (Name == GetSamplerResourceIDFunc) {
216  Replacement = ConstantInt::get(Int32Type, ResourceID);
217  } else {
218  continue;
219  }
220 
221  Inst->replaceAllUsesWith(Replacement);
222  InstsToErase.push_back(Inst);
223  Modified = true;
224  }
225 
226  return Modified;
227  }
228 
229  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
230  uint32_t NumReadOnlyImageArgs = 0;
231  uint32_t NumWriteOnlyImageArgs = 0;
232  uint32_t NumSamplerArgs = 0;
233 
234  bool Modified = false;
235  InstsToErase.clear();
236  for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
237  Argument &Arg = *ArgI;
238  StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
239 
240  // Handle image types.
241  if (IsImageType(Type)) {
242  StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
243  uint32_t ResourceID;
244  if (AccessQual == "read_only") {
245  ResourceID = NumReadOnlyImageArgs++;
246  } else if (AccessQual == "write_only") {
247  ResourceID = NumWriteOnlyImageArgs++;
248  } else {
249  llvm_unreachable("Wrong image access qualifier.");
250  }
251 
252  Argument &SizeArg = *(++ArgI);
253  Argument &FormatArg = *(++ArgI);
254  Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
255 
256  // Handle sampler type.
257  } else if (IsSamplerType(Type)) {
258  uint32_t ResourceID = NumSamplerArgs++;
259  Modified |= replaceSamplerUses(Arg, ResourceID);
260  }
261  }
262  for (unsigned i = 0; i < InstsToErase.size(); ++i) {
263  InstsToErase[i]->eraseFromParent();
264  }
265 
266  return Modified;
267  }
268 
269  std::tuple<Function *, MDNode *>
270  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
271  bool Modified = false;
272 
273  FunctionType *FT = F->getFunctionType();
274  SmallVector<Type *, 8> ArgTypes;
275 
276  // Metadata operands for new MDNode.
277  KernelArgMD NewArgMDs;
278  PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
279 
280  // Add implicit arguments to the signature.
281  for (unsigned i = 0; i < FT->getNumParams(); ++i) {
282  ArgTypes.push_back(FT->getParamType(i));
283  MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
284  PushArgMD(NewArgMDs, ArgMD);
285 
286  if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
287  continue;
288 
289  // Add size implicit argument.
290  ArgTypes.push_back(ImageSizeType);
291  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
292  PushArgMD(NewArgMDs, ArgMD);
293 
294  // Add format implicit argument.
295  ArgTypes.push_back(ImageFormatType);
296  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
297  PushArgMD(NewArgMDs, ArgMD);
298 
299  Modified = true;
300  }
301  if (!Modified) {
302  return std::make_tuple(nullptr, nullptr);
303  }
304 
305  // Create function with new signature and clone the old body into it.
306  auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
307  auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
308  ValueToValueMapTy VMap;
309  auto NewFArgIt = NewF->arg_begin();
310  for (auto &Arg: F->args()) {
311  auto ArgName = Arg.getName();
312  NewFArgIt->setName(ArgName);
313  VMap[&Arg] = &(*NewFArgIt++);
314  if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
315  (NewFArgIt++)->setName(Twine("__size_") + ArgName);
316  (NewFArgIt++)->setName(Twine("__format_") + ArgName);
317  }
318  }
320  CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
321 
322  // Build new MDNode.
323  SmallVector<Metadata *, 6> KernelMDArgs;
324  KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
325  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
326  KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
327  MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
328 
329  return std::make_tuple(NewF, NewMDNode);
330  }
331 
332  bool transformKernels(Module &M) {
333  NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
334  if (!KernelsMDNode)
335  return false;
336 
337  bool Modified = false;
338  for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
339  MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
340  Function *F = GetFunctionFromMDNode(KernelMDNode);
341  if (!F)
342  continue;
343 
344  Function *NewF;
345  MDNode *NewMDNode;
346  std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
347  if (NewF) {
348  // Replace old function and metadata with new ones.
349  F->eraseFromParent();
350  M.getFunctionList().push_back(NewF);
351  M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
352  NewF->getAttributes());
353  KernelsMDNode->setOperand(i, NewMDNode);
354 
355  F = NewF;
356  KernelMDNode = NewMDNode;
357  Modified = true;
358  }
359 
360  Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
361  }
362 
363  return Modified;
364  }
365 
366 public:
367  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
368 
369  bool runOnModule(Module &M) override {
370  Context = &M.getContext();
371  Int32Type = Type::getInt32Ty(M.getContext());
372  ImageSizeType = ArrayType::get(Int32Type, 3);
373  ImageFormatType = ArrayType::get(Int32Type, 2);
374 
375  return transformKernels(M);
376  }
377 
378  StringRef getPassName() const override {
379  return "R600 OpenCL Image Type Pass";
380  }
381 };
382 
383 } // end anonymous namespace
384 
386 
388  return new R600OpenCLImageTypeLoweringPass();
389 }
iterator_range< use_iterator > uses()
Definition: Value.h:355
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
LLVMContext & Context
MDNode * getOperand(unsigned i) const
Definition: Metadata.cpp:1081
static StringRef AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
This class represents lattice values for constants.
Definition: AllocatorList.h:24
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:135
Constant * getOrInsertFunction(StringRef Name, FunctionType *T, AttributeList AttributeList)
Look up the specified function in the module symbol table.
Definition: Module.cpp:144
static StringRef GetImageFormatFunc
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
ModulePass * createR600OpenCLImageTypeLoweringPass()
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:454
This class represents a function call, abstracting a target machine&#39;s calling convention.
This file contains the declarations for metadata subclasses.
arg_iterator arg_end()
Definition: Function.h:680
Metadata node.
Definition: Metadata.h:864
F(f)
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1069
This defines the Use class.
static StringRef ImageFormatArgMDType
void setOperand(unsigned I, MDNode *New)
Definition: Metadata.cpp:1089
static StringRef KernelsMDNodeName
static StringRef GetSamplerResourceIDFunc
A tuple of MDNodes.
Definition: Metadata.h:1326
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, bool ModuleLevelChanges, SmallVectorImpl< ReturnInst *> &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values...
amdgpu Simplify well known AMD library false Value Value const Twine & Name
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:244
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
static bool IsImageType(StringRef TypeString)
unsigned getNumOperands() const
Definition: Metadata.cpp:1077
LLVM_NODISCARD LLVM_ATTRIBUTE_ALWAYS_INLINE bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:267
User * getUser() const LLVM_READONLY
Returns the User that contains this Use.
Definition: Use.cpp:41
Class to represent function types.
Definition: DerivedTypes.h:103
NamedMDNode * getNamedMetadata(const Twine &Name) const
Return the first NamedMDNode in the module with the specified name.
Definition: Module.cpp:252
static StringRef ImageSizeArgMDType
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:224
LinkageTypes getLinkage() const
Definition: GlobalValue.h:451
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
static ConstantAsMetadata * get(Constant *C)
Definition: Metadata.h:410
StringRef getString() const
Definition: Metadata.cpp:464
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata *> MDs)
Definition: Metadata.h:1166
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
const FunctionListType & getFunctionList() const
Get the Module&#39;s list of functions (constant).
Definition: Module.h:530
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:69
static MDVector GetArgMD(MDNode *KernelMDNode, unsigned OpIdx)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:297
arg_iterator arg_begin()
Definition: Function.h:671
size_t size() const
Definition: SmallVector.h:53
static void PushArgMD(KernelArgMD &MD, const MDVector &V)
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:847
Module.h This file contains the declarations for the Module class.
Type * getReturnType() const
Definition: DerivedTypes.h:124
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:622
static StringRef KernelArgMDNodeNames[]
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:164
void push_back(pointer val)
Definition: ilist.h:313
static const unsigned NumKernelArgMDNodes
static StringRef GetImageResourceIDFunc
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Definition: Argument.h:48
amdgpu Simplify well known AMD library false Value Value * Arg
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:176
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:581
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
static Function * GetFunctionFromMDNode(MDNode *Node)
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:214
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:73
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
A single uniqued string.
Definition: Metadata.h:604
static bool IsSamplerType(StringRef TypeString)
static StringRef GetImageSizeFunc
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1075
static StringRef ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
iterator_range< arg_iterator > args()
Definition: Function.h:689