[Mlir-commits] [mlir] 1fee821 - [mlir][memref] Make result normalization aware of the number symbols
Kai Sasaki
llvmlistbot at llvm.org
Wed Jun 28 18:17:56 PDT 2023
Author: Kai Sasaki
Date: 2023-06-29T10:04:53+09:00
New Revision: 1fee821d22252ec946bf730bc1149951eb6281c7
URL: https://github.com/llvm/llvm-project/commit/1fee821d22252ec946bf730bc1149951eb6281c7
DIFF: https://github.com/llvm/llvm-project/commit/1fee821d22252ec946bf730bc1149951eb6281c7.diff
LOG: [mlir][memref] Make result normalization aware of the number symbols
Memref normalization fails to recognize the non-zero symbols used in the memref type itself with strided, offset information. It causes the crash with the type like `memref<128x512xf32, strided<[?, ?], offset: ?>>`. The original issue is here. https://github.com/llvm/llvm-project/issues/61345
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D150250
Added:
Modified:
mlir/include/mlir/Dialect/Affine/Utils.h
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
mlir/test/Transforms/normalize-memrefs.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ec86f16642e7de..b3ccbff3002fb1 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -249,8 +249,7 @@ LogicalResult normalizeMemRef(memref::AllocOp *op);
/// transformed to an identity map with a new shape being computed for the
/// normalized memref type and returns it. The old memref type is simplify
/// returned if the normalization failed.
-MemRefType normalizeMemRefType(MemRefType memrefType,
- unsigned numSymbolicOperands);
+MemRefType normalizeMemRefType(MemRefType memrefType);
/// Given an operation, inserts one or more single result affine apply
/// operations, results of which are exclusively used by this operation.
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index d567093188e0c0..1ba9c82e5af958 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1720,8 +1720,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
// Fetch a new memref type after normalizing the old memref to have an
// identity map layout.
- MemRefType newMemRefType =
- normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
if (newMemRefType == memrefType)
// Either memrefType already had an identity map or the map couldn't be
// transformed to an identity map.
@@ -1772,8 +1771,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
return success();
}
-MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
- unsigned numSymbolicOperands) {
+MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
unsigned rank = memrefType.getRank();
if (rank == 0)
return memrefType;
@@ -1784,6 +1782,7 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
return memrefType;
}
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
+ unsigned numSymbolicOperands = layoutMap.getNumSymbols();
// We don't do any checks for one-to-one'ness; we assume that it is
// one-to-one.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index aa21497fad8f8e..33772ccb7dd9d3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -367,8 +367,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
}
// Fetch a new memref type after normalizing the old memref to have an
// identity map layout.
- MemRefType newMemRefType = normalizeMemRefType(memrefType,
- /*numSymbolicOperands=*/0);
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
if (newMemRefType == memrefType || funcOp.isExternal()) {
// Either memrefType already had an identity map or the map couldn't be
// transformed to an identity map.
@@ -475,8 +474,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
}
// Computing a new memref type after normalizing the old memref to have an
// identity map layout.
- MemRefType newMemRefType = normalizeMemRefType(memrefType,
- /*numSymbolicOperands=*/0);
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
resultTypes.push_back(newMemRefType);
}
@@ -513,9 +511,9 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
resultTypes.push_back(resultType);
continue;
}
+
// Fetch a new memref type after normalizing the old memref.
- MemRefType newMemRefType = normalizeMemRefType(memrefType,
- /*numSymbolicOperands=*/0);
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
if (newMemRefType == memrefType) {
// Either memrefType already had an identity map or the map couldn't
// be transformed to an identity map.
diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
index 892d5e53f690c8..c7af033a22a2c6 100644
--- a/mlir/test/Transforms/normalize-memrefs.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -352,3 +352,14 @@ func.func @neg_map() -> memref<2x3xf32, #neg> {
%0 = memref.alloc() : memref<2x3xf32, #neg>
return %0 : memref<2x3xf32, #neg>
}
+
+// CHECK-LABEL: func @memref_with_strided_offset
+func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = bufferization.to_memref %arg0 : memref<128x512xf32, strided<[?, ?], offset: ?>>
+ %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref<?x512xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.*}} = memref.cast %{{.*}} : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+ %cast = memref.cast %subview : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+ %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
+ return %1 : tensor<16x512xf32>
+}
More information about the Mlir-commits
mailing list