[Mlir-commits] [mlir] [mlir][Vector] Add support for sub-byte transpose emulation (PR #80110)

Diego Caballero llvmlistbot at llvm.org
Wed Jan 31 16:15:00 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/80110

>From e6edb34f1f0c22d58cae5395c6b3d59c689b6cfd Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 31 Jan 2024 07:40:01 +0000
Subject: [PATCH 1/3] [mlir][Vector] Add support for sub-byte transpose
 emulation

This PR adds patterns to convert a sub-byte vector transpose into
a sequence of instructions that perform the transpose on i8 vector
elements. Whereas this rewrite may not lead to the absolute peak
performance, it should ensure correctness when dealing with sub-byte
transposes.
---
 .../Vector/TransformOps/VectorTransformOps.td |  4 +-
 .../Vector/Transforms/VectorRewritePatterns.h |  4 ++
 .../TransformOps/VectorTransformOps.cpp       |  1 +
 .../Transforms/VectorEmulateNarrowType.cpp    | 51 +++++++++++++++++++
 .../Vector/vector-rewrite-narrow-types.mlir   | 21 ++++++++
 5 files changed, 79 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb938..ce88360aa52e9 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_masked_transfers",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Apply opt-in patterns that lower vector.mask operations surrounding 
+    Apply opt-in patterns that lower vector.mask operations surrounding
     side-effecting ops:
       - MaskedTransferReadOpPattern
       - MaskedTransferWriteOpPattern
@@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
       - ReorderCastOpsOnBroadcast
       - ReorderElementwiseOpsOnTranspose
 
-    These patterns have the effect of rewriting a vector.multi_reduce into a 
+    These patterns have the effect of rewriting a vector.multi_reduce into a
     vector.contract.
   }];
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 49b74c0c466d2..f5941d32e683f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
                                              PatternBenefit benefit = 1);
 
+/// Appends patterns for emulating a sub-byte vector transpose.
+void populateVectorTransposeNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e5..19922c4295fe0 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
+  populateVectorTransposeNarrowTypeRewritePatterns(patterns);
 }
 
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0110a8df89aee..193c9a6182b49 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1052,6 +1052,52 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite a sub-byte vector transpose into a sequence of instructions that
+/// perform the transpose on wider (byte) element types.
+/// For example:
+///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+///
+///   is rewritten as:
+///
+///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
+///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+///   %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
+///
+struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    // Precondition: sub-byte integer transpose.
+    constexpr unsigned minNativeBitwidth = 8;
+    VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
+    if (srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth)
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "not a sub-byte transpose");
+
+    // Perform the rewrite.
+    Location loc = transposeOp.getLoc();
+    // Signed/unsigned interpretation shouldn't matter here as we are just
+    // transposing the elements and truncating them back to the original size.
+    // TODO: Use unsigned extension (more efficient) when emulation or backend
+    // support is available.
+    auto srcNativeVecType =
+        srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
+                                                  transposeOp.getVector());
+    Value newTranspose = rewriter.create<vector::TransposeOp>(
+        loc, extOp, transposeOp.getPermutation());
+    VectorType dstSubByteVecType = transposeOp.getResultVectorType();
+    rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
+                                                 newTranspose);
+    return success();
+  }
+};
+
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1080,3 +1126,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
       patterns.getContext(), benefit.getBenefit() + 1);
 }
+
+void vector::populateVectorTransposeNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index c4fbb4c219b91..02063a81664b8 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @i4_transpose(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
+  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
+  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+  return %0 : vector<16x8xi4>
+}
+
+// CHECK-LABEL: func.func @i7_transpose(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
+  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
+  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
+  return %0 : vector<16x8xi7>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
@@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+

>From 81fa7181e73265eb102df37733210f0bb483e2a9 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 31 Jan 2024 19:45:14 +0000
Subject: [PATCH 2/3] Feedback

---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp    | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 193c9a6182b49..6d6cd6a419de9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
     // transposing the elements and truncating them back to the original size.
     // TODO: Use unsigned extension (more efficient) when emulation or backend
     // support is available.
-    auto srcNativeVecType =
-        srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    auto srcNativeVecType = srcSubByteVecType.cloneWith(
+        std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
     Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
                                                   transposeOp.getVector());
     Value newTranspose = rewriter.create<vector::TransposeOp>(
@@ -1097,7 +1097,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-
 } // namespace
 
 //===----------------------------------------------------------------------===//

>From a20d2dd9deee837043fbac080002e21fdcc5598b Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 1 Feb 2024 00:14:43 +0000
Subject: [PATCH 3/3] Fix

---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6d6cd6a419de9..36fb66708407b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1074,9 +1074,11 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
     // Precondition: sub-byte integer transpose.
     constexpr unsigned minNativeBitwidth = 8;
     VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
-    if (srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth)
+    if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
+        srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
       return rewriter.notifyMatchFailure(transposeOp,
                                          "not a sub-byte transpose");
+    }
 
     // Perform the rewrite.
     Location loc = transposeOp.getLoc();



More information about the Mlir-commits mailing list