[Mlir-commits] [mlir] 7f600da - [MLIR][Shape] Allow `shape.any` to operate on extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 04:03:32 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T11:03:10Z
New Revision: 7f600da82882d38dc6e5d0d05e8ab85170ba020a

URL: https://github.com/llvm/llvm-project/commit/7f600da82882d38dc6e5d0d05e8ab85170ba020a
DIFF: https://github.com/llvm/llvm-project/commit/7f600da82882d38dc6e5d0d05e8ab85170ba020a.diff

LOG: [MLIR][Shape] Allow `shape.any` to operate on extent tensors

Differential Revision: https://reviews.llvm.org/D84433

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 014b72cd1339..64dba487c507 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -509,11 +509,14 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
 //===----------------------------------------------------------------------===//
 
 // TODO: Move the code below and witnesses to a 
diff erent file.
-def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
+def Shape_AnyOp : Shape_Op<"any", [Commutative,
+                                   NoSideEffect,
+                                   SameOperandsAndResultType]> {
   let summary = "Return any combination of the input shapes";
   let description = [{
-    This operation takes multiple input shapes and returns some combination of
-    their dimensions. This can be best seen with examples below.
+    This operation takes multiple input shapes or extent tensors and returns
+    some combination of their dimensions. This can be best seen with examples
+    below.
 
     The result is undefined, but still side-effect free, in cases where the
     inputs have 
diff ering ranks or 
diff er in extents of shared dimensions.
@@ -525,11 +528,10 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
     ```
   }];
 
-  let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
-  let results = (outs Shape_ShapeType:$result);
-
-  let assemblyFormat = "$inputs attr-dict";
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
+  let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
+  let assemblyFormat = "$inputs `:` type($result)  attr-dict";
   let hasFolder = 1;
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 934a28a6ed44..d8c0cbd5f9de 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -165,11 +165,12 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
 // Lower `any` to its first operand.
 // CHECK-LABEL: @any_of_three
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
-func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
-    -> !shape.shape {
+func @any_of_three(%a : tensor<?xindex>,
+                   %b : tensor<?xindex>,
+                   %c : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK: return %[[A]] : tensor<?xindex>
-  %result = shape.any %a, %b, %c
-  return %result : !shape.shape
+  %result = shape.any %a, %b, %c : tensor<?xindex>
+  return %result : tensor<?xindex>
 }
 
 // -----
@@ -177,9 +178,9 @@ func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
 // Lower `any` to its first operand.
 // CHECK-LABEL: @any_of_one
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
-func @any_of_one(%a : !shape.shape) -> !shape.shape {
+func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK: return %[[A]] : tensor<?xindex>
-  %result = shape.any %a
-  return %result : !shape.shape
+  %result = shape.any %a : tensor<?xindex>
+  return %result : tensor<?xindex>
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index e5b77a870a85..4d8fca8d1318 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -364,14 +364,25 @@ func @f() {
 
 // any can be replaced with a constant input if it has one.
 // CHECK-LABEL: func @f
-func @f(%arg0 : !shape.shape) -> !shape.shape {
+func @f(%arg : !shape.shape) -> !shape.shape {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape
   // CHECK-NEXT: return %[[CS]]
   %0 = shape.const_shape [2, 3, 4] : !shape.shape
-  %1 = shape.any %0, %arg0
+  %1 = shape.any %0, %arg : !shape.shape
   return %1 : !shape.shape
 }
 
+// -----
+
+// any can be replaced with a constant input if it has one.
+// CHECK-LABEL: func @f
+func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
+  // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
+  %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
+  %1 = shape.any %0, %arg : tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
 
 // -----
 
@@ -380,7 +391,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
   // CHECK-NEXT: %[[CS:.*]] = shape.any
   // CHECK-NEXT: return %[[CS]]
-  %1 = shape.any %arg0, %arg1
+  %1 = shape.any %arg0, %arg1 : !shape.shape
   return %1 : !shape.shape
 }
 

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index f023c02c510c..3b44af99b4fe 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -1,4 +1,3 @@
-// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
 // Verify the printed output can be parsed.
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 // Verify the generic form can be parsed.
@@ -99,7 +98,7 @@ func @test_constraints() {
   %w3 = shape.const_witness false
   %w4 = shape.assuming_all %w0, %w1, %w2, %w3
   shape.assuming %w4 -> !shape.shape {
-    %2 = shape.any %0, %1
+    %2 = shape.any %0, %1 : !shape.shape
     shape.assuming_yield %2 : !shape.shape
   }
   return
@@ -173,3 +172,14 @@ func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
   %result = shape.get_extent %arg, %c0 : tensor<?xindex>
   return %result : !shape.size
 }
+
+func @any() {
+  %0 = shape.const_shape [1, 2, 3] : !shape.shape
+  %1 = shape.const_shape [4, 5, 6] : !shape.shape
+  %2 = shape.any %0, %1 : !shape.shape
+  %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+  %5 = shape.any %3, %4 : tensor<?xindex>
+  return
+}
+


        


More information about the Mlir-commits mailing list