[Mlir-commits] [mlir] 864236d - [mlir][arith] Support wide integer constant emulation

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 8 21:07:48 PDT 2022


Author: Jakub Kuderski
Date: 2022-09-09T00:04:06-04:00
New Revision: 864236d1c1fa61d1cfa95e0d59effc59a96cb06d

URL: https://github.com/llvm/llvm-project/commit/864236d1c1fa61d1cfa95e0d59effc59a96cb06d
DIFF: https://github.com/llvm/llvm-project/commit/864236d1c1fa61d1cfa95e0d59effc59a96cb06d.diff

LOG: [mlir][arith] Support wide integer constant emulation

Reviewed By: antiagainst, Mogball

Differential Revision: https://reviews.llvm.org/D133136

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index 94e321ba1ad7..2d24513e2734 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
@@ -23,7 +24,80 @@ namespace mlir::arith {
 
 using namespace mlir;
 
+// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
+// Treats `value` as a 2*N bits-wide integer.
+// The bottom bits are returned in the first pair element, while the top bits in
+// the second one.
+static std::pair<APInt, APInt> getHalves(const APInt &value,
+                                         unsigned newBitWidth) {
+  APInt low = value.extractBits(newBitWidth, 0);
+  APInt high = value.extractBits(newBitWidth, newBitWidth);
+  return {std::move(low), std::move(high)};
+}
+
 namespace {
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type oldType = op.getType();
+    auto newType = getTypeConverter()->convertType(oldType).cast<VectorType>();
+    unsigned newBitWidth = newType.getElementTypeBitWidth();
+    Attribute oldValue = op.getValueAttr();
+
+    if (auto intAttr = oldValue.dyn_cast<IntegerAttr>()) {
+      auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
+      auto newAttr = DenseElementsAttr::get(newType, {low, high});
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+      return success();
+    }
+
+    if (auto splatAttr = oldValue.dyn_cast<SplatElementsAttr>()) {
+      auto [low, high] =
+          getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
+      int64_t numSplatElems = splatAttr.getNumElements();
+      SmallVector<APInt> values;
+      values.reserve(numSplatElems * 2);
+      for (int64_t i = 0; i < numSplatElems; ++i) {
+        values.push_back(low);
+        values.push_back(high);
+      }
+
+      auto attr = DenseElementsAttr::get(newType, values);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
+      return success();
+    }
+
+    if (auto elemsAttr = oldValue.dyn_cast<DenseElementsAttr>()) {
+      int64_t numElems = elemsAttr.getNumElements();
+      SmallVector<APInt> values;
+      values.reserve(numElems * 2);
+      for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
+        auto [low, high] = getHalves(origVal, newBitWidth);
+        values.push_back(std::move(low));
+        values.push_back(std::move(high));
+      }
+
+      auto attr = DenseElementsAttr::get(newType, values);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
+      return success();
+    }
+
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "unhandled constant attribute");
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
 struct EmulateWideIntPass final
     : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
   using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
@@ -42,7 +116,11 @@ struct EmulateWideIntPass final
     target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
       return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
     });
-    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+    target.addDynamicallyLegalOp<
+        // `func.*` ops
+        func::CallOp, func::ReturnOp,
+        // `arith.*` ops
+        arith::ConstantOp>(
         [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
 
     RewritePatternSet patterns(ctx);
@@ -54,6 +132,10 @@ struct EmulateWideIntPass final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
 arith::WideIntEmulationConverter::WideIntEmulationConverter(
     unsigned widestIntSupportedByTarget)
     : maxIntWidth(widestIntSupportedByTarget) {
@@ -117,4 +199,7 @@ void arith::populateWideIntEmulationPatterns(
                                                                  typeConverter);
   populateCallOpTypeConversionPattern(patterns, typeConverter);
   populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+  // Populate `arith.*` conversion patterns.
+  patterns.add<ConvertConstant>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index aafeb5b3a1cd..3a1a6c7b2d1f 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -49,3 +49,29 @@ func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
     %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
     return %res : vector<4xi64>
 }
+
+// CHECK-LABEL: func @constant_scalar
+// CHECK-SAME:     () -> vector<2xi32>
+// CHECK-NEXT:     [[C0:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:     [[C1:%.+]] = arith.constant dense<[0, 1]> : vector<2xi32>
+// CHECK-NEXT:     [[C2:%.+]] = arith.constant dense<[-7, -1]> : vector<2xi32>
+// CHECK-NEXT:     return [[C0]] : vector<2xi32>
+func.func @constant_scalar() -> i64 {
+    %c0 = arith.constant 0 : i64
+    %c1 = arith.constant 4294967296 : i64
+    %c2 = arith.constant -7 : i64
+    return %c0 : i64
+}
+
+// CHECK-LABEL: func @constant_vector
+// CHECK-SAME:     () -> vector<3x2xi32>
+// CHECK-NEXT:     [[C0:%.+]] = arith.constant dense
+// CHECK-SAME{LITERAL}:                             <[[0, 1], [0, 1], [0, 1]]> : vector<3x2xi32>
+// CHECK-NEXT:     [[C1:%.+]] = arith.constant dense
+// CHECK-SAME{LITERAL}:                             <[[0, 0], [1, 0], [-2, -1]]> : vector<3x2xi32>
+// CHECK-NEXT:     return [[C0]] : vector<3x2xi32>
+func.func @constant_vector() -> vector<3xi64> {
+    %c0 = arith.constant dense<4294967296> : vector<3xi64>
+    %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64>
+    return %c0 : vector<3xi64>
+}


        


More information about the Mlir-commits mailing list