[Mlir-commits] [mlir] dbd1bbc - [mlir][linalg][bufferize] Support arith.index_cast bufferization

Matthias Springer llvmlistbot at llvm.org
Thu Jan 27 02:50:40 PST 2022


Author: Matthias Springer
Date: 2022-01-27T19:50:31+09:00
New Revision: dbd1bbced9896d5caece9ee60a7953d2c80d939c

URL: https://github.com/llvm/llvm-project/commit/dbd1bbced9896d5caece9ee60a7953d2c80d939c
DIFF: https://github.com/llvm/llvm-project/commit/dbd1bbced9896d5caece9ee60a7953d2c80d939c.diff

LOG: [mlir][linalg][bufferize] Support arith.index_cast bufferization

This is in preparation of switching `-tensor-constant-bufferize` and `-arith-bufferize` to BufferizableOpInterface-based implementations.

Differential Revision: https://reviews.llvm.org/D118324

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index de3fbcd8b121c..2d09331ede237 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -57,6 +57,49 @@ struct ConstantOpInterface
   }
 };
 
+struct IndexCastOpInterface
+    : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
+                                                    arith::IndexCastOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return op->getResult(0);
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto castOp = cast<arith::IndexCastOp>(op);
+
+    Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
+    auto sourceType = source.getType().cast<BaseMemRefType>();
+
+    // Result type should have same layout and address space as the source type.
+    MemRefLayoutAttrInterface layout = {};
+    if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
+      layout = rankedMemRefType.getLayout();
+    Type resultType =
+        getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
+                      layout, sourceType.getMemorySpace());
+
+    replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
+                                                     resultType);
+    return success();
+  }
+};
 } // namespace arith_ext
 } // namespace comprehensive_bufferize
 } // namespace linalg
@@ -65,4 +108,6 @@ struct ConstantOpInterface
 void mlir::linalg::comprehensive_bufferize::arith_ext::
     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
   registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
+  registry
+      .addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 1a3b266ee4b80..75fb2be2c6533 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -96,3 +96,19 @@ func @rank_reducing(
   }
   return %5: tensor<?x1x6x8xf32>
 }
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK-LABEL:   func @index_cast(
+// CHECK-SAME:  %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
+func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
+  %index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
+  %index_scalar = arith.index_cast %scalar : i32 to index
+  return %index_tensor, %index_scalar : tensor<index>, index
+}
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
+// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
+// CHECK-SAME:   memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
+// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
+// CHECK: return %[[INDEX_TENSOR]]


        


More information about the Mlir-commits mailing list