[Mlir-commits] [mlir] 66c2b76 - [MLIR] Extend vector.gather to accept tensor as base
Hanhan Wang
llvmlistbot at llvm.org
Tue Aug 9 11:21:10 PDT 2022
Author: Jerry Wu
Date: 2022-08-09T11:19:16-07:00
New Revision: 66c2b76846246fcd1df31e1928b0df44a8fe43e1
URL: https://github.com/llvm/llvm-project/commit/66c2b76846246fcd1df31e1928b0df44a8fe43e1
DIFF: https://github.com/llvm/llvm-project/commit/66c2b76846246fcd1df31e1928b0df44a8fe43e1.diff
LOG: [MLIR] Extend vector.gather to accept tensor as base
In addition to memref, accept ranked tensor as the base operand of vector.gather, similar to vector.trasnfer_read.
This will allow us to vectorize noncontiguous tensor.extract into vector.gather. Full discussion can be found here: https://github.com/iree-org/iree/issues/9198
Reviewed By: hanchung, dcaballe
Differential Revision: https://reviews.llvm.org/D130097
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Vector/bufferize.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7437af610325a..34460f9c55706 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1758,21 +1758,24 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather">,
- Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
+ Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
- let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
+ let summary = [{
+ gathers elements from memory or ranked tensor into a vector as defined by an
+ index vector and mask
+ }];
let description = [{
- The gather operation gathers elements from memory into a 1-D vector as
- defined by a base with indices and an additional 1-D index vector, but
- only if the corresponding bit is set in a 1-D mask vector. Otherwise, the
- element is taken from a 1-D pass-through vector. Informally the semantics
- are:
+ The gather operation gathers elements from memory or ranked tensor into a
+ 1-D vector as defined by a base with indices and an additional 1-D index
+ vector, but only if the corresponding bit is set in a 1-D mask vector.
+ Otherwise, the element is taken from a 1-D pass-through vector. Informally
+ the semantics are:
```
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1797,8 +1800,8 @@ def Vector_GatherOp :
```
}];
let extraClassDeclaration = [{
- MemRefType getMemRefType() {
- return getBase().getType().cast<MemRefType>();
+ ShapedType getBaseType() {
+ return getBase().getType().cast<ShapedType>();
}
VectorType getIndexVectorType() {
return getIndexVec().getType().cast<VectorType>();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d579c4f080c7b..4191e922992cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -260,7 +260,8 @@ class VectorGatherOpConversion
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = gather->getLoc();
- MemRefType memRefType = gather.getMemRefType();
+ MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
+ assert(memRefType && "The base should be bufferized");
// Resolve alignment.
unsigned align;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c78cc1a459ab9..ce494c3d14742 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4065,12 +4065,15 @@ LogicalResult GatherOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType resVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (resVType.getElementType() != memType.getElementType())
+ if (!baseType.isa<MemRefType, RankedTensorType>())
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getIndices()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getDimSize(0) != indVType.getDimSize(0))
return emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 78c5c3bcd2b82..3ec95ed644dc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -113,6 +113,46 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.gather. Replaced with a new vector.gather that
+/// operates on a memref.
+struct GatherOpInterface
+ : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
+ vector::GatherOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ "only tensor types expected");
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ "only tensor types expected");
+ return false;
+ }
+
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto gatherOp = cast<vector::GatherOp>(op);
+ assert(gatherOp.getBaseType().isa<TensorType>() &&
+ "only tensor types expected");
+ FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
+ if (failed(buffer))
+ return failure();
+ replaceOpWithNewBufferizedOp<vector::GatherOp>(
+ rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
+ gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getPassThru());
+ return success();
+ }
+};
+
} // namespace
} // namespace vector
} // namespace mlir
@@ -122,5 +162,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
+ GatherOp::attachInterface<GatherOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index ab271a642cbcc..67c84c01d1c78 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -29,3 +29,17 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
: vector<5x6xf32>, tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @gather(
+// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
+// CHECK: %[[m:.*]] = bufferization.to_memref %[[base]] : memref<?x?xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+func.func @gather(%base: tensor<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index dc69bb0a78a6f..6ccf5295eccd6 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -669,6 +669,14 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
return
}
+// CHECK-LABEL: @gather_on_tensor
+func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: @expand_and_compress
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list