[Mlir-commits] [mlir] f0292ed - [mlir] Add structural type conversions for SCF dialect.

Sean Silva llvmlistbot at llvm.org
Wed Oct 21 11:59:01 PDT 2020


Author: Sean Silva
Date: 2020-10-21T11:58:27-07:00
New Revision: f0292ede9bbf8a24607c926b0439db20c203607a

URL: https://github.com/llvm/llvm-project/commit/f0292ede9bbf8a24607c926b0439db20c203607a
DIFF: https://github.com/llvm/llvm-project/commit/f0292ede9bbf8a24607c926b0439db20c203607a.diff

LOG: [mlir] Add structural type conversions for SCF dialect.

A "structural" type conversion is one where the underlying ops are
completely agnostic to the actual types involved and simply need to update
their types. An example of this is scf.if -- the scf.if op and the
corresponding scf.yield ops need to update their types accordingly to the
TypeConverter, but otherwise don't care what type conversions are happening.

To test the structural type conversions, it is convenient to define a
bufferize pass for a dialect, which exercises them nicely.

Differential Revision: https://reviews.llvm.org/D89757

Added: 
    mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
    mlir/test/Dialect/SCF/bufferize.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Passes.h
    mlir/include/mlir/Dialect/SCF/Passes.td
    mlir/include/mlir/Dialect/SCF/Transforms.h
    mlir/include/mlir/Transforms/Bufferize.h
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index 7edb2444e87c0..f3dda9bec335c 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -17,6 +17,9 @@
 
 namespace mlir {
 
+/// Creates a pass that bufferizes the SCF dialect.
+std::unique_ptr<Pass> createSCFBufferizePass();
+
 /// Creates a pass that specializes for loop for unrolling and
 /// vectorization.
 std::unique_ptr<Pass> createForLoopSpecializationPass();

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 6f3cf0e126423..611869466214a 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -11,6 +11,11 @@
 
 include "mlir/Pass/PassBase.td"
 
+def SCFBufferize : FunctionPass<"scf-bufferize"> {
+  let summary = "Bufferize the scf dialect.";
+  let constructor = "mlir::createSCFBufferizePass()";
+}
+
 def SCFForLoopSpecialization
     : FunctionPass<"for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 222ad6bf5584b..3164d337b4775 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -17,7 +17,11 @@
 
 namespace mlir {
 
+class ConversionTarget;
+class MLIRContext;
+class OwningRewritePatternList;
 class Region;
+class TypeConverter;
 
 namespace scf {
 
@@ -42,6 +46,19 @@ void naivelyFuseParallelOps(Region &region);
 /// The old loop is replaced with the new one.
 void tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
 
+/// Populates patterns for SCF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is scf.if -- the scf.if op and the
+/// corresponding scf.yield ops need to update their types accordingly to the
+/// TypeConverter, but otherwise don't care what type conversions are happening.
+void populateSCFStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target);
+
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index ddc00893cc472..5bee53ef01ce6 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -143,6 +143,15 @@ class BufferizeTypeConverter : public TypeConverter {
   SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
 };
 
+/// Marks ops used by bufferization for type conversion materializations as
+/// "legal" in the given ConversionTarget.
+///
+/// This function should be called by all bufferization passes using
+/// BufferizeTypeConverter so that materializations work proprely. One exception
+/// is bufferization passes doing "full" conversions, where it can be desirable
+/// for even the materializations to remain illegal so that they are eliminated.
+void populateBufferizeMaterializationLegality(ConversionTarget &target);
+
 /// Helper conversion pattern that encapsulates a BufferizeTypeConverter
 /// instance.
 template <typename SourceOp>

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..23cf72f6ed2a0
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -0,0 +1,41 @@
+//===- Bufferize.cpp - scf bufferize pass ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
+  void runOnFunction() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    populateBufferizeMaterializationLegality(target);
+    populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
+                                                    patterns, target);
+    if (failed(applyPartialConversion(func, target, patterns)))
+      return signalPassFailure();
+  };
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
+  return std::make_unique<SCFBufferizePass>();
+}

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index b3b20027896e1..6b516debac4a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,7 +1,9 @@
 add_mlir_dialect_library(MLIRSCFTransforms
+  Bufferize.cpp
   LoopSpecialization.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  StructuralTypeConversions.cpp
   Utils.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..30a2272f39a24
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,117 @@
+//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+class ConvertForOpTypes : public OpConversionPattern<ForOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ForOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type, 6> newResultTypes;
+    for (auto type : op.getResultTypes()) {
+      Type newType = typeConverter->convertType(type);
+      if (!newType)
+        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+      newResultTypes.push_back(newType);
+    }
+
+    // Clone and replace.
+    ForOp newOp = cast<ForOp>(rewriter.clone(*op.getOperation()));
+    newOp.getOperation()->setOperands(operands);
+    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    auto bodyArgs = newOp.getBody()->getArguments();
+    for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    rewriter.replaceOp(op, newOp.getResults());
+
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO: Generalize this to any type conversion, not just 1:1.
+    //
+    // We need to implement something more sophisticated here that tracks which
+    // types convert to which other types and does the appropriate
+    // materialization logic.
+    // For example, it's possible that one result type converts to 0 types and
+    // another to 2 types, so newResultTypes would at least be the right size to
+    // not crash in the llvm::zip call below, but then we would set the the
+    // wrong type on the SSA values! These edge cases are also why we cannot
+    // safely use the TypeConverter::convertTypes helper here.
+    SmallVector<Type, 6> newResultTypes;
+    for (auto type : op.getResultTypes()) {
+      Type newType = typeConverter->convertType(type);
+      if (!newType)
+        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+      newResultTypes.push_back(newType);
+    }
+
+    // TODO: Write this with updateRootInPlace once the conversion infra
+    // supports source materializations on ops updated in place.
+    IfOp newOp = cast<IfOp>(rewriter.clone(*op.getOperation()));
+    newOp.getOperation()->setOperands(operands);
+    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+// When the result types of a ForOp/IfOp get changed, the operand types of the
+// corresponding yield op need to be changed. In order to trigger the
+// appropriate type conversions / materializations, we need a dummy pattern.
+class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target) {
+  patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+      typeConverter, context);
+  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
+    return typeConverter.isLegal(op->getResultTypes());
+  });
+  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
+    // We only have conversions for a subset of ops that use scf.yield
+    // terminators.
+    if (!isa<ForOp, IfOp>(op.getParentOp()))
+      return true;
+    return typeConverter.isLegal(op.getOperandTypes());
+  });
+}

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 682fd9ff6719f..26eabe2b89473 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -72,6 +72,10 @@ BufferizeTypeConverter::getResultConversionKind(Type origin, Type converted) {
   return KeepAsFunctionResult;
 }
 
