Skip to content

Commit af3c3ec

Browse files
authored
[AArch64] recognise trn1/trn2 with flipped operands (#169858)
This PR is very similar to #167235, but applied to `trn` rather than `zip`. There are two further differences: - The `@combine_v8i16_8first` and `@combine_v8i16_8firstundef` test cases in `arm64-zip.ll` didn't have equivalents in `arm64-trn.ll`, so this PR adds new test cases `@vtrni8_8first`, `@vtrni8_9first`, `@vtrni8_89first_undef`. - `AArch64TTIImpl::getShuffleCost` calls `isZIPMask`, but not `isTRNMask`. It relies on `Kind == TTI::SK_Transpose` instead (which in turn is based on `ShuffleVectorInst::isTransposeMask` through `improveShuffleKindFromMask`). Therefore, this PR does not itself influence the slp-vectorizer. In a follow-up PR, I intend to override `AArch64TTIImpl::improveShuffleKindFromMask` to ensure we get `ShuffleKind::SK_Transpose` based on the new `isTRNMask`. In fact, that follow-up change is the actual motivation for this PR, as it will result in ```C++ int8x16_t g(int8_t x) { return (int8x16_t) { 0, x, 1, x, 2, x, 3, x, 4, x, 5, x, 6, x, 7, x }; } ``` from #137447 being optimised by the slp-vectorizer.
1 parent 8a115b6 commit af3c3ec

File tree

6 files changed

+202
-123
lines changed

6 files changed

+202
-123
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14965,9 +14965,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1496514965
unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
1496614966
return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2);
1496714967
}
14968-
if (isTRNMask(ShuffleMask, NumElts, WhichResult)) {
14968+
if (isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) {
1496914969
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
14970-
return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2);
14970+
return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2,
14971+
OperandOrder == 0 ? V2 : V1);
1497114972
}
1497214973

1497314974
if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
@@ -16679,7 +16680,7 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
1667916680
isREVMask(M, EltSize, NumElts, 16) ||
1668016681
isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
1668116682
isSingletonEXTMask(M, VT, DummyUnsigned) ||
16682-
isTRNMask(M, NumElts, DummyUnsigned) ||
16683+
isTRNMask(M, NumElts, DummyUnsigned, DummyUnsigned) ||
1668316684
isUZPMask(M, NumElts, DummyUnsigned) ||
1668416685
isZIPMask(M, NumElts, DummyUnsigned, DummyUnsigned) ||
1668516686
isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
@@ -31798,10 +31799,13 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
3179831799
OperandOrder == 0 ? Op1 : Op2,
3179931800
OperandOrder == 0 ? Op2 : Op1));
3180031801

31801-
if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
31802+
if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult,
31803+
OperandOrder)) {
3180231804
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
31803-
return convertFromScalableVector(
31804-
DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
31805+
SDValue TRN =
31806+
DAG.getNode(Opc, DL, ContainerVT, OperandOrder == 0 ? Op1 : Op2,
31807+
OperandOrder == 0 ? Op2 : Op1);
31808+
return convertFromScalableVector(DAG, VT, TRN);
3180531809
}
3180631810

3180731811
if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0)

llvm/lib/Target/AArch64/AArch64PerfectShuffle.h

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6699,33 +6699,53 @@ inline bool isUZPMask(ArrayRef<int> M, unsigned NumElts,
66996699
}
67006700

