[Mlir-commits] [mlir] 4f865b7 - [mlir] support creating memref descriptors from static shape with non-zero offset

Alex Zinenko llvmlistbot at llvm.org
Wed Feb 12 13:41:27 PST 2020


Author: Tobias Gysi
Date: 2020-02-12T22:40:49+01:00
New Revision: 4f865b77941db364eaf0a6c265d183274c503ecb

URL: https://github.com/llvm/llvm-project/commit/4f865b77941db364eaf0a6c265d183274c503ecb
DIFF: https://github.com/llvm/llvm-project/commit/4f865b77941db364eaf0a6c265d183274c503ecb.diff

LOG: [mlir] support creating memref descriptors from static shape with non-zero offset

This patch adapts the method MemRefDescriptor::fromStaticShape to
support static non-zero offsets. The updated method uses the
getStridesAndOffset method to extract strides and offset. The patch also
adapts the test cases since sizes and strides are now set in forward
instead of reverse order.

Differential Revision: https://reviews.llvm.org/D74474

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir
    mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 8d97ff1bd8ab..57ebe42ac688 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -430,7 +430,17 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
                                   LLVMTypeConverter &typeConverter,
                                   MemRefType type, Value memory) {
   assert(type.hasStaticShape() && "unexpected dynamic shape");
-  assert(type.getAffineMaps().empty() && "unexpected layout map");
+
+  // Extract all strides and offsets and verify they are static.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto result = getStridesAndOffset(type, strides, offset);
+  (void)result;
+  assert(succeeded(result) && "unexpected failure in stride computation");
+  assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+         "expected static offset");
+  assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
+         "expected static strides");
 
   auto convertedType = typeConverter.convertType(type);
   assert(convertedType && "unexpected failure in memref type conversion");
@@ -438,16 +448,12 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
   descr.setAllocatedPtr(builder, loc, memory);
   descr.setAlignedPtr(builder, loc, memory);
-  descr.setConstantOffset(builder, loc, 0);
-
-  // Fill in sizes and strides, in reverse order to simplify stride
-  // calculation.
-  uint64_t runningStride = 1;
-  for (unsigned i = type.getRank(); i > 0; --i) {
-    unsigned dim = i - 1;
-    descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
-    descr.setConstantStride(builder, loc, dim, runningStride);
-    runningStride *= type.getDimSize(dim);
+  descr.setConstantOffset(builder, loc, offset);
+
+  // Fill in sizes and strides
+  for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
+    descr.setConstantSize(builder, loc, i, type.getDimSize(i));
+    descr.setConstantStride(builder, loc, i, strides[i]);
   }
   return descr;
 }

diff  --git a/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir b/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir
index 115c71d12800..c6d080fde2aa 100644
--- a/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir
@@ -92,18 +92,18 @@ gpu.module @kernel {
     // CHECK: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
     // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
     // CHECK: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
-    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
-    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c6]], %[[descr4]][3, 2]
-    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 2]
+    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
+    // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
+    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
+    // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0]
     // CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
     // CHECK: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1]
     // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
     // CHECK: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1]
-    // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
-    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c4]], %[[descr8]][3, 0]
-    // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
-    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c12]], %[[descr9]][4, 0]
+    // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
+    // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2]
+    // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+    // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]
 
     %c0 = constant 0 : index
     store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3>

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index cc8cfc3b6917..b47d355f77f6 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -24,20 +24,48 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
-// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 // BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
   return %static : memref<32x18xf32>
 }
 
 // -----
 
+// CHECK-LABEL: func @check_static_return_with_offset
+// CHECK-COUNT-2: !llvm<"float*">
+// CHECK-COUNT-5: !llvm.i64
+// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-LABEL: func @check_static_return_with_offset
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
+func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> {
+// CHECK:  llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+
+// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(22 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  return %static : memref<32x18xf32, offset:7, strides:[22,1]>
+}
+
+// -----
+
 // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
 // ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
 // BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
@@ -302,7 +330,7 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
 // BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
 func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
 // CHECK:        llvm.mlir.constant(42 : index) : !llvm.i64
-// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// BAREPTR:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
 // BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
   %0 = dim %static, 0 : memref<42x32x15x13x27xf32>
 // CHECK-NEXT:  llvm.mlir.constant(32 : index) : !llvm.i64


        


More information about the Mlir-commits mailing list