[Mlir-commits] [mlir] afc3756 - [mlir][vector] Masking support for reductions in Linalg vectorizer

Diego Caballero llvmlistbot at llvm.org
Fri Jan 13 12:45:38 PST 2023


Author: Diego Caballero
Date: 2023-01-13T20:45:04Z
New Revision: afc3756e6c6da68f066703f384aca6c2ffb54286

URL: https://github.com/llvm/llvm-project/commit/afc3756e6c6da68f066703f384aca6c2ffb54286
DIFF: https://github.com/llvm/llvm-project/commit/afc3756e6c6da68f066703f384aca6c2ffb54286.diff

LOG: [mlir][vector] Masking support for reductions in Linalg vectorizer

This patch enables vectorization of reductions in Linalg vectorizer
using the vector.mask operation. It also introduces the logic to slice
and propagate the vector mask of a masked multi-reduction to their
respective lowering operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141571

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 0028abee51c27..deb86df396d1c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -203,6 +203,20 @@ inline bool isReductionIterator(Attribute attr) {
   return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns the
+/// maskable operation itself.
+Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
+                         Value mask);
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5a14f0da52b16..8c5d44ab3d31a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -340,6 +340,7 @@ def Vector_MultiDimReductionOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      DeclareOpInterfaceMethods<InferTypeOpInterface>,
+     DeclareOpInterfaceMethods<MaskableOpInterface>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
@@ -2338,16 +2339,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
 
   let skipDefaultBuilders = 1;
   let builders = [
-    OpBuilder<(ins "Value":$mask,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
-                   "Value":$passthru,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>
+    OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru,
+                   "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1e8335012504d..5f367d1a240a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -292,25 +292,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
 
   // Wrap the operation with a new `vector.mask` and update D-U chain.
   assert(opToMask && "Expected a valid operation to mask");
-  auto opResults = opToMask->getResultTypes();
-  auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
-    Block *insBlock = builder.getInsertionBlock();
-    // Create a block, put an op in that block. Look for a utility.
-    // Maybe in conversion pattern rewriter. Way to avoid splice.
-    // Set insertion point.
-    insBlock->getOperations().splice(
-        insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
-    builder.create<vector::YieldOp>(loc, opToMask->getResults());
-  };
-  // TODO: Allow multiple results in vector.mask.
-  auto maskOp =
-      opResults.empty()
-          ? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
-                                            createRegionMask)
-          : rewriter.create<vector::MaskOp>(opToMask->getLoc(),
-                                            opToMask->getResultTypes().front(),
-                                            mask, createRegionMask);
-
+  auto maskOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, opToMask, mask));
   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
 
   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
@@ -440,17 +423,16 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
 /// initial value.buildMultiDimReduce
 // Note: this is a true builder that notifies the OpBuilder listener.
 // TODO: Consider moving as a static helper on the ReduceOp.
-static Operation *buildMultiDimReduce(OpBuilder &b,
-                                      Operation *reduceOp, Value valueToReduce,
-                                      Value acc,
-                                      const SmallVector<bool> &reductionMask) {
+static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
+                                      Value valueToReduce, Value acc,
+                                      ArrayRef<bool> dimsToMask) {
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
-      reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
+      reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
 }
 
-static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
+static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
   return llvm::to_vector(
       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
 }
@@ -701,8 +683,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
   if (!reduceType ||
       (outputType && reduceType.getShape() == outputType.getShape()))
     return nullptr;
-  SmallVector<bool> reductionMask = getReductionMask(linalgOp);
-  return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
+  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
+  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
 }
 
 /// Generic vectorization for a single operation `op`, given already vectorized
@@ -972,11 +954,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
-  // TODO: Masking only supports dynamic generic ops without reductions for now.
-  if (!isElementwise(op) &&
-      llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
-        return itType != utils::IteratorType::parallel;
-      }))
+  // TODO: Masking only supports dynamic generic ops for now.
+  if (!isa<linalg::GenericOp>(op))
     return failure();
 
   // TODO: 0-d vectors are not supported yet.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f00d8494e3151..933945233c885 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -342,6 +342,13 @@ LogicalResult MultiDimReductionOp::verify() {
   return success();
 }
 
