[Mlir-commits] [mlir] d4217e6 - [mlir][memref] Missing type conversion in memref.reshape llvm lowering

Ivan Butygin llvmlistbot at llvm.org
Thu Jul 21 02:16:49 PDT 2022


Author: Ivan Butygin
Date: 2022-07-21T11:15:35+02:00
New Revision: d4217e6cc86b8bd9d36879eef6c29658fc3423f0

URL: https://github.com/llvm/llvm-project/commit/d4217e6cc86b8bd9d36879eef6c29658fc3423f0
DIFF: https://github.com/llvm/llvm-project/commit/d4217e6cc86b8bd9d36879eef6c29658fc3423f0.diff

LOG: [mlir][memref] Missing type conversion in memref.reshape llvm lowering

Shape can be memref of index type, so memref::LoadOp result need to be converted into llvm type.

Differential Revision: https://reviews.llvm.org/D129965

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 6e363997ac16d..2e574903b73c0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1161,6 +1161,11 @@ struct MemRefReshapeOpLowering
           Value shapeOp = reshapeOp.getShape();
           Value index = createIndexConstant(rewriter, loc, i);
           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
+          Type indexType = getIndexType();
+          if (dimSize.getType() != indexType)
+            dimSize = typeConverter->materializeTargetConversion(
+                rewriter, loc, indexType, dimSize);
+          assert(dimSize && "Invalid memref element type");
         }
 
         desc.setSize(rewriter, loc, i, dimSize);

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 1296f8e4881cd..2520294ed03dc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -306,3 +306,35 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
   return %0 : memref<?x?x12x32xf32>
   // CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @memref.reshape_index
+// CHECK-SAME:    %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
+func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) ->  memref<?xi32> {
+  // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef:.*]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0:.*]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+  // CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1:.*]][2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+  // CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
+
+  // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
+  // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr<i64>
+  // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+  // CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert4:.*]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi32>
+  // CHECK: return %[[result_cast:.*]] : memref<?xi32>
+
+  %1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
+  return %1 : memref<?xi32>
+}


        


More information about the Mlir-commits mailing list