[Mlir-commits] [mlir] [mlir][Vector] Add support for sub-byte transpose emulation (PR #80110)

Han-Chung Wang llvmlistbot at llvm.org
Wed Jan 31 00:09:31 PST 2024


================
@@ -1052,6 +1052,52 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite a sub-byte vector transpose into a sequence of instructions that
+/// perform the transpose on wider (byte) element types.
+/// For example:
+///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+///
+///   is rewritten as:
+///
+///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
+///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+///   %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
+///
+struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    // Precondition: sub-byte integer transpose.
+    constexpr unsigned minNativeBitwidth = 8;
+    VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
+    if (srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth)
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "not a sub-byte transpose");
+
+    // Perform the rewrite.
+    Location loc = transposeOp.getLoc();
+    // Signed/unsigned interpretation shouldn't matter here as we are just
+    // transposing the elements and truncating them back to the original size.
+    // TODO: Use unsigned extension (more efficient) when emulation or backend
+    // support is available.
+    auto srcNativeVecType =
+        srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
----------------
hanhanW wrote:

I think `rewriter.getIntegerType(minNativeBitwidth)` is better for consistency. 

https://github.com/llvm/llvm-project/pull/80110


More information about the Mlir-commits mailing list