Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Optimize counter bitwidth in Foreach control #293

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions apps/src/TestForeachCounter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import spatial.dsl._

@spatial object TestForeachCounter extends SpatialApp {

def main(args: Array[String]): Unit = {
// Loop upper bound
val N = 128

// The DRAM
val d = DRAM[Int](N)

// DRAM content
val data = Array.fill[Int](N)(0)
setMem(d, data)

Accel {
val s = SRAM[Int](N)

s load d(0::N)

Foreach(N by 1) { i => s(i) = s(i) + i }

d(0::N) store s
}

printArray(getMem(d), "Result: ")
}
}
3 changes: 3 additions & 0 deletions src/spatial/Spatial.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ trait Spatial extends Compiler with ParamLoader {
lazy val retiming = RetimingTransformer(state)
lazy val accumTransformer = AccumTransformer(state)
lazy val regReadCSE = RegReadCSE(state)
lazy val counterBitwidth = CounterBitwidthTransformer(state)

// --- Codegen
lazy val chiselCodegen = ChiselGen(state)
Expand Down Expand Up @@ -156,6 +157,8 @@ trait Spatial extends Compiler with ParamLoader {
/** Dead code elimination */
useAnalyzer ==>
transientCleanup ==> printer ==> transformerChecks ==>
// Counter bitwidth improvement
counterBitwidth ==> printer ==>
/** Stream controller rewrites */
(spatialConfig.distributeStreamCtr ? streamTransformer) ==> printer ==>
/** Memory analysis */
Expand Down
82 changes: 82 additions & 0 deletions src/spatial/transform/CounterBitwidthTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package spatial.transform

import argon._
import argon.node._
import argon.transform.MutateTransformer

import spatial.lang._
import spatial.node._
import spatial.util.shouldMotionFromConditional
import spatial.traversal.AccelTraversal
import spatial.metadata.control._
import spatial.metadata.memory._
import spatial.metadata.blackbox._

import utils.math.log2Up

import emul.FixedPoint

case class CounterBitwidthTransformer(IR: State) extends MutateTransformer
with AccelTraversal {

/** Calculate the least bitwidth required for an integer. */
private def getBitwidth(x: Int): Int = log2Up(x.abs)

/** Extract the content from Const and cast it to Int. */
private def constToInt(x: Sym[_]): Int =
if (x.isConst)
x.c.get.asInstanceOf[FixedPoint].toInt
else
throw new Exception(s"$x is not a Const.")

/** Create a new CounterNew object with compact bitwidth. */
private def getOptimizedCounterNew(ctr: CounterNew[_]): CounterNew[_] = ctr match {
case CounterNew(start, stop, step, par) =>
// we take the largest magnitude of start and stop to decide the boundary of bitwidth
val begin = constToInt(start)
val end = constToInt(stop)
val bitwidth = math.max(getBitwidth(begin), getBitwidth(end))

// TODO: Find a better way that can map bitwidth to the exact Fix type
if (bitwidth <= 7) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattfel1 Hi there, I feel that it would be tedious to implement all the bitwidth-to-type cast and there should be a better way that I'm not aware of. Maybe you've met this scenario before and have a good way to deal with it? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I don't think there is a non-tedious way to do this.
Since each bitwidth is its own trait (argon/lang/types/CustomBitWidths.scala), its painful to work with. Some people have used quasiquotes for this problem before but there isn't a nice way that I know of.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mattfel1 ! In my latest update I manually added all the mappings for different bit-width values. Hope it looks fine.

type T = Fix[TRUE,_8,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
} else if (bitwidth <= 15) {
type T = Fix[TRUE,_16,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
} else {
type T = Fix[TRUE,_32,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
}
case _ => ctr
}

/** Optimize a list of Counter. */
private def getOptimizeCounters(ctrs: Seq[Counter[_]]): Seq[Counter[_]] = {
ctrs.map {
case Op(ctr: CounterNew[_]) => stage(getOptimizeCounterNew(ctr))
}
}

override def transform[A:Type](lhs: Sym[A], rhs: Op[A])(implicit ctx: SrcCtx): Sym[A] = rhs match {
case AccelScope(_) =>
inAccel { super.transform(lhs, rhs) }

case OpForeach(ens, cchain, blk, iters, stopWhen) if inHw =>
val newctrs = getOptimizedCounters(cchain.counters)
val newcchain = stageWithFlow(CounterChainNew(newctrs)){ lhs2 => transferData(lhs, lhs2)}

stageWithFlow(
OpForeach(
ens,
newcchain,
stageBlock{blk.stms.foreach(visit)},
iters,
stopWhen)
){lhs2 => transferData(lhs, lhs2)}

case _ =>
dbgs(s"visiting $lhs = $rhs");
super.transform(lhs, rhs)
}
}