+/// Returns the mask type expected by this operation.
+Type MultiDimReductionOp::getExpectedMaskType() {
+  auto vecType = getSourceVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 namespace {
 // Only unit dimensions that are being reduced are folded. If the dimension is
 // unit, but not reduced, it is not folded, thereby keeping the output type the
@@ -5276,7 +5283,8 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void MaskOp::build(
     OpBuilder &builder, OperationState &result, Value mask,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+    Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
   assert(maskRegionBuilder &&
          "builder callback for 'maskRegion' must be present");
 
@@ -5284,21 +5292,22 @@ void MaskOp::build(
   OpBuilder::InsertionGuard guard(builder);
   Region *maskRegion = result.addRegion();
   builder.createBlock(maskRegion);
-  maskRegionBuilder(builder, result.location);
+  maskRegionBuilder(builder, maskableOp);
 }
 
 void MaskOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
-    Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, resultTypes, mask, /*passthru=*/Value(),
+    Value mask, Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
         maskRegionBuilder);
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
-    Value mask, Value passthru,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, mask, maskRegionBuilder);
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask,
+    Value passthru, Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+  build(builder, result, mask, maskableOp, maskRegionBuilder);
   if (passthru)
     result.addOperands(passthru);
   result.addTypes(resultTypes);
@@ -5738,6 +5747,34 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
   llvm_unreachable("unknown CombiningKind");
 }
 
+//===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void mlir::vector::createMaskOpRegion(OpBuilder &builder,
+                                      Operation *maskableOp) {
+  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
+  Block *insBlock = builder.getInsertionBlock();
+  // Create a block and move the op to that block.
+  insBlock->getOperations().splice(
+      insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
+  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
+}
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns
+/// the maskable operation itself.
+Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
+                                       Operation *maskableOp, Value mask) {
+  if (!mask)
+    return maskableOp;
+  return rewriter.create<MaskOp>(maskableOp->getLoc(),
+                                 maskableOp->getResultTypes(), mask, maskableOp,
+                                 createMaskOpRegion);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 31a24522a5f10..e89059cb9390a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -12,9 +12,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 
 #define DEBUG_TYPE "vector-multi-reduction"
@@ -40,6 +38,18 @@ class InnerOuterDimReductionConversion
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto src = multiReductionOp.getSource();
     auto loc = multiReductionOp.getLoc();
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
@@ -79,6 +89,15 @@ class InnerOuterDimReductionConversion
       indices.append(reductionDims.begin(), reductionDims.end());
       indices.append(parallelDims.begin(), parallelDims.end());
     }
+
+    // If masked, transpose the original mask.
+    Value transposedMask;
+    if (maskableOp.isMasked()) {
+      transposedMask = rewriter.create<vector::TransposeOp>(
+          loc, maskableOp.getMaskingOp().getMask(), indices);
+    }
+
+    // Transpose reduction source.
     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
     SmallVector<bool> reductionMask(srcRank, false);
     for (int i = 0; i < reductionSize; ++i) {
@@ -87,9 +106,14 @@ class InnerOuterDimReductionConversion
       else
         reductionMask[i] = true;
     }
-    rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
-        multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
-        reductionMask, multiReductionOp.getKind());
+
+    Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
+        multiReductionOp.getLoc(), transposeOp.getResult(),
+        multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
+    newMultiRedOp =
+        mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
+
+    rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
     return success();
   }
 
@@ -113,6 +137,18 @@ class ReduceMultiDimReductionRank
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
     auto loc = multiReductionOp.getLoc();
@@ -186,10 +222,22 @@ class ReduceMultiDimReductionRank
       std::swap(mask.front(), mask.back());
       std::swap(vectorShape.front(), vectorShape.back());
     }
+
+    Value newVectorMask;
+    if (maskableOp.isMasked()) {
+      Value vectorMask = maskableOp.getMaskingOp().getMask();
+      auto maskCastedType = VectorType::get(
+          vectorShape,
+          vectorMask.getType().cast<VectorType>().getElementType());
+      newVectorMask =
+          rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
+    }
+
     auto castedType = VectorType::get(
         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
+
     Value acc = multiReductionOp.getAcc();
     if (flattenedParallelDim) {
       auto accType = VectorType::get(
@@ -197,24 +245,26 @@ class ReduceMultiDimReductionRank
           multiReductionOp.getSourceVectorType().getElementType());
       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
     }
-    // 5. Creates the flattened form of vector.multi_reduction with inner/outer
+    // 6. Creates the flattened form of vector.multi_reduction with inner/outer
     // most dim as reduction.
-    auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+    Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
         loc, cast, acc, mask, multiReductionOp.getKind());
+    newMultiDimRedOp =
+        mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
 
-    // 6. If there are no parallel shapes, the result is a scalar.
+    // 7. If there are no parallel shapes, the result is a scalar.
     // TODO: support 0-d vectors when available.
     if (parallelShapes.empty()) {
-      rewriter.replaceOp(multiReductionOp, newOp.getDest());
+      rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
       return success();
     }
 
-    // 7. Creates shape cast for the output n-D -> 2-D
+    // 8. Creates shape cast for the output n-D -> 2-D.
     VectorType outputCastedType = VectorType::get(
         parallelShapes,
         multiReductionOp.getSourceVectorType().getElementType());
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        multiReductionOp, outputCastedType, newOp.getDest());
+        rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
     return success();
   }
 
