[Mlir-commits] [mlir] 73cb58d - [mlir][Shape] Lower cstr_eq to shape_eq + assert

Benjamin Kramer llvmlistbot at llvm.org
Wed Mar 3 08:26:17 PST 2021


Author: Benjamin Kramer
Date: 2021-03-03T17:22:28+01:00
New Revision: 73cb58dc48cac18520c1af5f910dfcc718e2e380

URL: https://github.com/llvm/llvm-project/commit/73cb58dc48cac18520c1af5f910dfcc718e2e380
DIFF: https://github.com/llvm/llvm-project/commit/73cb58dc48cac18520c1af5f910dfcc718e2e380.diff

LOG: [mlir][Shape] Lower cstr_eq to shape_eq + assert

Differential Revision: https://reviews.llvm.org/D97860

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
    mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index e9d31ac93438..af976056acb0 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -39,6 +39,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
 void mlir::populateConvertShapeConstraintsConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   patterns.insert<CstrBroadcastableToRequire>(ctx);
+  patterns.insert<CstrEqToRequire>(ctx);
   patterns.insert<ConvertCstrRequireOp>(ctx);
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
index aac3789c3b58..573f0eb2f5d8 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
@@ -24,4 +24,11 @@ def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes),
               (Shape_IsBroadcastableOp $shapes),
               (BroadcastableStringAttr))>;
 
+def EqStringAttr : NativeCodeCall<[{
+  $_builder.getStringAttr("required equal shapes")
+}]>;
+
+def CstrEqToRequire : Pat<(Shape_CstrEqOp $shapes),
+  (Shape_CstrRequireOp (Shape_ShapeEqOp $shapes), (EqStringAttr))>;
+
 #endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD

diff  --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 5b47d9453261..eff6c149bea0 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -14,6 +14,19 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
   return %witness : !shape.witness
 }
 
+// CHECK-LABEL:   func @cstr_eq(
+// CHECK-SAME:                             %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME:                             %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK:           %[[RET:.*]] = shape.const_witness true
+// CHECK:           %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
+// CHECK:           assert %[[EQUAL_IS_VALID]], "required equal shapes"
+// CHECK:           return %[[RET]] : !shape.witness
+// CHECK:         }
+func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
+  %witness = shape.cstr_eq %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
+  return %witness : !shape.witness
+}
+
 // CHECK-LABEL: func @cstr_require
 func @cstr_require(%arg0: i1) -> !shape.witness {
   // CHECK: %[[RET:.*]] = shape.const_witness true


        


More information about the Mlir-commits mailing list