[Mlir-commits] [mlir] fb9fc79 - One-shot-bufferize: allow non-tensor arguments in scg.while/for.
Johannes Reifferscheid
llvmlistbot at llvm.org
Wed Sep 7 06:54:41 PDT 2022
Author: Johannes Reifferscheid
Date: 2022-09-07T15:54:25+02:00
New Revision: fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e
URL: https://github.com/llvm/llvm-project/commit/fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e
DIFF: https://github.com/llvm/llvm-project/commit/fb9fc79809d5c2c3ab8809f7ff98fbfa8c3a9e7e.diff
LOG: One-shot-bufferize: allow non-tensor arguments in scg.while/for.
Currently, one-shot-bufferize crashes as soon as there's
a mixture of tensor and non-tensor arguments. This seems
to happen for no good reason.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D133419
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fd0ff88657900..fb8c8dd3e2b8e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -601,8 +601,13 @@ struct ForOpInterface
SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
- auto targetType =
- bufferization::getBufferType(forOp->getResult(it.index()), options);
+ Value result = forOp->getResult(it.index());
+ // If the type is not a tensor, bufferization doesn't need to touch it.
+ if (!result.getType().isa<TensorType>()) {
+ castedInitArgs.push_back(initArg);
+ continue;
+ }
+ auto targetType = bufferization::getBufferType(result, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -846,8 +851,13 @@ struct WhileOpInterface
SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
- auto targetType = bufferization::getBufferType(
- whileOp.getBeforeArguments()[it.index()], options);
+ Value beforeArg = whileOp.getBeforeArguments()[it.index()];
+ // If the type is not a tensor, bufferization doesn't need to touch it.
+ if (!beforeArg.getType().isa<TensorType>()) {
+ castedInitArgs.push_back(initArg);
+ continue;
+ }
+ auto targetType = bufferization::getBufferType(beforeArg, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -856,6 +866,8 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
+ if (!bbArg.getType().isa<TensorType>())
+ return bbArg.getType();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
}));
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index a72c6d3714aba..dab4331a2586f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -344,13 +344,14 @@ func.func @scf_for_swapping_yields(
// CHECK-SAME: %[[arg0:.*]]: memref<?xi1, #{{.*}}>
func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
// CHECK: scf.while : () -> () {
- %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+ %res:2 = scf.while (%arg1 = %arg0, %i = %idx) :
+ (tensor<?xi1>, index) -> (tensor<?xi1>, index) {
// CHECK: %[[condition:.*]] = memref.load %[[arg0]]
// CHECK: scf.condition(%[[condition]])
%condition = tensor.extract %arg1[%idx] : tensor<?xi1>
- scf.condition(%condition) %arg1 : tensor<?xi1>
+ scf.condition(%condition) %arg1, %idx : tensor<?xi1>, index
} do {
- ^bb0(%arg2: tensor<?xi1>):
+ ^bb0(%arg2: tensor<?xi1>, %i: index):
// CHECK: } do {
// CHECK: memref.store %{{.*}}, %[[arg0]]
// CHECK: scf.yield
@@ -358,11 +359,11 @@ func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
%pos = "dummy.some_op"() : () -> (index)
%val = "dummy.another_op"() : () -> (i1)
%1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
- scf.yield %1 : tensor<?xi1>
+ scf.yield %1, %i : tensor<?xi1>, index
}
// CHECK: return
- return %res : tensor<?xi1>
+ return %res#0 : tensor<?xi1>
}
// -----
@@ -853,3 +854,19 @@ func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
%x = tensor.extract %r[%c1] : tensor<?xf32>
return %x : f32
}
+
+// -----
+
+// CHECK-LABEL: func @non_tensor_for_arg
+func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true})
+ -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2.0 : f32
+ %c10 = arith.constant 10 : index
+ %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor<?xf32>) {
+ %t2 = tensor.insert %c2 into %t[%idx] : tensor<?xf32>
+ scf.yield %idx, %t2 : index, tensor<?xf32>
+ }
+ return %r1#1 : tensor<?xf32>
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list