[Mlir-commits] [mlir] 73cb58d - [mlir][Shape] Lower cstr_eq to shape_eq + assert
Benjamin Kramer
llvmlistbot at llvm.org
Wed Mar 3 08:26:17 PST 2021
Author: Benjamin Kramer
Date: 2021-03-03T17:22:28+01:00
New Revision: 73cb58dc48cac18520c1af5f910dfcc718e2e380
URL: https://github.com/llvm/llvm-project/commit/73cb58dc48cac18520c1af5f910dfcc718e2e380
DIFF: https://github.com/llvm/llvm-project/commit/73cb58dc48cac18520c1af5f910dfcc718e2e380.diff
LOG: [mlir][Shape] Lower cstr_eq to shape_eq + assert
Differential Revision: https://reviews.llvm.org/D97860
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index e9d31ac93438..af976056acb0 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -39,6 +39,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
void mlir::populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<CstrBroadcastableToRequire>(ctx);
+ patterns.insert<CstrEqToRequire>(ctx);
patterns.insert<ConvertCstrRequireOp>(ctx);
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
index aac3789c3b58..573f0eb2f5d8 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
@@ -24,4 +24,11 @@ def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes),
(Shape_IsBroadcastableOp $shapes),
(BroadcastableStringAttr))>;
+def EqStringAttr : NativeCodeCall<[{
+ $_builder.getStringAttr("required equal shapes")
+}]>;
+
+def CstrEqToRequire : Pat<(Shape_CstrEqOp $shapes),
+ (Shape_CstrRequireOp (Shape_ShapeEqOp $shapes), (EqStringAttr))>;
+
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD
diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 5b47d9453261..eff6c149bea0 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -14,6 +14,19 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
return %witness : !shape.witness
}
+// CHECK-LABEL: func @cstr_eq(
+// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK: %[[RET:.*]] = shape.const_witness true
+// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
+// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes"
+// CHECK: return %[[RET]] : !shape.witness
+// CHECK: }
+func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
+ %witness = shape.cstr_eq %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
+ return %witness : !shape.witness
+}
+
// CHECK-LABEL: func @cstr_require
func @cstr_require(%arg0: i1) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true
More information about the Mlir-commits
mailing list