@@ -42254,6 +42254,38 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4225442254 }
4225542255 }
4225642256
42257+ // Detect pattern for AVX2+ variable shifts (shl, lshr) for inf precision.
42258+ if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC &&
42259+ SupportedVectorVarShift(VT.getSimpleVT(), Subtarget, ISD::SHL)) {
42260+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
42261+
42262+ // Check if one of the arms of the VSELECT is a zero vector. If it's on the
42263+ // left side invert the predicate to simplify logic below.
42264+ SDValue Other;
42265+ if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
42266+ Other = RHS;
42267+ CC = ISD::getSetCCInverse(CC, VT.getVectorElementType());
42268+ } else if (ISD::isBuildVectorAllZeros(RHS.getNode())) {
42269+ Other = LHS;
42270+ }
42271+
42272+ // Look for the following patterns (>> becomes vsrlv):
42273+ // y < 32 ? x << y : 0 --> vshlv(x, y)
42274+ // y <= 31 ? x << y : 0 --> vshlv(x, y)
42275+ APInt CondRHS;
42276+ if (Other && Other.getNumOperands() == 2 &&
42277+ DAG.isEqualTo(Other.getOperand(1), Cond.getOperand(0)) &&
42278+ (Other.getOpcode() == ISD::SHL || Other.getOpcode() == ISD::SRL) &&
42279+ ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), CondRHS)) {
42280+
42281+ // Replace ISD::SHL or ISD::SHR with appropriate AVX2 vector-vector shift.
42282+ unsigned op = Other.getOpcode() == ISD::SHL ? X86ISD::VSHLV : X86ISD::VSRLV;
42283+ if ((CC == ISD::SETULT && CondRHS == VT.getScalarSizeInBits()) ||
42284+ (CC == ISD::SETULE && CondRHS == VT.getScalarSizeInBits() - 1))
42285+ return DAG.getNode(op, DL, VT, Other.getOperand(0), Other.getOperand(1));
42286+ }
42287+ }
42288+
4225742289 if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget))
4225842290 return V;
4225942291
0 commit comments