[Mlir-commits] [mlir] fb9fc79 - One-shot-bufferize: allow non-tensor arguments in scg.while/for.

Johannes Reifferscheid llvmlistbot at llvm.org
Wed Sep 7 06:54:41 PDT 2022


Author: Johannes Reifferscheid
Date: 2022-09-07T15:54:25+02:00
New Revision: fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e

URL: https://github.com/llvm/llvm-project/commit/fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e
DIFF: https://github.com/llvm/llvm-project/commit/fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e.diff

LOG: One-shot-bufferize: allow non-tensor arguments in scg.while/for.

Currently, one-shot-bufferize crashes as soon as there's
a mixture of tensor and non-tensor arguments. This seems
to happen for no good reason.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D133419

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fd0ff88657900..fb8c8dd3e2b8e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -601,8 +601,13 @@ struct ForOpInterface
     SmallVector<Value> castedInitArgs;
     for (const auto &it : llvm::enumerate(initArgs)) {
       Value initArg = it.value();
-      auto targetType =
-          bufferization::getBufferType(forOp->getResult(it.index()), options);
+      Value result = forOp->getResult(it.index());
+      // If the type is not a tensor, bufferization doesn't need to touch it.
+      if (!result.getType().isa<TensorType>()) {
+        castedInitArgs.push_back(initArg);
+        continue;
+      }
+      auto targetType = bufferization::getBufferType(result, options);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -846,8 +851,13 @@ struct WhileOpInterface
     SmallVector<Value> castedInitArgs;
     for (const auto &it : llvm::enumerate(initArgs)) {
       Value initArg = it.value();
-      auto targetType = bufferization::getBufferType(
-          whileOp.getBeforeArguments()[it.index()], options);
+      Value beforeArg = whileOp.getBeforeArguments()[it.index()];
+      // If the type is not a tensor, bufferization doesn't need to touch it.
+      if (!beforeArg.getType().isa<TensorType>()) {
+        castedInitArgs.push_back(initArg);
+        continue;
+      }
+      auto targetType = bufferization::getBufferType(beforeArg, options);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -856,6 +866,8 @@ struct WhileOpInterface
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
+          if (!bbArg.getType().isa<TensorType>())
+            return bbArg.getType();
           // TODO: error handling
           return bufferization::getBufferType(bbArg, options)->cast<Type>();
         }));

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index a72c6d3714aba..dab4331a2586f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -344,13 +344,14 @@ func.func @scf_for_swapping_yields(
 //  CHECK-SAME:     %[[arg0:.*]]: memref<?xi1, #{{.*}}>
 func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
   // CHECK: scf.while : () -> () {
-  %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+  %res:2 = scf.while (%arg1 = %arg0, %i = %idx) :
+      (tensor<?xi1>, index) -> (tensor<?xi1>, index) {
     // CHECK: %[[condition:.*]] = memref.load %[[arg0]]
     // CHECK: scf.condition(%[[condition]])
     %condition = tensor.extract %arg1[%idx] : tensor<?xi1>
-    scf.condition(%condition) %arg1 : tensor<?xi1>
+    scf.condition(%condition) %arg1, %idx : tensor<?xi1>, index
   } do {
-  ^bb0(%arg2: tensor<?xi1>):
+  ^bb0(%arg2: tensor<?xi1>, %i: index):
     // CHECK: } do {
     // CHECK: memref.store %{{.*}}, %[[arg0]]
     // CHECK: scf.yield
@@ -358,11 +359,11 @@ func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
     %pos = "dummy.some_op"() : () -> (index)
     %val = "dummy.another_op"() : () -> (i1)
     %1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
-    scf.yield %1 : tensor<?xi1>
+    scf.yield %1, %i : tensor<?xi1>, index
   }
 
   // CHECK: return
-  return %res : tensor<?xi1>
+  return %res#0 : tensor<?xi1>
 }
 
 // -----
@@ -853,3 +854,19 @@ func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
   %x = tensor.extract %r[%c1] : tensor<?xf32>
   return %x : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @non_tensor_for_arg
+func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true}) 
+    -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2.0 : f32
+  %c10 = arith.constant 10 : index
+  %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor<?xf32>) {
+    %t2 = tensor.insert %c2 into %t[%idx] : tensor<?xf32>
+    scf.yield %idx, %t2 : index, tensor<?xf32>
+  }
+  return %r1#1 : tensor<?xf32>
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list