[Mlir-commits] [mlir] dbd1bbc - [mlir][linalg][bufferize] Support arith.index_cast bufferization
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 27 02:50:40 PST 2022
Author: Matthias Springer
Date: 2022-01-27T19:50:31+09:00
New Revision: dbd1bbced9896d5caece9ee60a7953d2c80d939c
URL: https://github.com/llvm/llvm-project/commit/dbd1bbced9896d5caece9ee60a7953d2c80d939c
DIFF: https://github.com/llvm/llvm-project/commit/dbd1bbced9896d5caece9ee60a7953d2c80d939c.diff
LOG: [mlir][linalg][bufferize] Support arith.index_cast bufferization
This is in preparation of switching `-tensor-constant-bufferize` and `-arith-bufferize` to BufferizableOpInterface-based implementations.
Differential Revision: https://reviews.llvm.org/D118324
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index de3fbcd8b121c..2d09331ede237 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -57,6 +57,49 @@ struct ConstantOpInterface
}
};
+struct IndexCastOpInterface
+ : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
+ arith::IndexCastOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return op->getResult(0);
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationState &state) const {
+ auto castOp = cast<arith::IndexCastOp>(op);
+
+ Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
+ auto sourceType = source.getType().cast<BaseMemRefType>();
+
+ // Result type should have same layout and address space as the source type.
+ MemRefLayoutAttrInterface layout = {};
+ if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
+ layout = rankedMemRefType.getLayout();
+ Type resultType =
+ getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
+ layout, sourceType.getMemorySpace());
+
+ replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
+ resultType);
+ return success();
+ }
+};
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
@@ -65,4 +108,6 @@ struct ConstantOpInterface
void mlir::linalg::comprehensive_bufferize::arith_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
+ registry
+ .addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 1a3b266ee4b80..75fb2be2c6533 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -96,3 +96,19 @@ func @rank_reducing(
}
return %5: tensor<?x1x6x8xf32>
}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK-LABEL: func @index_cast(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
+func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
+ %index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
+ %index_scalar = arith.index_cast %scalar : i32 to index
+ return %index_tensor, %index_scalar : tensor<index>, index
+}
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
+// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
+// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
+// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
+// CHECK: return %[[INDEX_TENSOR]]
More information about the Mlir-commits
mailing list