[Mlir-commits] [mlir] a36e9ee - [mlir][SCF] populateSCFStructuralTypeConversionsAndLegality WhileOp support
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 14 02:46:21 PDT 2021
Author: Butygin
Date: 2021-07-14T12:43:04+03:00
New Revision: a36e9ee09d2ea46d752b6eea30168ec1fe73d17f
URL: https://github.com/llvm/llvm-project/commit/a36e9ee09d2ea46d752b6eea30168ec1fe73d17f
DIFF: https://github.com/llvm/llvm-project/commit/a36e9ee09d2ea46d752b6eea30168ec1fe73d17f.diff
LOG: [mlir][SCF] populateSCFStructuralTypeConversionsAndLegality WhileOp support
Differential Revision: https://reviews.llvm.org/D105923
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/test/Dialect/SCF/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 107c32779e926..c34660bb7f19b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -133,10 +133,53 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
};
} // namespace
+namespace {
+class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
+public:
+ using OpConversionPattern<WhileOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = getTypeConverter();
+ assert(converter);
+ SmallVector<Type> newResultTypes;
+ if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
+ return failure();
+
+ WhileOp::Adaptor adaptor(operands);
+ auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
+ adaptor.getOperands());
+ for (auto i : {0u, 1u}) {
+ auto &dstRegion = newOp.getRegion(i);
+ rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+ if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
+ return rewriter.notifyMatchFailure(op, "could not convert body types");
+ }
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
+public:
+ using OpConversionPattern<ConditionOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ConditionOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+ return success();
+ }
+};
+} // namespace
+
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
- patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+ patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
+ ConvertWhileOpTypes, ConvertConditionOpTypes>(
typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
@@ -144,8 +187,10 @@ void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
- if (!isa<ForOp, IfOp>(op->getParentOp()))
+ if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
return typeConverter.isLegal(op.getOperandTypes());
});
+ target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
}
diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index 39a0fefea1a5b..c0b5569b0efff 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -79,3 +79,25 @@ func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: in
}
return %ret : tensor<f32>
}
+
+// CHECK-LABEL: func @bufferize_while(
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: tensor<f32>
+// CHECK: %[[M:.*]] = memref.buffer_cast %[[ARG2]] : memref<f32>
+// CHECK: %[[RES1:.*]]:3 = scf.while (%{{.*}} = %[[ARG0]], %{{.*}} = %[[M]]) : (i64, memref<f32>) -> (i64, i64, memref<f32>)
+// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}}, %{{.*}} : i64, i64, memref<f32>
+// CHECK: ^bb0(%{{.*}}: i64, %{{.*}}: i64, %{{.*}}: memref<f32>):
+// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref<f32>
+// CHECK: %[[RES2:.*]] = memref.tensor_load %[[RES1]]#2 : memref<f32>
+// CHECK: return %[[RES1]]#1, %[[RES2]] : i64, tensor<f32>
+func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) {
+ %c2_i64 = constant 2 : i64
+ %0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor<f32>) -> (i64, i64, tensor<f32>) {
+ %1 = cmpi slt, %arg3, %arg1 : i64
+ scf.condition(%1) %arg3, %arg3, %arg4 : i64, i64, tensor<f32>
+ } do {
+ ^bb0(%arg5: i64, %arg6: i64, %arg7: tensor<f32>):
+ %1 = muli %arg6, %c2_i64 : i64
+ scf.yield %1, %arg7 : i64, tensor<f32>
+ }
+ return %0#1, %0#2 : i64, tensor<f32>
+}
More information about the Mlir-commits
mailing list