[Mlir-commits] [mlir] 55088ef - [mlir][memref] Fix MemrefToLLVM lowering pattern for memref.transpose
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 14 04:13:03 PDT 2023
Author: Felix Schneider
Date: 2023-09-14T13:12:55+02:00
New Revision: 55088efe061aed2d26cd26cdb8141a79bf9a8528
URL: https://github.com/llvm/llvm-project/commit/55088efe061aed2d26cd26cdb8141a79bf9a8528
DIFF: https://github.com/llvm/llvm-project/commit/55088efe061aed2d26cd26cdb8141a79bf9a8528.diff
LOG: [mlir][memref] Fix MemrefToLLVM lowering pattern for memref.transpose
The lowering pattern to LLVM for memref.transpose has a bug where
instead of transposing from (source) -> (dest) it actually transposes
(dest) -> (source). This patch fixes the bug and updates the test.
Fix https://github.com/llvm/llvm-project/issues/65145
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D159290
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 97faefe2cd4d631..159fa1da935700e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1428,11 +1428,14 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
// Copy the offset pointer from the old descriptor to the new one.
targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
- // Iterate over the dimensions and apply size/stride permutation.
+ // Iterate over the dimensions and apply size/stride permutation:
+ // When enumerating the results of the permutation map, the enumeration index
+ // is the index into the target dimensions and the DimExpr points to the
+ // dimension of the source memref.
for (const auto &en :
llvm::enumerate(transposeOp.getPermutation().getResults())) {
- int sourcePos = en.index();
- int targetPos = en.value().cast<AffineDimExpr>().getPosition();
+ int targetPos = en.index();
+ int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
targetMemRef.setSize(rewriter, loc, targetPos,
viewMemRef.size(rewriter, loc, sourcePos));
targetMemRef.setStride(rewriter, loc, targetPos,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index f28f400db7f3c2d..2ece4acc05f5d92 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -230,12 +230,18 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
%0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
More information about the Mlir-commits
mailing list