Skip to content

Commit

Permalink
Add test case that accumulates a stream as a memref using memref.real…
Browse files Browse the repository at this point in the history
…loc.
  • Loading branch information
ingomueller-net committed Apr 11, 2023
1 parent d49ed91 commit db91196
Showing 1 changed file with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
// RUN: -convert-iterators-to-llvm \
// RUN: -decompose-iterator-states \
// RUN: -decompose-tuples \
// RUN: -convert-tabular-to-llvm \
// RUN: -inline -canonicalize \
// RUN: -arith-bufferize \
// RUN: -expand-strided-metadata \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -convert-scf-to-cf -convert-cf-to-llvm \
// RUN: -convert-scf-to-cf \
// RUN: -convert-cf-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: | FileCheck %s

Expand Down Expand Up @@ -72,8 +79,53 @@ func.func @test_accumulate_avg_tuple() {
return
}

!memref_i32 = memref<?xi32>

func.func private @accumulate_realloc(
%acc : !memref_i32, %val : tuple<i32>) -> !memref_i32 {
%zero = arith.constant 0 : index
%one = arith.constant 1 : index
%dim = memref.dim %acc, %zero : !memref_i32
%new_dim = arith.addi %one, %dim : index
%realloced = memref.realloc %acc (%new_dim) : !memref_i32 to !memref_i32
%vali = tuple.to_elements %val : tuple<i32>
memref.store %vali, %realloced[%dim] : !memref_i32
return %realloced : !memref_i32
}

// CHECK-LABEL: test_accumulate_realloc
// CHECK-NEXT: (9)
// CHECK-NEXT: (8)
// CHECK-NEXT: (7)
// CHECK-NEXT: -
func.func @test_accumulate_realloc() {
iterators.print("test_accumulate_realloc")
%tensor = arith.constant dense<[9, 8, 7]> : tensor<3xi32>
%memref = bufferization.to_memref %tensor : memref<3xi32>
%view = "tabular.view_as_tabular"(%memref)
: (memref<3xi32>) -> !tabular.tabular_view<i32>
%stream = iterators.tabular_view_to_stream %view
to !iterators.stream<tuple<i32>>
%zero = arith.constant 0 : index
%alloced = memref.alloc (%zero) : !memref_i32
%accumulated = iterators.accumulate(%stream, %alloced)
with @accumulate_realloc
: (!iterators.stream<tuple<i32>>) -> !iterators.stream<!memref_i32>
%result:2 = iterators.stream_to_value %accumulated :
!iterators.stream<!memref_i32>
scf.if %result#1 {
%result_view = "tabular.view_as_tabular"(%result#0)
: (memref<?xi32>) -> !tabular.tabular_view<i32>
%result_stream = iterators.tabular_view_to_stream %result_view
to !iterators.stream<tuple<i32>>
"iterators.sink"(%result_stream) : (!iterators.stream<tuple<i32>>) -> ()
}
return
}

func.func @main() {
call @test_accumulate_sum_tuple() : () -> ()
call @test_accumulate_avg_tuple() : () -> ()
call @test_accumulate_realloc() : () -> ()
return
}

0 comments on commit db91196

Please sign in to comment.