[Mlir-commits] [mlir] 8b09141 - [mlir][arith][bufferize] Fix tensors with different layouts after bufferization

Matthias Springer llvmlistbot at llvm.org
Fri Apr 8 02:11:18 PDT 2022


Author: Matthias Springer
Date: 2022-04-08T18:11:10+09:00
New Revision: 8b09141909329d93b0de987ee18ee9cfaa7223ba

URL: https://github.com/llvm/llvm-project/commit/8b09141909329d93b0de987ee18ee9cfaa7223ba
DIFF: https://github.com/llvm/llvm-project/commit/8b09141909329d93b0de987ee18ee9cfaa7223ba.diff

LOG: [mlir][arith][bufferize] Fix tensors with different layouts after bufferization

Insert a cast if the two tensors with identical layout (that are passed to `arith.select`) have different layout maps after bufferization.

Differential Revision: https://reviews.llvm.org/D123321

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 12726a1656bbc..4f1add5a899b9 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -129,6 +129,7 @@ struct SelectOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto selectOp = cast<arith::SelectOp>(op);
+    Location loc = selectOp.getLoc();
 
     // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
     // TODO: It would be more efficient to copy the result of the `select` op
@@ -139,6 +140,26 @@ struct SelectOpInterface
         *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
     Value falseBuffer =
         *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
+
+    // The "true" and the "false" operands must have the same type. If the
+    // buffers have 
diff erent types, they 
diff er only in their layout map. Cast
+    // both of them to the most dynamic MemRef type.
+    if (trueBuffer.getType() != falseBuffer.getType()) {
+      auto trueType = trueBuffer.getType().cast<MemRefType>();
+      auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
+      int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
+      SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+                                          ShapedType::kDynamicStrideOrOffset);
+      AffineMap stridedLayout = makeStridedLinearLayoutMap(
+          dynamicStrides, dynamicOffset, op->getContext());
+      BaseMemRefType castedType = bufferization::getMemRefType(
+          tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
+          trueType.getMemorySpace());
+      trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
+      falseBuffer =
+          rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
+    }
+
     replaceOpWithNewBufferizedOp<arith::SelectOp>(
         rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
     return success();

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e71139262534a..ac2249da4282c 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -105,4 +105,18 @@ func @copy_deallocated() -> tensor<10xf32> {
   return %0 : tensor<10xf32>
 }
 
+// -----
 
+// CHECK-LABEL: func @select_
diff erent_tensors(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+func @select_
diff erent_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> tensor<?xf32> {
+  // CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, #{{.*}}>
+  // CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<?xf32>
+  %0 = linalg.init_tensor [%sz] : tensor<?xf32>
+
+  // A cast must be inserted because %t and %0 have 
diff erent memref types.
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, #{{.*}}>
+  // CHECK: arith.select %{{.*}}, %[[casted]], %[[m]]
+  %1 = arith.select %c, %0, %t : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list