[Mlir-commits] [mlir] Inbelic/conv index to spirv (PR #69790)

Finn Plummer llvmlistbot at llvm.org
Fri Oct 20 15:30:50 PDT 2023


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/69790

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes https://github.com/llvm/llvm-project/issues/63713

>From 25925c1d29ac95cfd8829edf0474c779300f0cc2 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sun, 27 Aug 2023 16:42:25 +0200
Subject: [PATCH 1/2] [mlir][index][spirv] Add conversion for index to spirv

Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a
conversion pass from the Index dialect to the SPIR-V dialect.
Furthermore, we add the new conversion patterns to the scf-to-spirv
conversion.

Fixes #63713
---
 .../Conversion/IndexToSPIRV/IndexToSPIRV.h    |  30 ++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  22 +
 .../SPIRV/Transforms/SPIRVConversion.h        |  11 +-
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../Conversion/IndexToSPIRV/CMakeLists.txt    |  15 +
 .../Conversion/IndexToSPIRV/IndexToSPIRV.cpp  | 418 ++++++++++++++++++
 .../Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp  |   2 +
 .../IndexToSPRIV/index-to-spirv.mlir          | 222 ++++++++++
 .../Conversion/SCFToSPIRV/use-indices.mlir    |  28 ++
 10 files changed, 747 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
 create mode 100644 mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
 create mode 100644 mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
 create mode 100644 mlir/test/Conversion/SCFToSPIRV/use-indices.mlir

diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
new file mode 100644
index 000000000000000..58a1c5246eef999
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -0,0 +1,30 @@
+//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class SPIRVTypeConverter;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                  RewritePatternSet &patterns);
+std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..c13c457fd97492a 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -35,6 +35,7 @@
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..9979faed4251787 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToSPIRVPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
+  let summary = "Lower the `index` dialect to the `spirv` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to SPIR-V dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
+    specified by use-64bit-index.
+  }];
+
+  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
+
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToStandard
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 89ded981d38f9f4..933d62e35fce8cd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Use 64-bit integers to convert index types.
-  bool use64bitIndex{false};
-
   /// Whether to enable fast math mode during conversion. If true, various
   /// patterns would assume no NaN/infinity numbers as inputs, and thus there
   /// will be no special guards emitted to check and handle such cases.
   bool enableFastMathMode{false};
