LLVM  8.0.1
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Fix bitcasted functions.
12 ///
13 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
14 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
15 /// bitcasts of functions and rewrite them to use wrapper functions instead.
16 ///
17 /// This doesn't catch all cases, such as when a function's address is taken in
18 /// one place and casted in another, but it works for many common cases.
19 ///
20 /// Note that LLVM already optimizes away function bitcasts in common cases by
21 /// dropping arguments as needed, so this pass only ends up getting used in less
22 /// common cases.
23 ///
24 //===----------------------------------------------------------------------===//
25 
26 #include "WebAssembly.h"
27 #include "llvm/IR/CallSite.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
38 
39 namespace {
40 class FixFunctionBitcasts final : public ModulePass {
41  StringRef getPassName() const override {
42  return "WebAssembly Fix Function Bitcasts";
43  }
44 
45  void getAnalysisUsage(AnalysisUsage &AU) const override {
46  AU.setPreservesCFG();
48  }
49 
50  bool runOnModule(Module &M) override;
51 
52 public:
53  static char ID;
54  FixFunctionBitcasts() : ModulePass(ID) {}
55 };
56 } // End anonymous namespace
57 
59 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
60  "Fix mismatching bitcasts for WebAssembly", false, false)
61 
63  return new FixFunctionBitcasts();
64 }
65 
66 // Recursively descend the def-use lists from V to find non-bitcast users of
67 // bitcasts of V.
68 static void FindUses(Value *V, Function &F,
69  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
70  SmallPtrSetImpl<Constant *> &ConstantBCs) {
71  for (Use &U : V->uses()) {
72  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
73  FindUses(BC, F, Uses, ConstantBCs);
74  else if (U.get()->getType() != F.getType()) {
75  CallSite CS(U.getUser());
76  if (!CS)
77  // Skip uses that aren't immediately called
78  continue;
79  Value *Callee = CS.getCalledValue();
80  if (Callee != V)
81  // Skip calls where the function isn't the callee
82  continue;
83  if (isa<Constant>(U.get())) {
84  // Only add constant bitcasts to the list once; they get RAUW'd
85  auto c = ConstantBCs.insert(cast<Constant>(U.get()));
86  if (!c.second)
87  continue;
88  }
89  Uses.push_back(std::make_pair(&U, &F));
90  }
91  }
92 }
93 
94 // Create a wrapper function with type Ty that calls F (which may have a
95 // different type). Attempt to support common bitcasted function idioms:
96 // - Call with more arguments than needed: arguments are dropped
97 // - Call with fewer arguments than needed: arguments are filled in with undef
98 // - Return value is not needed: drop it
99 // - Return value needed but not present: supply an undef
100 //
101 // If the all the argument types of trivially castable to one another (i.e.
102 // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
103 // instead).
104 //
105 // If there is a type mismatch that we know would result in an invalid wasm
106 // module then generate wrapper that contains unreachable (i.e. abort at
107 // runtime). Such programs are deep into undefined behaviour territory,
108 // but we choose to fail at runtime rather than generate and invalid module
109 // or fail at compiler time. The reason we delay the error is that we want
110 // to support the CMake which expects to be able to compile and link programs
111 // that refer to functions with entirely incorrect signatures (this is how
112 // CMake detects the existence of a function in a toolchain).
113 //
114 // For bitcasts that involve struct types we don't know at this stage if they
115 // would be equivalent at the wasm level and so we can't know if we need to
116 // generate a wrapper.
118  Module *M = F->getParent();
119 
121  F->getName() + "_bitcast", M);
122  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
123  const DataLayout &DL = BB->getModule()->getDataLayout();
124 
125  // Determine what arguments to pass.
127  Function::arg_iterator AI = Wrapper->arg_begin();
128  Function::arg_iterator AE = Wrapper->arg_end();
131  bool TypeMismatch = false;
132  bool WrapperNeeded = false;
133 
134  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
135  Type *RtnType = Ty->getReturnType();
136 
137  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
138  (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
139  (ExpectedRtnType != RtnType))
140  WrapperNeeded = true;
141 
142  for (; AI != AE && PI != PE; ++AI, ++PI) {
143  Type *ArgType = AI->getType();
144  Type *ParamType = *PI;
145 
146  if (ArgType == ParamType) {
147  Args.push_back(&*AI);
148  } else {
149  if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
150  Instruction *PtrCast =
151  CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
152  BB->getInstList().push_back(PtrCast);
153  Args.push_back(PtrCast);
154  } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
155  LLVM_DEBUG(dbgs() << "CreateWrapper: struct param type in bitcast: "
156  << F->getName() << "\n");
157  WrapperNeeded = false;
158  } else {
159  LLVM_DEBUG(dbgs() << "CreateWrapper: arg type mismatch calling: "
160  << F->getName() << "\n");
161  LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
162  << *ParamType << " Got: " << *ArgType << "\n");
163  TypeMismatch = true;
164  break;
165  }
166  }
167  }
168 
169  if (WrapperNeeded && !TypeMismatch) {
170  for (; PI != PE; ++PI)
171  Args.push_back(UndefValue::get(*PI));
172  if (F->isVarArg())
173  for (; AI != AE; ++AI)
174  Args.push_back(&*AI);
175 
176  CallInst *Call = CallInst::Create(F, Args, "", BB);
177 
178  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
179  Type *RtnType = Ty->getReturnType();
180  // Determine what value to return.
181  if (RtnType->isVoidTy()) {
182  ReturnInst::Create(M->getContext(), BB);
183  } else if (ExpectedRtnType->isVoidTy()) {
184  LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
185  ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
186  } else if (RtnType == ExpectedRtnType) {
187  ReturnInst::Create(M->getContext(), Call, BB);
188  } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
189  DL)) {
190  Instruction *Cast =
191  CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
192  BB->getInstList().push_back(Cast);
193  ReturnInst::Create(M->getContext(), Cast, BB);
194  } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
195  LLVM_DEBUG(dbgs() << "CreateWrapper: struct return type in bitcast: "
196  << F->getName() << "\n");
197  WrapperNeeded = false;
198  } else {
199  LLVM_DEBUG(dbgs() << "CreateWrapper: return type mismatch calling: "
200  << F->getName() << "\n");
201  LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
202  << " Got: " << *RtnType << "\n");
203  TypeMismatch = true;
204  }
205  }
206 
207  if (TypeMismatch) {
208  // Create a new wrapper that simply contains `unreachable`.
209  Wrapper->eraseFromParent();
211  F->getName() + "_bitcast_invalid", M);
212  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
213  new UnreachableInst(M->getContext(), BB);
214  Wrapper->setName(F->getName() + "_bitcast_invalid");
215  } else if (!WrapperNeeded) {
216  LLVM_DEBUG(dbgs() << "CreateWrapper: no wrapper needed: " << F->getName()
217  << "\n");
218  Wrapper->eraseFromParent();
219  return nullptr;
220  }
221  LLVM_DEBUG(dbgs() << "CreateWrapper: " << F->getName() << "\n");
222  return Wrapper;
223 }
224 
225 // Test whether a main function with type FuncTy should be rewritten to have
226 // type MainTy.
228  // Only fix the main function if it's the standard zero-arg form. That way,
229  // the standard cases will work as expected, and users will see signature
230  // mismatches from the linker for non-standard cases.
231  return FuncTy->getReturnType() == MainTy->getReturnType() &&
232  FuncTy->getNumParams() == 0 &&
233  !FuncTy->isVarArg();
234 }
235 
236 bool FixFunctionBitcasts::runOnModule(Module &M) {
237  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
238 
239  Function *Main = nullptr;
240  CallInst *CallMain = nullptr;
242  SmallPtrSet<Constant *, 2> ConstantBCs;
243 
244  // Collect all the places that need wrappers.
245  for (Function &F : M) {
246  FindUses(&F, F, Uses, ConstantBCs);
247 
248  // If we have a "main" function, and its type isn't
249  // "int main(int argc, char *argv[])", create an artificial call with it
250  // bitcasted to that type so that we generate a wrapper for it, so that
251  // the C runtime can call it.
252  if (F.getName() == "main") {
253  Main = &F;
254  LLVMContext &C = M.getContext();
255  Type *MainArgTys[] = {Type::getInt32Ty(C),
257  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
258  /*isVarArg=*/false);
259  if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
260  LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
261  << *F.getFunctionType() << "\n");
262  Value *Args[] = {UndefValue::get(MainArgTys[0]),
263  UndefValue::get(MainArgTys[1])};
264  Value *Casted =
265  ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
266  CallMain = CallInst::Create(Casted, Args, "call_main");
267  Use *UseMain = &CallMain->getOperandUse(2);
268  Uses.push_back(std::make_pair(UseMain, &F));
269  }
270  }
271  }
272 
274 
275  for (auto &UseFunc : Uses) {
276  Use *U = UseFunc.first;
277  Function *F = UseFunc.second;
278  PointerType *PTy = cast<PointerType>(U->get()->getType());
280 
281  // If the function is casted to something like i8* as a "generic pointer"
282  // to be later casted to something else, we can't generate a wrapper for it.
283  // Just ignore such casts for now.
284  if (!Ty)
285  continue;
286 
287  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
288  if (Pair.second)
289  Pair.first->second = CreateWrapper(F, Ty);
290 
291  Function *Wrapper = Pair.first->second;
292  if (!Wrapper)
293  continue;
294 
295  if (isa<Constant>(U->get()))
296  U->get()->replaceAllUsesWith(Wrapper);
297  else
298  U->set(Wrapper);
299  }
300 
301  // If we created a wrapper for main, rename the wrapper so that it's the
302  // one that gets called from startup.
303  if (CallMain) {
304  Main->setName("__original_main");
305  Function *MainWrapper =
306  cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
307  delete CallMain;
308  if (Main->isDeclaration()) {
309  // The wrapper is not needed in this case as we don't need to export
310  // it to anyone else.
311  MainWrapper->eraseFromParent();
312  } else {
313  // Otherwise give the wrapper the same linkage as the original main
314  // function, so that it can be called from the same places.
315  MainWrapper->setName("main");
316  MainWrapper->setLinkage(Main->getLinkage());
317  MainWrapper->setVisibility(Main->getVisibility());
318  }
319  }
320 
321  return true;
322 }
void setVisibility(VisibilityTypes V)
Definition: GlobalValue.h:239
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:177
uint64_t CallInst * C
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
iterator_range< use_iterator > uses()
Definition: Value.h:355
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
This class represents lattice values for constants.
Definition: AllocatorList.h:24
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
This class represents a function call, abstracting a target machine&#39;s calling convention.
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:630
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:57
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:705
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", Instruction *InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
const Use & getOperandUse(unsigned i) const
Definition: User.h:183
arg_iterator arg_end()
Definition: Function.h:680
F(f)
param_iterator param_end() const
Definition: DerivedTypes.h:129
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:221
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:244
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:42
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:285
Class to represent function types.
Definition: DerivedTypes.h:103
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:92
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
bool isVarArg() const
Definition: DerivedTypes.h:123
LinkageTypes getLinkage() const
Definition: GlobalValue.h:451
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
amdgpu Simplify well known AMD library false Value * Callee
Class to represent pointers.
Definition: DerivedTypes.h:467
static Function * CreateWrapper(Function *F, FunctionType *Ty)
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:1773
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:233
Value * getCalledValue() const
Definition: InstrTypes.h:1174
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:69
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
param_iterator param_begin() const
Definition: DerivedTypes.h:128
Represent the analysis usage information of a pass.
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:297
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:100
arg_iterator arg_begin()
Definition: Function.h:671
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1415
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs, and aliases.
Definition: Value.cpp:529
size_t size() const
Definition: SmallVector.h:53
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:220
ModulePass * createWebAssemblyFixFunctionBitcasts()
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:847
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
Type * getReturnType() const
Definition: DerivedTypes.h:124
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:286
void setLinkage(LinkageTypes LT)
Definition: GlobalValue.h:445
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:133
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:164
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:176
static void FindUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:214
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:206
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:566
LLVM Value Representation.
Definition: Value.h:73
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
#define LLVM_DEBUG(X)
Definition: Debug.h:123
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Type * getElementType() const
Definition: DerivedTypes.h:486
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:274
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:218