[Mlir-commits] [mlir] 3c07a21 - [mlir][index][spirv] Add conversion for index to spirv (#68085)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 08:46:59 PDT 2023


Author: Finn Plummer
Date: 2023-10-20T11:46:54-04:00
New Revision: 3c07a21683c3c0a30aef4a47dde31cb85ad9830c

URL: https://github.com/llvm/llvm-project/commit/3c07a21683c3c0a30aef4a47dde31cb85ad9830c
DIFF: https://github.com/llvm/llvm-project/commit/3c07a21683c3c0a30aef4a47dde31cb85ad9830c.diff

LOG: [mlir][index][spirv] Add conversion for index to spirv (#68085)

Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a conversion
pass from the Index dialect to the SPIR-V dialect. Furthermore, we add
the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713

Added: 
    mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
    mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
    mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
    mlir/test/Conversion/SCFToSPIRV/use-indices.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
new file mode 100644
index 000000000000000..58a1c5246eef999
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -0,0 +1,30 @@
+//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class SPIRVTypeConverter;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                  RewritePatternSet &patterns);
+std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..c13c457fd97492a 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -35,6 +35,7 @@
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a269fb4a83af41f..274784fe4a7b29c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToSPIRVPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
+  let summary = "Lower the `index` dialect to the `spirv` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to SPIR-V dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
+    specified by use-64bit-index.
+  }];
+
+  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
+
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToStandard
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 89ded981d38f9f4..933d62e35fce8cd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Use 64-bit integers to convert index types.
-  bool use64bitIndex{false};
-
   /// Whether to enable fast math mode during conversion. If true, various
   /// patterns would assume no NaN/infinity numbers as inputs, and thus there
   /// will be no special guards emitted to check and handle such cases.
   bool enableFastMathMode{false};
+
+  /// Use 64-bit integers when converting index types.
+  bool use64bitIndex{false};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
+  /// Gets the bitwidth of the index type when converted to SPIR-V.
+  unsigned getIndexTypeBitwidth() const {
+    return options.use64bitIndex ? 64 : 32;
+  }
+
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..7e1c7bcf9a8678a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
+add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)

diff  --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000000..1da0e0253501fec
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRIndexToSPIRV
+  IndexToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  )