@@ -230,6 +280,12 @@ struct TwoDimMultiReductionToElementWise
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: Support masking.
+      return failure();
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-2 ["parallel", "reduce"] or bail.
     if (srcRank != 2)
@@ -274,6 +330,18 @@ struct TwoDimMultiReductionToReduction
     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
       return failure();
 
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto loc = multiReductionOp.getLoc();
     Value result = rewriter.create<arith::ConstantOp>(
         loc, multiReductionOp.getDestType(),
@@ -285,13 +353,22 @@ struct TwoDimMultiReductionToReduction
           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
       auto acc = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
-      auto reducedValue = rewriter.create<vector::ReductionOp>(
+      Operation *reductionOp = rewriter.create<vector::ReductionOp>(
           loc, multiReductionOp.getKind(), v, acc);
+
+      // If masked, slice the mask and mask the new reduction operation.
+      if (maskableOp.isMasked()) {
+        Value mask = rewriter.create<vector::ExtractOp>(
+            loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
+        reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+      }
+
       result = rewriter.create<vector::InsertElementOp>(
-          loc, reducedValue, result,
+          loc, reductionOp->getResult(0), result,
           rewriter.create<arith::ConstantIndexOp>(loc, i));
     }
-    rewriter.replaceOp(multiReductionOp, result);
+
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
@@ -307,6 +384,12 @@ struct OneDimMultiReductionToTwoDim
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: Support masking.
+      return failure();
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-1 or bail.
     if (srcRank != 1)

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0ccd6c4b96733..d25ffe74841da 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1824,6 +1824,82 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>,
+                                       %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0)>],
+                        iterator_types = ["parallel", "reduction"] }
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%arg1 : tensor<?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_reduction(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x8xi1>
+// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+// CHECK:           return %[[VAL_15]] : tensor<?xf32>
+// CHECK:         }
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>,
+                                                 %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                         affine_map<(d0, d1, d2) -> (d2, d1)>],
+                        iterator_types = ["reduction", "parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?x?xf32>)
+    outs(%arg1 : tensor<?x?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8, 16]
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_transpose_reduction(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction <add>, %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+
+// -----
+
 // This is a regression test. This IR cannot be vectorized, but
 // structured.vectorize should nevertheless succeed.
 
@@ -1892,4 +1968,3 @@ transform.sequence failures(propagate) {
 // CHECK-LABEL: @wrong_reduction_detection
 // CHECK:         vector.broadcast
 // CHECK:         vector.transfer_write
-

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 6b372c3ef1c3e..ee4ab7a1c5c8f 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s
 
 func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
@@ -19,6 +19,8 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
+// -----
+
 func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
     return %0 : f32
@@ -31,6 +33,8 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -
 //       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
 //       CHECK:   return %[[RES]]
 
+// -----
+
 func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
@@ -72,6 +76,7 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
+// -----
 
 func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
@@ -85,6 +90,8 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: v
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
 //       CHECK:       return %[[RESULT]]
 
+// -----
+
 func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
     return %0 : vector<2x4xf32>
@@ -135,3 +142,95 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]
+
+// -----
+
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %c1 = arith.constant 1 : index
+  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %c0_1 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+  %cst_2 = arith.constant 0.000000e+00 : f32
+  %2 = vector.create_mask %dim : vector<4xi1>
+  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+  %c0_3 = arith.constant 0 : index
+  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+  return %5 : tensor<?xf32>
+}
+
+// Verify that the original 2-D mask is sliced and propagated properly to the
+// vector.reduction instances.
+
+// CHECK-LABEL:   func.func @vectorize_dynamic_reduction
+// CHECK:           %[[VAL_8:.*]] = tensor.dim
+// CHECK:           %[[VAL_9:.*]] = tensor.dim
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1>
+
+// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<4x8xi1>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_18:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<4x8xi1>
+// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_23:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<4x8xi1>
+// CHECK:           %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_28:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<4x8xi1>
+// CHECK:           %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_33:.*]] = vector.insertelement
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %c0_2 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+  %cst_3 = arith.constant 0.000000e+00 : f32
+  %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1>
+  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+  %c0_4 = arith.constant 0 : index
+  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_dynamic_transpose_reduction
+// CHECK:           %[[VAL_6:.*]] = tensor.dim
+// CHECK:           %[[VAL_7:.*]] = tensor.dim
+// CHECK:           %[[VAL_8:.*]] = tensor.dim
+// CHECK:           %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1>
+// CHECK:           %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1>
+
+// Just checking a few instances to make sure the vector mask is properly propagated:
+
+// CHECK:           %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
+// CHECK:           %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]]
+
+// CHECK:           %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
+// CHECK:           %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]]
+
+// CHECK:           %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
+// CHECK:           %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]]
+
+// CHECK:           %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
+// CHECK:           %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
+


        


More information about the Mlir-commits mailing list