[Mlir-commits] [mlir] d4217e6 - [mlir][memref] Missing type conversion in memref.reshape llvm lowering
Ivan Butygin
llvmlistbot at llvm.org
Thu Jul 21 02:16:49 PDT 2022
Author: Ivan Butygin
Date: 2022-07-21T11:15:35+02:00
New Revision: d4217e6cc86b8bd9d36879eef6c29658fc3423f0
URL: https://github.com/llvm/llvm-project/commit/d4217e6cc86b8bd9d36879eef6c29658fc3423f0
DIFF: https://github.com/llvm/llvm-project/commit/d4217e6cc86b8bd9d36879eef6c29658fc3423f0.diff
LOG: [mlir][memref] Missing type conversion in memref.reshape llvm lowering
Shape can be memref of index type, so memref::LoadOp result need to be converted into llvm type.
Differential Revision: https://reviews.llvm.org/D129965
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 6e363997ac16d..2e574903b73c0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1161,6 +1161,11 @@ struct MemRefReshapeOpLowering
Value shapeOp = reshapeOp.getShape();
Value index = createIndexConstant(rewriter, loc, i);
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
+ Type indexType = getIndexType();
+ if (dimSize.getType() != indexType)
+ dimSize = typeConverter->materializeTargetConversion(
+ rewriter, loc, indexType, dimSize);
+ assert(dimSize && "Invalid memref element type");
}
desc.setSize(rewriter, loc, i, dimSize);
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 1296f8e4881cd..2520294ed03dc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -306,3 +306,35 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
return %0 : memref<?x?x12x32xf32>
// CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @memref.reshape_index
+// CHECK-SAME: %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
+func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) -> memref<?xi32> {
+ // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef:.*]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0:.*]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+ // CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1:.*]][2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+ // CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
+
+ // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
+ // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr<i64>
+ // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+
+ // CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert4:.*]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi32>
+ // CHECK: return %[[result_cast:.*]] : memref<?xi32>
+
+ %1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
+ return %1 : memref<?xi32>
+}
More information about the Mlir-commits
mailing list