[Mlir-commits] [mlir] [mlir][Vector] Replace `vector.shuffle` with `vector.interleave` in vector narrow type emulation (PR #82550)
Diego Caballero
llvmlistbot at llvm.org
Wed Feb 21 16:56:26 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/82550
>From de8a47ae91437a53b483ba72b04c9660df06fe37 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 21 Feb 2024 23:03:21 +0000
Subject: [PATCH 1/3] [mlir][Vector] Replace `vector.shuffle` with
`vector.interleave` in vector narrow type emulation
This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
---
.../Transforms/VectorEmulateNarrowType.cpp | 17 ++--
.../Vector/vector-rewrite-narrow-types.mlir | 82 +++++++++++++------
2 files changed, 67 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
- if (!preconditionType || preconditionType.getRank() != 1 ||
- preconditionType.isScalable())
- return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+ if (!preconditionType || preconditionType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");
+ if (!preconditionType || preconditionType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
return commonConversionPrecondition(rewriter, preconditionType, op);
}
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}
- return rewriter.create<vector::ShuffleOp>(
- loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK-NOT: arith.extsi
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: shuffle
- // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK: %[[EXT:.*]] = vector.interleave
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
// CHECK-LABEL: func.func @i7_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
return %0 : vector<16x8xi7>
}
>From 87f772ff7352a47018849c39ac05175309a190e5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:10:27 +0000
Subject: [PATCH 2/3] Feedback
---
.../Vector/Transforms/VectorEmulateNarrowType.cpp | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9ebe36cd3861e0..41a778b8496ef0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -873,14 +873,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
- // 3. Interleave low and high i8 elements using a shuffle.
- SmallVector<int64_t> interleaveMaskValues;
- interleaveMaskValues.reserve(vecDimSize);
- for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
- interleaveMaskValues.push_back(i);
- interleaveMaskValues.push_back(i + (vecDimSize / 2));
- }
-
+ // 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
>From 4a2fa749f3cf746baeffa81b3a15554b6bd88448 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:55:54 +0000
Subject: [PATCH 3/3] Remove unused var
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 41a778b8496ef0..fc11ae63e718a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -857,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- int64_t vecDimSize = srcVecType.getShape().back();
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
More information about the Mlir-commits
mailing list