[Mlir-commits] [mlir] 6d68ef4 - [mlir][Vector] Canonicalize create_mask(transpose)

Diego Caballero llvmlistbot at llvm.org
Thu Mar 16 07:37:46 PDT 2023


Author: Diego Caballero
Date: 2023-03-16T14:35:52Z
New Revision: 6d68ef4e38fb3e701a8aacede568f57b73ebcfbf

URL: https://github.com/llvm/llvm-project/commit/6d68ef4e38fb3e701a8aacede568f57b73ebcfbf
DIFF: https://github.com/llvm/llvm-project/commit/6d68ef4e38fb3e701a8aacede568f57b73ebcfbf.diff

LOG: [mlir][Vector] Canonicalize create_mask(transpose)

When applying vector masking we may create a mask and then transpose it.
Transpositions are extremely expensive so this patch introduces a new
canonicalization pattern that remove the tranpose operation and create a
new transposed mask.

Differential Revision: https://reviews.llvm.org/D146193

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0180c7192fb2c..9796693b4b6cd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5265,13 +5265,38 @@ class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(create_mask) into a new transposed create_mask.
+class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto createMaskOp =
+        transposeOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return failure();
+
+    // Get the transpose permutation and apply it to the vector.create_mask
+    // operands.
+    auto maskOperands = createMaskOp.getOperands();
+    SmallVector<int64_t> permutation;
+    transposeOp.getTransp(permutation);
+    SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
+    applyPermutationToVector(newOperands, permutation);
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        transposeOp, transposeOp.getResultVectorType(), newOperands);
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results
-      .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
-          context);
+  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
+              TransposeFolder, FoldTransposeSplat>(context);
 }
 
 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 053e3620cab2e..f82540c28f829 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -54,6 +54,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
 
 // -----
 
+// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
+func.func @create_mask_transpose_to_transposed_create_mask(
+  %dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
+  // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
+  // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
+  %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
+  return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0


        


More information about the Mlir-commits mailing list