[Mlir-commits] [mlir] a70f2eb - [MLIR][Shape] Merge `shape` to `std`/`scf` lowerings.

Frederik Gossen llvmlistbot at llvm.org
Mon Sep 7 07:40:01 PDT 2020


Author: Frederik Gossen
Date: 2020-09-07T14:39:37Z
New Revision: a70f2eb3e39a42a71ba077247f9deafbdf1e8092

URL: https://github.com/llvm/llvm-project/commit/a70f2eb3e39a42a71ba077247f9deafbdf1e8092
DIFF: https://github.com/llvm/llvm-project/commit/a70f2eb3e39a42a71ba077247f9deafbdf1e8092.diff

LOG: [MLIR][Shape] Merge `shape` to `std`/`scf` lowerings.

Merge the two lowering passes because they are not useful by themselves. The new
pass lowers to `std` and `scf` is considered an auxiliary dialect.

See also
https://llvm.discourse.group/t/conversions-with-multiple-target-dialects/1541/12

Differential Revision: https://reviews.llvm.org/D86779

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
    mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 5dd10932981b..b04498598b29 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,7 +23,6 @@
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
-#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1b27a7308c7a..d4b478dbf4ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -239,17 +239,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
   let summary = "Convert operations from the shape dialect into the standard "
                 "dialect";
   let constructor = "mlir::createConvertShapeToStandardPass()";
-  let dependentDialects = ["StandardOpsDialect"];
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeToSCF
-//===----------------------------------------------------------------------===//
-
-def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
-  let summary = "Convert operations from the shape dialect to the SCF dialect";
-  let constructor = "mlir::createConvertShapeToSCFPass()";
-  let dependentDialects = ["scf::SCFDialect"];
+  let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
deleted file mode 100644
index f953f6e2ddf1..000000000000
--- a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
-#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
-
-#include <memory>
-
-namespace mlir {
-
-class MLIRContext;
-class FunctionPass;
-class OwningRewritePatternList;
-
-void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
-                                          MLIRContext *ctx);
-
-std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c2bb2130569d..fe2af07b2a6a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -12,7 +12,6 @@ add_subdirectory(OpenMPToLLVM)
 add_subdirectory(SCFToGPU)
 add_subdirectory(SCFToSPIRV)
 add_subdirectory(SCFToStandard)
-add_subdirectory(ShapeToSCF)
 add_subdirectory(ShapeToStandard)
 add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)

diff  --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
deleted file mode 100644
index 60dd2b8514da..000000000000
--- a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRShapeToSCF
-  ShapeToSCF.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRShape
-  MLIRPass
-  MLIRSCF
-  MLIRTransforms
-  )

diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
deleted file mode 100644
index ae326c5c513e..000000000000
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ /dev/null
@@ -1,337 +0,0 @@
-//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
-
-#include "../PassDetail.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::shape;
-using namespace mlir::scf;
-
-namespace {
-struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
-  using OpConversionPattern<BroadcastOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult BroadcastOpConverter::matchAndRewrite(
-    BroadcastOp op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
-  // on shapes.
-  if (op.getType().isa<ShapeType>())
-    return failure();
-
-  assert(!op.lhs().getType().isa<ShapeType>() &&
-         !op.rhs().getType().isa<ShapeType>());
-  auto loc = op.getLoc();
-  BroadcastOp::Adaptor transformed(operands);
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-
-  // Find smaller and greater rank and extent tensor.
-  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
-  Value lhsSmaller =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
-  Type indexTy = rewriter.getIndexType();
-  Type extentTensorTy = op.getType();
-  auto ifOp = rewriter.create<IfOp>(
-      loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
-      lhsSmaller,
-      [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
-                                               rhsRank, transformed.rhs()});
-      },
-      [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
-                                               lhsRank, transformed.lhs()});
-      });
-  Value smallerRank = ifOp.getResult(0);
-  Value smallerOperand = ifOp.getResult(1);
-  Value greaterRank = ifOp.getResult(2);
-  Value greaterOperand = ifOp.getResult(3);
-
-  // Allocate stack memory for the broadcasted extent tensor.
-  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
-  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
-
-  // Copy extents from greater operand that are not challenged.
-  Value rankDiff =
-      rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
-  rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
-                         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
-                           Value extent = b.create<ExtractElementOp>(
-                               loc, greaterOperand, ValueRange{iv});
-                           b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
-                           b.create<scf::YieldOp>(loc);
-                         });
-
-  // Determine remaining broadcasted extents.
-  rewriter.create<ForOp>(
-      loc, rankDiff, greaterRank, one, llvm::None,
-      [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
-        Value greaterOperandExtent =
-            b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
-        Value greaterOperandExtentIsOne =
-            b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
-        auto ifOp = b.create<IfOp>(
-            loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
-            [&](OpBuilder &b, Location loc) {
-              Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
-              Value smallerOperandExtent = b.create<ExtractElementOp>(
-                  loc, smallerOperand, ValueRange{ivShifted});
-              b.create<scf::YieldOp>(loc, smallerOperandExtent);
-            },
-            [&](OpBuilder &b, Location loc) {
-              b.create<scf::YieldOp>(loc, greaterOperandExtent);
-            });
-        Value extent = ifOp.getResult(0);
-        b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
-        b.create<scf::YieldOp>(loc);
-      });
-
-  // Load broadcasted shape as an extent tensor.
-  rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
-  return success();
-}
-
-namespace {
-/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
-/// only defined on `tensor<?xindex>` operands. The test for equality first
-/// compares their size and, if equal, checks every extent for equality.
-///
-/// Example:
-///
-/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
-///
-/// becomes
-///
-/// %c0 = constant 0 : index
-/// %0 = dim %arg0, %c0 : tensor<?xindex>
-/// %1 = dim %arg1, %c0 : tensor<?xindex>
-/// %2 = cmpi "eq", %0, %1 : index
-/// %result = scf.if %2 -> (i1) {
-///   %c1 = constant 1 : index
-///   %true = constant true
-///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
-///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
-///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
-///     %7 = cmpi "eq", %5, %6 : index
-///     %8 = and %arg3, %7 : i1
-///     scf.yield %8 : i1
-///   }
-///   scf.yield %4 : i1
-/// } else {
-///   %false = constant false
-///   scf.yield %false : i1
-/// }
-///
-struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
-  using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
-                                    ConversionPatternRewriter &rewriter) const {
-  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
-  // on shapes.
-  if (op.lhs().getType().isa<ShapeType>() ||
-      op.rhs().getType().isa<ShapeType>()) {
-    return failure();
-  }
-
-  ShapeEqOp::Adaptor transformed(operands);
-  auto loc = op.getLoc();
-  Type indexTy = rewriter.getIndexType();
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
-  Value eqRank =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
-  Type i1Ty = rewriter.getI1Type();
-  rewriter.replaceOpWithNewOp<IfOp>(
-      op, i1Ty, eqRank,
-      [&](OpBuilder &b, Location loc) {
-        Value one = b.create<ConstantIndexOp>(loc, 1);
-        Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
-        auto loop = b.create<scf::ForOp>(
-            loc, zero, lhsRank, one, ValueRange{init},
-            [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
-              Value conj = args[0];
-              Value lhsExtent =
-                  b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
-              Value rhsExtent =
-                  b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
-              Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
-                                                lhsExtent, rhsExtent);
-              Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
-              b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
-            });
-        b.create<scf::YieldOp>(loc, loop.getResults());
-      },
-      [&](OpBuilder &b, Location loc) {
-        Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
-        b.create<scf::YieldOp>(loc, result);
-      });
-  return success();
-}
-
-namespace {
-/// Converts `shape.reduce` to `scf.for`.
-struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
-};
-} // namespace
-
-LogicalResult
-ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
-                                   ConversionPatternRewriter &rewriter) const {
-  // For now, this lowering is only defined on `tensor<?xindex>` operands.
-  if (op.shape().getType().isa<ShapeType>())
-    return failure();
-
-  auto loc = op.getLoc();
-  shape::ReduceOp::Adaptor transformed(operands);
-
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  Type indexTy = rewriter.getIndexType();
-  Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
-
-  auto loop = rewriter.create<scf::ForOp>(
-      loc, zero, rank, one, op.initVals(),
-      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
-        Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
-
-        SmallVector<Value, 2> mappedValues{iv, extent};
-        mappedValues.append(args.begin(), args.end());
-
-        BlockAndValueMapping mapping;
-        Block *reduceBody = op.getBody();
-        mapping.map(reduceBody->getArguments(), mappedValues);
-        for (auto &nested : reduceBody->without_terminator())
-          b.clone(nested, mapping);
-
-        SmallVector<Value, 2> mappedResults;
-        for (auto result : reduceBody->getTerminator()->getOperands())
-          mappedResults.push_back(mapping.lookup(result));
-        b.create<scf::YieldOp>(loc, mappedResults);
-      });
-
-  rewriter.replaceOp(op, loop.getResults());
-  return success();
-}
-
-namespace {
-/// Converts `shape_of` to for loop for unranked tensors.
-class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
-public:
-  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
-                                    ConversionPatternRewriter &rewriter) const {
-  // For now, this lowering supports only error-free arguments.
-  if (op.getType().isa<ShapeType>())
-    return failure();
-
-  // For ranked tensors `shape_of` lowers to `std` and the pattern can be
-  // found in the corresponding pass.
-  ShapeOfOp::Adaptor transformed(operands);
-  Value arg = transformed.arg();
-  Type argTy = arg.getType();
-  if (argTy.isa<RankedTensorType>())
-    return failure();
-
-  // Allocate stack memory.
-  auto loc = op.getLoc();
-  Value rank = rewriter.create<mlir::RankOp>(loc, arg);
-  Type indexTy = rewriter.getIndexType();
-  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
-  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
-
-  // Copy shape extents to stack-allocated memory.
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  rewriter.create<scf::ForOp>(
-      loc, zero, rank, one, llvm::None,
-      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
-        Value dim = rewriter.create<DimOp>(loc, arg, iv);
-        rewriter.create<StoreOp>(loc, dim, mem, ValueRange{iv});
-        rewriter.create<scf::YieldOp>(loc);
-      });
-
-  // Load extents to tensor value.
-  rewriter.replaceOpWithNewOp<TensorLoadOp>(op.getOperation(), mem);
-  return success();
-}
-
-namespace {
-struct ConvertShapeToSCFPass
-    : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
-  void runOnFunction() override;
-};
-} // namespace
-
-void ConvertShapeToSCFPass::runOnFunction() {
-  MLIRContext &ctx = getContext();
-
-  // Populate conversion patterns.
-  OwningRewritePatternList patterns;
-  populateShapeToSCFConversionPatterns(patterns, &ctx);
-
-  // Setup target legality.
-  ConversionTarget target(getContext());
-  target.addLegalDialect<SCFDialect, StandardOpsDialect>();
-
-  // Apply conversion.
-  if (failed(applyPartialConversion(getFunction(), target, patterns)))
-    signalPassFailure();
-}
-
-void mlir::populateShapeToSCFConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  // clang-format off
-  patterns.insert<
-      BroadcastOpConverter,
-      ShapeEqOpConverter,
-      ReduceOpConverter,
-      ShapeOfOpConverter>(ctx);
-  // clang-format on
-}
-
-std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
-  return std::make_unique<ConvertShapeToSCFPass>();
-}

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e92bb83d4f42..8c917e08f942 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -12,10 +12,12 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace mlir::shape;
+using namespace mlir::scf;
 
 /// Conversion patterns.
 namespace {
@@ -63,67 +65,94 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
 } // namespace
 
 namespace {
-class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
-public:
-  using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
-    return success();
-  }
-};
-} // namespace
-
-namespace {
-class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
-public:
-  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
+  using OpConversionPattern<BroadcastOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+  matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
-LogicalResult ShapeOfOpConversion::matchAndRewrite(
-    ShapeOfOp op, ArrayRef<Value> operands,
+LogicalResult BroadcastOpConverter::matchAndRewrite(
+    BroadcastOp op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-
-  // For now, only error-free types are supported by this lowering.
+  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+  // on shapes.
   if (op.getType().isa<ShapeType>())
     return failure();
 
-  // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
-  // found in the corresponding pass.
-  ShapeOfOp::Adaptor transformed(operands);
-  Value tensorVal = transformed.arg();
-  Type tensorTy = tensorVal.getType();
-  if (tensorTy.isa<UnrankedTensorType>())
-    return failure();
-
-  // Build values for individual dimensions.
-  SmallVector<Value, 8> dimValues;
-  RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
-  int64_t rank = rankedTensorTy.getRank();
+  assert(!op.lhs().getType().isa<ShapeType>() &&
+         !op.rhs().getType().isa<ShapeType>());
   auto loc = op.getLoc();
-  for (int64_t i = 0; i < rank; i++) {
-    if (rankedTensorTy.isDynamicDim(i)) {
-      Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
-      dimValues.push_back(dimVal);
-    } else {
-      int64_t dim = rankedTensorTy.getDimSize(i);
-      Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
-      dimValues.push_back(dimVal);
-    }
-  }
-
-  // Materialize extent tensor.
-  Value staticExtentTensor =
-      rewriter.create<TensorFromElementsOp>(loc, dimValues);
-  rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
-                                            op.getType());
+  BroadcastOp::Adaptor transformed(operands);
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+  // Find smaller and greater rank and extent tensor.
+  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+  Value lhsSmaller =
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+  Type indexTy = rewriter.getIndexType();
+  Type extentTensorTy = op.getType();
+  auto ifOp = rewriter.create<IfOp>(
+      loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+      lhsSmaller,
+      [&](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
+                                               rhsRank, transformed.rhs()});
+      },
+      [&](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
+                                               lhsRank, transformed.lhs()});
+      });
+  Value smallerRank = ifOp.getResult(0);
+  Value smallerOperand = ifOp.getResult(1);
+  Value greaterRank = ifOp.getResult(2);
+  Value greaterOperand = ifOp.getResult(3);
+
+  // Allocate stack memory for the broadcasted extent tensor.
+  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
+
+  // Copy extents from greater operand that are not challenged.
+  Value rankDiff =
+      rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+  rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
+                         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+                           Value extent = b.create<ExtractElementOp>(
+                               loc, greaterOperand, ValueRange{iv});
+                           b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+                           b.create<scf::YieldOp>(loc);
+                         });
+
+  // Determine remaining broadcasted extents.
+  rewriter.create<ForOp>(
+      loc, rankDiff, greaterRank, one, llvm::None,
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+        Value greaterOperandExtent =
+            b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+        Value greaterOperandExtentIsOne =
+            b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
+        auto ifOp = b.create<IfOp>(
+            loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
+            [&](OpBuilder &b, Location loc) {
+              Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+              Value smallerOperandExtent = b.create<ExtractElementOp>(
+                  loc, smallerOperand, ValueRange{ivShifted});
+              b.create<scf::YieldOp>(loc, smallerOperandExtent);
+            },
+            [&](OpBuilder &b, Location loc) {
+              b.create<scf::YieldOp>(loc, greaterOperandExtent);
+            });
+        Value extent = ifOp.getResult(0);
+        b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+        b.create<scf::YieldOp>(loc);
+      });
+
+  // Load broadcasted shape as an extent tensor.
+  rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
   return success();
 }
 
