[Mlir-commits] [mlir] c908778 - [mlir] Fix masked vector transfer ops with broadcasts

Matthias Springer llvmlistbot at llvm.org
Wed May 12 19:38:33 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T11:37:36+09:00
New Revision: c9087788f7e41285445729127dd07ff7f82e3fc0

URL: https://github.com/llvm/llvm-project/commit/c9087788f7e41285445729127dd07ff7f82e3fc0
DIFF: https://github.com/llvm/llvm-project/commit/c9087788f7e41285445729127dd07ff7f82e3fc0.diff

LOG: [mlir] Fix masked vector transfer ops with broadcasts

Broadcast dimensions of a vector transfer op have no corresponding dimension in the mask vector. E.g., a 2-D TransferReadOp, where one dimension is a broadcast, can have a 1-D `mask` attribute.

This commit also adds a few additional transfer op integration tests for various combinations of broadcasts, masking, dim transposes, etc.

Differential Revision: https://reviews.llvm.org/D101745

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/VectorInterfaces.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Interfaces/VectorInterfaces.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h
index 7ae4ee35de33..620ada282144 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.h
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h
@@ -17,6 +17,18 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 
+namespace mlir {
+namespace vector {
+namespace detail {
+
+/// Given the vector type and the permutation map of a vector transfer op,
+/// compute the expected mask type.
+VectorType transferMaskType(VectorType vecType, AffineMap map);
+
+} // namespace detail
+} // namespace vector
+} // namespace mlir
+
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/VectorInterfaces.h.inc"
 

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 580c6985f5ad..49d35710c156 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -156,6 +156,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
         return $_op.vector().getType().template dyn_cast<VectorType>();
         }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Return the mask type if the op has a mask.",
+      /*retTy=*/"Optional<VectorType>",
+      /*methodName=*/"getMaskType",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.mask()
+            ? llvm::Optional<VectorType>(mlir::vector::detail::transferMaskType(
+                $_op.getVectorType(), $_op.permutation_map()))
+            : llvm::None;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{ Return the number of dimensions that participate in the
                   permutation map.}],

diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 08aca49c7af4..06912943e069 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -79,13 +79,20 @@ static BufferAllocs allocBuffers(OpTy xferOp) {
 
   if (xferOp.mask()) {
     auto maskType = MemRefType::get({}, xferOp.mask().getType());
-    result.maskBuffer = memref_alloca(maskType).value;
-    memref_store(xferOp.mask(), result.maskBuffer);
+    auto maskBuffer = memref_alloca(maskType).value;
+    memref_store(xferOp.mask(), maskBuffer);
+    result.maskBuffer = memref_load(maskBuffer);
   }
 
   return result;
 }
 
+template <typename OpTy>
+static bool isOutermostDimBroadcast(OpTy xferOp) {
+  auto map = xferOp.permutation_map();
+  return map.getResult(0).template isa<AffineConstantExpr>();
+}
+
 /// Given a vector transfer op, calculate which dimension of the `source`
 /// memref should be unpacked in the next application of TransferOpConversion.
 /// A return value of None indicates a broadcast.
@@ -95,7 +102,7 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
     return expr.getPosition();
   }
-  assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+  assert(isOutermostDimBroadcast(xferOp) &&
          "Expected AffineDimExpr or AffineConstantExpr");
   return None;
 }
@@ -143,14 +150,17 @@ static void maybeYieldValue(
 }
 
 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
-/// is set to true. Does not return a Value if the transfer op is not 1D or
-/// if the transfer op does not have a mask.
+/// is set to true. Does not return a Value if the transfer op does not have a
+/// mask, if the transfer op's mask is not 1D or if the to-be-unpacked dim of
+/// the transfer op is a broadcast.
 template <typename OpTy>
 static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
-  if (xferOp.getVectorType().getRank() != 1)
-    return Value();
   if (!xferOp.mask())
     return Value();
+  if (xferOp.getMaskType()->getRank() != 1)
+    return Value();
+  if (isOutermostDimBroadcast(xferOp))
+    return Value();
 
   auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
   return vector_extract_element(xferOp.mask(), ivI32).value;
