summaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
blob: fb29c315bea25163ff44b6efdd56e6f9879cc13f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>

// CHECK-LABEL:   func.func @for(
// CHECK-SAME:                   %[[VAL_1:.*0]]: memref<?xindex>,
// CHECK-SAME:                   %[[VAL_2:.*1]]: memref<?xindex>,
// CHECK-SAME:                   %[[VAL_3:.*2]]: memref<?xf32>,
// CHECK-SAME:                   %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME:                   %[[VAL_5:.*4]]: index,
// CHECK-SAME:                   %[[VAL_6:.*5]]: index,
// CHECK-SAME:                   %[[VAL_7:.*6]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args(
// CHECK-SAME:        %[[VAL_11:.*]] = %[[VAL_1]],
// CHECK-SAME:        %[[VAL_12:.*]] = %[[VAL_2]],
// CHECK-SAME:        %[[VAL_13:.*]] = %[[VAL_3]],
// CHECK-SAME:        %[[VAL_14:.*]] = %[[VAL_4]])
// CHECK:             scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] :
// CHECK:           }
// CHECK:           return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3
func.func @for(%in: tensor<1024xf32, #SparseVector>,
               %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
  %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
     -> tensor<1024xf32, #SparseVector> {
    scf.yield %vin : tensor<1024xf32, #SparseVector>
  }
  return %1 : tensor<1024xf32, #SparseVector>
}

// CHECK-LABEL:   func.func @if(
// CHECK-SAME:                  %[[VAL_1:.*0]]: memref<?xindex>,
// CHECK-SAME:                  %[[VAL_2:.*1]]: memref<?xindex>,
// CHECK-SAME:                  %[[VAL_3:.*2]]: memref<?xf32>,
// CHECK-SAME:                  %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME:                  %[[VAL_6:.*4]]: memref<?xindex>,
// CHECK-SAME:                  %[[VAL_7:.*5]]: memref<?xindex>,
// CHECK-SAME:                  %[[VAL_8:.*6]]: memref<?xf32>,
// CHECK-SAME:                  %[[VAL_9:.*7]]: !sparse_tensor.storage_specifier
// CHECK-SAME:                  %[[VAL_10:.*]]: i1)
// CHECK:           %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]]
// CHECK:             scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]
// CHECK:           } else {
// CHECK:             scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]]
// CHECK:           }
// CHECK:           return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 :
// CHECK-SAME:        memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
func.func @if(%t: tensor<1024xf32, #SparseVector>,
              %f: tensor<1024xf32, #SparseVector>,
              %c: i1) -> tensor<1024xf32, #SparseVector> {
  %1 = scf.if %c -> tensor<1024xf32, #SparseVector> {
    scf.yield %t : tensor<1024xf32, #SparseVector>
  } else {
    scf.yield %f : tensor<1024xf32, #SparseVector>
  }
  return %1 : tensor<1024xf32, #SparseVector>
}


// CHECK-LABEL:   func.func @while(
// CHECK-SAME:                     %[[VAL_1:.*0]]: memref<?xindex>,
// CHECK-SAME:                     %[[VAL_2:.*1]]: memref<?xindex>,
// CHECK-SAME:                     %[[VAL_3:.*2]]: memref<?xf32>,
// CHECK-SAME:                     %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME:                     %[[VAL_5:.*4]]: i1)
// CHECK:           %[[VAL_6:.*]]:4 = scf.while (
// CHECK-SAME:        %[[VAL_8:.*]] = %[[VAL_1]],
// CHECK-SAME:        %[[VAL_9:.*]] = %[[VAL_2]],
// CHECK-SAME:        %[[VAL_10:.*]] = %[[VAL_3]],
// CHECK-SAME:        %[[VAL_11:.*]] = %[[VAL_4]])
// CHECK:             scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]]
// CHECK:           } do {
// CHECK:           ^bb0(%[[VAL_13:.*5]]: memref<?xindex>,
// CHECK-SAME:           %[[VAL_14:.*6]]: memref<?xindex>,
// CHECK-SAME:           %[[VAL_15:.*7]]: memref<?xf32>,
// CHECK-SAME:           %[[VAL_16:.*8]]: !sparse_tensor.storage_specifier
// CHECK:             scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]]
// CHECK:           }
// CHECK:           return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 :
// CHECK-SAME:        memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
  %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
    scf.condition(%c) %in : tensor<1024xf32, #SparseVector>
  } do {
  ^bb0(%arg1: tensor<1024xf32, #SparseVector>):
    scf.yield %arg1 : tensor<1024xf32, #SparseVector>
  }
  return %0: tensor<1024xf32, #SparseVector>
}