@@ -161,26 +190,23 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
 }
 
 namespace {
-class ToExtentTensorOpConversion
-    : public OpConversionPattern<ToExtentTensorOp> {
+class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
 public:
-  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
+  using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    ToExtentTensorOpAdaptor adaptor(operands);
-
-    if (!adaptor.input().getType().isa<RankedTensorType>())
-      return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
-
-    rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
-                                              op.getType());
-    return success();
-  }
+  matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
+LogicalResult ConstSizeOpConversion::matchAndRewrite(
+    ConstSizeOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
+  return success();
+}
+
 namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -239,6 +265,236 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
   return success();
 }
 
+namespace {
+/// Converts `shape.reduce` to `scf.for`.
+struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+                                   ConversionPatternRewriter &rewriter) const {
+  // For now, this lowering is only defined on `tensor<?xindex>` operands.
+  if (op.shape().getType().isa<ShapeType>())
+    return failure();
+
+  auto loc = op.getLoc();
+  shape::ReduceOp::Adaptor transformed(operands);
+
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+  Type indexTy = rewriter.getIndexType();
+  Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
+
+  auto loop = rewriter.create<scf::ForOp>(
+      loc, zero, rank, one, op.initVals(),
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+        Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
+
+        SmallVector<Value, 2> mappedValues{iv, extent};
+        mappedValues.append(args.begin(), args.end());
+
+        BlockAndValueMapping mapping;
+        Block *reduceBody = op.getBody();
+        mapping.map(reduceBody->getArguments(), mappedValues);
+        for (auto &nested : reduceBody->without_terminator())
+          b.clone(nested, mapping);
+
+        SmallVector<Value, 2> mappedResults;
+        for (auto result : reduceBody->getTerminator()->getOperands())
+          mappedResults.push_back(mapping.lookup(result));
+        b.create<scf::YieldOp>(loc, mappedResults);
+      });
+
+  rewriter.replaceOp(op, loop.getResults());
+  return success();
+}
+
+namespace {
+/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
+/// only defined on `tensor<?xindex>` operands. The test for equality first
+/// compares their size and, if equal, checks every extent for equality.
+///
+/// Example:
+///
+/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+///
+/// becomes
+///
+/// %c0 = constant 0 : index
+/// %0 = dim %arg0, %c0 : tensor<?xindex>
+/// %1 = dim %arg1, %c0 : tensor<?xindex>
+/// %2 = cmpi "eq", %0, %1 : index
+/// %result = scf.if %2 -> (i1) {
+///   %c1 = constant 1 : index
+///   %true = constant true
+///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
+///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
+///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
+///     %7 = cmpi "eq", %5, %6 : index
+///     %8 = and %arg3, %7 : i1
+///     scf.yield %8 : i1
+///   }
+///   scf.yield %4 : i1
+/// } else {
+///   %false = constant false
+///   scf.yield %false : i1
+/// }
+///
+struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
+  using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+  // on shapes.
+  if (op.lhs().getType().isa<ShapeType>() ||
+      op.rhs().getType().isa<ShapeType>()) {
+    return failure();
+  }
+
+  ShapeEqOp::Adaptor transformed(operands);
+  auto loc = op.getLoc();
+  Type indexTy = rewriter.getIndexType();
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
+  Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
+  Value eqRank =
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
+  Type i1Ty = rewriter.getI1Type();
+  rewriter.replaceOpWithNewOp<IfOp>(
+      op, i1Ty, eqRank,
+      [&](OpBuilder &b, Location loc) {
+        Value one = b.create<ConstantIndexOp>(loc, 1);
+        Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
+        auto loop = b.create<scf::ForOp>(
+            loc, zero, lhsRank, one, ValueRange{init},
+            [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+              Value conj = args[0];
+              Value lhsExtent =
+                  b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
+              Value rhsExtent =
+                  b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
+              Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+                                                lhsExtent, rhsExtent);
+              Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
+              b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+            });
+        b.create<scf::YieldOp>(loc, loop.getResults());
+      },
+      [&](OpBuilder &b, Location loc) {
+        Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
+        b.create<scf::YieldOp>(loc, result);
+      });
+  return success();
+}
+
+namespace {
+class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
+public:
+  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ShapeOfOpConversion::matchAndRewrite(
+    ShapeOfOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+
+  // For now, only error-free types are supported by this lowering.
+  if (op.getType().isa<ShapeType>())
+    return failure();
+
+  // For ranked tensor arguments, lower to `tensor_from_elements`.
+  ShapeOfOp::Adaptor transformed(operands);
+  Value tensor = transformed.arg();
+  Type tensorTy = tensor.getType();
+  if (tensorTy.isa<RankedTensorType>()) {
+
+    // Build values for individual extents.
+    SmallVector<Value, 8> extentValues;
+    RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
+    int64_t rank = rankedTensorTy.getRank();
+    auto loc = op.getLoc();
+    for (int64_t i = 0; i < rank; i++) {
+      if (rankedTensorTy.isDynamicDim(i)) {
+        Value extent = rewriter.create<DimOp>(loc, tensor, i);
+        extentValues.push_back(extent);
+      } else {
+        Value extent =
+            rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
+        extentValues.push_back(extent);
+      }
+    }
+
+    // Materialize extent tensor.
+    Value staticExtentTensor =
+        rewriter.create<TensorFromElementsOp>(loc, extentValues);
+    rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+                                              op.getType());
+    return success();
+  }
+
+  // Allocate stack memory.
+  auto loc = op.getLoc();
+  Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
+  Type indexTy = rewriter.getIndexType();
+  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
+
+  // Copy shape extents to stack-allocated memory.
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+  rewriter.create<scf::ForOp>(
+      loc, zero, rank, one, llvm::None,
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+        Value dim = rewriter.create<DimOp>(loc, tensor, iv);
+        rewriter.create<StoreOp>(loc, dim, mem, ValueRange{iv});
+        rewriter.create<scf::YieldOp>(loc);
+      });
+
+  // Load extents to tensor value.
+  rewriter.replaceOpWithNewOp<TensorLoadOp>(op.getOperation(), mem);
+  return success();
+}
+
+namespace {
+class ToExtentTensorOpConversion
+    : public OpConversionPattern<ToExtentTensorOp> {
+public:
+  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    ToExtentTensorOpAdaptor adaptor(operands);
+
+    if (!adaptor.input().getType().isa<RankedTensorType>())
+      return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
+
+    rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
+                                              op.getType());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 /// Conversion pass.
 class ConvertShapeToStandardPass
@@ -252,7 +508,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
   // Setup target legality.
   MLIRContext &ctx = getContext();
   ConversionTarget target(ctx);
-  target.addLegalDialect<StandardOpsDialect>();
+  target.addLegalDialect<StandardOpsDialect, SCFDialect>();
   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
 
   // Setup conversion patterns.
@@ -271,11 +527,14 @@ void mlir::populateShapeToStandardConversionPatterns(
   patterns.insert<
       AnyOpConversion,
       BinaryOpConversion<AddOp, AddIOp>,
-      ConstShapeOpConverter,
       BinaryOpConversion<MulOp, MulIOp>,
+      BroadcastOpConverter,
+      ConstShapeOpConverter,
       ConstSizeOpConversion,
       GetExtentOpConverter,
       RankOpConverter,
+      ReduceOpConverter,
+      ShapeEqOpConverter,
       ShapeOfOpConversion,
       ToExtentTensorOpConversion>(ctx);
   // clang-format on

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
deleted file mode 100644
index cc384496dff0..000000000000
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ /dev/null
@@ -1,132 +0,0 @@
-// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: @shape_reduce
-// CHECK-SAME:  (%[[SHAPE:.*]]: tensor<?xindex>) -> index
-func @shape_reduce(%shape : tensor<?xindex>) -> index {
-  %init = constant 1 : index
-  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
-    ^bb0(%index : index, %extent : index, %acc: index):
-      %new_acc = muli %acc, %extent : index
-      shape.yield %new_acc : index
-  }
-  return %num_elements : index
-}
-// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
-// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
-// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
-// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
-// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
-// CHECK-NEXT:   %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
-// CHECK-NEXT:   %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
-// CHECK-NEXT:   scf.yield %[[NEW_ACC]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RESULT]] : index
-
-// -----
-
-// Don't lower `shape_of` for result type of `shape.shape`.
-// CHECK-LABEL: @shape_of
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
-func @shape_of(%arg : tensor<*xf32>) {
-  // CHECK: shape.shape
-  %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
-  return
-}
-
-// -----
-
-// Lower `shape_of` for unranked tensors.
-// CHECK-LABEL: @shape_of_unranked
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
-func @shape_of_unranked(%arg : tensor<*xf32>) {
-  // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
-  // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xindex>
-  // CHECK: %[[C0:.*]] = constant 0 : index
-  // CHECK: %[[C1:.*]] = constant 1 : index
-  // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
-  // CHECK:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
-  // CHECK:   store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref<?xindex>
-  // CHECK: }
-  // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xindex>
-  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
-  return
-}
-
-// -----
-
-// CHECK-LABEL:  @shape_eq
-// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
-func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
-  // CHECK: %[[C0:.*]] = constant 0 : index
-  // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
-  // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
-  // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]]
-  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
-  // CHECK:   %[[C1:.*]] = constant 1 : index
-  // CHECK:   %[[INIT:.*]] = constant true
-  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
-  // CHECK:     %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
-  // CHECK:     %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
-  // CHECK:     %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
-  // CHECK:     %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
-  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
-  // CHECK:   }
-  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
-  // CHECK: } else {
-  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = constant false
-  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
-  // CHECK: }
-  // CHECK: return %[[SHAPE_EQ]] : i1
-  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
-  return %result : i1
-}
-
-// -----
-
-// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
-// CHECK-LABEL: @broadcast
-func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
-  // CHECK: shape.broadcast
-  %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
-  return %c : !shape.shape
-}
-
-// -----
-
-// CHECK-LABEL: @broadcast
-// CHECK-SAME:  (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
-func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
-  // CHECK: %[[C0:.*]] = constant 0 : index
-  // CHECK: %[[C1:.*]] = constant 1 : index
-  // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-  // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-  // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
-  // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
-  // CHECK:   scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-  // CHECK: } else {
-  // CHECK:   scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-  // CHECK: }
-  // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
-  // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
-  // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
-  // CHECK:   %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
-  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
-  // CHECK: }
-  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
-  // CHECK:   %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
-  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
-  // CHECK:   %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
-  // CHECK:     %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
-  // CHECK:     %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
-  // CHECK:     scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
-  // CHECK:   } else {
-  // CHECK:     scf.yield %[[GREATER_OPERAND_EXTENT]] : index
-  // CHECK:   }
-  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
-  // CHECK: }
-  // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
-  %0 = shape.broadcast %a, %b
-      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
-  return
-}
-

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index b0fb5bac9071..bf8e74e5143e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -26,46 +26,6 @@ func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
 
 // -----
 