67016701
/// Return true for trn1 or trn2 masks of the form:
6702-
/// <0, 8, 2, 10, 4, 12, 6, 14> or
6703-
/// <1, 9, 3, 11, 5, 13, 7, 15>
6702+
/// <0, 8, 2, 10, 4, 12, 6, 14> (WhichResultOut = 0, OperandOrderOut = 0) or
6703+
/// <1, 9, 3, 11, 5, 13, 7, 15> (WhichResultOut = 1, OperandOrderOut = 0) or
6704+
/// <8, 0, 10, 2, 12, 4, 14, 6> (WhichResultOut = 0, OperandOrderOut = 1) or
6705+
/// <9, 1, 11, 3, 13, 5, 15, 7> (WhichResultOut = 1, OperandOrderOut = 1) or
67046706
inline bool isTRNMask(ArrayRef<int> M, unsigned NumElts,
6705-
unsigned &WhichResultOut) {
6707+
unsigned &WhichResultOut, unsigned &OperandOrderOut) {
67066708
if (NumElts % 2 != 0)
67076709
return false;
6708-
// Check the first non-undef element for trn1 vs trn2.
6709-
unsigned WhichResult = 2;
6710+
6711+
// "Result" corresponds to "WhichResultOut", selecting between trn1 and trn2.
6712+
// "Order" corresponds to "OperandOrderOut", selecting the order of operands
6713+
// for the instruction (flipped or not).
6714+
bool Result0Order0 = true; // WhichResultOut = 0, OperandOrderOut = 0
6715+
bool Result1Order0 = true; // WhichResultOut = 1, OperandOrderOut = 0
6716+
bool Result0Order1 = true; // WhichResultOut = 0, OperandOrderOut = 1
6717+
bool Result1Order1 = true; // WhichResultOut = 1, OperandOrderOut = 1
6718+
// Check all elements match.
67106719
for (unsigned i = 0; i != NumElts; i += 2) {
67116720
if (M[i] >= 0) {
6712-
WhichResult = ((unsigned)M[i] == i ? 0 : 1);
6713-
break;
6721+
unsigned EvenElt = (unsigned)M[i];
6722+
if (EvenElt != i)
6723+
Result0Order0 = false;
6724+
if (EvenElt != i + 1)
6725+
Result1Order0 = false;
6726+
if (EvenElt != NumElts + i)
6727+
Result0Order1 = false;
6728+
if (EvenElt != NumElts + i + 1)
6729+
Result1Order1 = false;
67146730
}
67156731
if (M[i + 1] >= 0) {
6716-
WhichResult = ((unsigned)M[i + 1] == i + NumElts ? 0 : 1);
6717-
break;
6732+
unsigned OddElt = (unsigned)M[i + 1];
6733+
if (OddElt != NumElts + i)
6734+
Result0Order0 = false;
6735+
if (OddElt != NumElts + i + 1)
6736+
Result1Order0 = false;
6737+
if (OddElt != i)
6738+
Result0Order1 = false;
6739+
if (OddElt != i + 1)
6740+
Result1Order1 = false;
67186741
}
67196742
}
6720-
if (WhichResult == 2)
6743+
6744+
if (Result0Order0 + Result1Order0 + Result0Order1 + Result1Order1 != 1)
67216745
return false;
67226746

6723-
for (unsigned i = 0; i < NumElts; i += 2) {
6724-
if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
6725-
(M[i + 1] >= 0 && (unsigned)M[i + 1] != i + NumElts + WhichResult))
6726-
return false;
6727-
}
6728-
WhichResultOut = WhichResult;
6747+
WhichResultOut = (Result0Order0 || Result0Order1) ? 0 : 1;
6748+
OperandOrderOut = (Result0Order0 || Result1Order0) ? 0 : 1;
67296749
return true;
67306750
}
67316751

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,15 @@ bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
215215
ShuffleVectorPseudo &MatchInfo) {
216216
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
217217
unsigned WhichResult;
218+
unsigned OperandOrder;
218219
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
219220
Register Dst = MI.getOperand(0).getReg();
220221
unsigned NumElts = MRI.getType(Dst).getNumElements();
221-
if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
222+
if (!isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder))
222223
return false;
223224
unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
224-
Register V1 = MI.getOperand(1).getReg();
225-
Register V2 = MI.getOperand(2).getReg();
225+
Register V1 = MI.getOperand(OperandOrder == 0 ? 1 : 2).getReg();
226+
Register V2 = MI.getOperand(OperandOrder == 0 ? 2 : 1).getReg();
226227
MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
227228
return true;
228229
}

llvm/test/CodeGen/AArch64/arm64-trn.ll

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,63 @@ define <4 x float> @vtrnQf(ptr %A, ptr %B) nounwind {
246246
ret <4 x float> %tmp5
247247
}
248248

