[Mlir-commits] [mlir] cae746d - [mlir][index] Add `convert-index-to-llvm` pass

Jeff Niu llvmlistbot at llvm.org
Fri Oct 21 09:46:29 PDT 2022


Author: Jeff Niu
Date: 2022-10-21T09:46:19-07:00
New Revision: cae746d9c4abe1662bf3379e59cd48f62dcbe3b7

URL: https://github.com/llvm/llvm-project/commit/cae746d9c4abe1662bf3379e59cd48f62dcbe3b7
DIFF: https://github.com/llvm/llvm-project/commit/cae746d9c4abe1662bf3379e59cd48f62dcbe3b7.diff

LOG: [mlir][index] Add `convert-index-to-llvm` pass

This patch adds a lowering pass to convert `index` dialect ops to LLVM.

Depends on D135694

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135697

Added: 
    mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h
    mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt
    mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
    mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h
new file mode 100644
index 0000000000000..9341fce7c994e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h
@@ -0,0 +1,28 @@
+//===- IndexToLLVM.h - Index to LLVM dialect conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H
+#define MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns);
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 0eed45997d45b..eeeb56ff3eb50 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -29,6 +29,7 @@
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8dc48cd0489a0..66ac9eedf1bfb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -439,6 +439,29 @@ def ConvertVulkanLaunchFuncToVulkanCalls
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToLLVMPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
+  let summary = "Lower the `index` dialect to the `llvm` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to LLVM dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`, which expand to series of LLVM operations.
+    Importantly, the index bitwidth should be correctly set to the target
+    pointer width via `index-bitwidth`.
+  }];
+
+  let dependentDialects = ["::mlir::LLVM::LLVMDialect"];
+
+  let options = [
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           "Bitwidth of the index type, 0 to use size of machine word">,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 33f917dbde97a..a65814d36b5b4 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
+add_subdirectory(IndexToLLVM)
 add_subdirectory(LinalgToLLVM)
 add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)

diff  --git a/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..a33a61852d748
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRIndexToLLVM
+  IndexToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  )

