[Mlir-commits] [mlir] 6c0fd4d - [mlir][MemRef] Fix DimOp folding of OffsetSizeAndStrideInterface.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jul 8 01:34:30 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-08T08:30:24Z
New Revision: 6c0fd4db79f2def432f761627bb8c7d4171a3237
URL: https://github.com/llvm/llvm-project/commit/6c0fd4db79f2def432f761627bb8c7d4171a3237
DIFF: https://github.com/llvm/llvm-project/commit/6c0fd4db79f2def432f761627bb8c7d4171a3237.diff
LOG: [mlir][MemRef] Fix DimOp folding of OffsetSizeAndStrideInterface.
This addresses the issue reported in
https://llvm.discourse.group/t/rank-reducing-memref-subview-offsetsizeandstrideopinterface-interface-issues/3805
Differential Revision: https://reviews.llvm.org/D105558
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 8b30fa94f9936..44e751ab0edeb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -110,6 +110,10 @@ class ShapedType : public Type {
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const;
+ /// If `dim` is a dynamic dim, return its relative index among the dynamic
+ /// dims. Otherwise, abort. The result is guaranteed to be nonnegative.
+ int64_t getRelativeIndexOfDynamicDim(unsigned dim) const;
+
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(unsigned idx) const;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 518539376c9f4..a4cbb23bf74dc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -175,9 +175,9 @@ struct SimplifyDeadAlloc : public OpRewritePattern<T> {
LogicalResult matchAndRewrite(T alloc,
PatternRewriter &rewriter) const override {
if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
- if (auto storeOp = dyn_cast<StoreOp>(op))
- return storeOp.value() == alloc;
- return !isa<DeallocOp>(op);
+ if (auto storeOp = dyn_cast<StoreOp>(op))
+ return storeOp.value() == alloc;
+ return !isa<DeallocOp>(op);
}))
return failure();
@@ -677,9 +677,9 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
if (auto sizeInterface =
dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
- assert(sizeInterface.isDynamicSize(unsignedIndex) &&
- "Expected dynamic subview size");
- return sizeInterface.getDynamicSize(unsignedIndex);
+ int64_t nthDynamicIndex =
+ memrefType.getRelativeIndexOfDynamicDim(unsignedIndex);
+ return sizeInterface.sizes()[nthDynamicIndex];
}
// dim(memrefcast) -> dim
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dbd47c2d1fcd0..b794c11d5948f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -271,13 +271,21 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return Value{*dynExtents};
}
+ // dim(insert_slice.result()) -> dim(insert_slice.dest())
+ if (auto insertSliceOp =
+ dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
+ this->sourceMutable().assign(insertSliceOp.dest());
+ return getResult();
+ }
+
// The size at the given index is now known to be a dynamic size.
unsigned unsignedIndex = index.getValue().getZExtValue();
- if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
- assert(sliceOp.isDynamicSize(unsignedIndex) &&
- "Expected dynamic slice size");
- return sliceOp.getDynamicSize(unsignedIndex);
+ if (auto sizeInterface =
+ dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
+ int64_t nthDynamicIndex =
+ tensorType.getRelativeIndexOfDynamicDim(unsignedIndex);
+ return sizeInterface.sizes()[nthDynamicIndex];
}
// dim(cast) -> dim
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index f350596384a90..0c715d2d528f5 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -427,6 +427,15 @@ int64_t ShapedType::getNumDynamicDims() const {
return llvm::count_if(getShape(), isDynamic);
}
+int64_t ShapedType::getRelativeIndexOfDynamicDim(unsigned dim) const {
+ assert(isDynamicDim(dim) && "expected a dynamic dim");
+ int nthDynamicIndex = -1;
+ for (unsigned idx = 0; idx <= dim; ++idx)
+ if (isDynamicDim(idx))
+ ++nthDynamicIndex;
+ return nthDynamicIndex;
+}
+
bool ShapedType::hasStaticShape() const {
return hasRank() && llvm::none_of(getShape(), isDynamic);
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 2ae2c06dea92e..302477f04421e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -387,11 +387,32 @@ func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
}
// -----
+
// CHECK-LABEL: func @allocator
// CHECK: %[[alloc:.+]] = memref.alloc
// CHECK: memref.store %[[alloc:.+]], %arg0
func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) {
%0 = memref.alloc(%arg1) : memref<?xi32>
memref.store %0, %arg0[] : memref<memref<?xi32>>
- return
+ return
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK-LABEL: func @rank_reducing_subview_dim
+// CHECK-SAME: %[[IDX_0:[0-9a-zA-Z]*]]: index
+// CHECK-SAME: %[[IDX_1:[0-9a-zA-Z]*]]: index
+func @rank_reducing_subview_dim(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> index
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0>
+ %1 = memref.dim %0, %c1 : memref<?x?xf32, #map0>
+
+ // CHECK-NEXT: return %[[IDX_1]] : index
+ return %1 : index
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f0259952da380..977357077df37 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -517,3 +517,42 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_extract_slice_dim
+// CHECK-SAME: %[[IDX_0:[0-9a-zA-Z]*]]: index
+// CHECK-SAME: %[[IDX_1:[0-9a-zA-Z]*]]: index
+func @rank_reducing_extract_slice_dim(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> index
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+
+ // CHECK-NEXT: return %[[IDX_1]] : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_insert_slice_dim
+// CHECK-SAME: %[[OUT:[0-9a-zA-Z]*]]: tensor<?x?x?xf32>
+func @rank_reducing_insert_slice_dim(%out : tensor<?x?x?xf32>, %in : tensor<?x?xf32>, %arg1 : index,
+ %arg2 : index) -> index
+{
+ // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = tensor.insert_slice %in into %out[%c0, %arg1, %c1] [1, 1, 1] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+
+ // CHECK-NEXT: %[[D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
+
+ // CHECK-NEXT: return %[[D1]] : index
+ return %1 : index
+}
More information about the Mlir-commits
mailing list