@@ -488,8 +498,8 @@ struct PrepareTransferReadConversion
     auto *newXfer = rewriter.clone(*xferOp.getOperation());
     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
     if (xferOp.mask()) {
-      auto loadedMask = memref_load(buffers.maskBuffer);
-      dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(loadedMask);
+      dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
+          buffers.maskBuffer);
     }
 
     memref_store(newXfer->getResult(0), buffers.dataBuffer);
@@ -541,9 +551,8 @@ struct PrepareTransferWriteConversion
     });
 
     if (xferOp.mask()) {
-      auto loadedMask = memref_load(buffers.maskBuffer);
       rewriter.updateRootInPlace(
-          xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); });
+          xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
     }
 
     return success();
@@ -590,8 +599,18 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
       auto maskBuffer = getMaskBuffer(xferOp);
       auto maskBufferType =
           maskBuffer.getType().template dyn_cast<MemRefType>();
-      auto castedMaskType = unpackOneDim(maskBufferType);
-      castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
+      if (isOutermostDimBroadcast(xferOp) ||
+          xferOp.getMaskType()->getRank() == 1) {
+        // Do not unpack a dimension of the mask, if:
+        // * To-be-unpacked transfer op dimension is a broadcast.
+        // * Mask is 1D, i.e., the mask cannot be further unpacked.
+        //   (That means that all remaining dimensions of the transfer op must
+        //   be broadcasts.)
+        castedMaskBuffer = maskBuffer;
+      } else {
+        auto castedMaskType = unpackOneDim(maskBufferType);
+        castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
+      }
     }
 
     // Loop bounds and step.
@@ -616,13 +635,20 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
                 Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
 
             // If old transfer op has a mask: Set mask on new transfer op.
