[Mlir-commits] [mlir] e3f5073 - [mlir] Add some more std bufferize patterns.

Sean Silva llvmlistbot at llvm.org
Mon Oct 19 15:54:09 PDT 2020


Author: Sean Silva
Date: 2020-10-19T15:51:45-07:00
New Revision: e3f5073a961076475c286a39a2cba2bf803eb32c

URL: https://github.com/llvm/llvm-project/commit/e3f5073a961076475c286a39a2cba2bf803eb32c
DIFF: https://github.com/llvm/llvm-project/commit/e3f5073a961076475c286a39a2cba2bf803eb32c.diff

LOG: [mlir] Add some more std bufferize patterns.

Add bufferizations for extract_element and tensor_from_elements.

Differential Revision: https://reviews.llvm.org/D89594

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 95a8b75e2c2b..0ebc97b626c1 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -18,6 +18,21 @@
 
 using namespace mlir;
 
+namespace {
+class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    ExtractElementOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
+                                        adaptor.indices());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
 public:
@@ -32,10 +47,34 @@ class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
 };
 } // namespace
 
+namespace {
+class BufferizeTensorFromElementsOp
+    : public OpConversionPattern<TensorFromElementsOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    int numberOfElements = op.elements().size();
+    auto resultType = MemRefType::get(
+        {numberOfElements}, op.getType().cast<TensorType>().getElementType());
+    Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
+    for (auto element : llvm::enumerate(op.elements())) {
+      Value index =
+          rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
+      rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
+    }
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+} // namespace
+
 void mlir::populateStdBufferizePatterns(MLIRContext *context,
                                         BufferizeTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns) {
-  patterns.insert<BufferizeTensorCastOp>(typeConverter, context);
+  patterns.insert<BufferizeExtractElementOp, BufferizeTensorCastOp,
+                  BufferizeTensorFromElementsOp>(typeConverter, context);
 }
 
 namespace {
@@ -49,9 +88,9 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
     target.addLegalDialect<StandardOpsDialect>();
 
     populateStdBufferizePatterns(context, typeConverter, patterns);
-    target.addIllegalOp<TensorCastOp>();
+    target.addIllegalOp<ExtractElementOp, TensorCastOp, TensorFromElementsOp>();
 
-    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+    if (failed(applyPartialConversion(getFunction(), target, patterns)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 981237d78cdd..d16a5dd6d9d4 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -1,5 +1,17 @@
 // RUN: mlir-opt %s -std-bufferize | FileCheck %s
 
+// CHECK-LABEL:   func @extract_element(
+// CHECK-SAME:                          %[[TENSOR:.*]]: tensor<?xf32>,
+// CHECK-SAME:                          %[[IDX:.*]]: index) -> f32 {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
+// CHECK:           %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
+// CHECK:           return %[[RET]] : f32
+// CHECK:         }
+func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
+  %0 = extract_element %arg0[%arg1] : tensor<?xf32>
+  return %0 : f32
+}
+
 // CHECK-LABEL:   func @tensor_cast(
 // CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
@@ -10,3 +22,18 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
   %0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
   return %0 : tensor<2xindex>
 }
+
+// CHECK-LABEL:   func @tensor_from_elements(
+// CHECK-SAME:                               %[[ELEM0:.*]]: index,
+// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
+// CHECK:           %[[MEMREF:.*]] = alloc()
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
+// CHECK:           %[[RET:.*]] = tensor_load %[[MEMREF]]
+// CHECK:           return %[[RET]] : tensor<2xindex>
+func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
+  %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
+  return %0 : tensor<2xindex>
+}


        


More information about the Mlir-commits mailing list