[Mlir-commits] [mlir] [mlir][vector] VectorEmulateNarrowType uses deinterleave (PR #94946)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 10 02:35:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Mubashar Ahmad (mub-at-arm)
<details>
<summary>Changes</summary>
VectorEmulateNarrowTyp now uses the
vector.deinterleave operation in its
lowering for truncating to sub-byte types
in arith.trunci.
---
Full diff: https://github.com/llvm/llvm-project/pull/94946.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+8-29)
- (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+12-5)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 59b6cb3ae667a..3f8f359de0d6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -922,39 +922,23 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
"Expected i8 type");
// 1. De-interleave low and high i8 elements.
- int64_t vecDimSize = srcVecType.getShape().back();
- SmallVector<int64_t> deinterleaveLowMaskValues;
- SmallVector<int64_t> deinterleaveHighMaskValues;
- assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
- deinterleaveLowMaskValues.reserve(vecDimSize / 2);
- deinterleaveHighMaskValues.reserve(vecDimSize / 2);
- for (int i = 0, end = vecDimSize; i < end; i += 2) {
- deinterleaveLowMaskValues.push_back(i);
- deinterleaveHighMaskValues.push_back(i + 1);
- }
-
- auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, srcValue, srcValue,
- rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
- auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, srcValue, srcValue,
- rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+ auto deinterleaving = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
// 2. Zero out the upper side of each low i8 element.
constexpr int8_t i8LowBitMask = 0x0F;
Value zeroOutMask = rewriter.create<arith::ConstantOp>(
- loc,
- DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
- Value zeroOutLow =
- rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+ loc, DenseElementsAttr::get(deinterleaving.getResultVectorType(),
+ i8LowBitMask));
+ Value zeroOutLow = rewriter.create<arith::AndIOp>(
+ loc, deinterleaving.getRes1(), zeroOutMask);
// 3. Move high i4 values to upper side of the byte.
constexpr int8_t bitsToShift = 4;
- VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+ VectorType deinterI8VecType = deinterleaving.getResultVectorType();
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
- Value shlHigh =
- rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+ Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaving.getRes2(),
+ shiftValues);
// 4. Merge high and low i4 values.
auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
@@ -1178,11 +1162,6 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (!srcVecType || !dstVecType)
return failure();
- // Only single dim vectors are supported until we have
- // `vector.deinterleave`.
- if (srcVecType.getRank() != 1)
- return failure();
-
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 614b2d4945348..12927d9547647 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -268,8 +268,7 @@ func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
-// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -283,8 +282,7 @@ func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -300,16 +298,25 @@ func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
// CHECK-NOT: vector.shli
// CHECK-NOT: vector.ori
// CHECK: arith.trunci
+// CHECK: vector.deinterleave
%0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
return %0 : vector<8x32xi4>
}
+// CHECK-LABEL: func.func @aligned_trunci_nd(
+func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+ // CHECK: arith.trunci
+ // CHECK: vector.deinterleave
+ %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
+ return %0 : vector<3x8x32xi4>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK: %[[EXT:.*]] = vector.interleave
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK: %{{.*}}, %{{.*}} = vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/94946
More information about the Mlir-commits
mailing list