[Mlir-commits] [mlir] eb7e299 - Reland "[mlir][Vector] Re-define masking semantics in vector.transfer ops""

Diego Caballero llvmlistbot at llvm.org
Mon Nov 28 19:42:09 PST 2022


Author: Diego Caballero
Date: 2022-11-29T03:36:54Z
New Revision: eb7e2998d135ac30198cb7e6709db3bdc155f2d0

URL: https://github.com/llvm/llvm-project/commit/eb7e2998d135ac30198cb7e6709db3bdc155f2d0
DIFF: https://github.com/llvm/llvm-project/commit/eb7e2998d135ac30198cb7e6709db3bdc155f2d0.diff

LOG: Reland "[mlir][Vector] Re-define masking semantics in vector.transfer ops""

This relands commit 847b5f82a4a34218bf16d6f83f1b7c32df3117ba.

Differential Revision: https://reviews.llvm.org/D138079

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/VectorInterfaces.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/lib/Interfaces/VectorInterfaces.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index edaf78b38b274..01582570ad2fe 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1207,9 +1207,11 @@ def Vector_TransferReadOp :
     provided to specify a fallback value in the case of out-of-bounds accesses
     and/or masking.
 
-    An optional SSA value `mask` of the same shape as the vector type may be
-    specified to mask out elements. Such elements will be replaces with
-    `padding`. Elements whose corresponding mask element is `0` are masked out.
+    An optional SSA value `mask` may be specified to mask out elements read from
+    the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
+    matches how elements are read from the MemRef/Tensor, *before* any
+    permutation or broadcasting. Elements whose corresponding mask element is
+    `0` are masked out and replaced with `padding`.
 
     An optional boolean array attribute `in_bounds` specifies for every vector
     dimension if the transfer is guaranteed to be within the source bounds.
@@ -1419,6 +1421,12 @@ def Vector_TransferWriteOp :
 
     The size of the slice is specified by the size of the vector.
 
+    An optional SSA value `mask` may be specified to mask out elements written
+    to the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
+    matches how elements are written into the MemRef/Tensor, *after* applying
+    any permutation. Elements whose corresponding mask element is `0` are
+    masked out.
+
     An optional SSA value `mask` of the same shape as the vector type may be
     specified to mask out elements. Elements whose corresponding mask element
     is `0` are masked out.

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h
index 620ada282144d..7ae4ee35de337 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.h
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h
@@ -17,18 +17,6 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 
-namespace mlir {
-namespace vector {
-namespace detail {
-
-/// Given the vector type and the permutation map of a vector transfer op,
-/// compute the expected mask type.
-VectorType transferMaskType(VectorType vecType, AffineMap map);
-
-} // namespace detail
-} // namespace vector
-} // namespace mlir
-
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/VectorInterfaces.h.inc"
 

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 832528bfa78ad..0da3309d165f4 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -169,16 +169,25 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
         }]
     >,
     InterfaceMethod<
-      /*desc=*/"Return the mask type if the op has a mask.",
+      /*desc=*/"Return the mask operand if the op has a mask. Otherwise, "
+               "return a empty value.",
+      /*retTy=*/"Value",
+      /*methodName=*/"getMask",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getMask();
+        }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the mask type if the op has a mask. Otherwise, return "
+               "an empty VectorType.",
       /*retTy=*/"::mlir::VectorType",
       /*methodName=*/"getMaskType",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return $_op.getMask()
-            ? ::mlir::vector::detail::transferMaskType(
-                $_op.getVectorType(), $_op.getPermutationMap())
-            : ::mlir::VectorType();
+        return $_op.getMask() ? $_op.getMask().getType() : ::mlir::VectorType();
       }]
     >,
     InterfaceMethod<

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8802930b46767..c4af2d8a19441 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3128,7 +3128,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
 static LogicalResult
 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                  VectorType vectorType, VectorType maskType,
