Skip to content

Commit

Permalink
feat: Add array repeat and scan ops (#1633)
Browse files Browse the repository at this point in the history
Closes #1627
  • Loading branch information
mark-koch authored Nov 21, 2024
1 parent 6bd094f commit 649589c
Show file tree
Hide file tree
Showing 4 changed files with 659 additions and 1 deletion.
1 change: 1 addition & 0 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ lazy_static! {
NoopDef.add_to_extension(&mut prelude).unwrap();
LiftDef.add_to_extension(&mut prelude).unwrap();
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
prelude
};
/// An extension registry containing only the prelude
Expand Down
293 changes: 292 additions & 1 deletion hugr-core/src/extension/prelude/array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::str::FromStr;

use itertools::Itertools;
use strum_macros::EnumIter;
use strum_macros::EnumString;
use strum_macros::IntoStaticStr;
Expand All @@ -17,8 +20,10 @@ use crate::ops::ExtensionOp;
use crate::ops::NamedOp;
use crate::ops::OpName;
use crate::type_row;
use crate::types::FuncTypeBase;
use crate::types::FuncValueType;

use crate::types::RowVariable;
use crate::types::TypeBound;

use crate::types::Type;
Expand All @@ -28,6 +33,7 @@ use crate::extension::SignatureError;
use crate::types::PolyFuncTypeRV;

use crate::types::type_param::TypeArg;
use crate::types::TypeRV;
use crate::Extension;

use super::PRELUDE_ID;
Expand All @@ -46,6 +52,7 @@ pub enum ArrayOpDef {
pop_left,
pop_right,
discard_empty,
repeat,
}

/// Static parameters for array operations. Includes array size. Type is part of the type scheme.
Expand Down Expand Up @@ -118,6 +125,14 @@ impl ArrayOpDef {
let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()];

match self {
repeat => {
let func =
Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone()));
PolyFuncTypeRV::new(
standard_params,
FuncValueType::new(vec![func], array_ty.clone()),
)
}
get => {
let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()];
let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable);
Expand Down Expand Up @@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef {
fn description(&self) -> String {
match self {
ArrayOpDef::new_array => "Create a new array from elements",
ArrayOpDef::repeat => {
"Creates a new array whose elements are initialised by calling \
the given function n times"
}
ArrayOpDef::get => "Get an element from an array",
ArrayOpDef::set => "Set an element in an array",
ArrayOpDef::swap => "Swap two elements in an array",
Expand Down Expand Up @@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp {
);
vec![ty_arg]
}
new_array | pop_left | pop_right | get | set | swap => {
new_array | repeat | pop_left | pop_right | get | set | swap => {
vec![TypeArg::BoundedNat { n: self.size }, ty_arg]
}
}
Expand Down Expand Up @@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
op.to_extension_op().unwrap()
}

/// Name of the operation for the combined map/fold operation
pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan");

/// Definition of the array scan op.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct ArrayScanDef;

impl NamedOp for ArrayScanDef {
fn name(&self) -> OpName {
ARRAY_SCAN_OP_ID
}
}

impl FromStr for ArrayScanDef {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == ArrayScanDef.name() {
Ok(Self)
} else {
Err(())
}
}
}

impl ArrayScanDef {
/// To avoid recursion when defining the extension, take the type definition as an argument.
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
// array<N, T1>, (T1, *A -> T2, *A), -> array<N, T2>, *A
let params = vec![
TypeParam::max_nat(),
TypeBound::Any.into(),
TypeBound::Any.into(),
TypeParam::new_list(TypeBound::Any),
];
let n = TypeArg::new_var_use(0, TypeParam::max_nat());
let t1 = Type::new_var_use(1, TypeBound::Any);
let t2 = Type::new_var_use(2, TypeBound::Any);
let s = TypeRV::new_row_var_use(3, TypeBound::Any);
PolyFuncTypeRV::new(
params,
FuncTypeBase::<RowVariable>::new(
vec![
instantiate(array_def, n.clone(), t1.clone()).into(),
Type::new_function(FuncTypeBase::<RowVariable>::new(
vec![t1.into(), s.clone()],
vec![t2.clone().into(), s.clone()],
))
.into(),
s.clone(),
],
vec![instantiate(array_def, n, t2).into(), s],
),
)
.into()
}
}

impl MakeOpDef for ArrayScanDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized,
{
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}

fn signature(&self) -> SignatureFunc {
self.signature_from_def(array_type_def())
}

fn extension(&self) -> ExtensionId {
PRELUDE_ID
}

fn description(&self) -> String {
"A combination of map and foldl. Applies a function to each element \
of the array with an accumulator that is passed through from start to \
finish. Returns the resulting array and the final state of the \
accumulator."
.into()
}

