[Mlir-commits] [mlir] 272cf8f - [mlir] Implement one-to-n structural conversion for ForOp

Ingo Müller llvmlistbot at llvm.org
Fri Jul 7 08:50:59 PDT 2023


Author: Spenser Bauman
Date: 2023-07-07T15:50:54Z
New Revision: 272cf8f7b2a6e79884f7b680fff3cd15a47eebb2

URL: https://github.com/llvm/llvm-project/commit/272cf8f7b2a6e79884f7b680fff3cd15a47eebb2
DIFF: https://github.com/llvm/llvm-project/commit/272cf8f7b2a6e79884f7b680fff3cd15a47eebb2.diff

LOG: [mlir] Implement one-to-n structural conversion for ForOp

Add the missing one-to-n structural type conversion pattern for the
scf.for operation.

Reviewed By: ingomueller-net

Differential Revision: https://reviews.llvm.org/D154299

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
    mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
index d47543a8ad6bee..8c2c544a89f7de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -138,6 +138,63 @@ class ConvertTypesInSCFConditionOp
   }
 };
 
+class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> {
+public:
+  using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+    // Nothing to do if there is no non-identity conversion.
+    if (!operandMapping.hasNonIdentityConversion() &&
+        !resultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // If the lower-bound, upper-bound, or step were expanded, abort the
+    // conversion. This conversion does not know what to do in such cases.
+    ValueRange lbs = adaptor.getLowerBound();
+    ValueRange ubs = adaptor.getUpperBound();
+    ValueRange steps = adaptor.getStep();
+    if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
+      return rewriter.notifyMatchFailure(
+          forOp, "index operands converted to multiple values");
+
+    Location loc = forOp.getLoc();
+
+    Region *region = &forOp.getRegion();
+    Block *block = &region->front();
+
+    // Construct the new for-op with an empty body.
+    ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
+    auto newOp =
+        rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
+    newOp->setAttrs(forOp->getAttrs());
+
+    // We do not need the empty blocks created by rewriter.
+    rewriter.eraseBlock(newOp.getBody());
+
+    // Convert the signature of the body region.
+    OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
+    if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+                                                   bodyTypeMapping)))
+      return failure();
+
+    // Perform signature conversion on the body block.
+    rewriter.applySignatureConversion(block, bodyTypeMapping);
+
+    // Splice the old body region into the new for-op.
+    Region &dstRegion = newOp.getBodyRegion();
+    rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
+
+    rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
+
+    return success();
+  }
+};
+
 namespace mlir {
 namespace scf {
 
@@ -146,6 +203,7 @@ void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
   patterns.add<
       // clang-format off
       ConvertTypesInSCFConditionOp,
+      ConvertTypesInSCFForOp,
       ConvertTypesInSCFIfOp,
       ConvertTypesInSCFWhileOp,
       ConvertTypesInSCFYieldOp

diff  --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
index 263711674a6ec8..535ab68e8d893c 100644
--- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
+++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
@@ -116,3 +116,68 @@ func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<
   }
   return %0 : tuple<tuple<>, i1>
 }
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield.
+
+// CHECK-LABEL: func.func @for_operands_results(
+// CHECK-SAME:                                    %[[ARG0:.*]]: i1,
+// CHECK-SAME:                                    %[[ARG1:.*]]: i2) -> (i1, i2) {
+// CHECK-NEXT:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT:    %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT:    %[[C10:.+]] = arith.constant 10 : index
+// CHECK-NEXT:    %[[OUT:.+]]:2 = scf.for %arg2 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER0:.+]] = %[[ARG0]], %[[ITER1:.+]] = %[[ARG1]]) -> (i1, i2) {
+// CHECK-NEXT:     scf.yield %[[ITER0]], %[[ITER1]] : i1, i2
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[OUT]]#0, %[[OUT]]#1 : i1, i2
+
+func.func @for_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1, tuple<i2>> {
+    scf.yield %acc : tuple<tuple<>, i1, tuple<i2>>
+  }
+
+  return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield
+
+// CHECK-LABEL: func.func @for_tuple_ops(
+// CHECK-SAME:                             %[[ARG0:.+]]: i1) -> i1 {
+// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-NEXT: %[[FOR:.+]] = scf.for %arg1 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER:.+]] = %[[ARG0]]) -> (i1) {
+// CHECK-NEXT:   %[[V1:.+]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT:   %[[V2:.+]] = "test.make_tuple"(%[[V1]], %[[ITER]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT:   %[[V3:.+]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT:   %[[V4:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT:   %[[V5:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT:   scf.yield %[[V5]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[V6:.+]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT: %[[V7:.+]] = "test.make_tuple"(%[[V6]], %[[FOR]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V8:.+]] = "test.op"(%[[V7]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V9:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V10:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: return %[[V10]] : i1
+
+func.func @for_tuple_ops(%arg0: tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1> {
+    %1 = "test.op"(%acc) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+    scf.yield %1 : tuple<tuple<>, i1>
+  } 
+
+  %1 = "test.op"(%0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+  return %1 : tuple<tuple<>, i1>
+}


        


More information about the Mlir-commits mailing list