[Mlir-commits] [mlir] d05d421 - [mlir] Add partial lowering of shape.cstr_broadcastable.

Tres Popp llvmlistbot at llvm.org
Tue Nov 3 00:57:40 PST 2020


Author: Tres Popp
Date: 2020-11-03T09:57:23+01:00
New Revision: d05d42199f77852e1f77b94acbe7b28f39ede64f

URL: https://github.com/llvm/llvm-project/commit/d05d42199f77852e1f77b94acbe7b28f39ede64f
DIFF: https://github.com/llvm/llvm-project/commit/d05d42199f77852e1f77b94acbe7b28f39ede64f.diff

LOG: [mlir] Add partial lowering of shape.cstr_broadcastable.

Because cstr operations allow more instruction reordering than asserts, we only
lower cstr_broadcastable to std ops with cstr_require. This ensures that the
more drastic lowering to asserts can happen specifically with the user's desire.

Differential Revision: https://reviews.llvm.org/D89325

Added: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td

Modified: 
    mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index eaea3de6c869..25c835d97723 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ShapeToStandard.td)
+mlir_tablegen(ShapeToStandard.cpp.inc -gen-rewriters)
+add_public_tablegen_target(ShapeToStandardIncGen)
+
 add_mlir_conversion_library(MLIRShapeToStandard
   ConvertShapeConstraints.cpp
   ShapeToStandard.cpp
@@ -7,6 +11,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
 
   DEPENDS
   MLIRConversionPassIncGen
+  ShapeToStandardIncGen
 
   LINK_COMPONENTS
   Core

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 704b0cdb0324..a7ada39261c5 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -566,6 +566,11 @@ class ToExtentTensorOpConversion
 };
 } // namespace
 
+namespace {
+/// Import the Shape Ops to Std Patterns.
+#include "ShapeToStandard.cpp.inc"
+} // namespace
+
 namespace {
 /// Conversion pass.
 class ConvertShapeToStandardPass
@@ -580,7 +585,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
   MLIRContext &ctx = getContext();
   ConversionTarget target(ctx);
   target.addLegalDialect<StandardOpsDialect, SCFDialect>();
-  target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
 
   // Setup conversion patterns.
   OwningRewritePatternList patterns;
@@ -595,6 +600,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
 void mlir::populateShapeToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   // clang-format off
+  populateWithGenerated(ctx, patterns);
   patterns.insert<
       AnyOpConversion,
       BinaryOpConversion<AddOp, AddIOp>,

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
new file mode 100644
index 000000000000..a5eaa7a2a889
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
@@ -0,0 +1,27 @@
+//==-- ShapeToStandard.td - Shape to Standard Patterns -------*- tablegen -*==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower Shape ops to Std.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_TD
+#define MLIR_CONVERSION_SHAPETOSTANDARD_TD
+
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+
+def BroadcastableStringAttr : NativeCodeCall<[{
+  $_builder.getStringAttr("required broadcastable shapes")
+}]>;
+
+def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
+            (Shape_CstrRequireOp
+              (Shape_IsBroadcastableOp $LHS, $RHS),
+              (BroadcastableStringAttr))>;
+
+#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 56594d529e4d..bff2956b347f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -420,3 +420,42 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
 // CHECK:           }
 // CHECK:           return %[[ALL_RESULT]] : i1
 // CHECK:         }
+
+// -----
+
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
+  %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
+  return %0 : !shape.witness
+}
+
+// CHECK-LABEL:   func @broadcast(
+// CHECK-SAME:                    %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME:                    %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
+// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
+// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK:           %[[TRUE:.*]] = constant true
+// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK:             %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
+// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
+// CHECK:             %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
+// CHECK:             %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
+// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
+// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
+// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
+// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
+// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
+// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
+// CHECK:           }
+// CHECK:           %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
+// CHECK:           return %[[RESULT]] : !shape.witness
+// CHECK:         }


        


More information about the Mlir-commits mailing list