-                 AffineMap permutationMap, ArrayAttr inBounds) {
+                 VectorType inferredMaskType, AffineMap permutationMap,
+                 ArrayAttr inBounds) {
   if (op->hasAttr("masked")) {
     return op->emitOpError("masked attribute has been removed. "
                            "Use in_bounds instead.");
@@ -3181,13 +3182,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
     if (permutationMap.getNumResults() != vectorType.getRank())
       return op->emitOpError("requires a permutation_map with result dims of "
                              "the same rank as the vector type");
-
-    VectorType expectedMaskType =
-        vector::detail::transferMaskType(vectorType, permutationMap);
-    if (maskType && expectedMaskType != maskType)
-      return op->emitOpError("expects mask type consistent with permutation "
-                             "map: ")
-             << maskType;
   }
 
   if (permutationMap.getNumSymbols() != 0)
@@ -3197,6 +3191,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
     return op->emitOpError("requires a permutation_map with input dims of the "
                            "same rank as the source type");
 
+  if (maskType && maskType != inferredMaskType)
+    return op->emitOpError("inferred mask type (")
+           << inferredMaskType << ") and mask operand type (" << maskType
+           << ") don't match";
+
   if (inBounds) {
     if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
       return op->emitOpError("expects the optional in_bounds attr of same rank "
@@ -3239,6 +3238,19 @@ void TransferReadOp::print(OpAsmPrinter &p) {
   p << " : " << getShapedType() << ", " << getVectorType();
 }
 
+/// Infers the mask type for a transfer read given its vector type and
+/// permutation map. The mask in a transfer read operation applies to the
+/// tensor/buffer reading part of it and its type should match the shape read
+/// *before* any permutation or broadcasting.
+static VectorType inferTransferReadMaskType(VectorType vecType,
+                                            AffineMap permMap) {
+  auto i1Type = IntegerType::get(permMap.getContext(), 1);
+  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
+  assert(invPermMap && "Inversed permutation map couldn't be computed");
+  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
+  return VectorType::get(maskShape, i1Type);
+}
+
 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   auto &builder = parser.getBuilder();
   SMLoc typesLoc;
@@ -3269,13 +3281,14 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   VectorType vectorType = types[1].dyn_cast<VectorType>();
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
-  auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
-  Attribute mapAttr = result.attributes.get(permutationAttrName);
-  if (!mapAttr) {
-    auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    // Update `mapAttr` that is used later to determine mask type.
-    mapAttr = AffineMapAttr::get(permMap);
-    result.attributes.set(permutationAttrName, mapAttr);
+  auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
+  Attribute permMapAttr = result.attributes.get(permMapAttrName);
+  AffineMap permMap;
+  if (!permMapAttr) {
+    permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+    result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
+  } else {
+    permMap = permMapAttr.cast<AffineMapAttr>().getValue();
   }
   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
       parser.resolveOperands(indexInfo, indexType, result.operands) ||
@@ -3286,10 +3299,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
     if (shapedType.getElementType().dyn_cast<VectorType>())
       return parser.emitError(
           maskInfo.location, "does not support masks with vector element type");
-    auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
     // Instead of adding the mask type as an op type, compute it based on the
     // vector type and the permutation map (to keep the type signature small).
-    auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
+    auto maskType = inferTransferReadMaskType(vectorType, permMap);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
   }
@@ -3307,13 +3319,17 @@ LogicalResult TransferReadOp::verify() {
   VectorType maskType = getMaskType();
   auto paddingType = getPadding().getType();
   auto permutationMap = getPermutationMap();
+  VectorType inferredMaskType =
+      maskType ? inferTransferReadMaskType(vectorType, permutationMap)
+               : VectorType();
   auto sourceElementType = shapedType.getElementType();
 
   if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
     return emitOpError("requires ") << shapedType.getRank() << " indices";
 
   if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
-                              shapedType, vectorType, maskType, permutationMap,
+                              shapedType, vectorType, maskType,
+                              inferredMaskType, permutationMap,
                               getInBounds() ? *getInBounds() : ArrayAttr())))
     return failure();
 
@@ -3677,6 +3693,18 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, vector, dest, indices, permutationMap, inBounds);
 }
 
