[Mlir-commits] [mlir] d3f5f33 - [mlir][scf] support 1:N type conversion for scf.for.

Peiming Liu llvmlistbot at llvm.org
Fri Oct 21 14:12:01 PDT 2022

Author: Peiming Liu
Date: 2022-10-21T21:11:55Z
New Revision: d3f5f330671e718a0e28598c412d09e9a3b54273

URL: https://github.com/llvm/llvm-project/commit/d3f5f330671e718a0e28598c412d09e9a3b54273
DIFF: https://github.com/llvm/llvm-project/commit/d3f5f330671e718a0e28598c412d09e9a3b54273.diff

LOG: [mlir][scf] support 1:N type conversion for scf.for.

scf.for used to only support 1:1 type conversion, this patch add support for 1:N type conversion.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136314




diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index fb15fcb3579e2..c4c219617b782 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -15,58 +15,102 @@ using namespace mlir;
 using namespace mlir::scf;
 namespace {
+// Unpacks the single unrealized_conversion_cast using the list of inputs
+// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
+static void unpackUnrealizedConversionCast(Value v,
+                                           SmallVectorImpl<Value> &unpacked) {
+  if (auto cast =
+          dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
+    if (cast.getInputs().size() != 1) {
+      // 1 : N type conversion.
+      unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
+      return;
+    }
+  }
+  // 1 : 1 type conversion.
+  unpacked.push_back(v);
 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
   using OpConversionPattern::OpConversionPattern;
   matchAndRewrite(ForOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Type, 6> newResultTypes;
-    for (auto type : op.getResultTypes()) {
-      Type newType = typeConverter->convertType(type);
-      if (!newType)
-        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
-      newResultTypes.push_back(newType);
+    SmallVector<Type> newResultTypes;
+    SmallVector<unsigned> offsets;
+    offsets.push_back(0);
+    // Do the type conversion and record the offsets.
+    for (Type type : op.getResultTypes()) {
+      if (failed(typeConverter->convertTypes(type, newResultTypes)))
+        return rewriter.notifyMatchFailure(op, "could not convert result");
+      offsets.push_back(newResultTypes.size());
-    // Clone the op without the regions and inline the regions from the old op.
+    // Create a empty new op and inline the regions from the old op.
     // This is a little bit tricky. We have two concerns here:
     // 1. We cannot update the op in place because the dialect conversion
     // framework does not track type changes for ops updated in place, so it
     // won't insert appropriate materializations on the changed result types.
-    // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
-    // clone the op.
+    // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
+    // to clone the op.
-    // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
-    // inefficient to recursively clone the regions, there is a correctness
-    // issue: if we clone with the regions, then the dialect conversion
-    // framework thinks that we just inserted all the cloned child ops. But what
-    // we want is to "take" the child regions and let the dialect conversion
-    // framework continue recursively into ops inside those regions (which are
-    // already in its worklist; inlining them into the new op's regions doesn't
-    // remove the child ops from the worklist).
-    ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
-    // Take the region from the old op and put it in the new op.
+    // 2. We need to resue the original region instead of cloning it, otherwise
+    // the dialect conversion framework thinks that we just inserted all the
+    // cloned child ops. But what we want is to "take" the child regions and let
+    // the dialect conversion framework continue recursively into ops inside
+    // those regions (which are already in its worklist; inlining them into the
+    // new op's regions doesn't remove the child ops from the worklist).
+    // convertRegionTypes already takes care of 1:N conversion.
+    if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
+      return failure();
+    // Unpacked the iteration arguments.
+    SmallVector<Value> flatArgs;
+    for (Value arg : adaptor.getInitArgs())
+      unpackUnrealizedConversionCast(arg, flatArgs);
+    // We can not do clone as the number of result types after conversion might
+    // be 
diff erent.
+    ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
+                                         adaptor.getUpperBound(),
+                                         adaptor.getStep(), flatArgs);
+    // Reserve whatever attributes in the original op.
+    newOp->setAttrs(op->getAttrs());
+    // We do not need the empty block created by rewriter.
+    rewriter.eraseBlock(newOp.getBody(0));
+    // Inline the type converted region from the original operation.
     rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
-    // Now, update all the types.
-    // Convert the type of the entry block of the ForOp's body.
-    if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
-                                           *getTypeConverter()))) {
-      return rewriter.notifyMatchFailure(op, "could not convert body types");
+    // Pack the return value.
+    SmallVector<Value, 6> packedRets;
+    for (unsigned i = 1, e = offsets.size(); i < e; i++) {
+      unsigned start = offsets[i - 1], end = offsets[i];
+      unsigned len = end - start;
+      ValueRange mappedValue = newOp.getResults().slice(start, len);
+      if (len != 1) {
+        // 1 : N type conversion.
+        Type origType = op.getResultTypes()[i - 1];
+        Value mat = typeConverter->materializeSourceConversion(
+            rewriter, op.getLoc(), origType, mappedValue);
+        if (!mat)
+          return rewriter.notifyMatchFailure(
+              op, "Failed to materialize 1:N type conversion");
+        packedRets.push_back(mat);
+      } else {
+        // 1 : 1 type conversion.
+        packedRets.push_back(mappedValue.front());
+      }
-    // Change the clone to use the updated operands. We could have cloned with
-    // a BlockAndValueMapping, but this seems a bit more direct.
-    newOp->setOperands(adaptor.getOperands());
-    // Update the result types to the new converted types.
-    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
-      std::get<0>(t).setType(std::get<1>(t));
-    rewriter.replaceOp(op, newOp.getResults());
+    rewriter.replaceOp(op, packedRets);
     return success();
@@ -81,12 +125,12 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // TODO: Generalize this to any type conversion, not just 1:1.
-    // We need to implement something more sophisticated here that tracks which
-    // types convert to which other types and does the appropriate
+    // We need to implement something more sophisticated here that tracks
+    // which types convert to which other types and does the appropriate
     // materialization logic.
     // For example, it's possible that one result type converts to 0 types and
-    // another to 2 types, so newResultTypes would at least be the right size to
-    // not crash in the llvm::zip call below, but then we would set the the
+    // another to 2 types, so newResultTypes would at least be the right size
+    // to not crash in the llvm::zip call below, but then we would set the the
     // wrong type on the SSA values! These edge cases are also why we cannot
     // safely use the TypeConverter::convertTypes helper here.
     SmallVector<Type, 6> newResultTypes;
@@ -125,7 +169,11 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
   matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
+    SmallVector<Value> unpackedYield;
+    for (Value operand : adaptor.getOperands())
+      unpackUnrealizedConversionCast(operand, unpackedYield);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
     return success();

diff  --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
new file mode 100644
index 0000000000000..13765eb12d626
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s
+#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+// CHECK-LABEL:  func @for(
+// CHECK-SAME:             %[[DIM_SIZE:.*0]]: memref<1xindex>,
+// CHECK-SAME:             %[[MEM_SIZE:.*1]]: memref<3xindex>,
+// CHECK-SAME:             %[[POINTER:.*2]]: memref<?xindex>,
+// CHECK-SAME:             %[[INDICES:.*3]]: memref<?xindex>,
+// CHECK-SAME:             %[[VALUE:.*4]]: memref<?xf32>,
+// CHECK-SAME:             %[[TMP_arg5:.*5]]: index,
+// CHECK-SAME:             %[[TMP_arg6:.*6]]: index,
+// CHECK-SAME:             %[[TMP_arg7:.*7]]: index
+// CHECK:          %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args(
+// CHECK-SAME:       %[[TMP_arg9:.*]] = %[[DIM_SIZE]],
+// CHECK-SAME:       %[[TMP_arg10:.*]] = %[[MEM_SIZE]],
+// CHECK-SAME:       %[[TMP_arg11:.*]] = %[[POINTER]],
+// CHECK-SAME:       %[[TMP_arg12:.*]] = %[[INDICES]],
+// CHECK-SAME:       %[[TMP_arg13:.*]] = %[[VALUE]]) 
+// CHECK:              scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK:            }
+// CHECK:          return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @for(%in: tensor<1024xf32, #SparseVector>,
+               %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
+  %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
+     -> tensor<1024xf32, #SparseVector> {
+    scf.yield %vin : tensor<1024xf32, #SparseVector>
+  }
+  return %1 : tensor<1024xf32, #SparseVector>


More information about the Mlir-commits mailing list