/// Add an operation implemented as a [MakeOpDef], which can provide the data
/// required to define an [OpDef], to an extension.
//
// This method is re-defined here since we need to pass the array type def while
// computing the signature, to avoid recursive loops initializing the extension.
fn add_to_extension(
&self,
extension: &mut Extension,
) -> Result<(), crate::extension::ExtensionBuildError> {
let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap());
let def = extension.add_op(self.name(), self.description(), sig)?;

self.post_opdef(def);

Ok(())
}
}

/// Definition of the array scan op.
#[derive(Clone, Debug, PartialEq)]
pub struct ArrayScan {
/// The element type of the input array.
src_ty: Type,
/// The target element type of the output array.
tgt_ty: Type,
/// The accumulator types.
acc_tys: Vec<Type>,
/// Size of the array.
size: u64,
}

impl ArrayScan {
fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec<Type>, size: u64) -> Self {
ArrayScan {
src_ty,
tgt_ty,
acc_tys,
size,
}
}
}

impl NamedOp for ArrayScan {
fn name(&self) -> OpName {
ARRAY_SCAN_OP_ID
}
}

impl MakeExtensionOp for ArrayScan {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
let def = ArrayScanDef::from_def(ext_op.def())?;
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<TypeArg> {
vec![
TypeArg::BoundedNat { n: self.size },
self.src_ty.clone().into(),
self.tgt_ty.clone().into(),
TypeArg::Sequence {
elems: self.acc_tys.clone().into_iter().map_into().collect(),
},
]
}
}

impl MakeRegisteredOp for ArrayScan {
fn extension_id(&self) -> ExtensionId {
PRELUDE_ID
}

fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry {
&PRELUDE_REGISTRY
}
}

impl HasDef for ArrayScan {
type Def = ArrayScanDef;
}

impl HasConcrete for ArrayScanDef {
type Concrete = ArrayScan;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
match type_args {
[TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] =>
{
let acc_tys: Result<_, OpLoadError> = acc_tys
.iter()
.map(|acc_ty| match acc_ty {
TypeArg::Type { ty } => Ok(ty.clone()),
_ => Err(SignatureError::InvalidTypeArgs.into()),
})
.collect();
Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n))
}
_ => Err(SignatureError::InvalidTypeArgs.into()),
}
}
}

#[cfg(test)]
mod tests {
use strum::IntoEnumIterator;
Expand All @@ -320,6 +525,7 @@ mod tests {
builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::{BOOL_T, QB_T},
ops::{OpTrait, OpType},
types::Signature,
};

use super::*;
Expand Down Expand Up @@ -459,4 +665,89 @@ mod tests {
)
);
}

#[test]
fn test_repeat() {
let size = 2;
let element_ty = QB_T;
let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size);

let optype: OpType = op.into();

let sig = optype.dataflow_signature().unwrap();

assert_eq!(
sig.io(),
(
&vec![Type::new_function(Signature::new(vec![], vec![QB_T]))].into(),
&vec![array_type(size, element_ty.clone())].into(),
)
);
}

#[test]
fn test_scan_def() {
let op = ArrayScan::new(BOOL_T, QB_T, vec![USIZE_T], 2);
let optype: OpType = op.clone().into();
let new_op: ArrayScan = optype.cast().unwrap();
assert_eq!(new_op, op);
}

#[test]
fn test_scan_map() {
let size = 2;
let src_ty = QB_T;
let tgt_ty = BOOL_T;

let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size);
let optype: OpType = op.into();
let sig = optype.dataflow_signature().unwrap();

assert_eq!(
sig.io(),
(
&vec![
array_type(size, src_ty.clone()),
Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()]))
]
.into(),
&vec![array_type(size, tgt_ty)].into(),
)
);
}

#[test]
fn test_scan_accs() {
let size = 2;
let src_ty = QB_T;
let tgt_ty = BOOL_T;
let acc_ty1 = USIZE_T;
let acc_ty2 = QB_T;

let op = ArrayScan::new(
src_ty.clone(),
tgt_ty.clone(),
vec![acc_ty1.clone(), acc_ty2.clone()],
size,
);
let optype: OpType = op.into();
let sig = optype.dataflow_signature().unwrap();

assert_eq!(
sig.io(),
(
&vec![
array_type(size, src_ty.clone()),
Type::new_function(Signature::new(
vec![src_ty, acc_ty1.clone(), acc_ty2.clone()],
vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()]
)),
acc_ty1.clone(),
acc_ty2.clone()
]
.into(),
&vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(),
)
);
}
}
Loading

0 comments on commit 649589c

Please sign in to comment.