+
+  /// Use 64-bit integers when converting index types.
+  bool use64bitIndex{false};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
+  /// Gets the bitwidth of the index type when converted to SPIR-V.
+  unsigned getIndexTypeBitwidth() const {
+    return options.use64bitIndex ? 64 : 32;
+  }
+
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..7e1c7bcf9a8678a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
+add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000000..1da0e0253501fec
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRIndexToSPIRV
+  IndexToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  )
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
new file mode 100644
index 000000000000000..b58efc096e2eafb
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -0,0 +1,418 @@
+//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
+using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
+using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
+using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
+using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
+using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
+using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
+using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
+using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
+using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
+using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
+
+using ConvertIndexShl =
+    spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
+using ConvertIndexShrS =
+    spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
+using ConvertIndexShrU =
+    spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
+
+/// It is the case that when we convert bitwise operations to SPIR-V operations
+/// we must take into account the special pattern in SPIR-V that if the
+/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
+/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
+/// index.add is never a boolean operation so we can directly convert it to the
+/// Bitwise[And|Or]Op.
+using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
+using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
+using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertConstantBool
+//===----------------------------------------------------------------------===//
+
+// Converts index.bool.constant operation to spirv.Constant.
+struct ConvertIndexConstantBoolOpPattern final
+    : OpConversionPattern<BoolConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
+                                                   op.getValueAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
+// when required.
+struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
+/// conversion in IndexToLLVM.
+struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
+    Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
+/// from the equivalent conversion in IndexToLLVM.
+struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
+                                                   IntegerAttr::get(n_type, 1));
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
+    Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+
+    // Pick the result
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
+/// in IndexToLLVM.
+struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result
+    Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have different signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
+    Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ConvertOp>
+struct ConvertIndexCast final : OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    Type srcType = adaptor.getInput().getType();
+    Type dstType = op.getType();
+    if (isa<IndexType>(srcType)) {
+      srcType = indexType;
+    }
+    if (isa<IndexType>(dstType)) {
+      dstType = indexType;
+    }
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
+                                                      adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+// Helper template to replace the operation
+template <typename ICmpOp>
+static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
+  return success();
+}
+
+struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We must convert the predicates to the corresponding int comparions.
+    switch (op.getPred()) {
+    case IndexCmpPredicate::EQ:
+      return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::NE:
+      return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGE:
+      return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGT:
+      return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLE:
+      return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLT:
+      return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGE:
+      return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGT:
+      return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULE:
+      return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULT:
+      return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
+struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+    unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, bitwidth));
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                         RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+    ConvertIndexAdd,
+    ConvertIndexSub,
+    ConvertIndexMul,
+    ConvertIndexDivS,
+    ConvertIndexDivU,
+    ConvertIndexRemS,
+    ConvertIndexRemU,
+    ConvertIndexMaxS,
+    ConvertIndexMaxU,
+    ConvertIndexMinS,
+    ConvertIndexMinU,
+    ConvertIndexShl,
+    ConvertIndexShrS,
+    ConvertIndexShrU,
+    ConvertIndexAnd,
+    ConvertIndexOr,
+    ConvertIndexXor,
+    ConvertIndexConstantBoolOpPattern,
+    ConvertIndexConstantOpPattern,
+    ConvertIndexCeilDivSPattern,
+    ConvertIndexCeilDivUPattern,
+    ConvertIndexFloorDivSPattern,
+    ConvertIndexCastS,
+    ConvertIndexCastU,
+    ConvertIndexCmpPattern,
+    ConvertIndexSizeOf
+  >(typeConverter, patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertIndexToSPIRVPass
+    : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<SPIRVConversionTarget> target =
+      SPIRVConversionTarget::get(targetAttr);
+
+    SPIRVConversionOptions options;
+    options.use64bitIndex = this->use64bitIndex;
+    SPIRVTypeConverter typeConverter(targetAttr, options);
+
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target->addLegalOp<UnrealizedConversionCastOp>();
+
+    // Allow the spirv operations we are converting to
+    target->addLegalDialect<spirv::SPIRVDialect>();
+    // Fail hard when there are any remaining 'index' ops.
+    target->addIllegalDialect<index::IndexDialect>();
+
+    RewritePatternSet patterns(&getContext());
+    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 1e8fe4423a422c5..3ef1d84ee264771 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -52,6 +53,7 @@ void SCFToSPIRVPass::runOnOperation() {
   populateFuncToSPIRVPatterns(typeConverter, patterns);
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
diff --git a/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
new file mode 100644
index 000000000000000..53dc896e98c7d67
--- /dev/null
+++ b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt %s -convert-index-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=false | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=true | FileCheck %s --check-prefix=INDEX64
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: @trivial_ops
+func.func @trivial_ops(%a: index, %b: index) {
+  // CHECK: spirv.IAdd
+  %0 = index.add %a, %b
+  // CHECK: spirv.ISub
+  %1 = index.sub %a, %b
+  // CHECK: spirv.IMul
+  %2 = index.mul %a, %b
+  // CHECK: spirv.SDiv
+  %3 = index.divs %a, %b
+  // CHECK: spirv.UDiv
+  %4 = index.divu %a, %b
+  // CHECK: spirv.SRem
+  %5 = index.rems %a, %b
+  // CHECK: spirv.UMod
+  %6 = index.remu %a, %b
+  // CHECK: spirv.GL.SMax
+  %7 = index.maxs %a, %b
+  // CHECK: spirv.GL.UMax
+  %8 = index.maxu %a, %b
+  // CHECK: spirv.GL.SMin
+  %9 = index.mins %a, %b
+  // CHECK: spirv.GL.UMin
+  %10 = index.minu %a, %b
+  // CHECK: spirv.ShiftLeftLogical
+  %11 = index.shl %a, %b
+  // CHECK: spirv.ShiftRightArithmetic
+  %12 = index.shrs %a, %b
+  // CHECK: spirv.ShiftRightLogical
+  %13 = index.shru %a, %b
+  return
+}
+
+// CHECK-LABEL: @bitwise_ops
+func.func @bitwise_ops(%a: index, %b: index) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = index.and %a, %b
+  // CHECK: spirv.BitwiseOr
+  %1 = index.or %a, %b
+  // CHECK: spirv.BitwiseXor
+  %2 = index.xor %a, %b
+  return
+}
+
+// INDEX32-LABEL: @constant_ops
+// INDEX64-LABEL: @constant_ops
+func.func @constant_ops() {
+  // INDEX32: spirv.Constant 42 : i32
+  // INDEX64: spirv.Constant 42 : i64
+  %0 = index.constant 42
+  // INDEX32: spirv.Constant true
+  // INDEX64: spirv.Constant true
+  %1 = index.bool.constant true
+  // INDEX32: spirv.Constant false
+  // INDEX64: spirv.Constant false
+  %2 = index.bool.constant false
+  return
+}
+
+// CHECK-LABEL: @ceildivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_POS:.*]] = spirv.SGreaterThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]]
+
+  // CHECK: %[[N_PLUS_X:.*]] = spirv.IAdd %[[N]], %[[X]]
+  // CHECK: %[[N_PLUS_X_DIV_M:.*]] = spirv.SDiv %[[N_PLUS_X]], %[[M]]
+  // CHECK: %[[POS_RES:.*]] = spirv.IAdd %[[N_PLUS_X_DIV_M]], %[[POS_ONE]]
+
+  // CHECK: %[[NEG_N:.*]] = spirv.ISub %[[ZERO]], %[[N]]
+  // CHECK: %[[NEG_N_DIV_M:.*]] = spirv.SDiv %[[NEG_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[ZERO]], %[[NEG_N_DIV_M]]
+
+  // CHECK: %[[N_POS:.*]] = spirv.SGreaterThan %[[N]], %[[ZERO]]
+  // CHECK: %[[SAME_SIGN:.*]] = spirv.LogicalEqual %[[N_POS]], %[[M_POS]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[SAME_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.ceildivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @ceildivu
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivu(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[ONE:.*]] = spirv.Constant 1
+
+  // CHECK: %[[N_MINUS_ONE:.*]] = spirv.ISub %[[N]], %[[ONE]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M:.*]] = spirv.UDiv %[[N_MINUS_ONE]], %[[M]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M_PLUS_ONE:.*]] = spirv.IAdd %[[N_MINUS_ONE_DIV_M]], %[[ONE]]
+
+  // CHECK: %[[CMP:.*]] = spirv.IEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[ZERO]], %[[N_MINUS_ONE_DIV_M_PLUS_ONE]]
+  %result = index.ceildivu %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @floordivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @floordivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_NEG:.*]] = spirv.SLessThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]
+
+  // CHECK: %[[X_MINUS_N:.*]] = spirv.ISub %[[X]], %[[N]]
+  // CHECK: %[[X_MINUS_N_DIV_M:.*]] = spirv.SDiv %[[X_MINUS_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]]
+
+  // CHECK: %[[POS_RES:.*]] = spirv.SDiv %[[N]], %[[M]]
+
+  // CHECK: %[[N_NEG:.*]] = spirv.SLessThan %[[N]], %[[ZERO]]
+  // CHECK: %[[DIFF_SIGN:.*]] = spirv.LogicalNotEqual %[[N_NEG]], %[[M_NEG]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[DIFF_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.floordivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @index_cmp
+func.func @index_cmp(%a : index, %b : index) {
+  // CHECK: spirv.IEqual
+  %0 = index.cmp eq(%a, %b)
+  // CHECK: spirv.INotEqual
+  %1 = index.cmp ne(%a, %b)
+
+  // CHECK: spirv.SLessThan
+  %2 = index.cmp slt(%a, %b)
+  // CHECK: spirv.SLessThanEqual
+  %3 = index.cmp sle(%a, %b)
+  // CHECK: spirv.SGreaterThan
+  %4 = index.cmp sgt(%a, %b)
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = index.cmp sge(%a, %b)
+
+  // CHECK: spirv.ULessThan
+  %6 = index.cmp ult(%a, %b)
+  // CHECK: spirv.ULessThanEqual
+  %7 = index.cmp ule(%a, %b)
+  // CHECK: spirv.UGreaterThan
+  %8 = index.cmp ugt(%a, %b)
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = index.cmp uge(%a, %b)
+  return
+}
+
+// CHECK-LABEL: @index_sizeof
+func.func @index_sizeof() {
+  // CHECK: spirv.Constant 32 : i32
+  %0 = index.sizeof
+  return
+}
+
+// INDEX32-LABEL: @index_cast_from
+// INDEX64-LABEL: @index_cast_from
+// INDEX32-SAME: %[[AI:.*]]: index
+// INDEX64-SAME: %[[AI:.*]]: index
+func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) {
+  // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32
+  // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64
+
+  // INDEX32: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : index to i64
+  // INDEX64: %[[V1:.*]] = spirv.SConvert %[[A]] : i64 to i32
+  %1 = index.casts %a : index to i32
+  // INDEX32: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : index to i64
+  // INDEX64: %[[V3:.*]] = spirv.UConvert %[[A]] : i64 to i32
+  %3 = index.castu %a : index to i32
+
+  // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]]
+  // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]]
+  return %0, %1, %2, %3 : i64, i32, i64, i32
+}
+
+// INDEX32-LABEL: @index_cast_to
+// INDEX64-LABEL: @index_cast_to
+// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) {
+  // INDEX64: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : i32 to index
+  // INDEX32: %[[V1:.*]] = spirv.SConvert %[[B]] : i64 to i32
+  %1 = index.casts %b : i64 to index
+  // INDEX64: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : i32 to index
+  // INDEX32: %[[V3:.*]] = spirv.UConvert %[[B]] : i64 to i32
+  %3 = index.castu %b : i64 to index
+  return %0, %1, %2, %3 : index, index, index, index
+}
+}
diff --git a/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
new file mode 100644
index 000000000000000..68a825fbd93ebde
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: @forward
+func.func @forward() {
+  // CHECK: %[[LB:.*]] = spirv.Constant 0 : i32
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[UB:.*]] = spirv.Constant 32 : i32
+  %c32 = arith.constant 32 : index
+  // CHECK: %[[STEP:.*]] = spirv.Constant 1 : i32
+  %c1 = arith.constant 1 : index
+
+  // CHECK:      spirv.mlir.loop {
+  // CHECK-NEXT:   spirv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
+  // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
+  // CHECK:        %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
+  // CHECK:        spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+  // CHECK:      ^[[BODY]]:
+  // CHECK:        %[[X:.*]] = spirv.IAdd %[[INDVAR]], %[[INDVAR]] : i32
+  // CHECK:        %[[INDNEXT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
+  // CHECK:        spirv.Branch ^[[HEADER]](%[[INDNEXT]] : i32)
+  // CHECK:      ^[[MERGE]]:
+  // CHECK:        spirv.mlir.merge
+  // CHECK:      }
+  scf.for %arg2 = %c0 to %c32 step %c1 {
+      %1 = index.add %arg2, %arg2
+  }
+  return
+}

>From 782e8eb34ff95e7643ca19bf35a4213fab520069 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sat, 21 Oct 2023 00:22:14 +0200
Subject: [PATCH 2/2] add SPIR-V Dialect to linked libs for build

---
 mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
index 1da0e0253501fec..e3b279d915a15dd 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_conversion_library(MLIRIndexToSPIRV
 
   LINK_LIBS PUBLIC
   MLIRIndexDialect
+  MLIRSPIRVDialect
   )



More information about the Mlir-commits mailing list