[Mlir-commits] [mlir] 83154c5 - [mlir] Add bufferization for std.select op.
Sean Silva
llvmlistbot at llvm.org
Tue Oct 27 11:47:58 PDT 2020
Author: Sean Silva
Date: 2020-10-27T11:46:33-07:00
New Revision: 83154c541806468802d687a8b3c8f1a65e92199c
URL: https://github.com/llvm/llvm-project/commit/83154c541806468802d687a8b3c8f1a65e92199c
DIFF: https://github.com/llvm/llvm-project/commit/83154c541806468802d687a8b3c8f1a65e92199c.diff
LOG: [mlir] Add bufferization for std.select op.
Differential Revision: https://reviews.llvm.org/D90204
Added:
Modified:
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index a1b1f0a64992..9056fbc25e14 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -88,6 +88,24 @@ class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
};
} // namespace
+namespace {
+class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.condition().getType().isa<IntegerType>())
+ return rewriter.notifyMatchFailure(op, "requires scalar condition");
+
+ SelectOp::Adaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<SelectOp>(
+ op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
+ return success();
+ }
+};
+} // namespace
+
namespace {
class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
public:
@@ -128,10 +146,15 @@ class BufferizeTensorFromElementsOp
void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns
- .insert<BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp,
- BufferizeTensorCastOp, BufferizeTensorFromElementsOp>(
- typeConverter, context);
+ patterns.insert<
+ // clang-format off
+ BufferizeDynamicTensorFromElementsOp,
+ BufferizeExtractElementOp,
+ BufferizeSelectOp,
+ BufferizeTensorCastOp,
+ BufferizeTensorFromElementsOp
+ // clang-format on
+ >(typeConverter, context);
}
namespace {
@@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorCastOp, TensorFromElementsOp>();
+ // We only bufferize the case of tensor selected type and scalar condition,
+ // as that boils down to a select over memref descriptors (don't need to
+ // touch the data).
+ target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
+ return typeConverter.isLegal(op.getType()) ||
+ !op.condition().getType().isa<IntegerType>();
+ });
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 61259985c286..b2cefe32120e 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
return %0 : f32
}
+// CHECK-LABEL: func @select(
+// CHECK-SAME: %[[PRED:.*]]: i1,
+// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
+// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK: %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref<f32>
+// CHECK: %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
+// CHECK: %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref<f32>
+// CHECK: return %[[RET]] : tensor<f32>
+func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+ %0 = select %arg0, %arg1, %arg2 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
// CHECK-LABEL: func @tensor_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
More information about the Mlir-commits
mailing list