[Mlir-commits] [mlir] 1d9dcca - [mlir][vector] Add canonicalzation pattern for shape_cast(create_mask)

Andrzej Warzynski llvmlistbot at llvm.org
Thu Aug 24 00:15:57 PDT 2023


Author: Andrzej Warzynski
Date: 2023-08-24T07:14:42Z
New Revision: 1d9dcca9e3cd36cbaf0759a4b084c808d1aee044

URL: https://github.com/llvm/llvm-project/commit/1d9dcca9e3cd36cbaf0759a4b084c808d1aee044
DIFF: https://github.com/llvm/llvm-project/commit/1d9dcca9e3cd36cbaf0759a4b084c808d1aee044.diff

LOG: [mlir][vector] Add canonicalzation pattern for shape_cast(create_mask)

This is primarily to avoid trailing unit dims:
```
%1 = vector.create_mask %c1, %dim_0, %c1, %c1 : vector<1x4x1x1xi1>
%2 = vector.shape_cast %1 : vector<1x4x1x1xi1> to vector<1x4xi1>
```
becomes:
```
%1 = vector.create_mask %c1, %dim_0 : vector<1x4xi1>
```

Differential Revision: https://reviews.llvm.org/D158111

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fbf81cf2b79e70..4e9364611b257d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4779,6 +4779,114 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Helper function that computes a new vector type based on the input vector
+/// type by removing the trailing one dims:
+///
+///   vector<4x1x1xi1> --> vector<4x1>
+///
+static VectorType trimTrailingOneDims(VectorType oldType) {
+  ArrayRef<int64_t> oldShape = oldType.getShape();
+  ArrayRef<int64_t> newShape = oldShape;
+
+  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
+  ArrayRef<bool> newScalableDims = oldScalableDims;
+
+  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
+    newShape = newShape.drop_back(1);
+    newScalableDims = newScalableDims.drop_back(1);
+  }
+
+  // Make sure we have at least 1 dimension.
+  // TODO: Add support for 0-D vectors.
+  if (newShape.empty()) {
+    newShape = oldShape.take_back();
+    newScalableDims = oldScalableDims.take_back();
+  }
+
+  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+}
+
+/// Folds qualifying shape_cast(create_mask) into a new create_mask
+///
+/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
+/// dimension. If the input vector comes from `vector.create_mask` for which
+/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
+/// to fold shape_cast into create_mask.
+///
+/// BEFORE:
+///    %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
+///    %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
+/// AFTER:
+///    %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
+class ShapeCastCreateMaskFolderTrailingOneDim final
+    : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
+                                PatternRewriter &rewriter) const override {
+    Value shapeOpSrc = shapeOp->getOperand(0);
+    auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
+    auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
+    if (!createMaskOp && !constantMaskOp)
+      return failure();
+
+    VectorType shapeOpResTy = shapeOp.getResultVectorType();
+    VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
+
+    VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
+    if (newVecType != shapeOpResTy)
+      return failure();
+
+    auto numDimsToDrop =
+        shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
+
+    // No unit dims to drop
+    if (!numDimsToDrop)
+      return failure();
+
+    if (createMaskOp) {
+      auto maskOperands = createMaskOp.getOperands();
+      auto numMaskOperands = maskOperands.size();
+
+      // Check every mask dim size to see whether it can be dropped
+      for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
+           --i) {
+        auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
+        if (!constant || (constant.value() != 1))
+          return failure();
+      }
+      SmallVector<Value> newMaskOperands =
+          maskOperands.drop_back(numDimsToDrop);
+
+      rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
+                                                        newMaskOperands);
+      return success();
+    }
+
+    if (constantMaskOp) {
+      auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+      auto numMaskOperands = maskDimSizes.size();
+
+      // Check every mask dim size to see whether it can be dropped
+      for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
+           --i) {
+        if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
+          return failure();
+      }
+
+      auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
+      ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
+
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
+                                                          newMaskOperandsAttr);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
 /// This only applies when the shape of the broadcast source
 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
@@ -4831,7 +4939,8 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
+  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
+              ShapeCastBroadcastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 17a9e381b61708..8b709eb643d918 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2219,3 +2219,66 @@ func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
   %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_shape_cast_with_mask(
+// CHECK-SAME:     %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x4xi1> {
+func.func @fold_shape_cast_with_mask(%arg0: tensor<1x?xf32>) -> vector<1x4xi1> {
+// CHECK-NOT: vector.shape_cast
+// CHECK:     %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:     %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32>
+// CHECK:     %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x4xi1>
+// CHECK:     return %[[VAL_3]] : vector<1x4xi1>
+  %c1 = arith.constant 1 : index
+  %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
+  %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x4x1x1xi1>
+  %2 = vector.shape_cast %1 : vector<1x4x1x1xi1> to vector<1x4xi1>
+  return %2 : vector<1x4xi1>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_shape_cast_with_mask_scalable(
+// CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[4]xi1> {
+func.func @fold_shape_cast_with_mask_scalable(%arg0: tensor<1x?xf32>) -> vector<1x[4]xi1> {
+// CHECK-NOT: vector.shape_cast
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[4]xi1>
+// CHECK:           return %[[VAL_3]] : vector<1x[4]xi1>
+  %c1 = arith.constant 1 : index
+  %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
+  %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
+  %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
+  return %2 : vector<1x[4]xi1>
+}
+
+// -----
+
+// Check that scalable "1" (i.e. [1]) is not folded
+// CHECK-LABEL:   func.func @fold_shape_cast_with_mask_scalable_one(
+// CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[1]xi1> {
+func.func @fold_shape_cast_with_mask_scalable_one(%arg0: tensor<1x?xf32>) -> vector<1x[1]xi1>{
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[1]xi1>
+// CHECK:           return %[[VAL_3]] : vector<1x[1]xi1>
+  %c1 = arith.constant 1 : index
+  %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
+  %1 = vector.create_mask %c1, %dim, %c1 : vector<1x[1]x1xi1>
+  %2 = vector.shape_cast %1 : vector<1x[1]x1xi1> to vector<1x[1]xi1>
+  return %2 : vector<1x[1]xi1>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1> {
+func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
+// CHECK-NOT: vector.shape_cast
+// CHECK:           %[[VAL_0:.*]] = vector.constant_mask [1] : vector<4xi1>
+// CHECK:           return %[[VAL_0]] : vector<4xi1>
+  %1 = vector.constant_mask [1, 1, 1] : vector<4x1x1xi1>
+  %2 = vector.shape_cast %1 : vector<4x1x1xi1> to vector<4xi1>
+  return %2 : vector<4xi1>
+}


        


More information about the Mlir-commits mailing list