[Mlir-commits] [mlir] e3f5073 - [mlir] Add some more std bufferize patterns.
Sean Silva
llvmlistbot at llvm.org
Mon Oct 19 15:54:09 PDT 2020
Author: Sean Silva
Date: 2020-10-19T15:51:45-07:00
New Revision: e3f5073a961076475c286a39a2cba2bf803eb32c
URL: https://github.com/llvm/llvm-project/commit/e3f5073a961076475c286a39a2cba2bf803eb32c
DIFF: https://github.com/llvm/llvm-project/commit/e3f5073a961076475c286a39a2cba2bf803eb32c.diff
LOG: [mlir] Add some more std bufferize patterns.
Add bufferizations for extract_element and tensor_from_elements.
Differential Revision: https://reviews.llvm.org/D89594
Added:
Modified:
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 95a8b75e2c2b..0ebc97b626c1 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -18,6 +18,21 @@
using namespace mlir;
+namespace {
+class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ ExtractElementOp::Adaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
+ adaptor.indices());
+ return success();
+ }
+};
+} // namespace
+
namespace {
class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
public:
@@ -32,10 +47,34 @@ class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
};
} // namespace
+namespace {
+class BufferizeTensorFromElementsOp
+ : public OpConversionPattern<TensorFromElementsOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ int numberOfElements = op.elements().size();
+ auto resultType = MemRefType::get(
+ {numberOfElements}, op.getType().cast<TensorType>().getElementType());
+ Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
+ for (auto element : llvm::enumerate(op.elements())) {
+ Value index =
+ rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
+ rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
+ }
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+} // namespace
+
void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeTensorCastOp>(typeConverter, context);
+ patterns.insert<BufferizeExtractElementOp, BufferizeTensorCastOp,
+ BufferizeTensorFromElementsOp>(typeConverter, context);
}
namespace {
@@ -49,9 +88,9 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<StandardOpsDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
- target.addIllegalOp<TensorCastOp>();
+ target.addIllegalOp<ExtractElementOp, TensorCastOp, TensorFromElementsOp>();
- if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 981237d78cdd..d16a5dd6d9d4 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -1,5 +1,17 @@
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
+// CHECK-LABEL: func @extract_element(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
+// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
+// CHECK: return %[[RET]] : f32
+// CHECK: }
+func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
+ %0 = extract_element %arg0[%arg1] : tensor<?xf32>
+ return %0 : f32
+}
+
// CHECK-LABEL: func @tensor_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
@@ -10,3 +22,18 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
return %0 : tensor<2xindex>
}
+
+// CHECK-LABEL: func @tensor_from_elements(
+// CHECK-SAME: %[[ELEM0:.*]]: index,
+// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
+// CHECK: %[[MEMREF:.*]] = alloc()
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
+// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<2xindex>
+func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
+ %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
+ return %0 : tensor<2xindex>
+}
More information about the Mlir-commits
mailing list