[Mlir-commits] [mlir] 3810f76 - [mlir][tensor|memref] Harden the checks on dim op

Quentin Colombet llvmlistbot at llvm.org
Thu Feb 2 02:39:04 PST 2023


Author: Quentin Colombet
Date: 2023-02-02T11:34:03+01:00
New Revision: 3810f76c50923dd0ef16ace6a550e5a4ab6c16a5

URL: https://github.com/llvm/llvm-project/commit/3810f76c50923dd0ef16ace6a550e5a4ab6c16a5
DIFF: https://github.com/llvm/llvm-project/commit/3810f76c50923dd0ef16ace6a550e5a4ab6c16a5.diff

LOG: [mlir][tensor|memref] Harden the checks on dim op

Prior to this patch it was possible to use the dim operation on a 0-D
memref/tensor.
Unless we want to change the semantic of a 0-D shape, this doesn't make
sense because, paraphrasing the dim op semantic, this is guaranteed to
produce something that is undefined. (The requested index is guaranteed
to be equal to or greater than the rank.)

Harden the type requirements for the dim op by disallowing 0-D shaped
types.

This "fixes" llvm.org/PR60195 by rejecting dim op on 0-D shapes instead of
crashing during LLVM conversion.

Differential Revision: https://reviews.llvm.org/D142445

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c2c0d0afe4f1..f5dab426cb9d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -572,7 +572,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     ```
   }];
 
-  let arguments = (ins AnyRankedOrUnrankedMemRef:$source,
+  let arguments = (ins AnyNon0RankedOrUnrankedMemRef:$source,
                        Index:$index);
   let results = (outs Index:$result);
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9328b1f4b2b8..e702189e7847 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -115,7 +115,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [
     ```
   }];
 
-  let arguments = (ins AnyTensor:$source,
+  let arguments = (ins AnyNon0RankedOrUnrankedTensor:$source,
                        Index:$index);
   let results = (outs Index:$result);
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7fb583febfba..d307bebecbe0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -548,6 +548,12 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
                          == }]
                       # rank>)>]>;
 
+// Whether a shaped type has a rank greater than or equal of the specified rank.
+class HasRankGreaterOrEqualPred<int rank> : And<[
+    HasRankPred,
+    CPred<[{$_self.cast<::mlir::ShapedType>().getRank() >= }] # rank>
+]>;
+
 // Vector types.
 
 class VectorOf<list<Type> allowedTypes> :
@@ -748,7 +754,16 @@ class RankedTensorOf<
     string summary = "ranked tensor">
   : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
 
+class Non0RankedTensorOf<list<Type> allowedTypes>
+  : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
+      "non-0-ranked.tensor">;
+
 def AnyRankedTensor : RankedTensorOf<[AnyType]>;
+def AnyNon0RankedTensor  : Non0RankedTensorOf<[AnyType]>;
+def AnyUnrankedTensor  : UnrankedTensorOf<[AnyType]>;
+
+def AnyNon0RankedOrUnrankedTensor:
+    AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>;
 
 // Ranked tensor type with one of the specified types and ranks.
 class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
@@ -782,13 +797,20 @@ def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
 class MemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
                         "::mlir::MemRefType">;
+class Non0RankedMemRefOf<list<Type> allowedTypes> :
+    ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
+         "non-0-ranked." # MemRefOf<allowedTypes>.summary,
+         "::mlir::MemRefType">;
 
 def AnyMemRef : MemRefOf<[AnyType]>;
+def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;
 
 class RankedOrUnrankedMemRefOf<list<Type> allowedTypes>:
     AnyTypeOf<[UnrankedMemRefOf<allowedTypes>, MemRefOf<allowedTypes>]>;
 
 def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
+def AnyNon0RankedOrUnrankedMemRef:
+    AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;
 
 // Memref declarations handle any memref, independent of rank, size, (static or
 // dynamic), layout, or memory space.

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index ccbf929dbd20..19874f08cd2f 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1040,3 +1040,11 @@ func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
   %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
+
+// -----
+
+// Asking the dimension of a 0-D shape doesn't make sense.
+func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
+  memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
+  return
+}

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index cbcc1e3d339b..fe665a32d709 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL:   func @dim(
-// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:              %[[TENSOR:.*]]: tensor<*xf32>,
 // CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
-// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
-// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
+// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<*xf32>
 // CHECK:           return %[[EXTENT]] : index
-func.func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
-  %0 = tensor.dim %arg0, %arg1 : tensor<f32>
+func.func @dim(%arg0: tensor<*xf32>, %arg1: index) -> index {
+  %0 = tensor.dim %arg0, %arg1 : tensor<*xf32>
   return %0 : index
 }
 

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 36c4dfe6e67a..d15819f497a7 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -8,6 +8,14 @@ func.func @dim(%arg : tensor<1x?xf32>) {
 
 // -----
 
+// Asking the dimension of a 0-D shape doesn't make sense.
+func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
+  tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
+  return
+}
+
+// -----
+
 func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
   // expected-error at +1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
   %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>


        


More information about the Mlir-commits mailing list