[Mlir-commits] [mlir] 1fee821 - [mlir][memref] Make result normalization aware of the number symbols

Kai Sasaki llvmlistbot at llvm.org
Wed Jun 28 18:17:56 PDT 2023


Author: Kai Sasaki
Date: 2023-06-29T10:04:53+09:00
New Revision: 1fee821d22252ec946bf730bc1149951eb6281c7

URL: https://github.com/llvm/llvm-project/commit/1fee821d22252ec946bf730bc1149951eb6281c7
DIFF: https://github.com/llvm/llvm-project/commit/1fee821d22252ec946bf730bc1149951eb6281c7.diff

LOG: [mlir][memref] Make result normalization aware of the number symbols

Memref normalization fails to recognize the non-zero symbols used in the memref type itself with strided, offset information. It causes the crash with the type like `memref<128x512xf32, strided<[?, ?], offset: ?>>`. The original issue is here. https://github.com/llvm/llvm-project/issues/61345

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D150250

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Utils.h
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
    mlir/test/Transforms/normalize-memrefs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ec86f16642e7de..b3ccbff3002fb1 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -249,8 +249,7 @@ LogicalResult normalizeMemRef(memref::AllocOp *op);
 /// transformed to an identity map with a new shape being computed for the
 /// normalized memref type and returns it. The old memref type is simplify
 /// returned if the normalization failed.
-MemRefType normalizeMemRefType(MemRefType memrefType,
-                               unsigned numSymbolicOperands);
+MemRefType normalizeMemRefType(MemRefType memrefType);
 
 /// Given an operation, inserts one or more single result affine apply
 /// operations, results of which are exclusively used by this operation.

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index d567093188e0c0..1ba9c82e5af958 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1720,8 +1720,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
 
   // Fetch a new memref type after normalizing the old memref to have an
   // identity map layout.
-  MemRefType newMemRefType =
-      normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());
+  MemRefType newMemRefType = normalizeMemRefType(memrefType);
   if (newMemRefType == memrefType)
     // Either memrefType already had an identity map or the map couldn't be
     // transformed to an identity map.
@@ -1772,8 +1771,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
   return success();
 }
 
-MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
-                                             unsigned numSymbolicOperands) {
+MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned rank = memrefType.getRank();
   if (rank == 0)
     return memrefType;
@@ -1784,6 +1782,7 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
     return memrefType;
   }
   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
+  unsigned numSymbolicOperands = layoutMap.getNumSymbols();
 
   // We don't do any checks for one-to-one'ness; we assume that it is
   // one-to-one.

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index aa21497fad8f8e..33772ccb7dd9d3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -367,8 +367,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
     }
     // Fetch a new memref type after normalizing the old memref to have an
     // identity map layout.
-    MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                   /*numSymbolicOperands=*/0);
+    MemRefType newMemRefType = normalizeMemRefType(memrefType);
     if (newMemRefType == memrefType || funcOp.isExternal()) {
       // Either memrefType already had an identity map or the map couldn't be
       // transformed to an identity map.
@@ -475,8 +474,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
       }
       // Computing a new memref type after normalizing the old memref to have an
       // identity map layout.
-      MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                     /*numSymbolicOperands=*/0);
+      MemRefType newMemRefType = normalizeMemRefType(memrefType);
       resultTypes.push_back(newMemRefType);
     }
 
@@ -513,9 +511,9 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
       resultTypes.push_back(resultType);
       continue;
     }
+
     // Fetch a new memref type after normalizing the old memref.
-    MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                   /*numSymbolicOperands=*/0);
+    MemRefType newMemRefType = normalizeMemRefType(memrefType);
     if (newMemRefType == memrefType) {
       // Either memrefType already had an identity map or the map couldn't
       // be transformed to an identity map.

diff  --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
index 892d5e53f690c8..c7af033a22a2c6 100644
--- a/mlir/test/Transforms/normalize-memrefs.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -352,3 +352,14 @@ func.func @neg_map() -> memref<2x3xf32, #neg> {
   %0 = memref.alloc() : memref<2x3xf32, #neg>
   return %0 : memref<2x3xf32, #neg>
 }
+
+// CHECK-LABEL: func @memref_with_strided_offset
+func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = bufferization.to_memref %arg0 : memref<128x512xf32, strided<[?, ?], offset: ?>>
+  %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref<?x512xf32, strided<[?, ?], offset: ?>>
+  // CHECK: %{{.*}} = memref.cast %{{.*}} : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+  %cast = memref.cast %subview : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+  %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
+  return %1 : tensor<16x512xf32>
+}


        


More information about the Mlir-commits mailing list