Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions src/api/c/transform_coordinates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,40 @@ Array<T> multiplyIndexed(const Array<T> &lhs, const Array<T> &rhs, std::vector<a
}

template<typename T>
static af_array transform_coordinates(const af_array& tf, const float d0, const float d1)
static af_array transform_coordinates(const af_array& tf_, const float d0_, const float d1_)
{
af::dim4 h_dims(4, 3);
T h_in[4*3] = { (T)0, (T)0, (T)d1, (T)d1,
(T)0, (T)d0, (T)d0, (T)0,
T h_in[4*3] = { (T)0, (T)0, (T)d1_, (T)d1_,
(T)0, (T)d0_, (T)d0_, (T)0,
(T)1, (T)1, (T)1, (T)1 };

const Array<T> TF = getArray<T>(tf);
Array<T> IN = createHostDataArray<T>(h_dims, h_in);
const Array<T> tf = getArray<T>(tf_);
Array<T> in = createHostDataArray<T>(h_dims, h_in);

std::vector<af_seq> idx(2);
idx[0] = af_make_seq(0, 2, 1);

// w = 1.0 / matmul(TF, IN(span, 2));
// iw = matmul(TF, IN(span, 2));
// w = 1.0 / matmul(tf, in(span, 2));
// iw = matmul(tf, in(span, 2));
idx[1] = af_make_seq(2, 2, 1);
Array<T> IW = multiplyIndexed(IN, TF, idx);
Array<T> iw = multiplyIndexed(in, tf, idx);

// xt = w * matmul(TF, IN(span, 0));
// xt = matmul(TF, IN(span, 0)) / iw;
// xt = w * matmul(tf, in(span, 0));
// xt = matmul(tf, in(span, 0)) / iw;
idx[1] = af_make_seq(0, 0, 1);
Array<T> XT = arithOp<T, af_div_t>(multiplyIndexed(IN, TF, idx), IW, IW.dims());
Array<T> xt = arithOp<T, af_div_t>(multiplyIndexed(in, tf, idx), iw, iw.dims());

// yt = w * matmul(TF, IN(span, 1));
// yt = matmul(TF, IN(span, 1)) / iw;
// yt = w * matmul(tf, in(span, 1));
// yt = matmul(tf, in(span, 1)) / iw;
idx[1] = af_make_seq(1, 1, 1);
Array<T> YT = arithOp<T, af_div_t>(multiplyIndexed(IN, TF, idx), IW, IW.dims());
Array<T> yw = arithOp<T, af_div_t>(multiplyIndexed(in, tf, idx), iw, iw.dims());

// return join(1, xt, yt)
Array<T> R = join(1, XT, YT);
return getHandle(R);
Array<T> r = join(1, xt, yw);
return getHandle(r);
}

af_err af_transform_coordinates(af_array *out, const af_array tf, const float d0, const float d1)
af_err af_transform_coordinates(af_array *out, const af_array tf, const float d0_, const float d1_)
{
try {
const ArrayInfo& tfInfo = getInfo(tf);
Expand All @@ -72,8 +72,8 @@ af_err af_transform_coordinates(af_array *out, const af_array tf, const float d0
af_array output;
af_dtype type = tfInfo.getType();
switch(type) {
case f32: output = transform_coordinates<float >(tf, d0, d1); break;
case f64: output = transform_coordinates<double>(tf, d0, d1); break;
case f32: output = transform_coordinates<float >(tf, d0_, d1_); break;
case f64: output = transform_coordinates<double>(tf, d0_, d1_); break;
default : TYPE_ERROR(1, type);
}
std::swap(*out, output);
Expand Down