[Mlir-commits] [mlir] b83c67d - [mlir][linalg][bufferize] Support scf.execute_region bufferization

Matthias Springer llvmlistbot at llvm.org
Wed Jan 19 01:17:26 PST 2022


Author: Matthias Springer
Date: 2022-01-19T18:17:09+09:00
New Revision: b83c67d978942320b1fb7d814ae43c98f95ebe44

URL: https://github.com/llvm/llvm-project/commit/b83c67d978942320b1fb7d814ae43c98f95ebe44
DIFF: https://github.com/llvm/llvm-project/commit/b83c67d978942320b1fb7d814ae43c98f95ebe44.diff

LOG: [mlir][linalg][bufferize] Support scf.execute_region bufferization

This op is needed for unit testing in a subsequent revision. (This is the first op that has a block that yields equivalent values via the op's results.)

Note: Bufferization of scf.execute_region ops with multiple blocks is not yet supported.

Differential Revision: https://reviews.llvm.org/D117424

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index be55eca51986d..2fd89de1e4aa0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -44,6 +44,7 @@ struct ExecuteRegionOpInterface
     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
     size_t resultNum = std::distance(op->getOpResults().begin(),
                                      llvm::find(op->getOpResults(), opResult));
+    // TODO: Support multiple blocks.
     assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
            "expected exactly 1 block");
     auto yieldOp = dyn_cast<scf::YieldOp>(
@@ -66,13 +67,59 @@ struct ExecuteRegionOpInterface
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
-    // TODO: Add bufferization support when needed. scf.execute_region should be
-    // bufferized similar to scf.if.
-    bool hasTensorReturnType = any_of(
-        op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
-    if (hasTensorReturnType)
-      return op->emitError(
-          "scf.execute_region with tensor result not supported");
+    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
+
+    // Compute new result types.
+    SmallVector<Type> newResultTypes;
+    for (Type type : executeRegionOp->getResultTypes()) {
+      if (auto rankedTensorType = type.dyn_cast<RankedTensorType>()) {
+        newResultTypes.push_back(getDynamicMemRefType(rankedTensorType));
+      } else if (auto tensorType = type.dyn_cast<TensorType>()) {
+        newResultTypes.push_back(
+            getUnrankedMemRefType(tensorType.getElementType()));
+      } else {
+        newResultTypes.push_back(type);
+      }
+    }
+
+    // Create new op and move over region.
+    auto newOp =
+        rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
+    newOp.getRegion().takeBody(executeRegionOp.getRegion());
+
+    // Update terminator.
+    assert(newOp.getRegion().getBlocks().size() == 1 &&
+           "only 1 block supported");
+    Block *newBlock = &newOp.getRegion().front();
+    auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+    rewriter.setInsertionPoint(yieldOp);
+    SmallVector<Value> newYieldValues;
+    for (auto it : llvm::enumerate(yieldOp.getResults())) {
+      Value val = it.value();
+      if (val.getType().isa<TensorType>()) {
+        newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
+            yieldOp.getLoc(), newResultTypes[it.index()], val));
+      } else {
+        newYieldValues.push_back(val);
+      }
+    }
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
+
+    // Update all uses of the old op.
+    rewriter.setInsertionPointAfter(newOp);
+    SmallVector<Value> newResults;
+    for (auto it : llvm::enumerate(executeRegionOp->getResultTypes())) {
+      if (it.value().isa<TensorType>()) {
+        newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
+            executeRegionOp.getLoc(), newOp->getResult(it.index())));
+      } else {
+        newResults.push_back(newOp->getResult(it.index()));
+      }
+    }
+
+    // Replace old op.
+    rewriter.replaceOp(executeRegionOp, newResults);
+
     return success();
   }
 

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 5cf7612b58a68..f9a809ea15784 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -159,8 +159,8 @@ func @mini_test_case1() -> tensor<10x20xf32> {
 
 // -----
 
+// expected-error @+1 {{memref return type is unsupported}}
 func @main() -> tensor<4xi32> {
-  // expected-error @+1 {{scf.execute_region with tensor result not supported}}
   %r = scf.execute_region -> tensor<4xi32> {
     %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
     scf.yield %A: tensor<4xi32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b37c3066a0d65..e6c521ddfdc24 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -446,6 +446,59 @@ func @main() {
 
 // -----
 
+// CHECK-LABEL: func @execute_region_test(
+//  CHECK-SAME:     %[[m1:.*]]: memref<?xf32
+func @execute_region_test(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
+    -> (f32, tensor<?xf32>, f32)
+{
+  %f1 = arith.constant 0.0 : f32
+  %f2 = arith.constant 1.0 : f32
+  %idx = arith.constant 7 : index
+
+  // scf.execute_region is canonicalized away after bufferization. So just the
+  // memref.store is left over.
+
+  // CHECK: memref.store %{{.*}}, %[[m1]][%{{.*}}]
+  %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
+    %t2 = tensor.insert %f2 into %t1[%idx] : tensor<?xf32>
+    scf.yield %f1, %t2, %f2 : f32, tensor<?xf32>, f32
+  }
+
+  // CHECK: return %{{.*}}, %{{.*}} : f32, f32
+  return %0, %1, %2 : f32, tensor<?xf32>, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @execute_region_with_conflict(
+//  CHECK-SAME:     %[[m1:.*]]: memref<?xf32
+func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
+    -> (f32, tensor<?xf32>, f32)
+{
+  %f1 = arith.constant 0.0 : f32
+  %idx = arith.constant 7 : index
+
+  // scf.execute_region is canonicalized away after bufferization. So just the
+  // memref.store is left over.
+
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK: memref.copy %[[m1]], %[[alloc]]
+  // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
+  %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
+    %t2 = tensor.insert %f1 into %t1[%idx] : tensor<?xf32>
+    scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
+  }
+
+  // CHECK: %[[load:.*]] = memref.load %[[m1]]
+  %3 = tensor.extract %t1[%idx] : tensor<?xf32>
+
+  // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref<?xf32, #{{.*}}>, f32
+  return %0, %1, %3 : f32, tensor<?xf32>, f32
+}
+
+// -----
+
 //      CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
 //      CHECK:  func private @some_external_func(memref<?xf32, #[[$DYN_1D_MAP]]>)


        


More information about the Mlir-commits mailing list