[Mlir-commits] [mlir] f3676c3 - [mlir][memref] memref.reinterpret_cast folding
Ivan Butygin
llvmlistbot at llvm.org
Fri Mar 11 10:23:24 PST 2022
Author: Ivan Butygin
Date: 2022-03-11T21:22:43+03:00
New Revision: f3676c3273b9ef32af4b63b8bfa156a8dac31e63
URL: https://github.com/llvm/llvm-project/commit/f3676c3273b9ef32af4b63b8bfa156a8dac31e63
DIFF: https://github.com/llvm/llvm-project/commit/f3676c3273b9ef32af4b63b8bfa156a8dac31e63.diff
LOG: [mlir][memref] memref.reinterpret_cast folding
* reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x)
* reinterpret_cast(cast(x)) -> reinterpret_cast(x)
* reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets are 0
Differential Revision: https://reviews.llvm.org/D120242
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d94c01e40db3f..dc04c461ec785 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1090,6 +1090,8 @@ def MemRef_ReinterpretCastOp
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index bf5d0a8bd1bf8..78b11829bb4fb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -53,6 +53,9 @@ SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
+/// Return true if `ofr` is constant integer equal to `value`.
+bool isConstantIntValue(OpFoldResult ofr, int64_t value);
+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0bd293f1da06f..18d1b7db744a7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1508,6 +1508,36 @@ LogicalResult ReinterpretCastOp::verify() {
return success();
}
+OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
+ Value src = source();
+ auto getPrevSrc = [&]() -> Value {
+ // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
+ if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
+ return prev.source();
+
+ // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
+ if (auto prev = src.getDefiningOp<CastOp>())
+ return prev.source();
+
+ // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
+ // are 0.
+ if (auto prev = src.getDefiningOp<SubViewOp>())
+ if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
+ return isConstantIntValue(val, 0);
+ }))
+ return prev.source();
+
+ return nullptr;
+ };
+
+ if (auto prevSrc = getPrevSrc()) {
+ sourceMutable().assign(prevSrc);
+ return getResult();
+ }
+
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3e50fac6fd3a8..419aa46329b67 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -81,6 +81,12 @@ Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return llvm::None;
}
+/// Return true if `ofr` is constant integer equal to `value`.
+bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
+ auto val = getConstantIntValue(ofr);
+ return val && *val == value;
+}
+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwidth and type mismatch that come from the fact there is
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index be3ac32cbfbda..96fff29db7347 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -657,3 +657,39 @@ func @scopeInline(%arg : memref<index>) {
// CHECK: func @scopeInline
// CHECK-NOT: memref.alloca_scope
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_reinterpret
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
+// CHECK: return %[[RES]]
+func @reinterpret_of_reinterpret(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
+ %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref<?xi8> to memref<?xi8>
+ %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
+ return %1 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_cast
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1]
+// CHECK: return %[[RES]]
+func @reinterpret_of_cast(%arg : memref<?xi8>, %size: index) -> memref<?xi8> {
+ %0 = memref.cast %arg : memref<?xi8> to memref<5xi8>
+ %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref<?xi8>
+ return %1 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_subview
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
+// CHECK: return %[[RES]]
+func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
+ %0 = memref.subview %arg[0] [%size1] [1] : memref<?xi8> to memref<?xi8>
+ %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
+ return %1 : memref<?xi8>
+}
More information about the Mlir-commits
mailing list