[Mlir-commits] [mlir] [mlir][vector] Fix patterns for dropping leading unit dims from masks (PR #73525)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 27 07:02:54 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.
---
Full diff: https://github.com/llvm/llvm-project/pull/73525.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+6)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-6)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+23-9)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+40)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..e9dab8f1e44ae68 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,12 @@ getAsConstantIndexOps(ArrayRef<Value> values);
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting.
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
- AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+ AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..3c85606da5ec522 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
}
};
+static Value processTransferMask(OpBuilder &b, Location loc, Value mask,
+ VectorType newType, AffineMap newMap,
+ VectorType oldMaskType) {
+ // Infer the type of the new mask from the new map.
+ auto newMaskType = inferTransferOpMaskType(newType, newMap);
+
+ // If the new mask is broadcastable to the old result type, we can safely
+ // use a `vector.extract` to get the new mask. Otherwise the best we can
+ // do is shape cast.
+ if (mlir::vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+ BroadcastableToResult::Success) {
+ int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+ return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+ }
+ return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
Value mask = Value();
if (read.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- int64_t dropDim = oldType.getRank() - newType.getRank();
- mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
- splatZero(dropDim));
+ VectorType maskType = read.getMaskType();
+ mask = processTransferMask(rewriter, read.getLoc(), read.getMask(),
+ newType, newMap, maskType);
}
auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
write.getLoc(), write.getVector(), splatZero(dropDim));
if (write.getMask()) {
- // The mask shape must always match the shape of the written vector, so we
- // can safely use the same extraction indices.
- auto newMask = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.getMask(), splatZero(dropDim));
+ VectorType maskType = write.getMaskType();
+ Value newMask = processTransferMask(
+ rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
return %0: vector<1x1xf16>
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x1x4xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
return
}
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+ // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func.func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/73525
More information about the Mlir-commits
mailing list