[Mlir-commits] [mlir] 83154c5 - [mlir] Add bufferization for std.select op.

Sean Silva llvmlistbot at llvm.org
Tue Oct 27 11:47:58 PDT 2020


Author: Sean Silva
Date: 2020-10-27T11:46:33-07:00
New Revision: 83154c541806468802d687a8b3c8f1a65e92199c

URL: https://github.com/llvm/llvm-project/commit/83154c541806468802d687a8b3c8f1a65e92199c
DIFF: https://github.com/llvm/llvm-project/commit/83154c541806468802d687a8b3c8f1a65e92199c.diff

LOG: [mlir] Add bufferization for std.select op.

Differential Revision: https://reviews.llvm.org/D90204

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index a1b1f0a64992..9056fbc25e14 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -88,6 +88,24 @@ class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
 };
 } // namespace
 
+namespace {
+class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.condition().getType().isa<IntegerType>())
+      return rewriter.notifyMatchFailure(op, "requires scalar condition");
+
+    SelectOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<SelectOp>(
+        op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
 public:
@@ -128,10 +146,15 @@ class BufferizeTensorFromElementsOp
 void mlir::populateStdBufferizePatterns(MLIRContext *context,
                                         BufferizeTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns) {
-  patterns
-      .insert<BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp,
-              BufferizeTensorCastOp, BufferizeTensorFromElementsOp>(
-          typeConverter, context);
+  patterns.insert<
+      // clang-format off
+      BufferizeDynamicTensorFromElementsOp,
+      BufferizeExtractElementOp,
+      BufferizeSelectOp,
+      BufferizeTensorCastOp,
+      BufferizeTensorFromElementsOp
+      // clang-format on
+      >(typeConverter, context);
 }
 
 namespace {
@@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
     populateStdBufferizePatterns(context, typeConverter, patterns);
     target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
                         TensorCastOp, TensorFromElementsOp>();
+    // We only bufferize the case of tensor selected type and scalar condition,
+    // as that boils down to a select over memref descriptors (don't need to
+    // touch the data).
+    target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
+      return typeConverter.isLegal(op.getType()) ||
+             !op.condition().getType().isa<IntegerType>();
+    });
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 61259985c286..b2cefe32120e 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL:   func @select(
+// CHECK-SAME:                 %[[PRED:.*]]: i1,
+// CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
+// CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref<f32>
+// CHECK:           %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK:           %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref<f32>
+// CHECK:           return %[[RET]] : tensor<f32>
+func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+  %0 = select %arg0, %arg1, %arg2 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
 // CHECK-LABEL:   func @tensor_cast(
 // CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]


        


More information about the Mlir-commits mailing list