Skip to content

Commit

Permalink
Update API to reflect ArrayFire 3.7.0 release
Browse files Browse the repository at this point in the history
  • Loading branch information
9prady9 committed Mar 14, 2020
1 parent 0557ab4 commit 32bc2f8
Show file tree
Hide file tree
Showing 19 changed files with 1,468 additions and 14 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,25 @@ indexing = []
graphics = []
image = []
lapack = []
machine_learning = []
macros = []
random = []
signal = []
sparse = []
statistics = []
vision = []
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
"macros", "random", "signal", "sparse", "statistics", "vision"]
"machine_learning", "macros", "random", "signal", "sparse", "statistics", "vision"]

[dependencies]
libc = "0.2"
num = "0.2"
lazy_static = "1.0"
half = "1.5.0"

[dev-dependencies]
float-cmp = "0.6.0"
half = "1.5.0"

[build-dependencies]
serde_json = "1.0"
Expand Down Expand Up @@ -85,3 +88,7 @@ path = "examples/conway.rs"
[[example]]
name = "fft"
path = "examples/fft.rs"

[[example]]
name = "using_half"
path = "examples/using_half.rs"
2 changes: 1 addition & 1 deletion examples/conway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {
}

fn normalise(a: &Array<f32>) -> Array<f32> {
(a / (max_all(&abs(a)).0 as f32))
a / (max_all(&abs(a)).0 as f32)
}

