[Mlir-commits] [mlir] a36348c - [mlir][bufferize] Fix bug in AllocTensorElimination

Matthias Springer llvmlistbot at llvm.org
Mon Aug 15 02:46:05 PDT 2022


Author: Matthias Springer
Date: 2022-08-15T11:45:58+02:00
New Revision: a36348c5868def196062fe74714b9a8ab754cdd0

URL: https://github.com/llvm/llvm-project/commit/a36348c5868def196062fe74714b9a8ab754cdd0
DIFF: https://github.com/llvm/llvm-project/commit/a36348c5868def196062fe74714b9a8ab754cdd0.diff

LOG: [mlir][bufferize] Fix bug in AllocTensorElimination

AllocTensorElimination does currently not support chains where the type is
changing. AllocTensorElimination used to generate invalid IR for such
inputs. With this commit, AllocTensorElimination does no longer apply to
such inputs. (It can be extended to support such IR if needed.)

Differential Revision: https://reviews.llvm.org/D131880

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
index 719797ac23473..ff308e62e2742 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
@@ -140,6 +140,15 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
         return WalkResult::skip();
       Value allocTensor = maybeAllocTensor.front();
 
+      // Replace only if the types match.
+      // TODO: This could be extended to support IR such as:
+      // %0 = bufferization.alloc_tensor : tensor<128xf32>
+      // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+      // %2 = tensor.expand_shape %1 ...
+      // %3 = tensor.insert_slice %2 into ...
+      if (allocTensor.getType() != operand.get().getType())
+        return WalkResult::skip();
+
       // Find a suitable insertion point.
       Operation *insertionPoint =
           findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir
index 2bea701e24b04..b062dd200cb51 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir
@@ -94,7 +94,7 @@ func.func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tens
 //      CHECK: func @insertion_point_outside_loop(
 // CHECK-SAME:     %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
 func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
-                                   %idx : index) -> (tensor<?xf32>) {
+                                        %idx : index) -> (tensor<?xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c5 = arith.constant 5 : index
@@ -118,3 +118,21 @@ func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
 
   return %r : tensor<?xf32>
 }
+
+// -----
+
+// AllocTensorElimination does currently not apply to chains where the type is
+// changing. This test just ensures that we do not crash or generate IR that
+// does not verify.
+
+// CHECK-LABEL: func @shape_mismatch
+func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+  %cst = arith.constant 8.0 : f32
+  %0 = bufferization.alloc_tensor() : tensor<128xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
+  %2 = tensor.expand_shape %1 [[0, 1, 2]]
+      : tensor<128xf32> into tensor<1x1x128xf32>
+  %3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1]
+      : tensor<1x1x128xf32> into tensor<5x6x128xf32>
+  return %3 : tensor<5x6x128xf32>
+}


        


More information about the Mlir-commits mailing list