Skip to content

Commit d5bc359

Browse files
committed
X86: add patterns for X86ISD::VSHLV and X86ISD::VSRLV
Replace VSELECT instruction which zeroes their result on exceeding legal SHL/SRL shift amount.
1 parent bed7001 commit d5bc359

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

‎lib/Target/X86/X86ISelLowering.cpp‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42254,6 +42254,38 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4225442254
}
4225542255
}
4225642256

42257+
// Detect pattern for AVX2+ variable shifts (shl, lshr) for inf precision.
42258+
if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC &&
42259+
SupportedVectorVarShift(VT.getSimpleVT(), Subtarget, ISD::SHL)) {
42260+
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
42261+
42262+
// Check if one of the arms of the VSELECT is a zero vector. If it's on the
42263+
// left side invert the predicate to simplify logic below.
42264+
SDValue Other;
42265+
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
42266+
Other = RHS;
42267+
CC = ISD::getSetCCInverse(CC, VT.getVectorElementType());
42268+
} else if (ISD::isBuildVectorAllZeros(RHS.getNode())) {
42269+
Other = LHS;
42270+
}
42271+
42272+
// Look for the following patterns (>> becomes vsrlv):
42273+
// y < 32 ? x << y : 0 --> vshlv(x, y)
42274+
// y <= 31 ? x << y : 0 --> vshlv(x, y)
42275+
APInt CondRHS;
42276+
if (Other && Other.getNumOperands() == 2 &&
42277+
DAG.isEqualTo(Other.getOperand(1), Cond.getOperand(0)) &&
42278+
(Other.getOpcode() == ISD::SHL || Other.getOpcode() == ISD::SRL) &&
42279+
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), CondRHS)) {
42280+
42281+
// Replace ISD::SHL or ISD::SHR with appropriate AVX2 vector-vector shift.
42282+
unsigned op = Other.getOpcode() == ISD::SHL ? X86ISD::VSHLV : X86ISD::VSRLV;
42283+
if ((CC == ISD::SETULT && CondRHS == VT.getScalarSizeInBits()) ||
42284+
(CC == ISD::SETULE && CondRHS == VT.getScalarSizeInBits() - 1))
42285+
return DAG.getNode(op, DL, VT, Other.getOperand(0), Other.getOperand(1));
42286+
}
42287+
}
42288+
4225742289
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget))
4225842290
return V;
4225942291

0 commit comments

Comments
 (0)