[Mlir-commits] [mlir] 55088ef - [mlir][memref] Fix MemrefToLLVM lowering pattern for memref.transpose

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 14 04:13:03 PDT 2023


Author: Felix Schneider
Date: 2023-09-14T13:12:55+02:00
New Revision: 55088efe061aed2d26cd26cdb8141a79bf9a8528

URL: https://github.com/llvm/llvm-project/commit/55088efe061aed2d26cd26cdb8141a79bf9a8528
DIFF: https://github.com/llvm/llvm-project/commit/55088efe061aed2d26cd26cdb8141a79bf9a8528.diff

LOG: [mlir][memref] Fix MemrefToLLVM lowering pattern for memref.transpose

The lowering pattern to LLVM for memref.transpose has a bug where
instead of transposing from (source) -> (dest) it actually transposes
(dest) -> (source). This patch fixes the bug and updates the test.

Fix https://github.com/llvm/llvm-project/issues/65145

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D159290

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 97faefe2cd4d631..159fa1da935700e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1428,11 +1428,14 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
     // Copy the offset pointer from the old descriptor to the new one.
     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
 
-    // Iterate over the dimensions and apply size/stride permutation.
+    // Iterate over the dimensions and apply size/stride permutation:
+    // When enumerating the results of the permutation map, the enumeration index
+    // is the index into the target dimensions and the DimExpr points to the
+    // dimension of the source memref.
     for (const auto &en :
          llvm::enumerate(transposeOp.getPermutation().getResults())) {
-      int sourcePos = en.index();
-      int targetPos = en.value().cast<AffineDimExpr>().getPosition();
+      int targetPos = en.index();
+      int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
       targetMemRef.setSize(rewriter, loc, targetPos,
                            viewMemRef.size(rewriter, loc, sourcePos));
       targetMemRef.setStride(rewriter, loc, targetPos,

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index f28f400db7f3c2d..2ece4acc05f5d92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -230,12 +230,18 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 //       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:   llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
   %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return


        


More information about the Mlir-commits mailing list