fn conways_game_of_life() {
Expand Down
15 changes: 15 additions & 0 deletions examples/using_half.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use arrayfire::*;
use half::f16;

fn main() {
set_device(0);
info();

let values: Vec<_> = (1u8..101).map(f32::from).collect();

let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();

let hvals = Array::new(&half_values, Dim4::new(&[10, 10, 1, 1]));

print(&hvals);
}
260 changes: 258 additions & 2 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::array::Array;
use crate::defines::{AfError, BinaryOp};
use crate::error::HANDLE_ERROR;
use crate::util::{AfArray, MutAfArray, MutDouble, MutUint};
use crate::util::{HasAfEnum, RealNumber, Scanable};
use crate::util::{HasAfEnum, RealNumber, ReduceByKeyInput, Scanable};

#[allow(dead_code)]
extern "C" {
Expand Down Expand Up @@ -59,6 +59,71 @@ extern "C" {
op: c_uint,
inclusive: c_int,
) -> c_int;
fn af_all_true_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_any_true_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_count_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_max_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_min_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_product_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_product_by_key_nan(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
nan_val: c_double,
) -> c_int;
fn af_sum_by_key(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
) -> c_int;
fn af_sum_by_key_nan(
keys_out: MutAfArray,
vals_out: MutAfArray,
keys: AfArray,
vals: AfArray,
dim: c_int,
nan_val: c_double,
) -> c_int;
}

macro_rules! dim_reduce_func_def {
Expand Down Expand Up @@ -527,7 +592,8 @@ all_reduce_func_def!(
let dims = Dim4::new(&[5, 5, 1, 1]);
let a = randu::<f32>(dims);
print(&a);
println!(\"Result : {:?}\", product_all(&a));
let res = product_all(&a);
println!(\"Result : {:?}\", res);
```
",
product_all,
Expand Down Expand Up @@ -1137,3 +1203,193 @@ where
}
temp.into()
}

macro_rules! dim_reduce_by_key_func_def {
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
#[doc=$brief_str]
/// # Parameters
///
/// - `keys` - key Array
/// - `vals` - value Array
/// - `dim` - Dimension along which the input Array is reduced
///
/// # Return Values
///
/// Tuple of Arrays, with output keys and values after reduction
///
#[doc=$ex_str]
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
dim: i32
) -> (Array<KeyType>, Array<$out_type>)
where
KeyType: ReduceByKeyInput,
ValueType: HasAfEnum,
$out_type: HasAfEnum,
{
let mut out_keys: i64 = 0;
let mut out_vals: i64 = 0;
unsafe {
let err_val = $ffi_name(
&mut out_keys as MutAfArray,
&mut out_vals as MutAfArray,
keys.get() as AfArray,
vals.get() as AfArray,
dim as c_int,
);
HANDLE_ERROR(AfError::from(err_val));
}
(out_keys.into(), out_vals.into())
}
};
}

dim_reduce_by_key_func_def!(
"
Key based AND of elements along a given dimension
All positive non-zero values are considered true, while negative and zero
values are considered as false.
",
"
# Examples
```rust
use arrayfire::{Dim4, print, randu, all_true_by_key};
let dims = Dim4::new(&[5, 3, 1, 1]);
let vals = randu::<f32>(dims);
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
print(&vals);
print(&keys);
let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
print(&out_keys);
print(&out_vals);
```
",
all_true_by_key,
af_all_true_by_key,
ValueType::AggregateOutType
);

dim_reduce_by_key_func_def!(
"
Key based OR of elements along a given dimension
All positive non-zero values are considered true, while negative and zero
values are considered as false.
",
"
# Examples
```rust
use arrayfire::{Dim4, print, randu, any_true_by_key};
let dims = Dim4::new(&[5, 3, 1, 1]);
let vals = randu::<f32>(dims);
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
print(&vals);
print(&keys);
let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
print(&out_keys);
print(&out_vals);
```
",
any_true_by_key,
af_any_true_by_key,
ValueType::AggregateOutType
);

dim_reduce_by_key_func_def!(
"Find total count of elements with similar keys along a given dimension",
"",
count_by_key,
af_count_by_key,
ValueType::AggregateOutType
);

dim_reduce_by_key_func_def!(
"Find maximum among values of similar keys along a given dimension",
"",
max_by_key,
af_max_by_key,
ValueType::AggregateOutType
);

dim_reduce_by_key_func_def!(
"Find minimum among values of similar keys along a given dimension",
"",
min_by_key,
af_min_by_key,
ValueType::AggregateOutType
);

dim_reduce_by_key_func_def!(
"Find product of all values with similar keys along a given dimension",
"",
product_by_key,
af_product_by_key,
ValueType::ProductOutType
);

dim_reduce_by_key_func_def!(
"Find sum of all values with similar keys along a given dimension",
"",
sum_by_key,
af_sum_by_key,
ValueType::AggregateOutType
);

macro_rules! dim_reduce_by_key_nan_func_def {
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
#[doc=$brief_str]
///
/// This version of sum by key can replaced all NaN values in the input
/// with a user provided value before performing the reduction operation.
/// # Parameters
///
/// - `keys` - key Array
/// - `vals` - value Array
/// - `dim` - Dimension along which the input Array is reduced
///
/// # Return Values
///
/// Tuple of Arrays, with output keys and values after reduction
///
#[doc=$ex_str]
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
dim: i32, replace_value: f64
) -> (Array<KeyType>, Array<$out_type>)
where
KeyType: ReduceByKeyInput,
ValueType: HasAfEnum,
$out_type: HasAfEnum,
{
let mut out_keys: i64 = 0;
let mut out_vals: i64 = 0;
unsafe {
let err_val = $ffi_name(
&mut out_keys as MutAfArray,
&mut out_vals as MutAfArray,
keys.get() as AfArray,
vals.get() as AfArray,
dim as c_int,
replace_value as c_double,
);
HANDLE_ERROR(AfError::from(err_val));
}
(out_keys.into(), out_vals.into())
}
};
}

dim_reduce_by_key_nan_func_def!(
"Compute sum of all values with similar keys along a given dimension",
"",
sum_by_key_nan,
af_sum_by_key_nan,
ValueType::AggregateOutType
);

dim_reduce_by_key_nan_func_def!(
"Compute product of all values with similar keys along a given dimension",
"",
product_by_key_nan,
af_product_by_key_nan,
ValueType::ProductOutType
);
7 changes: 7 additions & 0 deletions src/arith/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ extern "C" {
fn af_log10(out: MutAfArray, arr: AfArray) -> c_int;
fn af_log2(out: MutAfArray, arr: AfArray) -> c_int;
fn af_sqrt(out: MutAfArray, arr: AfArray) -> c_int;
fn af_rsqrt(out: MutAfArray, arr: AfArray) -> c_int;
fn af_cbrt(out: MutAfArray, arr: AfArray) -> c_int;
fn af_factorial(out: MutAfArray, arr: AfArray) -> c_int;
fn af_tgamma(out: MutAfArray, arr: AfArray) -> c_int;
Expand Down Expand Up @@ -199,6 +200,12 @@ unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
unary_func!("Compute sin", sin, af_sin, UnaryOutType);
unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
unary_func!(
"Compute the reciprocal square root",
rsqrt,
af_rsqrt,
UnaryOutType
);
unary_func!("Compute tan", tan, af_tan, UnaryOutType);
unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);

Expand Down
Loading

0 comments on commit 32bc2f8

Please sign in to comment.