diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp | 79 |
1 files changed, 52 insertions, 27 deletions
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index b6713730bfa9..a82be5b973cf 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -256,6 +256,12 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) +static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) { + assert(Ty->isPtrOrPtrVectorTy()); + PointerType *NPT = PointerType::get(Ty->getContext(), NewAddrSpace); + return Ty->getWithNewType(NPT); +} + // Check whether that's no-op pointer bicast using a pair of // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over // different address spaces. @@ -301,14 +307,14 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL, switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op->getType()->isPointerTy()); + assert(Op->getType()->isPtrOrPtrVectorTy()); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op->getType()->isPointerTy(); + return Op->getType()->isPtrOrPtrVectorTy(); case Instruction::Call: { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V); return II && II->getIntrinsicID() == Intrinsic::ptrmask; @@ -373,6 +379,24 @@ bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II, case Intrinsic::ptrmask: // This is handled as an address expression, not as a use memory operation. return false; + case Intrinsic::masked_gather: { + Type *RetTy = II->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy}); + II->setArgOperand(0, NewV); + II->setCalledFunction(NewDecl); + return true; + } + case Intrinsic::masked_scatter: { + Type *ValueTy = II->getOperand(0)->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy}); + II->setArgOperand(1, NewV); + II->setCalledFunction(NewDecl); + return true; + } default: { Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); if (!Rewrite) @@ -394,6 +418,14 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands( appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); break; + case Intrinsic::masked_gather: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), + PostorderStack, Visited); + break; + case Intrinsic::masked_scatter: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1), + PostorderStack, Visited); + break; default: SmallVector<int, 2> OpIndexes; if (TTI->collectFlatAddressOperands(OpIndexes, IID)) { @@ -412,7 +444,7 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands( void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const { - assert(V->getType()->isPointerTy()); + assert(V->getType()->isPtrOrPtrVectorTy()); // Generic addressing expressions may be hidden in nested constant // expressions. @@ -460,8 +492,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { // addressing calculations may also be faster. for (Instruction &I : instructions(F)) { if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - if (!GEP->getType()->isVectorTy()) - PushPtrOperand(GEP->getPointerOperand()); + PushPtrOperand(GEP->getPointerOperand()); } else if (auto *LI = dyn_cast<LoadInst>(&I)) PushPtrOperand(LI->getPointerOperand()); else if (auto *SI = dyn_cast<StoreInst>(&I)) @@ -480,14 +511,12 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { } else if (auto *II = dyn_cast<IntrinsicInst>(&I)) collectRewritableIntrinsicOperands(II, PostorderStack, Visited); else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) { - // FIXME: Handle vectors of pointers - if (Cmp->getOperand(0)->getType()->isPointerTy()) { + if (Cmp->getOperand(0)->getType()->isPtrOrPtrVectorTy()) { PushPtrOperand(Cmp->getOperand(0)); PushPtrOperand(Cmp->getOperand(1)); } } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) { - if (!ASC->getType()->isVectorTy()) - PushPtrOperand(ASC->getPointerOperand()); + PushPtrOperand(ASC->getPointerOperand()); } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) { if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI)) PushPtrOperand( @@ -529,8 +558,7 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( SmallVectorImpl<const Use *> *UndefUsesToFix) { Value *Operand = OperandUse.get(); - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(Operand->getType()), NewAddrSpace); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace); if (Constant *C = dyn_cast<Constant>(Operand)) return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); @@ -543,8 +571,7 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( if (I != PredicatedAS.end()) { // Insert an addrspacecast on that operand before the user. unsigned NewAS = I->second; - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(Operand->getType()), NewAS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); NewI->insertBefore(Inst); NewI->setDebugLoc(Inst->getDebugLoc()); @@ -572,8 +599,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl<const Use *> *UndefUsesToFix) const { - Type *NewPtrType = PointerType::getWithSamePointeeType( - cast<PointerType>(I->getType()), NewAddrSpace); + Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace); if (I->getOpcode() == Instruction::AddrSpaceCast) { Value *Src = I->getOperand(0); @@ -607,8 +633,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( if (AS != UninitializedAddressSpace) { // For the assumed address space, insert an `addrspacecast` to make that // explicit. - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(I->getType()), AS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS); auto *NewI = new AddrSpaceCastInst(I, NewPtrTy); NewI->insertAfter(I); return NewI; @@ -617,7 +642,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( // Computes the converted pointer operands. SmallVector<Value *, 4> NewPointerOperands; for (const Use &OperandUse : I->operands()) { - if (!OperandUse.get()->getType()->isPointerTy()) + if (!OperandUse.get()->getType()->isPtrOrPtrVectorTy()) NewPointerOperands.push_back(nullptr); else NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( @@ -629,7 +654,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( case Instruction::BitCast: return new BitCastInst(NewPointerOperands[0], NewPtrType); case Instruction::PHI: { - assert(I->getType()->isPointerTy()); + assert(I->getType()->isPtrOrPtrVectorTy()); PHINode *PHI = cast<PHINode>(I); PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { @@ -648,7 +673,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( return NewGEP; } case Instruction::Select: - assert(I->getType()->isPointerTy()); + assert(I->getType()->isPtrOrPtrVectorTy()); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); case Instruction::IntToPtr: { @@ -674,10 +699,10 @@ static Value *cloneConstantExprWithNewAddressSpace( ConstantExpr *CE, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, const TargetTransformInfo *TTI) { - Type *TargetType = CE->getType()->isPointerTy() - ? PointerType::getWithSamePointeeType( - cast<PointerType>(CE->getType()), NewAddrSpace) - : CE->getType(); + Type *TargetType = + CE->getType()->isPtrOrPtrVectorTy() + ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace) + : CE->getType(); if (CE->getOpcode() == Instruction::AddrSpaceCast) { // Because CE is flat, the source address space must be specific. @@ -1226,9 +1251,9 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { unsigned NewAS = NewV->getType()->getPointerAddressSpace(); if (ASC->getDestAddressSpace() == NewAS) { - if (!cast<PointerType>(ASC->getType()) - ->hasSameElementTypeAs( - cast<PointerType>(NewV->getType()))) { + if (!cast<PointerType>(ASC->getType()->getScalarType()) + ->hasSameElementTypeAs( + cast<PointerType>(NewV->getType()->getScalarType()))) { BasicBlock::iterator InsertPos; if (Instruction *NewVInst = dyn_cast<Instruction>(NewV)) InsertPos = std::next(NewVInst->getIterator()); |