[Mlir-commits] [mlir] 2fe40c3 - [mlir][bufferize] Fix op filter

Matthias Springer llvmlistbot at llvm.org
Thu May 12 00:33:15 PDT 2022


Author: Matthias Springer
Date: 2022-05-12T09:33:07+02:00
New Revision: 2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f

URL: https://github.com/llvm/llvm-project/commit/2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f
DIFF: https://github.com/llvm/llvm-project/commit/2fe40c34eaea00db8a1b11b13c174f4f1a0eb92f.diff

LOG: [mlir][bufferize] Fix op filter

Bufferization has an optional filter to exclude certain ops from analysis+bufferization. There were a few remaining places in the codebase where the filter was not checked.

Differential Revision: https://reviews.llvm.org/D125356

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 29d983bdffdf1..f5b6203c2395f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -117,7 +117,7 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
 SmallVector<OpOperand *>
 AnalysisState::getAliasingOpOperand(OpResult result) const {
   if (Operation *op = result.getDefiningOp())
-    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+    if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
       return bufferizableOp.getAliasingOpOperand(result, *this);
   return {};
 }
@@ -127,7 +127,7 @@ AnalysisState::getAliasingOpOperand(OpResult result) const {
 SmallVector<OpResult>
 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
   if (auto bufferizableOp =
-          dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+          getOptions().dynCastBufferizableOp(opOperand.getOwner()))
     return bufferizableOp.getAliasingOpResult(opOperand, *this);
   return {};
 }
@@ -136,7 +136,7 @@ AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
 /// op is not bufferizable.
 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
   if (auto bufferizableOp =
-          dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+          getOptions().dynCastBufferizableOp(opOperand.getOwner()))
     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.
@@ -148,7 +148,7 @@ bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
 /// `true` if the op is not bufferizable.
 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
   if (auto bufferizableOp =
-          dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+          getOptions().dynCastBufferizableOp(opOperand.getOwner()))
     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.
@@ -160,7 +160,7 @@ bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
 /// alias. Return false if the op is not bufferizable.
 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
   if (auto bufferizableOp =
-          dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
+          getOptions().dynCastBufferizableOp(opOperand.getOwner()))
     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index dd0ba1ca68a02..c15d3cd86bf9c 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,31 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// This is a regression test. The linalg-bufferize pass should ignore all func
+// dialect ops.
+
+// CHECK-LABEL: func private @csum(tensor<6xi64>) -> tensor<6xi64>
+func.func private @csum(%arg0: tensor<6xi64>) -> tensor<6xi64>
+
+// CHECK: func public @main(%[[arg0:.*]]: tensor<2x3xi1>)
+// CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[arg0]]
+// CHECK:   %[[collapse_m:.*]] = bufferization.to_memref %[[collapse]]
+// CHECK:   %[[alloc:.*]] = memref.alloc()
+// CHECK:   linalg.generic {{.*}} ins(%[[collapse_m]] : memref<6xi1>) outs(%[[alloc]] : memref<6xi64>)
+// CHECK:   %[[generic_t:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK:   %[[call:.*]] = call @csum(%[[generic_t]])
+// CHECK:   return %[[call]]
+func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xi1> into tensor<6xi1>
+  %1 = linalg.init_tensor [6] : tensor<6xi64>
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor<6xi1>) outs(%1 : tensor<6xi64>) {
+  ^bb0(%arg1: i1, %arg2: i64):
+    %4 = arith.extui %arg1 : i1 to i64
+    linalg.yield %4 : i64
+  } -> tensor<6xi64>
+  %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
+  return %3 : tensor<6xi64>
+}


        


More information about the Mlir-commits mailing list