[Mlir-commits] [mlir] 54d81e4 - [mlir] Allow negative strides and offset in StridedLayoutAttr

Ivan Butygin llvmlistbot at llvm.org
Wed Sep 21 04:53:38 PDT 2022


Author: Ivan Butygin
Date: 2022-09-21T13:21:53+02:00
New Revision: 54d81e49e3b72f6a305891fe169ecd7c6f559223

URL: https://github.com/llvm/llvm-project/commit/54d81e49e3b72f6a305891fe169ecd7c6f559223
DIFF: https://github.com/llvm/llvm-project/commit/54d81e49e3b72f6a305891fe169ecd7c6f559223.diff

LOG: [mlir] Allow negative strides and offset in StridedLayoutAttr

Negative strides are useful for creating reverse-view of array. We don't have specific example for negative offset yet but will add it for consistency.

Differential Revision: https://reviews.llvm.org/D134147

Added: 
    

Modified: 
    mlir/lib/AsmParser/AttributeParser.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/Builtin/types.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/IR/invalid-builtin-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 580dbb0eefde7..819c86c997f3c 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -1174,17 +1174,23 @@ Attribute Parser::parseStridedLayoutAttr() {
 
     SMLoc loc = getToken().getLoc();
     auto emitWrongTokenError = [&] {
-      emitError(loc, "expected a non-negative 64-bit signed integer or '?'");
+      emitError(loc, "expected a 64-bit signed integer or '?'");
       return llvm::None;
     };
 
+    bool negative = consumeIf(Token::minus);
+
     if (getToken().is(Token::integer)) {
       Optional<uint64_t> value = getToken().getUInt64IntegerValue();
       if (!value ||
           *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
         return emitWrongTokenError();
       consumeToken();
-      return static_cast<int64_t>(*value);
+      auto result = static_cast<int64_t>(*value);
+      if (negative)
+        result = -result;
+
+      return result;
     }
 
     return emitWrongTokenError();

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 22eff2dc34b9c..70c2b47f41721 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -269,14 +269,9 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
 LogicalResult
 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                           int64_t offset, ArrayRef<int64_t> strides) {
-  if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset)
-    return emitError() << "offset must be non-negative or dynamic";
+  if (llvm::any_of(strides, [&](int64_t stride) { return stride == 0; }))
+    return emitError() << "strides must not be zero";
 
-  if (llvm::any_of(strides, [&](int64_t stride) {
-        return stride <= 0 && stride != ShapedType::kDynamicStrideOrOffset;
-      })) {
-    return emitError() << "strides must be positive or dynamic";
-  }
   return success();
 }
 

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 14d2b9709b7cc..ce6eb620671c2 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -540,6 +540,31 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @subview_negative_stride
+// CHECK-SAME: (%[[ARG:.*]]: memref<7xf32>)
+func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
+  // CHECK: %[[ORIG:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<7xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[NEW1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[PTR1:.*]] = llvm.extractvalue %[[ORIG]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[PTR2:.*]] = llvm.bitcast %[[PTR1]] : !llvm.ptr<f32> to !llvm.ptr<f32>
+  // CHECK: %[[NEW2:.*]] = llvm.insertvalue %[[PTR2]], %[[NEW1]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[PTR3:.*]] = llvm.extractvalue %[[ORIG]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[PTR4:.*]] = llvm.bitcast %[[PTR3]] : !llvm.ptr<f32> to !llvm.ptr<f32>
+  // CHECK: %[[NEW3:.*]] = llvm.insertvalue %[[PTR4]], %[[NEW2]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(6 : index) : i64
+  // CHECK: %[[NEW4:.*]] = llvm.insertvalue %[[OFFSET]], %[[NEW3]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(7 : i64) : i64
+  // CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(-1 : i64) : i64
+  // CHECK: %[[NEW5:.*]] = llvm.insertvalue %[[SIZE]], %[[NEW4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[NEW6:.*]] = llvm.insertvalue %[[STRIDE]], %[[NEW5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[NEW6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> to memref<7xf32, strided<[-1], offset: 6>>
+  // CHECK: return %[[RES]] : memref<7xf32, strided<[-1], offset: 6>>
+  %0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
+  return %0 : memref<7xf32, strided<[-1], offset: 6>>
+}
+
+// -----
+
 // CHECK-LABEL: func @assume_alignment
 func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr<f16>, ptr<f16>, i64, array<2 x i64>, array<2 x i64>)>