-            if (xferOp.mask()) {
+            // Special case: If the mask of the old transfer op is 1D and the
+            //               unpacked dim is not a broadcast, no mask is needed
+            //               on the new transfer op.
+            if (xferOp.mask() && (isOutermostDimBroadcast(xferOp) ||
+                                  xferOp.getMaskType()->getRank() > 1)) {
               OpBuilder::InsertionGuard guard(b);
               b.setInsertionPoint(newXfer); // Insert load before newXfer.
 
               SmallVector<Value, 8> loadIndices;
               Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-              loadIndices.push_back(iv);
+              // In case of broadcast: Use same indices to load from memref as
+              // before.
+              if (!isOutermostDimBroadcast(xferOp))
+                loadIndices.push_back(iv);
 
               auto mask = memref_load(castedMaskBuffer, loadIndices);
               rewriter.updateRootInPlace(
@@ -661,7 +687,7 @@ static Optional<int64_t> get1dMemrefIndices(
     return dim;
   }
 
-  assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+  assert(isOutermostDimBroadcast(xferOp) &&
          "Expected AffineDimExpr or AffineConstantExpr");
   return None;
 }

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 0088a3e6c9a4..0326127d9e30 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2491,10 +2491,11 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
-  auto attr = result.attributes.get(permutationAttrName);
-  if (!attr) {
+  Attribute mapAttr = result.attributes.get(permutationAttrName);
+  if (!mapAttr) {
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
+    mapAttr = AffineMapAttr::get(permMap);
+    result.attributes.set(permutationAttrName, mapAttr);
   }
   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
       parser.resolveOperands(indexInfo, indexType, result.operands) ||
@@ -2502,7 +2503,10 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
                             result.operands))
     return failure();
   if (hasMask.succeeded()) {
-    auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+    auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
+    // Instead of adding the mask type as an op type, compute it based on the
+    // vector type and the permutation map (to keep the type signature small).
+    auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
   }

diff  --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index c16fad25b642..77fcd301fa03 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -10,6 +10,26 @@
 
 using namespace mlir;
 
+namespace mlir {
+namespace vector {
+namespace detail {
+
+VectorType transferMaskType(VectorType vecType, AffineMap map) {
+  auto i1Type = IntegerType::get(map.getContext(), 1);
+  SmallVector<int64_t, 8> shape;
+  for (int64_t i = 0; i < vecType.getRank(); ++i) {
+    // Only result dims have a corresponding dim in the mask.
+    if (auto expr = map.getResult(i).template isa<AffineDimExpr>()) {
+      shape.push_back(vecType.getDimSize(i));
+    }
+  }
+  return VectorType::get(shape, i1Type);
+}
+
+} // namespace detail
+} // namespace vector
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // VectorUnroll Interfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 4773bcb1aa74..dce95f322bfb 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -5,6 +5,14 @@
 
 // Test for special cases of 1D vector transfer ops.
 
+memref.global "private" @gv : memref<5x6xf32> =
+    dense<[[0. , 1. , 2. , 3. , 4. , 5. ],
+           [10., 11., 12., 13., 14., 15.],
+           [20., 21., 22., 23., 24., 25.],
+           [30., 31., 32., 33., 34., 35.],
+           [40., 41., 42., 43., 44., 45.]]>
+
+// Non-contiguous, strided load
 func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
   %f = vector.transfer_read %A[%base1, %base2], %fm42
@@ -14,6 +22,7 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   return
 }
 
+// Broadcast
 func @transfer_read_1d_broadcast(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
@@ -24,6 +33,7 @@ func @transfer_read_1d_broadcast(
   return
 }
 
+// Non-contiguous, strided load
 func @transfer_read_1d_in_bounds(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
@@ -34,6 +44,7 @@ func @transfer_read_1d_in_bounds(
   return
 }
 
+// Non-contiguous, strided load
 func @transfer_read_1d_mask(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
@@ -45,6 +56,7 @@ func @transfer_read_1d_mask(
   return
 }
 
+// Non-contiguous, strided load
 func @transfer_read_1d_mask_in_bounds(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
@@ -56,6 +68,7 @@ func @transfer_read_1d_mask_in_bounds(
   return
 }
 
+// Non-contiguous, strided store
 func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fn1 = constant -1.0 : f32
   %vf0 = splat %fn1 : vector<7xf32>
@@ -65,57 +78,68 @@ func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   return
 }
 
+// Non-contiguous, strided store
+func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+  %fn1 = constant -2.0 : f32
+  %vf0 = splat %fn1 : vector<7xf32>
+  %mask = constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
+  vector.transfer_write %vf0, %A[%base1, %base2], %mask
+    {permutation_map = affine_map<(d0, d1) -> (d0)>}
+    : vector<7xf32>, memref<?x?xf32>
+  return
+}
+
 func @entry() {
   %c0 = constant 0: index
   %c1 = constant 1: index
   %c2 = constant 2: index
   %c3 = constant 3: index
-  %f10 = constant 10.0: f32
-  // work with dims of 4, not of 3
-  %first = constant 5: index
-  %second = constant 6: index
-  %A = memref.alloc(%first, %second) : memref<?x?xf32>
-  scf.for %i = %c0 to %first step %c1 {
-    %i32 = index_cast %i : index to i32
-    %fi = sitofp %i32 : i32 to f32
-    %fi10 = mulf %fi, %f10 : f32
-    scf.for %j = %c0 to %second step %c1 {
-        %j32 = index_cast %j : index to i32
-        %fj = sitofp %j32 : i32 to f32
-        %fres = addf %fi10, %fj : f32
-        memref.store %fres, %A[%i, %j] : memref<?x?xf32>
-    }
-  }
-
-  // Read from 2D memref on first dimension. Cannot be lowered to an LLVM
-  // vector load. Instead, generates scalar loads.
+  %0 = memref.get_global @gv : memref<5x6xf32>
+  %A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32>
+
+  // 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM
+  //    vector load. Instead, generates scalar loads.
   call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
-  // Write to 2D memref on first dimension. Cannot be lowered to an LLVM
-  // vector store. Instead, generates scalar stores.
+  // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
+
+  // 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
+  //    vector store. Instead, generates scalar stores.
   call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
-  // (Same as above.)
+
+  // 3. (Same as 1. To check if 2 works correctly.)
   call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
-  // Read a scalar from a 2D memref and broadcast the value to a 1D vector.
-  // Generates a loop with vector.insertelement.
+  // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
+
+  // 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
+  //    Generates a loop with vector.insertelement.
   call @transfer_read_1d_broadcast(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
-  //  Read from 2D memref on first dimension. Accesses are in-bounds, so no
-  // if-check is generated inside the generated loop.
+  // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
+
+  // 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no
+  //    if-check is generated inside the generated loop.
   call @transfer_read_1d_in_bounds(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
-  // Optional mask attribute is specified and, in addition, there may be
-  // out-of-bounds accesses.
+  // CHECK: ( 12, 22, -1 )
+
+  // 6. Optional mask attribute is specified and, in addition, there may be
+  //    out-of-bounds accesses.
   call @transfer_read_1d_mask(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
-  // Same as above, but accesses are in-bounds.
+  // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
+
+  // 7. Same as 6, but accesses are in-bounds.
   call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( 12, -42, -1 )
+
+  // 8. Write to 2D memref on first dimension with a mask.
+  call @transfer_write_1d_mask(%A, %c1, %c0)
+      : (memref<?x?xf32>, index, index) -> ()
+
+  // 9. (Same as 1. To check if 8 works correctly.)
+  call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )
+
   return
 }
-
-// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
-// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
-// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
-// CHECK: ( 12, 22, -1 )
-// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
-// CHECK: ( 12, -42, -1 )

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index f4eef8b98b76..802b8beeb3fd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -3,6 +3,11 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
+memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
+                                                       [10., 11., 12., 13.],
+                                                       [20., 21., 22., 23.]]>
+
+// Vector load
 func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = constant -42.0: f32
   %f = vector.transfer_read %A[%base1, %base2], %fm42
@@ -12,6 +17,7 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   return
 }
 
+// Vector load with mask
 func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = constant -42.0: f32
   %mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
@@ -25,6 +31,47 @@ func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index)
   return
 }
 
+// Vector load with mask + transpose
+func @transfer_read_2d_mask_transposed(
+    %A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
+                          [1, 1, 1, 1], [0, 1, 1, 0],
+                          [1, 1, 1, 1], [1, 1, 1, 1],
+                          [1, 1, 1, 1], [0, 0, 0, 0],
+                          [1, 1, 1, 1]]> : vector<9x4xi1>
+  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+    memref<?x?xf32>, vector<9x4xf32>
+  vector.print %f: vector<9x4xf32>
+  return
+}
+
+// Vector load with mask + broadcast
+func @transfer_read_2d_mask_broadcast(
+    %A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
+  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
+    memref<?x?xf32>, vector<4x9xf32>
+  vector.print %f: vector<4x9xf32>
+  return
+}
+
+// Transpose + vector load with mask + broadcast
+func @transfer_read_2d_mask_transpose_broadcast_last_dim(
+    %A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[1, 0, 1, 1]> : vector<4xi1>
+  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
+    memref<?x?xf32>, vector<4x9xf32>
+  vector.print %f: vector<4x9xf32>
+  return
+}
+
+// Load + transpose
 func @transfer_read_2d_transposed(
     %A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = constant -42.0: f32
@@ -35,6 +82,7 @@ func @transfer_read_2d_transposed(
   return
 }
 
+// Load 1D + broadcast to 2D
 func @transfer_read_2d_broadcast(
     %A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = constant -42.0: f32
@@ -45,6 +93,7 @@ func @transfer_read_2d_broadcast(
   return
 }
 
+// Vector store
 func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fn1 = constant -1.0 : f32
   %vf0 = splat %fn1 : vector<1x4xf32>
@@ -54,55 +103,79 @@ func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   return
 }
 
+// Vector store with mask
+func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fn1 = constant -2.0 : f32
+  %mask = constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
+  %vf0 = splat %fn1 : vector<1x4xf32>
+  vector.transfer_write %vf0, %A[%base1, %base2], %mask
+    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
+    vector<1x4xf32>, memref<?x?xf32>
+  return
+}
+
 func @entry() {
   %c0 = constant 0: index
   %c1 = constant 1: index
   %c2 = constant 2: index
   %c3 = constant 3: index
-  %c4 = constant 4: index
-  %c5 = constant 5: index
-  %c8 = constant 5: index
-  %f10 = constant 10.0: f32
-  // work with dims of 4, not of 3
-  %first = constant 3: index
-  %second = constant 4: index
-  %A = memref.alloc(%first, %second) : memref<?x?xf32>
-  scf.for %i = %c0 to %first step %c1 {
-    %i32 = index_cast %i : index to i32
-    %fi = sitofp %i32 : i32 to f32
-    %fi10 = mulf %fi, %f10 : f32
-    scf.for %j = %c0 to %second step %c1 {
-        %j32 = index_cast %j : index to i32
-        %fj = sitofp %j32 : i32 to f32
-        %fres = addf %fi10, %fj : f32
-        memref.store %fres, %A[%i, %j] : memref<?x?xf32>
-    }
-  }
-  // On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...]
-  // Read shifted by 2 and pad with -42:
+  %0 = memref.get_global @gv : memref<3x4xf32>
+  %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
+
+  // 1. Read 2D vector from 2D memref.
   call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
-  // Same as above, but transposed
+  // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 2. Read 2D vector from 2D memref at specified location and transpose the
+  //    result.
   call @transfer_read_2d_transposed(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
-  // Write into memory shifted by 3
-  call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
-  // Read shifted by 0 and pad with -42:
-  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
-  // Same as above, but apply a mask
+  // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 3. Read 2D vector from 2D memref with a 2D mask. In addition, some
+  //    accesses are out-of-bounds.
   call @transfer_read_2d_mask(%A, %c0, %c0)
       : (memref<?x?xf32>, index, index) -> ()
-  // Same as above, but without mask and transposed
-  call @transfer_read_2d_transposed(%A, %c0, %c0)
+  // CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 4. Same as 3, but transpose the result.
+  call @transfer_read_2d_mask_transposed(%A, %c0, %c0)
       : (memref<?x?xf32>, index, index) -> ()
-  // Second vector dimension is a broadcast
+  // CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) )
+
+  // 5. Read 1D vector from 2D memref at specified location and broadcast the
+  //    result to 2D.
   call @transfer_read_2d_broadcast(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 6. Read 1D vector from 2D memref at specified location with mask and
+  //    broadcast the result to 2D.
+  call @transfer_read_2d_mask_broadcast(%A, %c2, %c1)
+      : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) )
+
+  // 7. Read 1D vector from 2D memref (second dimension) at specified location
+  //    with mask and broadcast the result to 2D. In this test case, mask
+  //    elements must be evaluated before lowering to an (N>1)-D transfer.
+  call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1)
+      : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 8. Write 2D vector into 2D memref at specified location.
+  call @transfer_write_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+  // 9. Read memref to verify step 8.
+  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  // 10. Write 2D vector into 2D memref at specified location with mask.
+  call @transfer_write_2d_mask(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+  // 11. Read memref to verify step 10.
+  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+
   return
 }
 
-// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
-// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
-// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
-// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
-// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
-// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index ae7fee3c9110..ff64dbbc8e4c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -1,15 +1,8 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
-// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
 // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-// Test case is based on test-transfer-read-2d.
-
 func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
                        %o: index, %a: index, %b: index, %c: index) {
   %fm42 = constant -42.0: f32
@@ -29,6 +22,17 @@ func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
   return
 }
 
+func @transfer_read_3d_mask_broadcast(
+    %A : memref<?x?x?x?xf32>, %o: index, %a: index, %b: index, %c: index) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[0, 1]> : vector<2xi1>
+  %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>}
+      : memref<?x?x?x?xf32>, vector<2x5x3xf32>
+  vector.print %f: vector<2x5x3xf32>
+  return
+}
+
 func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
                                   %o: index, %a: index, %b: index, %c: index) {
   %fm42 = constant -42.0: f32
@@ -80,20 +84,34 @@ func @entry() {
     }
   }
 
+  // 1. Read 3D vector from 4D memref.
   call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
+
+  // 2. Write 3D vector to 4D memref.
   call @transfer_write_3d(%A, %c0, %c0, %c1, %c1)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+
+  // 3. Read memref to verify step 2.
   call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
+
+  // 4. Read 3D vector from 4D memref and transpose vector.
   call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
+
+  // 5. Read 1D vector from 4D memref and broadcast vector to 3D.
   call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
+
+  // 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D.
+  call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0)
+      : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) )
+
   return
 }
-
-// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
-// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
-// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
-// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )


        


More information about the Mlir-commits mailing list