-
Notifications
You must be signed in to change notification settings - Fork 550
Expand file tree
/
Copy pathclamp.cpp
More file actions
83 lines (75 loc) · 2.92 KB
/
clamp.cpp
File metadata and controls
83 lines (75 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <arith.hpp>
#include <backend.hpp>
#include <common/ArrayInfo.hpp>
#include <common/err_common.hpp>
#include <common/half.hpp>
#include <handle.hpp>
#include <implicit.hpp>
#include <logic.hpp>
#include <optypes.hpp>
#include <af/arith.h>
#include <af/array.h>
#include <af/data.h>
#include <af/defines.h>
using af::dim4;
using arrayfire::common::half;
using detail::arithOp;
using detail::Array;
using detail::cdouble;
using detail::cfloat;
using detail::intl;
using detail::schar;
using detail::uchar;
using detail::uint;
using detail::uintl;
using detail::ushort;
template<typename T>
static inline af_array clampOp(const af_array in, const af_array lo,
const af_array hi, const dim4& odims) {
const Array<T> L = castArray<T>(lo);
const Array<T> H = castArray<T>(hi);
const Array<T> I = castArray<T>(in);
return getHandle(
arithOp<T, af_min_t>(arithOp<T, af_max_t>(I, L, odims), H, odims));
}
af_err af_clamp(af_array* out, const af_array in, const af_array lo,
const af_array hi, const bool batch) {
try {
const ArrayInfo& linfo = getInfo(lo);
const ArrayInfo& hinfo = getInfo(hi);
const ArrayInfo& iinfo = getInfo(in);
DIM_ASSERT(2, linfo.dims() == hinfo.dims());
TYPE_ASSERT(linfo.getType() == hinfo.getType());
dim4 odims = getOutDims(iinfo.dims(), linfo.dims(), batch);
const af_dtype otype = implicit(iinfo.getType(), linfo.getType());
af_array res;
switch (otype) {
case f32: res = clampOp<float>(in, lo, hi, odims); break;
case f64: res = clampOp<double>(in, lo, hi, odims); break;
case c32: res = clampOp<cfloat>(in, lo, hi, odims); break;
case c64: res = clampOp<cdouble>(in, lo, hi, odims); break;
case s32: res = clampOp<int>(in, lo, hi, odims); break;
case u32: res = clampOp<uint>(in, lo, hi, odims); break;
case s8: res = clampOp<schar>(in, lo, hi, odims); break;
case u8: res = clampOp<uchar>(in, lo, hi, odims); break;
case b8: res = clampOp<char>(in, lo, hi, odims); break;
case s64: res = clampOp<intl>(in, lo, hi, odims); break;
case u64: res = clampOp<uintl>(in, lo, hi, odims); break;
case s16: res = clampOp<short>(in, lo, hi, odims); break;
case u16: res = clampOp<ushort>(in, lo, hi, odims); break;
case f16: res = clampOp<half>(in, lo, hi, odims); break;
default: TYPE_ERROR(0, otype);
}
std::swap(*out, res);
}
CATCHALL;
return AF_SUCCESS;
}