[Mlir-commits] [mlir] a36e9ee - [mlir][SCF] populateSCFStructuralTypeConversionsAndLegality WhileOp support

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 14 02:46:21 PDT 2021


Author: Butygin
Date: 2021-07-14T12:43:04+03:00
New Revision: a36e9ee09d2ea46d752b6eea30168ec1fe73d17f

URL: https://github.com/llvm/llvm-project/commit/a36e9ee09d2ea46d752b6eea30168ec1fe73d17f
DIFF: https://github.com/llvm/llvm-project/commit/a36e9ee09d2ea46d752b6eea30168ec1fe73d17f.diff

LOG: [mlir][SCF] populateSCFStructuralTypeConversionsAndLegality WhileOp support

Differential Revision: https://reviews.llvm.org/D105923

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
    mlir/test/Dialect/SCF/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 107c32779e926..c34660bb7f19b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -133,10 +133,53 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
 };
 } // namespace
 
+namespace {
+class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
+public:
+  using OpConversionPattern<WhileOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(WhileOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = getTypeConverter();
+    assert(converter);
+    SmallVector<Type> newResultTypes;
+    if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
+      return failure();
+
+    WhileOp::Adaptor adaptor(operands);
+    auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
+                                          adaptor.getOperands());
+    for (auto i : {0u, 1u}) {
+      auto &dstRegion = newOp.getRegion(i);
+      rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+      if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
+        return rewriter.notifyMatchFailure(op, "could not convert body types");
+    }
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
+public:
+  using OpConversionPattern<ConditionOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ConditionOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+    return success();
+  }
+};
+} // namespace
+
 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target) {
-  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
+               ConvertWhileOpTypes, ConvertConditionOpTypes>(
       typeConverter, patterns.getContext());
   target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
     return typeConverter.isLegal(op->getResultTypes());
@@ -144,8 +187,10 @@ void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
     // We only have conversions for a subset of ops that use scf.yield
     // terminators.
-    if (!isa<ForOp, IfOp>(op->getParentOp()))
+    if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
       return true;
     return typeConverter.isLegal(op.getOperandTypes());
   });
+  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op); });
 }

diff  --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index 39a0fefea1a5b..c0b5569b0efff 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -79,3 +79,25 @@ func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: in
   }
   return %ret : tensor<f32>
 }
+
+// CHECK-LABEL:   func @bufferize_while(
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: tensor<f32>
+// CHECK: %[[M:.*]] = memref.buffer_cast %[[ARG2]] : memref<f32>
+// CHECK: %[[RES1:.*]]:3 = scf.while (%{{.*}} = %[[ARG0]], %{{.*}} = %[[M]]) : (i64, memref<f32>) -> (i64, i64, memref<f32>)
+// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}}, %{{.*}} : i64, i64, memref<f32>
+// CHECK: ^bb0(%{{.*}}: i64, %{{.*}}: i64, %{{.*}}: memref<f32>):
+// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref<f32>
+// CHECK:  %[[RES2:.*]] = memref.tensor_load %[[RES1]]#2 : memref<f32>
+// CHECK:  return %[[RES1]]#1, %[[RES2]] : i64, tensor<f32>
+func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) {
+  %c2_i64 = constant 2 : i64
+  %0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor<f32>) -> (i64, i64, tensor<f32>) {
+    %1 = cmpi slt, %arg3, %arg1 : i64
+    scf.condition(%1) %arg3, %arg3, %arg4 : i64, i64, tensor<f32>
+  } do {
+  ^bb0(%arg5: i64, %arg6: i64, %arg7: tensor<f32>):
+    %1 = muli %arg6, %c2_i64 : i64
+    scf.yield %1, %arg7 : i64, tensor<f32>
+  }
+  return %0#1, %0#2 : i64, tensor<f32>
+}


        


More information about the Mlir-commits mailing list