[Mlir-commits] [mlir] 6e8f7d5 - [mlir][vector] Fix patterns for dropping leading unit dims from masks (#73525)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 09:35:37 PST 2023


Author: Quinn Dawkins
Date: 2023-11-27T12:35:32-05:00
New Revision: 6e8f7d5966f51f3c3ddf76bb00eb60f4fb9ceebe

URL: https://github.com/llvm/llvm-project/commit/6e8f7d5966f51f3c3ddf76bb00eb60f4fb9ceebe
DIFF: https://github.com/llvm/llvm-project/commit/6e8f7d5966f51f3c3ddf76bb00eb60f4fb9ceebe.diff

LOG: [mlir][vector] Fix patterns for dropping leading unit dims from masks (#73525)

Previously the pattern only worked when the permutation map was a minor
identity. Infer the new mask type from the new transfer map after
dropping leading unit dims.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 9ab20e20d975429..59d585a77b1e29f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -160,6 +160,18 @@ getAsConstantIndexOps(ArrayRef<Value> values);
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
 
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
+/// *before* any permutation or broadcasting. For example,
+///
+/// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
+///
+/// Has inferred mask type:
+///
+/// maskType = vector<2x1xi1>
+VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
+
 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
 /// as masked operation.
 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..c462b23e1133fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
   p << " : " << getShapedType() << ", " << getVectorType();
 }
 
-/// Infers the mask type for a transfer op given its vector type and
-/// permutation map. The mask in a transfer op operation applies to the
-/// tensor/buffer part of it and its type should match the vector shape
-/// *before* any permutation or broadcasting.
-static VectorType inferTransferOpMaskType(VectorType vecType,
-                                          AffineMap permMap) {
+VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
+                                                 AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
   assert(invPermMap && "Inversed permutation map couldn't be computed");

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 75f32b23e57b0d6..84294e4552a607e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
   }
 };
 
+static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
+                                  VectorType newType, AffineMap newMap,
+                                  VectorType oldMaskType) {
+  // Infer the type of the new mask from the new map.
+  VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
+
+  // If the new mask is broadcastable to the old result type, we can safely
+  // use a `vector.extract` to get the new mask. Otherwise the best we can
+  // do is shape cast.
+  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+      BroadcastableToResult::Success) {
+    int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+    return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+  }
+  return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+}
+
 // Turns vector.transfer_read on vector with leading 1 dimensions into
 // vector.shape_cast followed by vector.transfer_read on vector without leading
 // 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
 
     Value mask = Value();
     if (read.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      int64_t dropDim = oldType.getRank() - newType.getRank();
-      mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
-                                                splatZero(dropDim));
+      VectorType maskType = read.getMaskType();
+      mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
+                                  newType, newMap, maskType);
     }
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
         write.getLoc(), write.getVector(), splatZero(dropDim));
 
     if (write.getMask()) {
-      // The mask shape must always match the shape of the written vector, so we
-      // can safely use the same extraction indices.
-      auto newMask = rewriter.create<vector::ExtractOp>(
-          write.getLoc(), write.getMask(), splatZero(dropDim));
+      VectorType maskType = write.getMaskType();
+      Value newMask = dropUnitDimsFromMask(
+          rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
           write, newVector, write.getSource(), write.getIndices(),
           AffineMapAttr::get(newMap), newMask, inBoundsAttr);

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 5de30206927db2f..71dffca8f14da59 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
   return %0: vector<1x1xf16>
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
+func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
+                            permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
   return
 }
 
+// -----
+
+// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
+func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
+  // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
+                        permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
 func.func @cast_away_elementwise_leading_one_dims(
   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,


        


More information about the Mlir-commits mailing list