[Mlir-commits] [mlir] 272cf8f - [mlir] Implement one-to-n structural conversion for ForOp
Ingo Müller
llvmlistbot at llvm.org
Fri Jul 7 08:50:59 PDT 2023
Author: Spenser Bauman
Date: 2023-07-07T15:50:54Z
New Revision: 272cf8f7b2a6e79884f7b680fff3cd15a47eebb2
URL: https://github.com/llvm/llvm-project/commit/272cf8f7b2a6e79884f7b680fff3cd15a47eebb2
DIFF: https://github.com/llvm/llvm-project/commit/272cf8f7b2a6e79884f7b680fff3cd15a47eebb2.diff
LOG: [mlir] Implement one-to-n structural conversion for ForOp
Add the missing one-to-n structural type conversion pattern for the
scf.for operation.
Reviewed By: ingomueller-net
Differential Revision: https://reviews.llvm.org/D154299
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
index d47543a8ad6bee..8c2c544a89f7de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -138,6 +138,63 @@ class ConvertTypesInSCFConditionOp
}
};
+class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> {
+public:
+ using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+ // Nothing to do if there is no non-identity conversion.
+ if (!operandMapping.hasNonIdentityConversion() &&
+ !resultMapping.hasNonIdentityConversion())
+ return failure();
+
+ // If the lower-bound, upper-bound, or step were expanded, abort the
+ // conversion. This conversion does not know what to do in such cases.
+ ValueRange lbs = adaptor.getLowerBound();
+ ValueRange ubs = adaptor.getUpperBound();
+ ValueRange steps = adaptor.getStep();
+ if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
+ return rewriter.notifyMatchFailure(
+ forOp, "index operands converted to multiple values");
+
+ Location loc = forOp.getLoc();
+
+ Region *region = &forOp.getRegion();
+ Block *block = ®ion->front();
+
+ // Construct the new for-op with an empty body.
+ ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
+ auto newOp =
+ rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
+ newOp->setAttrs(forOp->getAttrs());
+
+ // We do not need the empty blocks created by rewriter.
+ rewriter.eraseBlock(newOp.getBody());
+
+ // Convert the signature of the body region.
+ OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+ bodyTypeMapping)))
+ return failure();
+
+ // Perform signature conversion on the body block.
+ rewriter.applySignatureConversion(block, bodyTypeMapping);
+
+ // Splice the old body region into the new for-op.
+ Region &dstRegion = newOp.getBodyRegion();
+ rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
+
+ rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
+
+ return success();
+ }
+};
+
namespace mlir {
namespace scf {
@@ -146,6 +203,7 @@ void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
patterns.add<
// clang-format off
ConvertTypesInSCFConditionOp,
+ ConvertTypesInSCFForOp,
ConvertTypesInSCFIfOp,
ConvertTypesInSCFWhileOp,
ConvertTypesInSCFYieldOp
diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
index 263711674a6ec8..535ab68e8d893c 100644
--- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
+++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
@@ -116,3 +116,68 @@ func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<
}
return %0 : tuple<tuple<>, i1>
}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield.
+
+// CHECK-LABEL: func.func @for_operands_results(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
+// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-NEXT: %[[OUT:.+]]:2 = scf.for %arg2 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER0:.+]] = %[[ARG0]], %[[ITER1:.+]] = %[[ARG1]]) -> (i1, i2) {
+// CHECK-NEXT: scf.yield %[[ITER0]], %[[ITER1]] : i1, i2
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[OUT]]#0, %[[OUT]]#1 : i1, i2
+
+func.func @for_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1, tuple<i2>> {
+ scf.yield %acc : tuple<tuple<>, i1, tuple<i2>>
+ }
+
+ return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield
+
+// CHECK-LABEL: func.func @for_tuple_ops(
+// CHECK-SAME: %[[ARG0:.+]]: i1) -> i1 {
+// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-NEXT: %[[FOR:.+]] = scf.for %arg1 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER:.+]] = %[[ARG0]]) -> (i1) {
+// CHECK-NEXT: %[[V1:.+]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT: %[[V2:.+]] = "test.make_tuple"(%[[V1]], %[[ITER]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V3:.+]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V4:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V5:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: scf.yield %[[V5]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[V6:.+]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT: %[[V7:.+]] = "test.make_tuple"(%[[V6]], %[[FOR]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V8:.+]] = "test.op"(%[[V7]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V9:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V10:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: return %[[V10]] : i1
+
+func.func @for_tuple_ops(%arg0: tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1> {
+ %1 = "test.op"(%acc) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+ scf.yield %1 : tuple<tuple<>, i1>
+ }
+
+ %1 = "test.op"(%0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+ return %1 : tuple<tuple<>, i1>
+}
More information about the Mlir-commits
mailing list