[Mlir-commits] [mlir] cf60d3f - [mlir][arith] Extend the `floordivsi` converter

Andrzej Warzynski llvmlistbot at llvm.org
Fri Mar 24 01:49:26 PDT 2023


Author: Andrzej Warzynski
Date: 2023-03-24T08:48:55Z
New Revision: cf60d3f1a688671c8eb7859bf0572c403c3c0cca

URL: https://github.com/llvm/llvm-project/commit/cf60d3f1a688671c8eb7859bf0572c403c3c0cca
DIFF: https://github.com/llvm/llvm-project/commit/cf60d3f1a688671c8eb7859bf0572c403c3c0cca.diff

LOG: [mlir][arith] Extend the `floordivsi` converter

This patch extends the `createConst` method so that it can generate
constant vectors (it can already generate scalars). This change is
required to be able to apply the converter for `arith.floordivsi`
(i.e. `FloorDivSIOpConverter`) to vectors.

While `arith.floordivsi` is my main motivation for this change, this
patch should also allow other Arith ops to be converted in vector cases.
In my example, the Linalg vectorizer updates `arith.floordivsi` to
operate on vectors and hence the need for this change.

Differential Revision: https://reviews.llvm.org/D146741

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index ee561e655965f..c5b80346bd52f 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -31,6 +31,7 @@ def ArithBufferize : Pass<"arith-bufferize", "ModuleOp"> {
 def ArithExpandOps : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let constructor = "mlir::arith::createArithExpandOpsPass()";
+  let dependentDialects = ["vector::VectorDialect"];
 }
 
 def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index b70110cbce913..8f34531937c5c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -24,8 +25,15 @@ using namespace mlir;
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
-  return rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getIntegerAttr(type, value));
+
+  auto elTy = getElementTypeOrSelf(type);
+  auto constantAttr = rewriter.getIntegerAttr(elTy, value);
+
+  if (auto vecTy = llvm::dyn_cast<ShapedType>(type))
+    return rewriter.create<arith::ConstantOp>(
+        loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr));
+
+  return rewriter.create<arith::ConstantOp>(loc, constantAttr);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 3d55c2068f24e..7b7eb4003956a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -114,6 +114,34 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
 
 // -----
 
+// Test floor divide with vector
+// CHECK-LABEL:   func.func @floordivi_vec(
+// CHECK-SAME:                             %[[VAL_0:.*]]: vector<4xi32>,
+// CHECK-SAME:                             %[[VAL_1:.*]]: vector<4xi32>) -> vector<4xi32> {
+func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
+  %res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
+  return %res : vector<4xi32>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
+// CHECK:           %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
+// CHECK:           %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
+// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
+// CHECK:           %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
+// CHECK:           %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
+// CHECK:           %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
+// CHECK:           %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
+// CHECK:           %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
+// CHECK:           %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
+// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
+// CHECK:           %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
+// CHECK:           %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
+}
+
+// -----
+
 // Test ceil divide with unsigned integer
 // CHECK-LABEL:       func @ceildivui
 // CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {


        


More information about the Mlir-commits mailing list