[Mlir-commits] [mlir] 1ca1197 - [mlir][scf] support 1:N type conversion for scf.if/while/condition

Peiming Liu llvmlistbot at llvm.org
Wed Nov 2 09:53:41 PDT 2022

Author: Peiming Liu
Date: 2022-11-02T16:53:36Z
New Revision: 1ca119728ee1566ecc53bed350cf6c8db6bc88e5

URL: https://github.com/llvm/llvm-project/commit/1ca119728ee1566ecc53bed350cf6c8db6bc88e5
DIFF: https://github.com/llvm/llvm-project/commit/1ca119728ee1566ecc53bed350cf6c8db6bc88e5.diff

LOG: [mlir][scf] support 1:N type conversion for scf.if/while/condition

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D137100




diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index a441b6c80b75b..ac3d76d569228 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -155,44 +155,57 @@ class ConvertForOpTypes
 } // namespace
 namespace {
-class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
+class ConvertIfOpTypes
+    : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(IfOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // TODO: Generalize this to any type conversion, not just 1:1.
-    //
-    // We need to implement something more sophisticated here that tracks
-    // which types convert to which other types and does the appropriate
-    // materialization logic.
-    // For example, it's possible that one result type converts to 0 types and
-    // another to 2 types, so newResultTypes would at least be the right size
-    // to not crash in the llvm::zip call below, but then we would set the the
-    // wrong type on the SSA values! These edge cases are also why we cannot
-    // safely use the TypeConverter::convertTypes helper here.
-    SmallVector<Type, 6> newResultTypes;
-    for (auto type : op.getResultTypes()) {
-      Type newType = typeConverter->convertType(type);
-      if (!newType)
-        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
-      newResultTypes.push_back(newType);
-    }
+  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
-    // See comments in the ForOp pattern for why we clone without regions and
-    // then inline.
-    IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+  Optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+                                 ConversionPatternRewriter &rewriter,
+                                 TypeRange dstTypes) const {
+    IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
+                                       adaptor.getCondition(), true);
+    newOp->setAttrs(op->getAttrs());
+    // We do not need the empty blocks created by rewriter.
+    rewriter.eraseBlock(newOp.elseBlock());
+    rewriter.eraseBlock(newOp.thenBlock());
+    // Inlines block from the original operation.
     rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
     rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
-    // Update the operands and types.
-    newOp->setOperands(adaptor.getOperands());
-    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
-      std::get<0>(t).setType(std::get<1>(t));
-    rewriter.replaceOp(op, newOp.getResults());
-    return success();
+    return newOp;
+  }
+} // namespace
+namespace {
+class ConvertWhileOpTypes
+    : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
+  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+  Optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter,
+                                    TypeRange dstTypes) const {
+    // Unpacked the iteration arguments.
+    SmallVector<Value> flatArgs;
+    for (Value arg : adaptor.getOperands())
+      unpackUnrealizedConversionCast(arg, flatArgs);
+    auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+    for (auto i : {0u, 1u}) {
+      if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
+        return llvm::None;
+      auto &dstRegion = newOp.getRegion(i);
+      rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+    }
+    return newOp;
 } // namespace
@@ -217,34 +230,6 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
 } // namespace