diff  --git a/mlir/test/Dialect/Builtin/types.mlir b/mlir/test/Dialect/Builtin/types.mlir
index d01b819dc7fbb..80840ec32424e 100644
--- a/mlir/test/Dialect/Builtin/types.mlir
+++ b/mlir/test/Dialect/Builtin/types.mlir
@@ -16,3 +16,7 @@ func.func private @f6() -> memref<?x?xf32, strided<[42, 1], offset: 0>>
 func.func private @f7() -> memref<f32, strided<[]>>
 // CHECK: memref<f32, strided<[], offset: ?>>
 func.func private @f8() -> memref<f32, strided<[], offset: ?>>
+// CHECK: memref<?xf32, strided<[-1], offset: ?>>
+func.func private @f9() -> memref<?xf32, strided<[-1], offset: ?>>
+// CHECK: memref<f32, strided<[], offset: -1>>
+func.func private @f10() -> memref<f32, strided<[], offset: -1>>

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 421a04f89cffc..3835cb3221c93 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -127,6 +127,43 @@ func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, strided<
 //       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
 //  CHECK-SAME:       : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
 
+// -----
+
+func.func @subview_negative_stride1(%arg0 : memref<?xf32>) -> memref<?xf32, strided<[?], offset: ?>>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant -1 : index
+  %1 = memref.dim %arg0, %c0 : memref<?xf32>
+  %2 = arith.addi %1, %c1 : index
+  %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
+  return %3 : memref<?xf32, strided<[?], offset: ?>>
+}
+//       CHECK: func @subview_negative_stride1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>)
+//       CHECK:   %[[C1:.*]] = arith.constant 0
+//       CHECK:   %[[C2:.*]] = arith.constant -1
+//       CHECK:   %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<?xf32>
+//       CHECK:   %[[DIM2:.*]] = arith.addi %[[DIM1]], %[[C2]] : index
+//       CHECK:   %[[RES1:.*]] = memref.subview %[[ARG0]][%[[DIM2]]] [%[[DIM1]]] [-1] : memref<?xf32> to memref<?xf32, strided<[-1], offset: ?>>
+//       CHECK:   %[[RES2:.*]] = memref.cast %[[RES1]] : memref<?xf32, strided<[-1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
+
+// -----
+
+func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, strided<[?], offset: ?>>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant -1 : index
+  %1 = memref.dim %arg0, %c0 : memref<7xf32>
+  %2 = arith.addi %1, %c1 : index
+  %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<7xf32> to memref<?xf32, strided<[?], offset: ?>>
+  return %3 : memref<?xf32, strided<[?], offset: ?>>
+}
+//       CHECK: func @subview_negative_stride2
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<7xf32>)
+//       CHECK:   %[[RES1:.*]] = memref.subview %[[ARG0]][6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
+//       CHECK:   %[[RES2:.*]] = memref.cast %[[RES1]] : memref<7xf32, strided<[-1], offset: 6>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
 
 // -----
 

diff  --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 4f168954785db..9884212e916c1 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -74,7 +74,7 @@ func.func private @memref_unfinished_strided() -> memref<?x?xf32, strided<>>
 
 // -----
 
-// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
+// expected-error @below {{expected a 64-bit signed integer or '?'}}
 func.func private @memref_unfinished_stride_list() -> memref<?x?xf32, strided<[>>
 
 // -----
@@ -89,7 +89,7 @@ func.func private @memref_missing_offset_colon() -> memref<?x?xf32, strided<[],
 
 // -----
 
-// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
+// expected-error @below {{expected a 64-bit signed integer or '?'}}
 func.func private @memref_missing_offset_value() -> memref<?x?xf32, strided<[], offset: >>
 
 // -----
@@ -99,21 +99,11 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
 
 // -----
 
-// expected-error @below {{strides must be positive or dynamic}}
+// expected-error @below {{strides must not be zero}}
 func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
 
 // -----
 
-// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
-func.func private @memref_negative_stride() -> memref<?x?xf32, strided<[-2, -2]>>
-
-// -----
-
-// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
-func.func private @memref_negative_offset() -> memref<?x?xf32, strided<[2, 1], offset: -2>>
-
-// -----
-
 // expected-error @below {{expected the number of strides to match the rank}}
 func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
 


        


More information about the Mlir-commits mailing list