[Mlir-commits] [mlir] 66c2b76 - [MLIR] Extend vector.gather to accept tensor as base

Hanhan Wang llvmlistbot at llvm.org
Tue Aug 9 11:21:10 PDT 2022


Author: Jerry Wu
Date: 2022-08-09T11:19:16-07:00
New Revision: 66c2b76846246fcd1df31e1928b0df44a8fe43e1

URL: https://github.com/llvm/llvm-project/commit/66c2b76846246fcd1df31e1928b0df44a8fe43e1
DIFF: https://github.com/llvm/llvm-project/commit/66c2b76846246fcd1df31e1928b0df44a8fe43e1.diff

LOG: [MLIR] Extend vector.gather to accept tensor as base

In addition to memref, accept ranked tensor as the base operand of vector.gather, similar to vector.trasnfer_read.

This will allow us to vectorize noncontiguous tensor.extract into vector.gather. Full discussion can be found here: https://github.com/iree-org/iree/issues/9198

Reviewed By: hanchung, dcaballe

Differential Revision: https://reviews.llvm.org/D130097

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Vector/bufferize.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7437af610325a..34460f9c55706 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1758,21 +1758,24 @@ def Vector_MaskedStoreOp :
 
 def Vector_GatherOp :
   Vector_Op<"gather">,
-    Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
+    Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
 
-  let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
+  let summary = [{
+    gathers elements from memory or ranked tensor into a vector as defined by an
+    index vector and mask
+  }];
 
   let description = [{
-    The gather operation gathers elements from memory into a 1-D vector as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit is set in a 1-D mask vector. Otherwise, the
-    element is taken from a 1-D pass-through vector. Informally the semantics
-    are:
+    The gather operation gathers elements from memory or ranked tensor into a
+    1-D vector as defined by a base with indices and an additional 1-D index
+    vector, but only if the corresponding bit is set in a 1-D mask vector.
+    Otherwise, the element is taken from a 1-D pass-through vector. Informally
+    the semantics are:
     ```
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1797,8 +1800,8 @@ def Vector_GatherOp :
     ```
   }];
   let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return getBase().getType().cast<MemRefType>();
+    ShapedType getBaseType() {
+      return getBase().getType().cast<ShapedType>();
     }
     VectorType getIndexVectorType() {
       return getIndexVec().getType().cast<VectorType>();

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d579c4f080c7b..4191e922992cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -260,7 +260,8 @@ class VectorGatherOpConversion
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = gather->getLoc();
-    MemRefType memRefType = gather.getMemRefType();
+    MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
+    assert(memRefType && "The base should be bufferized");
 
     // Resolve alignment.
     unsigned align;

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c78cc1a459ab9..ce494c3d14742 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4065,12 +4065,15 @@ LogicalResult GatherOp::verify() {
   VectorType indVType = getIndexVectorType();
   VectorType maskVType = getMaskVectorType();
   VectorType resVType = getVectorType();
-  MemRefType memType = getMemRefType();
+  ShapedType baseType = getBaseType();
 
-  if (resVType.getElementType() != memType.getElementType())
+  if (!baseType.isa<MemRefType, RankedTensorType>())
+    return emitOpError("requires base to be a memref or ranked tensor type");
+
+  if (resVType.getElementType() != baseType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (llvm::size(getIndices()) != baseType.getRank())
+    return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getDimSize(0) != indVType.getDimSize(0))
     return emitOpError("expected result dim to match indices dim");
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))

diff  --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 78c5c3bcd2b82..3ec95ed644dc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -113,6 +113,46 @@ struct TransferWriteOpInterface
   }
 };
 
+/// Bufferization of vector.gather. Replaced with a new vector.gather that
+/// operates on a memref.
+struct GatherOpInterface
+    : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
+                                                    vector::GatherOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    assert(opOperand.get().getType().isa<RankedTensorType>() &&
+           "only tensor types expected");
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    assert(opOperand.get().getType().isa<RankedTensorType>() &&
+           "only tensor types expected");
+    return false;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto gatherOp = cast<vector::GatherOp>(op);
+    assert(gatherOp.getBaseType().isa<TensorType>() &&
+           "only tensor types expected");
+    FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
+    if (failed(buffer))
+      return failure();
+    replaceOpWithNewBufferizedOp<vector::GatherOp>(
+        rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
+        gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+        gatherOp.getPassThru());
+    return success();
+  }
+};
+
 } // namespace
 } // namespace vector
 } // namespace mlir
@@ -122,5 +162,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
+    GatherOp::attachInterface<GatherOpInterface>(*ctx);
   });
 }

diff  --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index ab271a642cbcc..67c84c01d1c78 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -29,3 +29,17 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
       : vector<5x6xf32>, tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @gather(
+//  CHECK-SAME:     %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
+//  CHECK-SAME:     %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
+//       CHECK:   %[[m:.*]] = bufferization.to_memref %[[base]] : memref<?x?xf32>
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+func.func @gather(%base: tensor<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %0 : vector<16xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index dc69bb0a78a6f..6ccf5295eccd6 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -669,6 +669,14 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_on_tensor
+func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list