-namespace {
-class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
-  using OpConversionPattern<WhileOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(WhileOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto *converter = getTypeConverter();
-    assert(converter);
-    SmallVector<Type> newResultTypes;
-    if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
-      return failure();
-    auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
-                                          adaptor.getOperands());
-    for (auto i : {0u, 1u}) {
-      auto &dstRegion = newOp.getRegion(i);
-      rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
-      if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
-        return rewriter.notifyMatchFailure(op, "could not convert body types");
-    }
-    rewriter.replaceOp(op, newOp.getResults());
-    return success();
-  }
-} // namespace
 namespace {
 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
@@ -252,8 +237,11 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
   matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.updateRootInPlace(
-        op, [&]() { op->setOperands(adaptor.getOperands()); });
+    SmallVector<Value> unpackedYield;
+    for (Value operand : adaptor.getOperands())
+      unpackUnrealizedConversionCast(operand, unpackedYield);
+    rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
     return success();

diff  --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index 334d58c623936..207e46b3d45ae 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -30,3 +30,68 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
   return %1 : tensor<1024xf32, #SparseVector>
+// CHECK-LABEL:  func @if(
+//  CHECK-SAME:          %[[DIM_SIZE:.*0]]: memref<1xindex>,
+//  CHECK-SAME:          %[[DIM_CURSOR:.*1]]: memref<1xindex>,
+//  CHECK-SAME:          %[[MEM_SIZE:.*2]]: memref<3xindex>,
+//  CHECK-SAME:          %[[POINTER:.*3]]: memref<?xindex>,
+//  CHECK-SAME:          %[[INDICES:.*4]]: memref<?xindex>,
+//  CHECK-SAME:          %[[VALUE:.*5]]: memref<?xf32>,
+//  CHECK-SAME:          %[[DIM_SIZE_1:.*6]]: memref<1xindex>,
+//  CHECK-SAME:          %[[DIM_CURSOR_1:.*7]]: memref<1xindex>,
+//  CHECK-SAME:          %[[MEM_SIZE_1:.*8]]: memref<3xindex>,
+//  CHECK-SAME:          %[[POINTER_1:.*9]]: memref<?xindex>,
+//  CHECK-SAME:          %[[INDICES_1:.*10]]: memref<?xindex>,
+//  CHECK-SAME:          %[[VALUE_1:.*11]]: memref<?xf32>,
+//  CHECK-SAME:          %[[TMP_arg12:.*12]]: i1) ->
+//  CHECK-SAME:          (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+//       CHECK:  %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+//       CHECK:    scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+//       CHECK:  } else {
+//       CHECK:    scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+//       CHECK:  }
+//       CHECK:  return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @if(%t: tensor<1024xf32, #SparseVector>,
+              %f: tensor<1024xf32, #SparseVector>,
+              %c: i1) -> tensor<1024xf32, #SparseVector> {
+  %1 = scf.if %c -> tensor<1024xf32, #SparseVector> {
+    scf.yield %t : tensor<1024xf32, #SparseVector>
+  } else {
+    scf.yield %f : tensor<1024xf32, #SparseVector>
+  }
+  return %1 : tensor<1024xf32, #SparseVector>
+// CHECK-LABEL:  func @while(
+//  CHECK-SAME:              %[[DIM_SIZE:.*0]]: memref<1xindex>,
+//  CHECK-SAME:              %[[DIM_CURSOR:.*1]]: memref<1xindex>,
+//  CHECK-SAME:              %[[MEM_SIZE:.*2]]: memref<3xindex>,
+//  CHECK-SAME:              %[[POINTER:.*3]]: memref<?xindex>,
+//  CHECK-SAME:              %[[INDICES:.*4]]: memref<?xindex>,
+//  CHECK-SAME:              %[[VALUE:.*5]]: memref<?xf32>,
+//  CHECK-SAME:              %[[TMP_arg6:.*6]]: i1) ->
+//  CHECK-SAME:              (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+//       CHECK:  %[[SV:.*]]:6 = scf.while (
+//  CHECK-SAME:              %[[TMP_arg7:.*]] = %[[DIM_SIZE]],
+//  CHECK-SAME:              %[[TMP_arg8:.*]] = %[[DIM_CURSOR]],
+//  CHECK-SAME:              %[[TMP_arg9:.*]] = %[[MEM_SIZE]],
+//  CHECK-SAME:              %[[TMP_arg10:.*]] = %[[POINTER]],
+//  CHECK-SAME:              %[[TMP_arg11:.*]] = %[[INDICES]],
+//  CHECK-SAME:              %[[TMP_arg12:.*]] = %[[VALUE]]) 
+//       CHECK:    scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+//       CHECK:  } do {
+//       CHECK:  ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>):
+//       CHECK:    scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+//       CHECK:  }
+//       CHECK:  return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
+  %0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
+    scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector>
+  } do {
+  ^bb0(%arg7: tensor<1024xf32, #SparseVector>):
+    scf.yield %arg7 : tensor<1024xf32, #SparseVector>
+  }
+  return %0: tensor<1024xf32, #SparseVector>


More information about the Mlir-commits mailing list