[Mlir-commits] [mlir] 56e1971 - [MLIR][Shape] Generalize `shape.concat` to extent tensors

Jacques Pienaar llvmlistbot at llvm.org
Thu Jun 9 08:23:32 PDT 2022


Author: Yuanqiang Liu
Date: 2022-06-09T08:23:26-07:00
New Revision: 56e19717f56ac7b96d1eb91107fe666429363ef7

URL: https://github.com/llvm/llvm-project/commit/56e19717f56ac7b96d1eb91107fe666429363ef7
DIFF: https://github.com/llvm/llvm-project/commit/56e19717f56ac7b96d1eb91107fe666429363ef7.diff

LOG: [MLIR][Shape] Generalize `shape.concat` to extent tensors

The operation `shape.concat` was used for type shape only.
We now enable it for extent tensors.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D127321

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 36f7c7c8d03e5..2773883f57a2a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -766,10 +766,13 @@ def Shape_ConcatOp : Shape_Op<"concat", [NoSideEffect]> {
     concat([], [4,5,6]) -> [4,5,6]
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
-  let results = (outs Shape_ShapeType:$result);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, Shape_ShapeOrExtentTensorType:$rhs);
+  let results = (outs Shape_ShapeOrExtentTensorType:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
   let hasFolder = 1;
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 456f27fbb126a..700dc478d330f 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -179,12 +179,24 @@ func.func @f() -> !shape.shape {
   // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape
   %lhs = shape.const_shape [0, 1] : !shape.shape
   %rhs = shape.const_shape [2, 3] : !shape.shape
-  %0 = shape.concat %lhs, %rhs
+  %0 = shape.concat %lhs, %rhs : !shape.shape , !shape.shape -> !shape.shape
   return %0 : !shape.shape
 }
 
 // -----
 
+// Basic case.
+// CHECK-LABEL: func @f
+func.func @f() -> tensor<4xindex> {
+  // CHECK: shape.const_shape [0, 1, 2, 3] : tensor<4xindex>
+  %lhs = shape.const_shape [0, 1] : tensor<2xindex>
+  %rhs = shape.const_shape [2, 3] : tensor<2xindex>
+  %0 = shape.concat %lhs, %rhs : tensor<2xindex>, tensor<2xindex> -> tensor<4xindex>
+  return %0 : tensor<4xindex>
+}
+
+// -----
+
 // Basic case.
 // CHECK-LABEL: func @f
 func.func @f() -> tensor<2xindex> {


        


More information about the Mlir-commits mailing list