[Mlir-commits] [mlir] [mlir][spirv] Lower `arith` overflow flags to corresponding SPIR-V op decorations (PR #77714)

Ivan Butygin llvmlistbot at llvm.org
Thu Jan 11 05:41:40 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/77714

>From 84f6bb9e98c93748ec35d3990fbc48bf92330a28 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 11 Jan 2024 01:57:42 +0100
Subject: [PATCH 1/3] [mlir][spirv] Lower `arith` overflow flags to
 corresponding op decoration.

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 59 ++++++++++++++++++-
 .../ArithToSPIRV/arith-to-spirv.mlir          | 40 +++++++++++++
 2 files changed, 96 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index aba6a21deccb0c..3b851604f597af 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -158,8 +158,61 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
   return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
 }
 
+// TODO: Move to some common place?
+static std::string getDecorationString(spirv::Decoration decor) {
+  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
+}
+
 namespace {
 
+/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
+/// operations. Op can potentially support overflow flags.
+template <typename Op, typename SPIRVOp>
+struct ElementwiseArithOpPattern : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() <= 3);
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+    }
+
+    if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+        !getElementTypeOrSelf(op.getType()).isIndex() &&
+        dstType != op.getType()) {
+      return op.emitError("bitwidth emulation is not implemented yet on "
+                          "unsigned op pattern version");
+    }
+
+    auto converter = this->getTypeConverter<SPIRVTypeConverter>();
+    auto overflowFlags = arith::IntegerOverflowFlags::none;
+    if (auto overflowIface =
+            dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
+      if (converter->getTargetEnv().allows(
+              spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
+        overflowFlags = overflowIface.getOverflowAttr().getValue();
+    }
+
+    auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+        op, dstType, adaptor.getOperands());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
+                     rewriter.getUnitAttr());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
+                     rewriter.getUnitAttr());
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -1154,9 +1207,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
-    spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
-    spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
+    ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397..8bf90ed0aec8ee 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module
+
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
+} {
+
+// No decorations should be generated is corresponding Extensions/Capabilities are missing
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module

>From 90d08201ab1cd352a020fe0ec7546154f14c755f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 11 Jan 2024 02:24:37 +0100
Subject: [PATCH 2/3] fix template

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 3b851604f597af..8ed7e584048992 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -189,7 +189,7 @@ struct ElementwiseArithOpPattern : public OpConversionPattern<Op> {
                           "unsigned op pattern version");
     }
 
-    auto converter = this->getTypeConverter<SPIRVTypeConverter>();
+    auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
     auto overflowFlags = arith::IntegerOverflowFlags::none;
     if (auto overflowIface =
             dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {

>From d77b6bbad282ba6e3cb4022666015781670cd264 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 11 Jan 2024 14:41:20 +0100
Subject: [PATCH 3/3] review comments

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 8ed7e584048992..1abad1e9fa4d85 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -168,14 +168,15 @@ namespace {
 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
 /// operations. Op can potentially support overflow flags.
 template <typename Op, typename SPIRVOp>
-struct ElementwiseArithOpPattern : public OpConversionPattern<Op> {
+struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
   using OpConversionPattern<Op>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() <= 3);
-    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
     if (!dstType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -189,7 +190,6 @@ struct ElementwiseArithOpPattern : public OpConversionPattern<Op> {
                           "unsigned op pattern version");
     }
 
-    auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
     auto overflowFlags = arith::IntegerOverflowFlags::none;
     if (auto overflowIface =
             dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {



More information about the Mlir-commits mailing list