249+
define <8 x i8> @vtrni8_trn1_flipped(<8 x i8> %A, <8 x i8> %B) nounwind {
250+
; CHECKLE-LABEL: vtrni8_trn1_flipped:
251+
; CHECKLE: // %bb.0:
252+
; CHECKLE-NEXT: trn1 v0.8b, v1.8b, v0.8b
253+
; CHECKLE-NEXT: ret
254+
;
255+
; CHECKBE-LABEL: vtrni8_trn1_flipped:
256+
; CHECKBE: // %bb.0:
257+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
258+
; CHECKBE-NEXT: rev64 v1.8b, v1.8b
259+
; CHECKBE-NEXT: trn1 v0.8b, v1.8b, v0.8b
260+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
261+
; CHECKBE-NEXT: ret
262+
%tmp1 = shufflevector <8 x i8> %A, <8 x i8> %B, <8 x i32> <i32 8, i32 0, i32 10, i32 2, i32 12, i32 4, i32 14, i32 6>
263+
ret <8 x i8> %tmp1
264+
}
265+
266+
define <8 x i8> @vtrni8_trn2_flipped(<8 x i8> %A, <8 x i8> %B) nounwind {
267+
; CHECKLE-LABEL: vtrni8_trn2_flipped:
268+
; CHECKLE: // %bb.0:
269+
; CHECKLE-NEXT: trn2 v0.8b, v1.8b, v0.8b
270+
; CHECKLE-NEXT: ret
271+
;
272+
; CHECKBE-LABEL: vtrni8_trn2_flipped:
273+
; CHECKBE: // %bb.0:
274+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
275+
; CHECKBE-NEXT: rev64 v1.8b, v1.8b
276+
; CHECKBE-NEXT: trn2 v0.8b, v1.8b, v0.8b
277+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
278+
; CHECKBE-NEXT: ret
279+
%tmp1 = shufflevector <8 x i8> %A, <8 x i8> %B, <8 x i32> <i32 9, i32 1, i32 11, i32 3, i32 13, i32 5, i32 15, i32 7>
280+
ret <8 x i8> %tmp1
281+
}
282+
283+
define <8 x i8> @vtrni8_both_flipped_with_poison_values(<8 x i8> %A, <8 x i8> %B) nounwind {
284+
; CHECKLE-LABEL: vtrni8_both_flipped_with_poison_values:
285+
; CHECKLE: // %bb.0:
286+
; CHECKLE-NEXT: trn1 v2.8b, v1.8b, v0.8b
287+
; CHECKLE-NEXT: trn2 v0.8b, v1.8b, v0.8b
288+
; CHECKLE-NEXT: add v0.8b, v2.8b, v0.8b
289+
; CHECKLE-NEXT: ret
290+
;
291+
; CHECKBE-LABEL: vtrni8_both_flipped_with_poison_values:
292+
; CHECKBE: // %bb.0:
293+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
294+
; CHECKBE-NEXT: rev64 v1.8b, v1.8b
295+
; CHECKBE-NEXT: trn1 v2.8b, v1.8b, v0.8b
296+
; CHECKBE-NEXT: trn2 v0.8b, v1.8b, v0.8b
297+
; CHECKBE-NEXT: add v0.8b, v2.8b, v0.8b
298+
; CHECKBE-NEXT: rev64 v0.8b, v0.8b
299+
; CHECKBE-NEXT: ret
300+
%tmp1 = shufflevector <8 x i8> %A, <8 x i8> %B, <8 x i32> <i32 poison, i32 0, i32 poison, i32 2, i32 poison, i32 4, i32 14, i32 6>
301+
%tmp2 = shufflevector <8 x i8> %A, <8 x i8> %B, <8 x i32> <i32 poison, i32 1, i32 poison, i32 3, i32 13, i32 5, i32 15, i32 poison>
302+
%tmp3 = add <8 x i8> %tmp1, %tmp2
303+
ret <8 x i8> %tmp3
304+
}
305+
249306
; Undef shuffle indices (even at the start of the shuffle mask) should not prevent matching to VTRN:
250307

251308
define <8 x i8> @vtrni8_undef(ptr %A, ptr %B) nounwind {

llvm/test/CodeGen/AArch64/fixed-vector-deinterleave.ll

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ define {<2 x half>, <2 x half>} @vector_deinterleave_v2f16_v4f16(<4 x half> %vec
66
; CHECK-SD-LABEL: vector_deinterleave_v2f16_v4f16:
77
; CHECK-SD: // %bb.0:
88
; CHECK-SD-NEXT: // kill: def $d0 killed $d0 def $q0
9-
; CHECK-SD-NEXT: dup v2.2s, v0.s[1]
10-
; CHECK-SD-NEXT: mov v1.16b, v2.16b
11-
; CHECK-SD-NEXT: zip1 v2.4h, v0.4h, v2.4h
12-
; CHECK-SD-NEXT: mov v1.h[0], v0.h[1]
9+
; CHECK-SD-NEXT: dup v1.2s, v0.s[1]
10+
; CHECK-SD-NEXT: zip1 v2.4h, v0.4h, v1.4h
11+
; CHECK-SD-NEXT: trn2 v1.4h, v0.4h, v1.4h
1312
; CHECK-SD-NEXT: fmov d0, d2
14-
; CHECK-SD-NEXT: // kill: def $d1 killed $d1 killed $q1
1513
; CHECK-SD-NEXT: ret
1614
;
1715
; CHECK-GI-LABEL: vector_deinterleave_v2f16_v4f16:

0 commit comments

Comments
 (0)