[Mlir-commits] [mlir] 1ac874c - [mlir][Vector] Add support for masked vector gather ops

Diego Caballero llvmlistbot at llvm.org
Tue Feb 14 22:15:28 PST 2023


Author: Diego Caballero
Date: 2023-02-15T06:10:22Z
New Revision: 1ac874c9aa1859fe67fad110c278588a5a670d78

URL: https://github.com/llvm/llvm-project/commit/1ac874c9aa1859fe67fad110c278588a5a670d78
DIFF: https://github.com/llvm/llvm-project/commit/1ac874c9aa1859fe67fad110c278588a5a670d78.diff

LOG: [mlir][Vector] Add support for masked vector gather ops

This patch adds support for masked vector.gather ops using the
vector.mask representation. It includes the implementation of the
MaskableOpInterface, Linalg vectorizer support and lowering to LLVM.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143939

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/test/Dialect/Vector/lower-vector-mask.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6f6d80c11c96a..94d4b64465126 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1846,7 +1846,7 @@ def Vector_MaskedStoreOp :
 }
 
 def Vector_GatherOp :
-  Vector_Op<"gather">,
+  Vector_Op<"gather", [DeclareOpInterfaceMethods<MaskableOpInterface>]>,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOf<[AnyInteger, Index]>:$index_vec,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a20d2360beca..fc36477151d52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -416,8 +416,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
       vector::BroadcastableToResult::Success)
     return value;
   Location loc = b.getInsertionPoint()->getLoc();
-  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
-                                                    value);
+  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
 }
 
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@@ -532,14 +531,16 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
-vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
+static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
+                                                VectorizationState &state,
+                                                Operation *op,
+                                                LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
-  auto targetShape = linalgOp.computeStaticLoopSizes();
+  auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
   // Compute a one-dimensional index vector for the index op dimension.
   SmallVector<int64_t> constantSeq =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
@@ -597,32 +598,33 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
 ///
 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
-static Value
-calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
-                      const IRMapping &bvm,
-                      const SmallVectorImpl<int64_t> &targetShape) {
+static Value calculateGatherOffset(RewriterBase &rewriter,
+                                   tensor::ExtractOp extractOp,
+                                   const IRMapping &bvm,
+                                   const ArrayRef<int64_t> targetShape) {
   // The vector of indices for GatherOp should be shaped as the output vector
-  auto indexVecType = VectorType::get(targetShape, b.getIndexType());
+  auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
   auto loc = extractOp.getLoc();
 
-  Value offset = b.create<vector::BroadcastOp>(
-      loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
+  Value offset = broadcastIfNeeded(
+      rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
     auto dimSize = broadcastIfNeeded(
-        b,
-        b.create<arith::ConstantIndexOp>(
+        rewriter,
+        rewriter.create<arith::ConstantIndexOp>(
             loc,
             extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
         indexVecType.getShape());
 
-    offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+    offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
 
-    auto extractOpIndex = broadcastIfNeeded(
-        b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
+    auto extractOpIndex =
+        broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
+                          indexVecType.getShape());
 
-    offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
+    offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
   }
 
   return offset;
@@ -632,17 +634,16 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
-                                                  Operation *op,
-                                                  LinalgOp linalgOp,
-                                                  const IRMapping &bvm) {
+static VectorizationResult
+vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
+                       Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
   // Compute the static loop sizes of the extract op.
-  auto targetShape = linalgOp.computeStaticLoopSizes();
+  auto targetShape = state.getCanonicalVecShape();
 
   auto resultType =
       VectorType::get(targetShape, extractOp.getResult().getType());
@@ -662,9 +663,10 @@ static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
   Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
 
   // Generate the gather load
-  auto gatherOp = rewriter.create<vector::GatherOp>(
+  Operation *gatherOp = rewriter.create<vector::GatherOp>(
       loc, resultType, extractOp.getTensor(), baseIndices, offset,
       maskConstantOp, passThruConstantOp);
+  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
   return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
 }
@@ -904,14 +906,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   // 4b. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgIndex(rewriter, op, linalgOp);
+    return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
   };
   hooks.push_back(vectorizeIndex);
 
   // 4c. Register CustomVectorizationHook for extractOp.
   CustomVectorizationHook vectorizeExtract =
       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
-    return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
+    return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
   };
   hooks.push_back(vectorizeExtract);
 
@@ -1007,8 +1009,10 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
     return failure();
 
   if (linalgOp.hasDynamicShape() &&
-      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
+      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
+    LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
     return failure();
+  }
 
   SmallVector<CustomVectorizationPrecondition> customPreconditions;
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e145f1b66c91..64125efeb21cb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4597,6 +4597,16 @@ LogicalResult GatherOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type GatherOp::getExpectedMaskType() {
+  auto vecType = this->getIndexVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 namespace {
 class GatherFolder final : public OpRewritePattern<GatherOp> {
 public:

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index eaba09753f7f7..7c66e65fdef8b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -109,6 +110,29 @@ struct MaskedTransferWriteOpPattern
   }
 };
 
+/// Lowers a masked `vector.gather` operation.
+struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
+public:
+  using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
+
+  LogicalResult
+  matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    Value passthru = maskingOp.hasPassthru()
+                         ? maskingOp.getPassthru()
+                         : rewriter.create<arith::ConstantOp>(
+                               gatherOp.getLoc(),
+                               rewriter.getZeroAttr(gatherOp.getVectorType()));
+
+    // Replace the `vector.mask` operation.
+    rewriter.replaceOpWithNewOp<GatherOp>(
+        maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
+        gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+        passthru);
+    return success();
+  }
+};
+
 struct LowerVectorMaskPass
     : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
   using Base::Base;
@@ -136,8 +160,8 @@ struct LowerVectorMaskPass
 /// not its nested `MaskableOpInterface`.
 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
     RewritePatternSet &patterns) {
-  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
-      patterns.getContext());
+  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
+               MaskedGatherOpPattern>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {

diff  --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
index 360e35d40908b..8f8fae095cac3 100644
--- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir
+++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir
@@ -48,3 +48,32 @@ func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor<?xf
 // CHECK:           return %[[VAL_4]] : tensor<?xf32>
 // CHECK:         }
 
+// -----
+
+func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %c3 = arith.constant 3 : index
+  %0 = vector.create_mask %c3 : vector<4xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %cst_1 = arith.constant dense<true> : vector<4xi1>
+  %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %c0_3 = arith.constant 0 : index
+  %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %c0_4 = arith.constant 0 : index
+  %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32>
+  return %3 : tensor<3xf32>
+}
+
+// CHECK-LABEL:   func.func @vector_gather(
+// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<64xf32>,
+// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1>
+// CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
+


        


More information about the Mlir-commits mailing list