[Mlir-commits] [mlir] f3676c3 - [mlir][memref] memref.reinterpret_cast folding

Ivan Butygin llvmlistbot at llvm.org
Fri Mar 11 10:23:24 PST 2022


Author: Ivan Butygin
Date: 2022-03-11T21:22:43+03:00
New Revision: f3676c3273b9ef32af4b63b8bfa156a8dac31e63

URL: https://github.com/llvm/llvm-project/commit/f3676c3273b9ef32af4b63b8bfa156a8dac31e63
DIFF: https://github.com/llvm/llvm-project/commit/f3676c3273b9ef32af4b63b8bfa156a8dac31e63.diff

LOG: [mlir][memref] memref.reinterpret_cast folding

* reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x)
* reinterpret_cast(cast(x)) -> reinterpret_cast(x)
* reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets are 0

Differential Revision: https://reviews.llvm.org/D120242

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d94c01e40db3f..dc04c461ec785 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1090,6 +1090,8 @@ def MemRef_ReinterpretCastOp
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index bf5d0a8bd1bf8..78b11829bb4fb 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -53,6 +53,9 @@ SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 
+/// Return true if `ofr` is constant integer equal to `value`.
+bool isConstantIntValue(OpFoldResult ofr, int64_t value);
+
 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
 /// or the same SSA value.
 /// Ignore integer bitwitdh and type mismatch that come from the fact there is

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0bd293f1da06f..18d1b7db744a7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1508,6 +1508,36 @@ LogicalResult ReinterpretCastOp::verify() {
   return success();
 }
 
+OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
+  Value src = source();
+  auto getPrevSrc = [&]() -> Value {
+    // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
+    if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
+      return prev.source();
+
+    // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
+    if (auto prev = src.getDefiningOp<CastOp>())
+      return prev.source();
+
+    // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
+    // are 0.
+    if (auto prev = src.getDefiningOp<SubViewOp>())
+      if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
+            return isConstantIntValue(val, 0);
+          }))
+        return prev.source();
+
+    return nullptr;
+  };
+
+  if (auto prevSrc = getPrevSrc()) {
+    sourceMutable().assign(prevSrc);
+    return getResult();
+  }
+
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3e50fac6fd3a8..419aa46329b67 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -81,6 +81,12 @@ Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return llvm::None;
 }
 
+/// Return true if `ofr` is constant integer equal to `value`.
+bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
+  auto val = getConstantIntValue(ofr);
+  return val && *val == value;
+}
+
 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
 /// or the same SSA value.
 /// Ignore integer bitwidth and type mismatch that come from the fact there is

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index be3ac32cbfbda..96fff29db7347 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -657,3 +657,39 @@ func @scopeInline(%arg : memref<index>) {
 
 // CHECK:   func @scopeInline
 // CHECK-NOT:  memref.alloca_scope
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_reinterpret
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
+//       CHECK: return %[[RES]]
+func @reinterpret_of_reinterpret(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
+  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref<?xi8> to memref<?xi8>
+  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
+  return %1 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1]
+//       CHECK: return %[[RES]]
+func @reinterpret_of_cast(%arg : memref<?xi8>, %size: index) -> memref<?xi8> {
+  %0 = memref.cast %arg : memref<?xi8> to memref<5xi8>
+  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref<?xi8>
+  return %1 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @reinterpret_of_subview
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
+//       CHECK: return %[[RES]]
+func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
+  %0 = memref.subview %arg[0] [%size1] [1] : memref<?xi8> to memref<?xi8>
+  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
+  return %1 : memref<?xi8>
+}


        


More information about the Mlir-commits mailing list