diff  --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
new file mode 100644
index 000000000000000..b58efc096e2eafb
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -0,0 +1,418 @@
+//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
+using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
+using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
+using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
+using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
+using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
+using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
+using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
+using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
+using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
+using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
+
+using ConvertIndexShl =
+    spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
+using ConvertIndexShrS =
+    spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
+using ConvertIndexShrU =
+    spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
+
+/// It is the case that when we convert bitwise operations to SPIR-V operations
+/// we must take into account the special pattern in SPIR-V that if the
+/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
+/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
+/// index.add is never a boolean operation so we can directly convert it to the
+/// Bitwise[And|Or]Op.
+using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
+using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
+using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertConstantBool
+//===----------------------------------------------------------------------===//
+
+// Converts index.bool.constant operation to spirv.Constant.
+struct ConvertIndexConstantBoolOpPattern final
+    : OpConversionPattern<BoolConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
+                                                   op.getValueAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
+// when required.
+struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
+/// conversion in IndexToLLVM.
+struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
+    Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
+/// from the equivalent conversion in IndexToLLVM.
+struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
+                                                   IntegerAttr::get(n_type, 1));
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
+    Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+
+    // Pick the result
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
+/// in IndexToLLVM.
+struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result
+    Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have 
diff erent signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
+    Value 
diff Sign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, 
diff Sign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ConvertOp>
+struct ConvertIndexCast final : OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    Type srcType = adaptor.getInput().getType();
+    Type dstType = op.getType();
+    if (isa<IndexType>(srcType)) {
+      srcType = indexType;
+    }
+    if (isa<IndexType>(dstType)) {
+      dstType = indexType;
+    }
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
+                                                      adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+// Helper template to replace the operation
+template <typename ICmpOp>
+static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
+  return success();
+}
+
+struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We must convert the predicates to the corresponding int comparions.
+    switch (op.getPred()) {
+    case IndexCmpPredicate::EQ:
+      return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::NE:
+      return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGE:
+      return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGT:
+      return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLE:
+      return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLT:
+      return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGE:
+      return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGT:
+      return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULE:
+      return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULT:
+      return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
+struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+    unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, bitwidth));
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                         RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+    ConvertIndexAdd,
+    ConvertIndexSub,
+    ConvertIndexMul,
+    ConvertIndexDivS,
+    ConvertIndexDivU,
+    ConvertIndexRemS,
+    ConvertIndexRemU,
+    ConvertIndexMaxS,
+    ConvertIndexMaxU,
+    ConvertIndexMinS,
+    ConvertIndexMinU,
+    ConvertIndexShl,
+    ConvertIndexShrS,
+    ConvertIndexShrU,
+    ConvertIndexAnd,
+    ConvertIndexOr,
+    ConvertIndexXor,
+    ConvertIndexConstantBoolOpPattern,
+    ConvertIndexConstantOpPattern,
+    ConvertIndexCeilDivSPattern,
+    ConvertIndexCeilDivUPattern,
+    ConvertIndexFloorDivSPattern,
+    ConvertIndexCastS,
+    ConvertIndexCastU,
+    ConvertIndexCmpPattern,
+    ConvertIndexSizeOf
+  >(typeConverter, patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertIndexToSPIRVPass
+    : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<SPIRVConversionTarget> target =
+      SPIRVConversionTarget::get(targetAttr);
+
+    SPIRVConversionOptions options;
+    options.use64bitIndex = this->use64bitIndex;
+    SPIRVTypeConverter typeConverter(targetAttr, options);
+
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    target->addLegalOp<UnrealizedConversionCastOp>();
+
+    // Allow the spirv operations we are converting to
+    target->addLegalDialect<spirv::SPIRVDialect>();
+    // Fail hard when there are any remaining 'index' ops.
+    target->addIllegalDialect<index::IndexDialect>();
+
+    RewritePatternSet patterns(&getContext());
+    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 1e8fe4423a422c5..3ef1d84ee264771 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -52,6 +53,7 @@ void SCFToSPIRVPass::runOnOperation() {
   populateFuncToSPIRVPatterns(typeConverter, patterns);
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();

diff  --git a/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
new file mode 100644
index 000000000000000..53dc896e98c7d67
--- /dev/null
+++ b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt %s -convert-index-to-spirv | FileCheck %s
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=false | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=true | FileCheck %s --check-prefix=INDEX64
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: @trivial_ops
+func.func @trivial_ops(%a: index, %b: index) {
+  // CHECK: spirv.IAdd
+  %0 = index.add %a, %b
+  // CHECK: spirv.ISub
+  %1 = index.sub %a, %b
+  // CHECK: spirv.IMul
+  %2 = index.mul %a, %b
+  // CHECK: spirv.SDiv
+  %3 = index.divs %a, %b
+  // CHECK: spirv.UDiv
+  %4 = index.divu %a, %b
+  // CHECK: spirv.SRem
+  %5 = index.rems %a, %b
+  // CHECK: spirv.UMod
+  %6 = index.remu %a, %b
+  // CHECK: spirv.GL.SMax
+  %7 = index.maxs %a, %b
+  // CHECK: spirv.GL.UMax
+  %8 = index.maxu %a, %b
+  // CHECK: spirv.GL.SMin
+  %9 = index.mins %a, %b
+  // CHECK: spirv.GL.UMin
+  %10 = index.minu %a, %b
+  // CHECK: spirv.ShiftLeftLogical
+  %11 = index.shl %a, %b
+  // CHECK: spirv.ShiftRightArithmetic
+  %12 = index.shrs %a, %b
+  // CHECK: spirv.ShiftRightLogical
+  %13 = index.shru %a, %b
+  return
+}
+
+// CHECK-LABEL: @bitwise_ops
+func.func @bitwise_ops(%a: index, %b: index) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = index.and %a, %b
+  // CHECK: spirv.BitwiseOr
+  %1 = index.or %a, %b
+  // CHECK: spirv.BitwiseXor
+  %2 = index.xor %a, %b
+  return
+}
+
+// INDEX32-LABEL: @constant_ops
+// INDEX64-LABEL: @constant_ops
+func.func @constant_ops() {
+  // INDEX32: spirv.Constant 42 : i32
+  // INDEX64: spirv.Constant 42 : i64
+  %0 = index.constant 42
+  // INDEX32: spirv.Constant true
+  // INDEX64: spirv.Constant true
+  %1 = index.bool.constant true
+  // INDEX32: spirv.Constant false
+  // INDEX64: spirv.Constant false
+  %2 = index.bool.constant false
+  return
+}
+
+// CHECK-LABEL: @ceildivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_POS:.*]] = spirv.SGreaterThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]]
+
+  // CHECK: %[[N_PLUS_X:.*]] = spirv.IAdd %[[N]], %[[X]]
+  // CHECK: %[[N_PLUS_X_DIV_M:.*]] = spirv.SDiv %[[N_PLUS_X]], %[[M]]
+  // CHECK: %[[POS_RES:.*]] = spirv.IAdd %[[N_PLUS_X_DIV_M]], %[[POS_ONE]]
+
+  // CHECK: %[[NEG_N:.*]] = spirv.ISub %[[ZERO]], %[[N]]
+  // CHECK: %[[NEG_N_DIV_M:.*]] = spirv.SDiv %[[NEG_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[ZERO]], %[[NEG_N_DIV_M]]
+
+  // CHECK: %[[N_POS:.*]] = spirv.SGreaterThan %[[N]], %[[ZERO]]
+  // CHECK: %[[SAME_SIGN:.*]] = spirv.LogicalEqual %[[N_POS]], %[[M_POS]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[SAME_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.ceildivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @ceildivu
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @ceildivu(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[ONE:.*]] = spirv.Constant 1
+
+  // CHECK: %[[N_MINUS_ONE:.*]] = spirv.ISub %[[N]], %[[ONE]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M:.*]] = spirv.UDiv %[[N_MINUS_ONE]], %[[M]]
+  // CHECK: %[[N_MINUS_ONE_DIV_M_PLUS_ONE:.*]] = spirv.IAdd %[[N_MINUS_ONE_DIV_M]], %[[ONE]]
+
+  // CHECK: %[[CMP:.*]] = spirv.IEqual %[[N]], %[[ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[ZERO]], %[[N_MINUS_ONE_DIV_M_PLUS_ONE]]
+  %result = index.ceildivu %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @floordivs
+// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
+func.func @floordivs(%n: index, %m: index) -> index {
+  // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
+  // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0
+  // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1
+  // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1
+
+  // CHECK: %[[M_NEG:.*]] = spirv.SLessThan %[[M]], %[[ZERO]]
+  // CHECK: %[[X:.*]] = spirv.Select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]
+
+  // CHECK: %[[X_MINUS_N:.*]] = spirv.ISub %[[X]], %[[N]]
+  // CHECK: %[[X_MINUS_N_DIV_M:.*]] = spirv.SDiv %[[X_MINUS_N]], %[[M]]
+  // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]]
+
+  // CHECK: %[[POS_RES:.*]] = spirv.SDiv %[[N]], %[[M]]
+
+  // CHECK: %[[N_NEG:.*]] = spirv.SLessThan %[[N]], %[[ZERO]]
+  // CHECK: %[[DIFF_SIGN:.*]] = spirv.LogicalNotEqual %[[N_NEG]], %[[M_NEG]]
+  // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]]
+
+  // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[DIFF_SIGN]], %[[N_NON_ZERO]]
+  // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
+  %result = index.floordivs %n, %m
+
+  // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
+  // return %[[RESULTI]]
+  return %result : index
+}
+
+// CHECK-LABEL: @index_cmp
+func.func @index_cmp(%a : index, %b : index) {
+  // CHECK: spirv.IEqual
+  %0 = index.cmp eq(%a, %b)
+  // CHECK: spirv.INotEqual
+  %1 = index.cmp ne(%a, %b)
+
+  // CHECK: spirv.SLessThan
+  %2 = index.cmp slt(%a, %b)
+  // CHECK: spirv.SLessThanEqual
+  %3 = index.cmp sle(%a, %b)
+  // CHECK: spirv.SGreaterThan
+  %4 = index.cmp sgt(%a, %b)
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = index.cmp sge(%a, %b)
+
+  // CHECK: spirv.ULessThan
+  %6 = index.cmp ult(%a, %b)
+  // CHECK: spirv.ULessThanEqual
+  %7 = index.cmp ule(%a, %b)
+  // CHECK: spirv.UGreaterThan
+  %8 = index.cmp ugt(%a, %b)
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = index.cmp uge(%a, %b)
+  return
+}
+
+// CHECK-LABEL: @index_sizeof
+func.func @index_sizeof() {
+  // CHECK: spirv.Constant 32 : i32
+  %0 = index.sizeof
+  return
+}
+
+// INDEX32-LABEL: @index_cast_from
+// INDEX64-LABEL: @index_cast_from
+// INDEX32-SAME: %[[AI:.*]]: index
+// INDEX64-SAME: %[[AI:.*]]: index
+func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) {
+  // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32
+  // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64
+
+  // INDEX32: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : index to i64
+  // INDEX64: %[[V1:.*]] = spirv.SConvert %[[A]] : i64 to i32
+  %1 = index.casts %a : index to i32
+  // INDEX32: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : index to i64
+  // INDEX64: %[[V3:.*]] = spirv.UConvert %[[A]] : i64 to i32
+  %3 = index.castu %a : index to i32
+
+  // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]]
+  // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]]
+  return %0, %1, %2, %3 : i64, i32, i64, i32
+}
+
+// INDEX32-LABEL: @index_cast_to
+// INDEX64-LABEL: @index_cast_to
+// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
+func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) {
+  // INDEX64: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64
+  %0 = index.casts %a : i32 to index
+  // INDEX32: %[[V1:.*]] = spirv.SConvert %[[B]] : i64 to i32
+  %1 = index.casts %b : i64 to index
+  // INDEX64: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64
+  %2 = index.castu %a : i32 to index
+  // INDEX32: %[[V3:.*]] = spirv.UConvert %[[B]] : i64 to i32
+  %3 = index.castu %b : i64 to index
+  return %0, %1, %2, %3 : index, index, index, index
+}
+}

diff  --git a/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
new file mode 100644
index 000000000000000..68a825fbd93ebde
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: @forward
+func.func @forward() {
+  // CHECK: %[[LB:.*]] = spirv.Constant 0 : i32
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[UB:.*]] = spirv.Constant 32 : i32
+  %c32 = arith.constant 32 : index
+  // CHECK: %[[STEP:.*]] = spirv.Constant 1 : i32
+  %c1 = arith.constant 1 : index
+
+  // CHECK:      spirv.mlir.loop {
+  // CHECK-NEXT:   spirv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
+  // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
+  // CHECK:        %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
+  // CHECK:        spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+  // CHECK:      ^[[BODY]]:
+  // CHECK:        %[[X:.*]] = spirv.IAdd %[[INDVAR]], %[[INDVAR]] : i32
+  // CHECK:        %[[INDNEXT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
+  // CHECK:        spirv.Branch ^[[HEADER]](%[[INDNEXT]] : i32)
+  // CHECK:      ^[[MERGE]]:
+  // CHECK:        spirv.mlir.merge
+  // CHECK:      }
+  scf.for %arg2 = %c0 to %c32 step %c1 {
+      %1 = index.add %arg2, %arg2
+  }
+  return
+}


        


More information about the Mlir-commits mailing list