[Mlir-commits] [mlir] 586cebe - [mlir][scf] Implement structural conversion for 1:N type conversions.

Ingo Müller llvmlistbot at llvm.org
Tue Mar 28 01:33:05 PDT 2023


Author: Ingo Müller
Date: 2023-03-28T08:33:00Z
New Revision: 586cebef271f627e80c3535e7cd201915f88b349

URL: https://github.com/llvm/llvm-project/commit/586cebef271f627e80c3535e7cd201915f88b349
DIFF: https://github.com/llvm/llvm-project/commit/586cebef271f627e80c3535e7cd201915f88b349.diff

LOG: [mlir][scf] Implement structural conversion for 1:N type conversions.

This patch implements patterns for the newly introduced 1:N type
conversion utils for several ops of the SCF dialect. It also adds an
option to the existing test pass as well as test cases that applies the
patterns through the test pass.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D146959

Added: 
    mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
    mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
    mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index bfeab9da7632f..fbe73a260b409 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -120,6 +120,12 @@ void populateSCFStructuralTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target);
 
+/// Populates the provided pattern set with patterns that do 1:N type
+/// conversions on (some) SCF ops. This is intended to be used with
+/// applyPartialOneToNConversion.
+void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
+                                                RewritePatternSet &patterns);
+
 /// Options to dictate how loops should be pipelined.
 struct PipeliningOption {
   /// Lambda returning all the operation in the forOp, with their stage, in the

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 3dd90997b6ea3..20abf2b583bbf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   LoopPipelining.cpp
   LoopRangeFolding.cpp
   LoopSpecialization.cpp
+  OneToNTypeConversion.cpp
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
new file mode 100644
index 0000000000000..74207e6fbb647
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -0,0 +1,161 @@
+//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The patterns in this file are heavily inspired (and copied from)
+// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
+// type conversions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
+public:
+  using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter,
+                  const OneToNTypeMapping & /*operandMapping*/,
+                  const OneToNTypeMapping &resultMapping,
+                  const ValueRange /*convertedOperands*/) const override {
+    Location loc = op->getLoc();
+
+    // Nothing to do if there is no non-identity conversion.
+    if (!resultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Create new IfOp.
+    TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
+    auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
+                                       op.getCondition(), true);
+    newOp->setAttrs(op->getAttrs());
+
+    // We do not need the empty blocks created by rewriter.
+    rewriter.eraseBlock(newOp.elseBlock());
+    rewriter.eraseBlock(newOp.thenBlock());
+
+    // Inlines block from the original operation.
+    rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
+                                newOp.getThenRegion().end());
+    rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
+                                newOp.getElseRegion().end());
+
+    rewriter.replaceOp(op, SmallVector<Value>(newOp->getResults()),
+                       resultMapping);
+    return success();
+  }
+};
+
+class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
+public:
+  using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter,
+                  const OneToNTypeMapping &operandMapping,
+                  const OneToNTypeMapping &resultMapping,
+                  const ValueRange convertedOperands) const override {
+    Location loc = op->getLoc();
+
+    // Nothing to do if the op doesn't have any non-identity conversions for its
+    // operands or results.
+    if (!operandMapping.hasNonIdentityConversion() &&
+        !resultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Create new WhileOp.
+    TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
+
+    auto newOp =
+        rewriter.create<WhileOp>(loc, convertedResultTypes, convertedOperands);
+    newOp->setAttrs(op->getAttrs());
+
+    // Update block signatures.
+    std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
+                                                      resultMapping};
+    for (unsigned int i : {0u, 1u}) {
+      Region *region = &op.getRegion(i);
+      Block *block = &region->front();
+
+      rewriter.applySignatureConversion(block, blockMappings[i]);
+
+      // Move updated region to new WhileOp.
+      Region &dstRegion = newOp.getRegion(i);
+      rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+    }
+
+    rewriter.replaceOp(op, SmallVector<Value>(newOp->getResults()),
+                       resultMapping);
+    return success();
+  }
+};
+
+class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
+public:
+  using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter,
+                  const OneToNTypeMapping &operandMapping,
+                  const OneToNTypeMapping & /*resultMapping*/,
+                  const ValueRange convertedOperands) const override {
+    // Nothing to do if there is no non-identity conversion.
+    if (!operandMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Convert operands.
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+
+    return success();
+  }
+};
+
+class ConvertTypesInSCFConditionOp
+    : public OneToNOpConversionPattern<ConditionOp> {
+public:
+  using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter,
+                  const OneToNTypeMapping &operandMapping,
+                  const OneToNTypeMapping & /*resultMapping*/,
+                  const ValueRange convertedOperands) const override {
+    // Nothing to do if there is no non-identity conversion.
+    if (!operandMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Convert operands.
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+
+    return success();
+  }
+};
+
+namespace mlir {
+namespace scf {
+
+void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter,
+                                                RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+      ConvertTypesInSCFConditionOp,
+      ConvertTypesInSCFIfOp,
+      ConvertTypesInSCFWhileOp,
+      ConvertTypesInSCFYieldOp
+      // clang-format on
+      >(typeConverter, patterns.getContext());
+}
+
+} // namespace scf
+} // namespace mlir