-// Don't lower `shape_of` with `shape.shape` type.
-// CHECK-LABEL: @shape_of
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
-func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
-  // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
-  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
-  return
-}
-
-// -----
-
-// Lower `shape_of` for statically shaped tensor.
-// CHECK-LABEL: @shape_of_stat
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
-func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
-  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
-  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
-  // CHECK-DAG: %[[C3:.*]] = constant 3 : index
-  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
-  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
-  return
-}
-
-// -----
-
-// Lower `shape_of` for dynamically shaped tensor.
-// CHECK-LABEL: @shape_of_dyn
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
-func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
-  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
-  // CHECK-DAG: %[[C5:.*]] = constant 5 : index
-  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
-  // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
-  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
-  %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
-  return
-}
-
-// -----
-
 // Convert `rank` to `dim` of the first dimension.
 // CHECK-LABEL: @rank
 // CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
@@ -190,3 +150,174 @@ func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
   // CHECK: return %[[RES]]
   return %casted : tensor<3xindex>
 }
+
+// CHECK-LABEL: @shape_reduce
+// CHECK-SAME:  (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @shape_reduce(%shape : tensor<?xindex>) -> index {
+  %init = constant 1 : index
+  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+    ^bb0(%index : index, %extent : index, %acc: index):
+      %new_acc = muli %acc, %extent : index
+      shape.yield %new_acc : index
+  }
+  return %num_elements : index
+}
+// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
+// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
+// CHECK-NEXT:   %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
+// CHECK-NEXT:   %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
+// CHECK-NEXT:   scf.yield %[[NEW_ACC]] : index
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT]] : index
+
+// -----
+
+// Don't lower `shape_of` for result type of `shape.shape`.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of(%arg : tensor<*xf32>) {
+  // CHECK: shape.shape
+  %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
+  return
+}
+
+// -----
+
+// Lower `shape_of` for unranked tensors.
+// CHECK-LABEL: @shape_of_unranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of_unranked(%arg : tensor<*xf32>) {
+  // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+  // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xindex>
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+  // CHECK:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+  // CHECK:   store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xindex>
+  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+  return
+}
+
+// -----
+
+// Don't lower `shape_of` with `shape.shape` type.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+  // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
+  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
+  return
+}
+
+// -----
+
+// Lower `shape_of` for statically shaped tensor.
+// CHECK-LABEL: @shape_of_stat
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+  // CHECK-DAG: %[[C3:.*]] = constant 3 : index
+  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
+  return
+}
+
+// -----
+
+// Lower `shape_of` for dynamically shaped tensor.
+// CHECK-LABEL: @shape_of_dyn
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
+func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK-DAG: %[[C5:.*]] = constant 5 : index
+  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+  // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
+  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+  %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:  @shape_eq
+// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
+func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]]
+  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+  // CHECK:   %[[C1:.*]] = constant 1 : index
+  // CHECK:   %[[INIT:.*]] = constant true
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+  // CHECK:     %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
+  // CHECK:     %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
+  // CHECK:     %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
+  // CHECK:   }
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: } else {
+  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = constant false
+  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
+  // CHECK: }
+  // CHECK: return %[[SHAPE_EQ]] : i1
+  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+  return %result : i1
+}
+
+// -----
+
+// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
+// CHECK-LABEL: @broadcast
+func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
+  // CHECK: shape.broadcast
+  %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
+  return %c : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast
+// CHECK-SAME:  (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
+  // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+  // CHECK:   scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+  // CHECK: }
+  // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
+  // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
+  // CHECK:   %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
+  // CHECK:   %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+  // CHECK:   %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
+  // CHECK:     %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+  // CHECK:     %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
+  // CHECK:     scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+  // CHECK:   } else {
+  // CHECK:     scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+  // CHECK:   }
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
+  %0 = shape.broadcast %a, %b
+      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  return
+}
+


        


More information about the Mlir-commits mailing list