+void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
+  target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
+};
+
 //===----------------------------------------------------------------------===//
 // BufferizeFuncOpConverter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
new file mode 100644
index 0000000000000..01b353da83ed8
--- /dev/null
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -scf-bufferize | FileCheck %s
+
+// CHECK-LABEL:   func @if(
+// CHECK-SAME:             %[[PRED:.*]]: i1,
+// CHECK-SAME:             %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
+// CHECK-SAME:             %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
+// CHECK:             %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref<?xf32>
+// CHECK:             scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
+// CHECK:           } else {
+// CHECK:             %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref<?xf32>
+// CHECK:             scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
+// CHECK:           }
+// CHECK:           %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32>
+// CHECK:           return %[[RESULT_TENSOR]] : tensor<?xf32>
+// CHECK:         }
+func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = scf.if %pred -> (tensor<?xf32>) {
+    scf.yield %true_val : tensor<?xf32>
+  } else {
+    scf.yield %false_val : tensor<?xf32>
+  }
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func @for(
+// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:              %[[LB:.*]]: index, %[[UB:.*]]: index,
+// CHECK-SAME:              %[[STEP:.*]]: index) -> tensor<f32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
+// CHECK:             scf.yield %[[ITER]] : memref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
+// CHECK:           return %[[VAL_8]] : tensor<f32>
+// CHECK:         }
+func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
+  %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
+    scf.yield %iter : tensor<f32>
+  }
+  return %ret : tensor<f32>
+}


        


More information about the Mlir-commits mailing list