[Mlir-commits] [mlir] [mlir][index][spirv] Add conversion for index to spirv (PR #68085)

Finn Plummer llvmlistbot at llvm.org
Tue Oct 3 03:25:27 PDT 2023


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/68085

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713

>From b917dd0ed98eccbaf99d7fda0cb763451f80aecf Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sun, 27 Aug 2023 16:42:25 +0200
Subject: [PATCH] [mlir][index][spirv] Add conversion for index to spirv

Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a
conversion pass from the Index dialect to the SPIR-V dialect.
Furthermore, we add the new conversion patterns to the scf-to-spirv
conversion.

Fixes #63713
---
 .../Conversion/IndexToSPIRV/IndexToSPIRV.h    |  28 ++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  22 +
 .../SPIRV/Transforms/SPIRVConversion.h        |  18 +-
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../Conversion/IndexToSPIRV/CMakeLists.txt    |  15 +
 .../Conversion/IndexToSPIRV/IndexToSPIRV.cpp  | 418 ++++++++++++++++++
 .../Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp  |   2 +
 .../SPIRV/Transforms/SPIRVConversion.cpp      |   2 +
 .../IndexToSPRIV/index-to-spirv.mlir          | 218 +++++++++
 .../Conversion/SCFToSPIRV/use-indices.mlir    |  28 ++
 11 files changed, 750 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
 create mode 100644 mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
 create mode 100644 mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
 create mode 100644 mlir/test/Conversion/SCFToSPIRV/use-indices.mlir

diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
new file mode 100644
index 000000000000000..d1a3c87249508b7
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -0,0 +1,28 @@
+//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class SPIRVTypeConverter;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                  RewritePatternSet &patterns);
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index fc5e9adba114405..9660d89ec23e3be 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 11008baa0160efe..1e45abb66880c12 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -629,6 +629,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToSPIRVPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
+  let summary = "Lower the `index` dialect to the `spirv` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to SPIR-V dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
+    specified by use-64bit-index.
+  }];
+
+  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
+
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToStandard
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 89ded981d38f9f4..4a4e58464a80df7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,13 +55,20 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Use 64-bit integers to convert index types.
-  bool use64bitIndex{false};
-
   /// Whether to enable fast math mode during conversion. If true, various
   /// patterns would assume no NaN/infinity numbers as inputs, and thus there
   /// will be no special guards emitted to check and handle such cases.
   bool enableFastMathMode{false};
