[Mlir-commits] [mlir] 586cebe - [mlir][scf] Implement structural conversion for 1:N type conversions.
Ingo Müller
llvmlistbot at llvm.org
Tue Mar 28 01:33:05 PDT 2023
Author: Ingo Müller
Date: 2023-03-28T08:33:00Z
New Revision: 586cebef271f627e80c3535e7cd201915f88b349
URL: https://github.com/llvm/llvm-project/commit/586cebef271f627e80c3535e7cd201915f88b349
DIFF: https://github.com/llvm/llvm-project/commit/586cebef271f627e80c3535e7cd201915f88b349.diff
LOG: [mlir][scf] Implement structural conversion for 1:N type conversions.
This patch implements patterns for the newly introduced 1:N type
conversion utils for several ops of the SCF dialect. It also adds an
option to the existing test pass as well as test cases that applies the
patterns through the test pass.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D146959
Added:
mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index bfeab9da7632f..fbe73a260b409 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -120,6 +120,12 @@ void populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
+/// Populates the provided pattern set with patterns that do 1:N type
+/// conversions on (some) SCF ops. This is intended to be used with
+/// applyPartialOneToNConversion.
+void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Options to dictate how loops should be pipelined.
struct PipeliningOption {
/// Lambda returning all the operation in the forOp, with their stage, in the
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 3dd90997b6ea3..20abf2b583bbf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
+ OneToNTypeConversion.cpp
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
new file mode 100644
index 0000000000000..74207e6fbb647
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -0,0 +1,161 @@
+//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The patterns in this file are heavily inspired (and copied from)
+// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
+// type conversions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
+public:
+ using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter,
+ const OneToNTypeMapping & /*operandMapping*/,
+ const OneToNTypeMapping &resultMapping,
+ const ValueRange /*convertedOperands*/) const override {
+ Location loc = op->getLoc();
+
+ // Nothing to do if there is no non-identity conversion.
+ if (!resultMapping.hasNonIdentityConversion())
+ return failure();
+
+ // Create new IfOp.
+ TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
+ auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
+ op.getCondition(), true);
+ newOp->setAttrs(op->getAttrs());
+
+ // We do not need the empty blocks created by rewriter.
+ rewriter.eraseBlock(newOp.elseBlock());
+ rewriter.eraseBlock(newOp.thenBlock());
+
+ // Inlines block from the original operation.
+ rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
+ newOp.getThenRegion().end());
+ rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
+ newOp.getElseRegion().end());
+
+ rewriter.replaceOp(op, SmallVector<Value>(newOp->getResults()),
+ resultMapping);
+ return success();
+ }
+};
+
+class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
+public:
+ using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter,
+ const OneToNTypeMapping &operandMapping,
+ const OneToNTypeMapping &resultMapping,
+ const ValueRange convertedOperands) const override {
+ Location loc = op->getLoc();
+
+ // Nothing to do if the op doesn't have any non-identity conversions for its
+ // operands or results.
+ if (!operandMapping.hasNonIdentityConversion() &&
+ !resultMapping.hasNonIdentityConversion())
+ return failure();
+
+ // Create new WhileOp.
+ TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
+
+ auto newOp =
+ rewriter.create<WhileOp>(loc, convertedResultTypes, convertedOperands);
+ newOp->setAttrs(op->getAttrs());
+
+ // Update block signatures.
+ std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
+ resultMapping};
+ for (unsigned int i : {0u, 1u}) {
+ Region *region = &op.getRegion(i);
+ Block *block = ®ion->front();
+
+ rewriter.applySignatureConversion(block, blockMappings[i]);
+
+ // Move updated region to new WhileOp.
+ Region &dstRegion = newOp.getRegion(i);
+ rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+ }
+
+ rewriter.replaceOp(op, SmallVector<Value>(newOp->getResults()),
+ resultMapping);
+ return success();
+ }
+};
+
+class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
+public:
+ using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter,
+ const OneToNTypeMapping &operandMapping,
+ const OneToNTypeMapping & /*resultMapping*/,
+ const ValueRange convertedOperands) const override {
+ // Nothing to do if there is no non-identity conversion.
+ if (!operandMapping.hasNonIdentityConversion())
+ return failure();
+
+ // Convert operands.
+ rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+
+ return success();
+ }
+};
+
+class ConvertTypesInSCFConditionOp
+ : public OneToNOpConversionPattern<ConditionOp> {
+public:
+ using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter,
+ const OneToNTypeMapping &operandMapping,
+ const OneToNTypeMapping & /*resultMapping*/,
+ const ValueRange convertedOperands) const override {
+ // Nothing to do if there is no non-identity conversion.
+ if (!operandMapping.hasNonIdentityConversion())
+ return failure();
+
+ // Convert operands.
+ rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+
+ return success();
+ }
+};
+
+namespace mlir {
+namespace scf {
+
+void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<
+ // clang-format off
+ ConvertTypesInSCFConditionOp,
+ ConvertTypesInSCFIfOp,
+ ConvertTypesInSCFWhileOp,
+ ConvertTypesInSCFYieldOp
+ // clang-format on
+ >(typeConverter, patterns.getContext());
+}
+
+} // namespace scf
+} // namespace mlir
diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
new file mode 100644
index 0000000000000..dd2013c9a7368
--- /dev/null
+++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
@@ -0,0 +1,118 @@
+// RUN: mlir-opt %s -split-input-file \
+// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \
+// RUN: | FileCheck %s
+
+// Test case: Nested 1:N type conversion is carried through scf.if and
+// scf.yield.
+
+// CHECK-LABEL: func.func @if_result(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i2,
+// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) {
+// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) {
+// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
+// CHECK-NEXT: } else {
+// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2
+func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
+ %0 = scf.if %arg1 -> (tuple<tuple<>, i1, tuple<i2>>) {
+ scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
+ } else {
+ scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
+ }
+ return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.if and
+// scf.yield and unconverted ops inside have proper materializations.
+
+// CHECK-LABEL: func.func @if_tuple_ops(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 {
+// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) {
+// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: scf.yield %[[V5]] : i1
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: scf.yield %[[V8]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V2]] : i1
+func.func @if_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
+ %0 = scf.if %arg1 -> (tuple<tuple<>, i1>) {
+ %1 = "test.op"(%arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+ scf.yield %1 : tuple<tuple<>, i1>
+ } else {
+ %1 = "test.source"() : () -> tuple<tuple<>, i1>
+ scf.yield %1 : tuple<tuple<>, i1>
+ }
+ return %0 : tuple<tuple<>, i1>
+}
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.while,
+// scf.condition, and scf.yield.
+
+// CHECK-LABEL: func.func @while_operands_results(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i2,
+// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) {
+// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) {
+// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2
+// } do {
+// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2):
+// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2
+// }
+// return %[[V0]]#0, %[[V0]]#1 : i1, i2
+func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
+ %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
+ scf.condition(%arg1) %arg2 : tuple<tuple<>, i1, tuple<i2>>
+ } do {
+ ^bb0(%arg2: tuple<tuple<>, i1, tuple<i2>>):
+ scf.yield %arg2 : tuple<tuple<>, i1, tuple<i2>>
+ }
+ return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.while,
+// scf.condition, and unconverted ops inside have proper materializations.
+
+// CHECK-LABEL: func.func @while_tuple_ops(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 {
+// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 {
+// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1):
+// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
+// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: scf.yield %[[V8]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]] : i1
+func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
+ %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
+ %1 = "test.op"(%arg2) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+ scf.condition(%arg1) %1 : tuple<tuple<>, i1>
+ } do {
+ ^bb0(%arg2: tuple<tuple<>, i1>):
+ %1 = "test.source"() : () -> tuple<tuple<>, i1>
+ scf.yield %1 : tuple<tuple<>, i1>
+ }
+ return %0 : tuple<tuple<>, i1>
+}
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
index 418978688c90d..b72302202f72b 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
@@ -7,6 +7,9 @@ add_mlir_library(MLIRTestOneToNTypeConversionPass
MLIRFuncDialect
MLIRFuncTransforms
MLIRIR
+ MLIRPass
+ MLIRSCFDialect
+ MLIRSCFTransforms
MLIRTestDialect
MLIRTransformUtils
)
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 220bcb58bf788..c60c323a58d4f 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -43,6 +44,10 @@ struct TestOneToNTypeConversionPass
llvm::cl::desc("Enable conversion on func ops"),
llvm::cl::init(false)};
+ Option<bool> convertSCFOps{*this, "convert-scf-ops",
+ llvm::cl::desc("Enable conversion on scf ops"),
+ llvm::cl::init(false)};
+
Option<bool> convertTupleOps{*this, "convert-tuple-ops",
llvm::cl::desc("Enable conversion on tuple ops"),
llvm::cl::init(false)};
@@ -237,6 +242,8 @@ void TestOneToNTypeConversionPass::runOnOperation() {
populateDecomposeTuplesTestPatterns(typeConverter, patterns);
if (convertFuncOps)
populateFuncTypeConversionPatterns(typeConverter, patterns);
+ if (convertSCFOps)
+ scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
// Run conversion.
if (failed(applyPartialOneToNConversion(module, typeConverter,
More information about the Mlir-commits
mailing list