[Mlir-commits] [mlir] 1cd434d - [mlir][Vector] Add canonicalization pattern for vector.transpose(vector.constant_mask)
Diego Caballero
llvmlistbot at llvm.org
Wed Mar 29 12:54:35 PDT 2023
Author: Diego Caballero
Date: 2023-03-29T19:53:29Z
New Revision: 1cd434d007584d84f2c2dd3d5bbfdc1f12b9a7b2
URL: https://github.com/llvm/llvm-project/commit/1cd434d007584d84f2c2dd3d5bbfdc1f12b9a7b2
DIFF: https://github.com/llvm/llvm-project/commit/1cd434d007584d84f2c2dd3d5bbfdc1f12b9a7b2.diff
LOG: [mlir][Vector] Add canonicalization pattern for vector.transpose(vector.constant_mask)
We already had vector.transpose(vector.create_mask) ->
vector.create_mask. This patch adds the constant mask version of it.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D147099
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index abd8962c9227..8ee59659c3a4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5269,23 +5269,37 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
PatternRewriter &rewriter) const override {
- auto createMaskOp =
- transposeOp.getVector().getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
+ Value transposeSrc = transpOp.getVector();
+ auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
+ auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
+ if (!createMaskOp && !constantMaskOp)
return failure();
- // Get the transpose permutation and apply it to the vector.create_mask
- // operands.
- auto maskOperands = createMaskOp.getOperands();
+ // Get the transpose permutation and apply it to the vector.create_mask or
+ // vector.constant_mask operands.
SmallVector<int64_t> permutation;
- transposeOp.getTransp(permutation);
- SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
- applyPermutationToVector(newOperands, permutation);
+ transpOp.getTransp(permutation);
+
+ if (createMaskOp) {
+ auto maskOperands = createMaskOp.getOperands();
+ SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
+ applyPermutationToVector(newOperands, permutation);
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ transpOp, transpOp.getResultVectorType(), newOperands);
+ return success();
+ }
+
+ // ConstantMaskOp case.
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes();
+ SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
+ applyPermutationToVector(newMaskDimSizes, permutation);
- rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
- transposeOp, transposeOp.getResultVectorType(), newOperands);
+ rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+ transpOp, transpOp.getResultVectorType(),
+ ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f82540c28f82..88c91ff46a8b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -58,8 +58,9 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(
%dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
- // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
- // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
+ // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
+ // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
+ // CHECK-NOT: vector.transpose
%0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1>
%1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
@@ -67,6 +68,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
+// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
+func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
+ // CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
+ // CHECK: vector.constant_mask [3, 1, 2] : vector<4x2x3xi1>
+ // CHECK-NOT: vector.transpose
+ %0 = vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
+ return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
+}
+
+// -----
+
func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
More information about the Mlir-commits
mailing list