+
+  /// Use 64-bit integers when converting index types.
+  bool use64bitIndex{false};
+
+  /// Whether we should treat an integer type as a scalar value within the
+  /// SPIR-V type converter. Used when we need to check if the integer type is a
+  /// supported bitwidth, as described above in emulateLT32BitScalarTypes.
+  /// Turned off when we are converting from index to SPIR-V as it will be an
+  /// i32 or i64.
+  bool convertIntAsScalar{true};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +84,11 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
+  /// Gets the bitwidth of the index type when converted to SPIR-V.
+  unsigned getIndexTypeBitwidth() const {
+    return options.use64bitIndex ? 64 : 32;
+  }
+
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 275e095245e89ce..8dad4c5fa25916a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -23,6 +23,7 @@ add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
+add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000000..1da0e0253501fec
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRIndexToSPIRV
+  IndexToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  )
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
new file mode 100644
index 000000000000000..4290a81ae813824
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -0,0 +1,418 @@
+//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
+using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
+using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
+using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
+using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
+using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
+using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
+using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
+using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
+using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
+using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
+
+using ConvertIndexShl =
+    spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
+using ConvertIndexShrS =
+    spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
+using ConvertIndexShrU =
+    spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
+
+/// It is the case that when we convert bitwise operations to SPIR-V operations
+/// we must take into account of the special pattern in SPIR-V that if the
+/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
+/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
+/// index.add is never a boolean operation so we can directly convert it to the
+/// Bitwise[And|Or]Op
+using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
+using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
+using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertConstantBool
+//===----------------------------------------------------------------------===//
+
+// Converts index.bool.constant operation to spirv.Constant.
+struct ConvertIndexConstantBoolOpPattern final
+    : OpConversionPattern<BoolConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
+                                                   op->getAttr("value"));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
+// when required.
+struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
+struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
+    Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
+struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
+                                                   IntegerAttr::get(n_type, 1));
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
+    Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+
+    // Pick the result
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
+struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result
+    Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have different signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
+    Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ConvertOp>
+struct ConvertIndexCast : public OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    Type srcType = adaptor.getInput().getType();
+    Type dstType = op.getType();
+    if (isa<IndexType>(srcType)) {
+      srcType = indexType;
+    }
+    if (isa<IndexType>(dstType)) {
+      dstType = indexType;
+    }
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
+                                                      adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+// Helper template to replace the operation
+template <typename ICmpOp>
+static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
+  return success();
+}
+
+struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We must convert the predicates to the corresponding int comparions.
+    switch (op.getPred()) {
+    case IndexCmpPredicate::EQ:
+      return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::NE:
+      return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGE:
+      return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGT:
+      return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLE:
+      return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLT:
+      return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGE:
+      return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGT:
+      return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULE:
+      return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULT:
+      return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
+struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+    unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, bitwidth));
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                         RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+    ConvertIndexAdd,
+    ConvertIndexSub,
+    ConvertIndexMul,
+    ConvertIndexDivS,
+    ConvertIndexDivU,
+    ConvertIndexRemS,
+    ConvertIndexRemU,
+    ConvertIndexMaxS,
+    ConvertIndexMaxU,
+    ConvertIndexMinS,
+    ConvertIndexMinU,
+    ConvertIndexShl,
+    ConvertIndexShrS,
+    ConvertIndexShrU,
+    ConvertIndexAnd,
+    ConvertIndexOr,
+    ConvertIndexXor,
+    ConvertIndexConstantBoolOpPattern,
+    ConvertIndexConstantOpPattern,
+    ConvertIndexCeilDivSPattern,
+    ConvertIndexCeilDivUPattern,
+    ConvertIndexFloorDivSPattern,
+    ConvertIndexCastS,
+    ConvertIndexCastU,
+    ConvertIndexCmpPattern,
+    ConvertIndexSizeOf
+  >(typeConverter, patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertIndexToSPIRVPass
+    : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
+  using Base::Base;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertIndexToSPIRVPass::runOnOperation() {
+  Operation *op = getOperation();
+  spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+  std::unique_ptr<SPIRVConversionTarget> target =
+    SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVConversionOptions options;
+  options.use64bitIndex = this->use64bitIndex;
+  options.convertIntAsScalar = false;
+  SPIRVTypeConverter typeConverter(targetAttr, options);
+
+  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+  // in patterns for other dialects.
+  target->addLegalOp<UnrealizedConversionCastOp>();
+
+  // Allow the spirv operations we are converting to
+  target->addLegalDialect<spirv::SPIRVDialect>();
+  // Fail hard when there are any remaining 'index' ops.
+  target->addIllegalDialect<index::IndexDialect>();
+
+  RewritePatternSet patterns(&getContext());
+  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 1e8fe4423a422c5..08ffe2980345c43 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -52,6 +53,7 @@ void SCFToSPIRVPass::runOnOperation() {
   populateFuncToSPIRVPatterns(typeConverter, patterns);
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  mlir::index::populateIndexToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..122a0a3612fb2c1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -693,6 +693,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
 
   addConversion([this](IntegerType intType) -> std::optional<Type> {
+    if (!this->options.convertIntAsScalar)
+      return intType;
     if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
       return convertScalarType(this->targetEnv, this->options, scalarType);
     if (intType.getWidth() < 8)
diff --git a/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
new file mode 100644
index 000000000000000..5e6755100b03e41
--- /dev/null
+++ b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
@@ -0,0 +1,218 @@
+// RUN: mlir-opt %s -convert-index-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=false | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=true | FileCheck %s --check-prefix=INDEX64
+
+// CHECK-LABEL: @trivial_ops
+func.func @trivial_ops(%a: index, %b: index) {
+  // CHECK: spirv.IAdd
+  %0 = index.add %a, %b
+  // CHECK: spirv.ISub
+  %1 = index.sub %a, %b
+  // CHECK: spirv.IMul
+  %2 = index.mul %a, %b
+  // CHECK: spirv.SDiv
+  %3 = index.divs %a, %b
+  // CHECK: spirv.UDiv
+  %4 = index.divu %a, %b
+  // CHECK: spirv.SRem
+  %5 = index.rems %a, %b
+  // CHECK: spirv.UMod
+  %6 = index.remu %a, %b
+  // CHECK: spirv.GL.SMax
+  %7 = index.maxs %a, %b
+  // CHECK: spirv.GL.UMax
+  %8 = index.maxu %a, %b
+  // CHECK: spirv.GL.SMin
+  %9 = index.mins %a, %b
+  // CHECK: spirv.GL.UMin
+  %10 = index.minu %a, %b
+  // CHECK: spirv.ShiftLeftLogical
+  %11 = index.shl %a, %b
+  // CHECK: spirv.ShiftRightArithmetic
+  %12 = index.shrs %a, %b
+  // CHECK: spirv.ShiftRightLogical
+  %13 = index.shru %a, %b
+  return
+}
+
+// CHECK-LABEL: @bitwise_ops
+func.func @bitwise_ops(%a: index, %b: index) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = index.and %a, %b
+  // CHECK: spirv.BitwiseOr
+  %1 = index.or %a, %b
+  // CHECK: spirv.BitwiseXor
+  %2 = index.xor %a, %b
+  return
+}
+
+// INDEX32-LABEL: @constant_ops
+// INDEX64-LABEL: @constant_ops
+func.func @constant_ops() {
+  // INDEX32: spirv.Constant 42 : i32
+  // INDEX64: spirv.Constant 42 : i64
+  %0 = index.constant 42
+  // INDEX32: spirv.Constant true
+  // INDEX64: spirv.Constant true
+  %1 = index.bool.constant true
+  // INDEX32: spirv.Constant false
+  // INDEX64: spirv.Constant false
+  %2 = index.bool.constant false
+  return
+}
+
+// CHECK-LABEL: @ceildivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_POS:.*]] = spirv.SGreaterThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]]
+
+  // CHECK: %[[N_PLUS_X:.*]] = spirv.IAdd %[[N]], %[[X]]
+  // CHECK: %[[N_PLUS_X_DIV_M:.*]] = spirv.SDiv %[[N_PLUS_X]], %[[M]]
+  // CHECK: %[[POS_RES:.*]] = spirv.IAdd %[[N_PLUS_X_DIV_M]], %[[POS_ONE]]
+
+  // CHECK: %[[NEG_N:.*]] = spirv.ISub %[[ZERO]], %[[N]]
+  // CHECK: %[[NEG_N_DIV_M:.*]] = spirv.SDiv %[[NEG_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[ZERO]], %[[NEG_N_DIV_M]]
+
+  // CHECK: %[[N_POS:.*]] = spirv.SGreaterThan %[[N]], %[[ZERO]]
+  // CHECK: %[[SAME_SIGN:.*]] = spirv.LogicalEqual %[[N_POS]], %[[M_POS]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[SAME_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.ceildivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @ceildivu
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivu(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[ONE:.*]] = spirv.Constant 1
+
+  // CHECK: %[[N_MINUS_ONE:.*]] = spirv.ISub %[[N]], %[[ONE]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M:.*]] = spirv.UDiv %[[N_MINUS_ONE]], %[[M]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M_PLUS_ONE:.*]] = spirv.IAdd %[[N_MINUS_ONE_DIV_M]], %[[ONE]]
+
+  // CHECK: %[[CMP:.*]] = spirv.IEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[ZERO]], %[[N_MINUS_ONE_DIV_M_PLUS_ONE]]
+  %result = index.ceildivu %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @floordivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @floordivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_NEG:.*]] = spirv.SLessThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]
+
+  // CHECK: %[[X_MINUS_N:.*]] = spirv.ISub %[[X]], %[[N]]
+  // CHECK: %[[X_MINUS_N_DIV_M:.*]] = spirv.SDiv %[[X_MINUS_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]]
+
+  // CHECK: %[[POS_RES:.*]] = spirv.SDiv %[[N]], %[[M]]
+
+  // CHECK: %[[N_NEG:.*]] = spirv.SLessThan %[[N]], %[[ZERO]]
+  // CHECK: %[[DIFF_SIGN:.*]] = spirv.LogicalNotEqual %[[N_NEG]], %[[M_NEG]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[DIFF_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.floordivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @index_cmp
+func.func @index_cmp(%a : index, %b : index) {
+  // CHECK: spirv.IEqual
+  %0 = index.cmp eq(%a, %b)
+  // CHECK: spirv.INotEqual
+  %1 = index.cmp ne(%a, %b)
+
+  // CHECK: spirv.SLessThan
+  %2 = index.cmp slt(%a, %b)
+  // CHECK: spirv.SLessThanEqual
+  %3 = index.cmp sle(%a, %b)
+  // CHECK: spirv.SGreaterThan
+  %4 = index.cmp sgt(%a, %b)
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = index.cmp sge(%a, %b)
+
+  // CHECK: spirv.ULessThan
+  %6 = index.cmp ult(%a, %b)
+  // CHECK: spirv.ULessThanEqual
+  %7 = index.cmp ule(%a, %b)
+  // CHECK: spirv.UGreaterThan
+  %8 = index.cmp ugt(%a, %b)
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = index.cmp uge(%a, %b)
+  return
+}
+
+// CHECK-LABEL: @index_sizeof
+func.func @index_sizeof() {
+  // CHECK: spirv.Constant 32 : i32
+  %0 = index.sizeof
+  return
+}
+
+// INDEX32-LABEL: @index_cast_from
+// INDEX64-LABEL: @index_cast_from
+// INDEX32-SAME: %[[AI:.*]]: index
+// INDEX64-SAME: %[[AI:.*]]: index
+func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) {
+  // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32
+  // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64
+
+  // INDEX32: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : index to i64
+  // INDEX64: %[[V1:.*]] = spirv.SConvert %[[A]] : i64 to i32
+  %1 = index.casts %a : index to i32
+  // INDEX32: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : index to i64
+  // INDEX64: %[[V3:.*]] = spirv.UConvert %[[A]] : i64 to i32
+  %3 = index.castu %a : index to i32
+
+  // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]]
+  // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]]
+  return %0, %1, %2, %3 : i64, i32, i64, i32
+}
+
+// INDEX32-LABEL: @index_cast_to
+// INDEX64-LABEL: @index_cast_to
+// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) {
+  // INDEX64: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : i32 to index
+  // INDEX32: %[[V1:.*]] = spirv.SConvert %[[B]] : i64 to i32
+  %1 = index.casts %b : i64 to index
+  // INDEX64: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : i32 to index
+  // INDEX32: %[[V3:.*]] = spirv.UConvert %[[B]] : i64 to i32
+  %3 = index.castu %b : i64 to index
+  return %0, %1, %2, %3 : index, index, index, index
+}
diff --git a/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
new file mode 100644
index 000000000000000..68a825fbd93ebde
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: @forward
+func.func @forward() {
+  // CHECK: %[[LB:.*]] = spirv.Constant 0 : i32
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[UB:.*]] = spirv.Constant 32 : i32
+  %c32 = arith.constant 32 : index
+  // CHECK: %[[STEP:.*]] = spirv.Constant 1 : i32
+  %c1 = arith.constant 1 : index
+
+  // CHECK:      spirv.mlir.loop {
+  // CHECK-NEXT:   spirv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
+  // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
+  // CHECK:        %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
+  // CHECK:        spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+  // CHECK:      ^[[BODY]]:
+  // CHECK:        %[[X:.*]] = spirv.IAdd %[[INDVAR]], %[[INDVAR]] : i32
+  // CHECK:        %[[INDNEXT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
+  // CHECK:        spirv.Branch ^[[HEADER]](%[[INDNEXT]] : i32)
+  // CHECK:      ^[[MERGE]]:
+  // CHECK:        spirv.mlir.merge
+  // CHECK:      }
+  scf.for %arg2 = %c0 to %c32 step %c1 {
+      %1 = index.add %arg2, %arg2
+  }
+  return
+}



More information about the Mlir-commits mailing list