diff  --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
new file mode 100644
index 0000000000000..dd2013c9a7368
--- /dev/null
+++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
@@ -0,0 +1,118 @@
+// RUN: mlir-opt %s -split-input-file \
+// RUN:   -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \
+// RUN: | FileCheck %s
+
+// Test case: Nested 1:N type conversion is carried through scf.if and
+// scf.yield.
+
+// CHECK-LABEL: func.func @if_result(
+// CHECK-SAME:                       %[[ARG0:.*]]: i1,
+// CHECK-SAME:                       %[[ARG1:.*]]: i2,
+// CHECK-SAME:                       %[[ARG2:.*]]: i1) -> (i1, i2) {
+// CHECK-NEXT:    %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) {
+// CHECK-NEXT:     scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[V0]]#0, %[[V0]]#1 : i1, i2
+func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
+  %0 = scf.if %arg1 -> (tuple<tuple<>, i1, tuple<i2>>) {
+    scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
+  } else {
+    scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
+  }
+  return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.if and
+// scf.yield and unconverted ops inside have proper materializations.
+
+// CHECK-LABEL: func.func @if_tuple_ops(
+// CHECK-SAME:                          %[[ARG0:.*]]: i1,
+// CHECK-SAME:                          %[[ARG1:.*]]: i1) -> i1 {
+// CHECK-NEXT:    %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT:    %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT:    %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) {
+// CHECK-NEXT:      %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT:      %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT:      %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT:      scf.yield %[[V5]] : i1
+// CHECK-NEXT:    } else {
+// CHECK-NEXT:      %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
+// CHECK-NEXT:      %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT:      %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT:      scf.yield %[[V8]] : i1
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[V2]] : i1
+func.func @if_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
+  %0 = scf.if %arg1 -> (tuple<tuple<>, i1>) {
+    %1 = "test.op"(%arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+    scf.yield %1 : tuple<tuple<>, i1>
+  } else {
+    %1 = "test.source"() : () -> tuple<tuple<>, i1>
+    scf.yield %1 : tuple<tuple<>, i1>
+  }
+  return %0 : tuple<tuple<>, i1>
+}
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.while,
+// scf.condition, and scf.yield.
+
+// CHECK-LABEL: func.func @while_operands_results(
+// CHECK-SAME:                                    %[[ARG0:.*]]: i1,
+// CHECK-SAME:                                    %[[ARG1:.*]]: i2,
+// CHECK-SAME:                                    %[[ARG2:.*]]: i1) -> (i1, i2) {
+//   %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) {
+//     scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2
+//   } do {
+//   ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2):
+//     scf.yield %[[ARG5]], %[[ARG4]] : i1, i2
+//   }
+//   return %[[V0]]#0, %[[V0]]#1 : i1, i2
+func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
+  %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
+    scf.condition(%arg1) %arg2 : tuple<tuple<>, i1, tuple<i2>>
+  } do {
+  ^bb0(%arg2: tuple<tuple<>, i1, tuple<i2>>):
+    scf.yield %arg2 : tuple<tuple<>, i1, tuple<i2>>
+  }
+  return %0 : tuple<tuple<>, i1, tuple<i2>>
+}
+
+// -----
+
+// Test case: Nested 1:N type conversion is carried through scf.while,
+// scf.condition, and unconverted ops inside have proper materializations.
+
+// CHECK-LABEL: func.func @while_tuple_ops(
+// CHECK-SAME:                             %[[ARG0:.*]]: i1,
+// CHECK-SAME:                             %[[ARG1:.*]]: i1) -> i1 {
+// CHECK-NEXT:    %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 {
+// CHECK-NEXT:      %[[V1:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK-NEXT:      %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
+// CHECK-NEXT:      %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+// CHECK-NEXT:      %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT:      %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT:      scf.condition(%[[ARG1]]) %[[V5]] : i1
+// CHECK-NEXT:    } do {
+// CHECK-NEXT:    ^bb0(%[[ARG3:.*]]: i1):
+// CHECK-NEXT:      %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
+// CHECK-NEXT:      %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT:      %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT:      scf.yield %[[V8]] : i1
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[V0]] : i1
+func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
+  %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
+    %1 = "test.op"(%arg2) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
+    scf.condition(%arg1) %1 : tuple<tuple<>, i1>
+  } do {
+  ^bb0(%arg2: tuple<tuple<>, i1>):
+    %1 = "test.source"() : () -> tuple<tuple<>, i1>
+    scf.yield %1 : tuple<tuple<>, i1>
+  }
+  return %0 : tuple<tuple<>, i1>
+}

diff  --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
index 418978688c90d..b72302202f72b 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
@@ -7,6 +7,9 @@ add_mlir_library(MLIRTestOneToNTypeConversionPass
   MLIRFuncDialect
   MLIRFuncTransforms
   MLIRIR
+  MLIRPass
+  MLIRSCFDialect
+  MLIRSCFTransforms
   MLIRTestDialect
   MLIRTransformUtils
  )

diff  --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 220bcb58bf788..c60c323a58d4f 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -8,6 +8,7 @@
 
 #include "TestDialect.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
 
@@ -43,6 +44,10 @@ struct TestOneToNTypeConversionPass
                               llvm::cl::desc("Enable conversion on func ops"),
                               llvm::cl::init(false)};
 
+  Option<bool> convertSCFOps{*this, "convert-scf-ops",
+                             llvm::cl::desc("Enable conversion on scf ops"),
+                             llvm::cl::init(false)};
+
   Option<bool> convertTupleOps{*this, "convert-tuple-ops",
                                llvm::cl::desc("Enable conversion on tuple ops"),
                                llvm::cl::init(false)};
@@ -237,6 +242,8 @@ void TestOneToNTypeConversionPass::runOnOperation() {
     populateDecomposeTuplesTestPatterns(typeConverter, patterns);
   if (convertFuncOps)
     populateFuncTypeConversionPatterns(typeConverter, patterns);
+  if (convertSCFOps)
+    scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
 
   // Run conversion.
   if (failed(applyPartialOneToNConversion(module, typeConverter,


        


More information about the Mlir-commits mailing list