+/// Infers the mask type for a transfer write given its vector type and
+/// permutation map. The mask in a transfer read operation applies to the
+/// tensor/buffer writing part of it and its type should match the shape written
+/// *after* any permutation.
+static VectorType inferTransferWriteMaskType(VectorType vecType,
+                                             AffineMap permMap) {
+  auto i1Type = IntegerType::get(permMap.getContext(), 1);
+  SmallVector<int64_t, 8> maskShape =
+      compressUnusedDims(permMap).compose(vecType.getShape());
+  return VectorType::get(maskShape, i1Type);
+}
+
 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
                                    OperationState &result) {
   auto &builder = parser.getBuilder();
@@ -3704,11 +3732,14 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   ShapedType shapedType = types[1].dyn_cast<ShapedType>();
   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
-  auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
-  auto attr = result.attributes.get(permutationAttrName);
-  if (!attr) {
-    auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
+  auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
+  auto permMapAttr = result.attributes.get(permMapAttrName);
+  AffineMap permMap;
+  if (!permMapAttr) {
+    permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+    result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
+  } else {
+    permMap = permMapAttr.cast<AffineMapAttr>().getValue();
   }
   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -3718,7 +3749,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
     if (shapedType.getElementType().dyn_cast<VectorType>())
       return parser.emitError(
           maskInfo.location, "does not support masks with vector element type");
-    auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+    auto maskType = inferTransferWriteMaskType(vectorType, permMap);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
   }
@@ -3744,6 +3775,9 @@ LogicalResult TransferWriteOp::verify() {
   VectorType vectorType = getVectorType();
   VectorType maskType = getMaskType();
   auto permutationMap = getPermutationMap();
+  VectorType inferredMaskType =
+      maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
+               : VectorType();
 
   if (llvm::size(getIndices()) != shapedType.getRank())
     return emitOpError("requires ") << shapedType.getRank() << " indices";
@@ -3754,7 +3788,8 @@ LogicalResult TransferWriteOp::verify() {
     return emitOpError("should not have broadcast dimensions");
 
   if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
-                              shapedType, vectorType, maskType, permutationMap,
+                              shapedType, vectorType, maskType,
+                              inferredMaskType, permutationMap,
                               getInBounds() ? *getInBounds() : ArrayAttr())))
     return failure();
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index de72c6d2cfeac..d7ec87e95f4ef 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -83,26 +83,6 @@ struct TransferReadPermutationLowering
       newVectorShape[pos.value()] = originalShape[pos.index()];
     }
 
-    // Transpose mask operand.
-    Value newMask;
-    if (op.getMask()) {
-      // Remove unused dims from the permutation map. E.g.:
-      // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
-      // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
-      auto comp = compressUnusedDims(map);
-      // Get positions of remaining result dims.
-      // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
-      // maskTransposeIndices = [ 2,     1,    0]
-      SmallVector<int64_t> maskTransposeIndices;
-      for (unsigned i = 0; i < comp.getNumResults(); ++i) {
-        if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
-          maskTransposeIndices.push_back(expr.getPosition());
-      }
-
-      newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
-                                                     maskTransposeIndices);
-    }
-
     // Transpose in_bounds attribute.
     ArrayAttr newInBoundsAttr =
         op.getInBounds() ? transposeInBoundsAttr(
@@ -114,7 +94,8 @@ struct TransferReadPermutationLowering
         VectorType::get(newVectorShape, op.getVectorType().getElementType());
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
-        AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
+        AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+        newInBoundsAttr);
 
     // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -168,11 +149,6 @@ struct TransferWritePermutationLowering
                       return expr.dyn_cast<AffineDimExpr>().getPosition();
                     });
 
-    // Transpose mask operand.
-    Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
-                                       op.getLoc(), op.getMask(), indices)
-                                 : Value();
-
     // Transpose in_bounds attribute.
     ArrayAttr newInBoundsAttr =
         op.getInBounds() ? transposeInBoundsAttr(
@@ -186,7 +162,7 @@ struct TransferWritePermutationLowering
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        newMask, newInBoundsAttr);
+        op.getMask(), newInBoundsAttr);
 
     return success();
   }

diff  --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index 3115a5d983c45..c16fad25b642d 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -10,19 +10,6 @@
 
 using namespace mlir;
 
-VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
-                                                  AffineMap map) {
-  auto i1Type = IntegerType::get(map.getContext(), 1);
-  SmallVector<int64_t, 8> shape;
-  for (int64_t i = 0; i < vecType.getRank(); ++i) {
-    // Only result dims have a corresponding dim in the mask.
-    if (map.getResult(i).template isa<AffineDimExpr>()) {
-      shape.push_back(vecType.getDimSize(i));
-    }
-  }
-  return VectorType::get(shape, i1Type);
-}
-
 //===----------------------------------------------------------------------===//
 // VectorUnroll Interfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir
index 118df27b64e6a..812c8d95f371c 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir
@@ -5,10 +5,9 @@
 
 // CHECK-LABEL: func @transfer_read_2d_mask_transposed(
 //   CHECK-DAG:   %[[PADDING:.*]] = arith.constant dense<-4.200000e+01> : vector<9xf32>
-//   CHECK-DAG:   %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<9x4xi1>
+//   CHECK-DAG:   %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<4x9xi1>
 //       CHECK:   %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
-//       CHECK:   %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
-//       CHECK:   memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
+//       CHECK:   memref.store %[[MASK]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
 //       CHECK:   %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
 //       CHECK:   scf.for {{.*}} {
 //       CHECK:     scf.if {{.*}} {
@@ -25,11 +24,10 @@
 func.func @transfer_read_2d_mask_transposed(
     %A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
   %fm42 = arith.constant -42.0: f32
-  %mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
-                          [1, 1, 1, 1], [0, 1, 1, 0],
-                          [1, 1, 1, 1], [1, 1, 1, 1],
-                          [1, 1, 1, 1], [0, 0, 0, 0],
-                          [1, 1, 1, 1]]> : vector<9x4xi1>
+  %mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
+                                [0, 0, 1, 1, 1, 1, 1, 0, 1],
+                                [1, 1, 1, 1, 1, 1, 1, 0, 1],
+                                [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
   %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
     memref<?x?xf32>, vector<9x4xf32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 88b1abbd209a7..06d0903a5284c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -49,7 +49,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   %v0 = vector.splat %c0 : vector<4x3xi32>
   %vi0 = vector.splat %i0 : vector<4x3xindex>
   %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
-  %m2 = vector.splat %i1 : vector<5x4xi1>
+  %m2 = vector.splat %i1 : vector<4x5xi1>
   //
   // CHECK: vector.transfer_read
   %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index b22a1e4a829d1..0da64debdb6af 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -282,19 +282,19 @@ func.func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x
   %c0 = arith.constant 0 : index
 
 // CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
-  %mask0 = vector.splat %m : vector<7x14xi1>
+  %mask0 = vector.splat %m : vector<14x7xi1>
   %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
 // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
 
 // CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1>
-  %mask1 = vector.splat %m : vector<14x16xi1>
+  %mask1 = vector.splat %m : vector<16x14xi1>
   %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
 // CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
-  %mask2 = vector.splat %m : vector<7x14xi1>
+  %mask2 = vector.splat %m : vector<14x7xi1>
   %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
 // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
@@ -338,7 +338,7 @@ func.func @transfer_write_permutations(
   %c0 = arith.constant 0 : index
 
   // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1>
-  %mask0 = vector.splat %m : vector<7x14x8x16xi1>
+  %mask0 = vector.splat %m : vector<8x14x16x7xi1>
   %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
   // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
   // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index fd087e4200c08..2f38ef674dd3e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -40,11 +40,10 @@ func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: in
 func.func @transfer_read_2d_mask_transposed(
     %A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = arith.constant -42.0: f32
-  %mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
-                          [1, 1, 1, 1], [0, 1, 1, 0],
-                          [1, 1, 1, 1], [1, 1, 1, 1],
-                          [1, 1, 1, 1], [0, 0, 0, 0],
-                          [1, 1, 1, 1]]> : vector<9x4xi1>
+  %mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
+                          [0, 0, 1, 1, 1, 1, 1, 0, 1],
+                          [1, 1, 1, 1, 1, 1, 1, 0, 1],
+                          [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
   %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
     memref<?x?xf32>, vector<9x4xf32>


        


More information about the Mlir-commits mailing list