diff  --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
new file mode 100644
index 0000000000000..844c57a74a198
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
@@ -0,0 +1,347 @@
+//===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexAttrs.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
+struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Value m = adaptor.getRhs();
+    Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
+    Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
+    Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+
+    // Compute `x`.
+    Value mPos =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
+    Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
+    Value sameSign =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
+    Value nNonZero =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
+    Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
+struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Value m = adaptor.getRhs();
+    Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
+    Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
+    Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
+
+    // Pick the result.
+    Value cmp =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
+    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
+struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Value m = adaptor.getRhs();
+    Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
+    Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
+    Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+
+    // Compute `x`.
+    Value mNeg =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
+    Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result.
+    Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have 
diff erent signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
+    Value 
diff Sign =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
+    Value nNonZero =
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
+    Value cmp = rewriter.create<LLVM::AndOp>(loc, 
diff Sign, nNonZero);
+    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CovnertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ExtOp>
+struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
+  using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type in = adaptor.getInput().getType();
+    Type out = this->getTypeConverter()->convertType(op.getType());
+    if (in == out)
+      rewriter.replaceOp(op, adaptor.getInput());
+    else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
+      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
+    else
+      rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+/// Assert that the LLVM comparison enum lines up with index's enum.
+static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
+                                      IndexCmpPredicate rhs) {
+  return static_cast<int>(lhs) == static_cast<int>(rhs);
+}
+
+static_assert(
+    LLVM::getMaxEnumValForICmpPredicate() ==
+            getMaxEnumValForIndexCmpPredicate() &&
+        checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
+        checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
+        checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
+        checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
+        checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
+        checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
+        checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
+        checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
+        checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
+        checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
+    "LLVM ICmpPredicate mismatches IndexCmpPredicate");
+
+struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // The LLVM enum has the same values as the index predicate enums.
+    rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
+        op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
+        adaptor.getLhs(), adaptor.getRhs());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
+struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
+        op, getTypeConverter()->getIndexType(),
+        getTypeConverter()->getIndexTypeBitwidth());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexConstant
+//===----------------------------------------------------------------------===//
+
+/// Convert an index constant. Truncate the value as appropriate.
+struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type type = getTypeConverter()->getIndexType();
+    APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
+    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
+        op, type, IntegerAttr::get(type, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>;
+using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>;
+using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>;
+using ConvertIndexDivS =
+    mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>;
+using ConvertIndexDivU =
+    mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>;
+using ConvertIndexRemS =
+    mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>;
+using ConvertIndexRemU =
+    mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>;
+using ConvertIndexMaxS =
+    mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
+using ConvertIndexMaxU =
+    mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
+using ConvertIndexBoolConstant =
+    mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void index::populateIndexToLLVMConversionPatterns(
+    LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+  patterns.insert<
+      // clang-format off
+      ConvertIndexAdd,
+      ConvertIndexSub,
+      ConvertIndexMul,
+      ConvertIndexDivS,
+      ConvertIndexDivU,
+      ConvertIndexRemS,
+      ConvertIndexRemU,
+      ConvertIndexMaxS,
+      ConvertIndexMaxU,
+      ConvertIndexCeilDivS,
+      ConvertIndexCeilDivU,
+      ConvertIndexFloorDivS,
+      ConvertIndexCastS,
+      ConvertIndexCastU,
+      ConvertIndexCmp,
+      ConvertIndexSizeOf,
+      ConvertIndexConstant,
+      ConvertIndexBoolConstant
+      // clang-format on
+      >(typeConverter);
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertIndexToLLVMPass
+    : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
+  using Base::Base;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertIndexToLLVMPass::runOnOperation() {
+  // Configure dialect conversion.
+  ConversionTarget target(getContext());
+  target.addIllegalDialect<IndexDialect>();
+  target.addLegalDialect<LLVM::LLVMDialect>();
+
+  // Set LLVM lowering options.
+  LowerToLLVMOptions options(&getContext());
+  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+    options.overrideIndexBitwidth(indexBitwidth);
+  LLVMTypeConverter typeConverter(&getContext(), options);
+
+  // Populate patterns and run the conversion.
+  RewritePatternSet patterns(&getContext());
+  populateIndexToLLVMConversionPatterns(typeConverter, patterns);
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    return signalPassFailure();
+}

diff  --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
new file mode 100644
index 0000000000000..ee8e6629aa719
--- /dev/null
+++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt %s -convert-index-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=32 | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=64 | FileCheck %s --check-prefix=INDEX64
+
+// CHECK-LABEL: @trivial_ops
+func.func @trivial_ops(%a: index, %b: index) {
+  // CHECK: llvm.add
+  %0 = index.add %a, %b
+  // CHECK: llvm.sub
+  %1 = index.sub %a, %b
+  // CHECK: llvm.mul
+  %2 = index.mul %a, %b
+  // CHECK: llvm.sdiv
+  %3 = index.divs %a, %b
+  // CHECK: llvm.udiv
+  %4 = index.divu %a, %b
+  // CHECK: llvm.srem
+  %5 = index.rems %a, %b
+  // CHECK: llvm.urem
+  %6 = index.remu %a, %b
+  // CHECK: llvm.intr.smax
+  %7 = index.maxs %a, %b
+  // CHECK: llvm.intr.umax
+  %8 = index.maxu %a, %b
+  // CHECK: llvm.mlir.constant(true
+  %9 = index.bool.constant true
+  return
+}
+
+// CHECK-LABEL: @ceildivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
+  // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
+  // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :
+
+  // CHECK: %[[M_POS:.*]] = llvm.icmp "sgt" %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = llvm.select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]]
+
+  // CHECK: %[[N_PLUS_X:.*]] = llvm.add %[[N]], %[[X]]
+  // CHECK: %[[N_PLUS_X_DIV_M:.*]] = llvm.sdiv %[[N_PLUS_X]], %[[M]]
+  // CHECK: %[[POS_RES:.*]] = llvm.add %[[N_PLUS_X_DIV_M]], %[[POS_ONE]]
+
+  // CHECK: %[[NEG_N:.*]] = llvm.sub %[[ZERO]], %[[N]]
+  // CHECK: %[[NEG_N_DIV_M:.*]] = llvm.sdiv %[[NEG_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = llvm.sub %[[ZERO]], %[[NEG_N_DIV_M]]
+
+  // CHECK: %[[N_POS:.*]] = llvm.icmp "sgt" %[[N]], %[[ZERO]]
+  // CHECK: %[[SAME_SIGN:.*]] = llvm.icmp "eq" %[[N_POS]], %[[M_POS]]
+  // CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]]
+  // CHECK: %[[CMP:.*]] = llvm.and %[[SAME_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.ceildivs %n, %m
+
+  // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // CHECK: return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @ceildivu
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivu(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 :
+
+  // CHECK: %[[MINUS_ONE:.*]] = llvm.sub %[[N]], %[[ONE]]
+  // CHECK: %[[QUOTIENT:.*]] = llvm.udiv %[[MINUS_ONE]], %[[M]]
+  // CHECK: %[[PLUS_ONE:.*]] = llvm.add %[[QUOTIENT]], %[[ONE]]
+
+  // CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[N]], %[[ZERO]]
+  // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[ZERO]], %[[PLUS_ONE]]
+  %result = index.ceildivu %n, %m
+
+  // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // CHECK: return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @floordivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @floordivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
+  // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
+  // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :
+
+  // CHECK: %[[M_NEG:.*]] = llvm.icmp "slt" %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = llvm.select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]
+
+  // CHECK: %[[X_MINUS_N:.*]] = llvm.sub %[[X]], %[[N]]
+  // CHECK: %[[X_MINUS_N_DIV_M:.*]] = llvm.sdiv %[[X_MINUS_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = llvm.sub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]]
+
+  // CHECK: %[[POS_RES:.*]] = llvm.sdiv %[[N]], %[[M]]
+
+  // CHECK: %[[N_NEG:.*]] = llvm.icmp "slt" %[[N]], %[[ZERO]]
+  // CHECK: %[[DIFF_SIGN:.*]] = llvm.icmp "ne" %[[N_NEG]], %[[M_NEG]]
+  // CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]]
+  // CHECK: %[[CMP:.*]] = llvm.and %[[DIFF_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[NEG_RES]], %[[POS_RES]]
+  %result = index.floordivs %n, %m
+
+  // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // CHECK: return %[[RESULTI]]
+  return %result : index
+}
+
+// INDEX32-LABEL: @index_cast_from
+// INDEX64-LABEL: @index_cast_from
+// INDEX32-SAME: %[[AI:.*]]: index
+// INDEX64-SAME: %[[AI:.*]]: index
+func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) {
+  // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32
+  // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64
+
+  // INDEX32: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64
+  %0 = index.casts %a : index to i64
+  // INDEX64: %[[V1:.*]] = llvm.trunc %[[A]] : i64 to i32
+  %1 = index.casts %a : index to i32
+  // INDEX32: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64
+  %2 = index.castu %a : index to i64
+  // INDEX64: %[[V3:.*]] = llvm.trunc %[[A]] : i64 to i32
+  %3 = index.castu %a : index to i32
+
+  // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]]
+  // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]]
+  return %0, %1, %2, %3 : i64, i32, i64, i32
+}
+
+// INDEX32-LABEL: @index_cast_to
+// INDEX64-LABEL: @index_cast_to
+// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) {
+  // INDEX64: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64
+  %0 = index.casts %a : i32 to index
+  // INDEX32: %[[V1:.*]] = llvm.trunc %[[B]] : i64 to i32
+  %1 = index.casts %b : i64 to index
+  // INDEX64: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64
+  %2 = index.castu %a : i32 to index
+  // INDEX32: %[[V3:.*]] = llvm.trunc %[[B]] : i64 to i32
+  %3 = index.castu %b : i64 to index
+  return %0, %1, %2, %3 : index, index, index, index
+}
+
+// INDEX32-LABEL: @index_sizeof
+// INDEX64-LABEL: @index_sizeof
+func.func @index_sizeof() {
+  // INDEX32-NEXT: llvm.mlir.constant(32 : i32)
+  // INDEX64-NEXT: llvm.mlir.constant(64 : i64)
+  %0 = index.sizeof
+  return
+}
+
+// INDEX32-LABEL: @index_constant
+// INDEX64-LABEL: @index_constant
+func.func @index_constant() {
+  // INDEX32: llvm.mlir.constant(-2100000000 : i32) : i32
+  // INDEX64: llvm.mlir.constant(-2100000000 : i64) : i64
+  %0 = index.constant -2100000000
+  // INDEX32: llvm.mlir.constant(2100000000 : i32) : i32
+  // INDEX64: llvm.mlir.constant(2100000000 : i64) : i64
+  %1 = index.constant 2100000000
+  // INDEX32: llvm.mlir.constant(1294967296 : i32) : i32
+  // INDEX64: llvm.mlir.constant(-3000000000 : i64) : i64
+  %2 = index.constant -3000000000
+  // INDEX32: llvm.mlir.constant(-1294967296 : i32) : i32
+  // INDEX64: llvm.mlir.constant(3000000000 : i64) : i64
+  %3 = index.constant 3000000000
+  return
+}


        


More information about the Mlir-commits mailing list