-
Notifications
You must be signed in to change notification settings - Fork 550
Expand file tree
/
Copy pathdeconvolution.cpp
More file actions
336 lines (294 loc) · 12.2 KB
/
deconvolution.cpp
File metadata and controls
336 lines (294 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
/*******************************************************
* Copyright (c) 2018, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <Array.hpp>
#include <arith.hpp>
#include <backend.hpp>
#include <common/cast.hpp>
#include <common/dispatch.hpp>
#include <common/err_common.hpp>
#include <complex.hpp>
#include <copy.hpp>
#include <fft_common.hpp>
#include <fftconvolve.hpp>
#include <handle.hpp>
#include <logic.hpp>
#include <math.hpp>
#include <reduce.hpp>
#include <select.hpp>
#include <shift.hpp>
#include <unary.hpp>
#include <af/image.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <type_traits>
#include <vector>
using af::dim4;
using arrayfire::common::cast;
using detail::arithOp;
using detail::Array;
using detail::cdouble;
using detail::cfloat;
using detail::createSubArray;
using detail::createValueArray;
using detail::logicOp;
using detail::padArrayBorders;
using detail::scalar;
using detail::schar;
using detail::select_scalar;
using detail::shift;
using detail::uchar;
using detail::uint;
using detail::ushort;
using std::array;
using std::vector;
const int BASE_DIM = 2;
#if defined(AF_CPU)
// CPU backend uses FFTW or MKL
// FFTW works with any data size, but is optimized for
// size decomposition with prime factors up to
// 13.
const dim_t GREATEST_PRIME_FACTOR = 13;
#else
// cuFFT/clFFT works with any data size, but is optimized
// for size decomposition with prime factors up to
// 7.
const dim_t GREATEST_PRIME_FACTOR = 7;
#endif
template<typename T, typename CT>
Array<T> complexNorm(const Array<CT>& input) {
auto mag = detail::abs<T, CT>(input);
return arithOp<T, af_mul_t>(mag, mag, input.dims());
}
std::vector<af_seq> calcPadInfo(dim4& inLPad, dim4& psfLPad, dim4& inUPad,
dim4& psfUPad, dim4& odims, dim_t nElems,
const dim4& idims, const dim4& fdims) {
vector<af_seq> index(4);
for (int d = 0; d < 4; ++d) {
if (d < BASE_DIM) {
dim_t pad = idims[d] + fdims[d];
while (greatestPrimeFactor(pad) > GREATEST_PRIME_FACTOR) { pad++; }
dim_t diffLen = pad - idims[d];
inLPad[d] = diffLen / 2;
inUPad[d] = diffLen / 2 + diffLen % 2;
psfLPad[d] = 0;
psfUPad[d] = pad - fdims[d];
odims[d] = pad;
index[d].begin = inLPad[d];
index[d].end = index[d].begin + idims[d] - 1;
index[d].step = 1;
nElems *= odims[d];
} else {
inLPad[d] = 0;
psfLPad[d] = 0;
inUPad[d] = 0;
psfUPad[d] = 0;
odims[d] = std::max(idims[d], fdims[d]);
index[d] = af_span;
}
}
return index;
}
template<typename T, typename CT>
void richardsonLucy(Array<T>& currentEstimate, const Array<T>& in,
const Array<CT>& P, const Array<CT>& Pc,
const unsigned iters, const float normFactor,
const dim4 odims) {
for (unsigned i = 0; i < iters; ++i) {
auto fft1 = fft_r2c<CT, T>(currentEstimate, BASE_DIM);
auto cmul1 = arithOp<CT, af_mul_t>(fft1, P, P.dims());
auto ifft1 = fft_c2r<CT, T>(cmul1, normFactor, odims, BASE_DIM);
auto div1 = arithOp<T, af_div_t>(in, ifft1, in.dims());
auto fft2 = fft_r2c<CT, T>(div1, BASE_DIM);
auto cmul2 = arithOp<CT, af_mul_t>(fft2, Pc, Pc.dims());
auto ifft2 = fft_c2r<CT, T>(cmul2, normFactor, odims, BASE_DIM);
currentEstimate =
arithOp<T, af_mul_t>(currentEstimate, ifft2, ifft2.dims());
}
}
template<typename T, typename CT>
void landweber(Array<T>& currentEstimate, const Array<T>& in,
const Array<CT>& P, const Array<CT>& Pc, const unsigned iters,
const float relaxFactor, const float normFactor,
const dim4 odims) {
const dim4& dims = P.dims();
auto I = fft_r2c<CT, T>(in, BASE_DIM);
auto Pn = complexNorm<T, CT>(P);
auto ONE = createValueArray(dims, scalar<T>(1.0));
auto alpha = createValueArray(dims, scalar<T>(relaxFactor));
auto alphaC = cast<CT>(alpha);
auto prod = arithOp<T, af_mul_t>(alpha, Pn, dims);
auto lhsFac = arithOp<T, af_sub_t>(ONE, prod, dims);
auto lhs = cast<CT>(lhsFac);
auto rhsFac = arithOp<CT, af_mul_t>(Pc, I, dims);
auto rhs = arithOp<CT, af_mul_t>(rhsFac, alphaC, dims);
auto iterTemp = I;
for (unsigned i = 0; i < iters; ++i) {
auto mul = arithOp<CT, af_mul_t>(iterTemp, lhs, dims);
iterTemp = arithOp<CT, af_add_t>(mul, rhs, dims);
}
currentEstimate = fft_c2r<CT, T>(iterTemp, normFactor, odims, BASE_DIM);
}
template<typename InputType, typename RealType = float>
af_array iterDeconv(const af_array in, const af_array ker, const uint iters,
const float rfactor, const af_iterative_deconv_algo algo) {
using T = RealType;
using CT = typename std::conditional<std::is_same<T, double>::value,
cdouble, cfloat>::type;
auto input = castArray<T>(in);
auto psf = castArray<T>(ker);
const dim4& idims = input.dims();
const dim4& fdims = psf.dims();
dim_t nElems = 1;
dim4 inUPad, psfUPad, inLPad, psfLPad, odims(1);
auto index = calcPadInfo(inLPad, psfLPad, inUPad, psfUPad, odims, nElems,
idims, fdims);
auto paddedIn =
padArrayBorders<T>(input, inLPad, inUPad, AF_PAD_CLAMP_TO_EDGE);
auto paddedPsf = padArrayBorders<T>(psf, psfLPad, psfUPad, AF_PAD_ZERO);
const std::array<int, 4> shiftDims = {-int(fdims[0] / 2),
-int(fdims[1] / 2), 0, 0};
auto shiftedPsf = shift(paddedPsf, shiftDims.data());
auto P = fft_r2c<CT, T>(shiftedPsf, BASE_DIM);
auto Pc = conj(P);
Array<T> currentEstimate = paddedIn;
const double normFactor = 1 / static_cast<double>(nElems);
switch (algo) {
case AF_ITERATIVE_DECONV_RICHARDSONLUCY:
richardsonLucy(currentEstimate, paddedIn, P, Pc, iters, normFactor,
odims);
break;
case AF_ITERATIVE_DECONV_LANDWEBER:
default:
landweber(currentEstimate, paddedIn, P, Pc, iters, rfactor,
normFactor, odims);
}
return getHandle(createSubArray<T>(currentEstimate, index));
}
af_err af_iterative_deconv(af_array* out, const af_array in, const af_array ker,
const unsigned iterations, const float relax_factor,
const af_iterative_deconv_algo algo) {
try {
const ArrayInfo& inputInfo = getInfo(in);
const dim4& inputDims = inputInfo.dims();
const ArrayInfo& kernelInfo = getInfo(ker);
const dim4& kernelDims = kernelInfo.dims();
DIM_ASSERT(2, (inputDims.ndims() == 2));
DIM_ASSERT(3, (kernelDims.ndims() == 2));
ARG_ASSERT(4, (iterations > 0));
ARG_ASSERT(5, std::isfinite(relax_factor));
ARG_ASSERT(5, (relax_factor > 0));
ARG_ASSERT(6, (algo == AF_ITERATIVE_DECONV_DEFAULT ||
algo == AF_ITERATIVE_DECONV_LANDWEBER ||
algo == AF_ITERATIVE_DECONV_RICHARDSONLUCY));
af_array res = 0;
unsigned iters = iterations;
float rfac = relax_factor;
af_dtype inputType = inputInfo.getType();
switch (inputType) {
case f32:
res = iterDeconv<float>(in, ker, iters, rfac, algo);
break;
case s16:
res = iterDeconv<short>(in, ker, iters, rfac, algo);
break;
case u16:
res = iterDeconv<ushort>(in, ker, iters, rfac, algo);
break;
case s8: res = iterDeconv<schar>(in, ker, iters, rfac, algo); break;
case u8: res = iterDeconv<uchar>(in, ker, iters, rfac, algo); break;
default: TYPE_ERROR(1, inputType);
}
std::swap(res, *out);
}
CATCHALL;
return AF_SUCCESS;
}
template<typename CT>
Array<CT> denominator(const Array<CT>& I, const Array<CT>& P, const float gamma,
const af_inverse_deconv_algo algo) {
using T = typename af::dtype_traits<CT>::base_type;
auto RCNST = createValueArray(I.dims(), scalar<T>(gamma));
if (algo == AF_INVERSE_DECONV_TIKHONOV) {
auto normP = complexNorm<T, CT>(P);
auto denom = arithOp<T, af_add_t>(normP, RCNST, normP.dims());
return cast<CT, T>(denom);
} else {
// TODO(pradeep) Wiener Filter code path is disabled.
// This code path doesn't is not exposed using current API
auto normI = complexNorm<T, CT>(I);
auto sRes = arithOp<T, af_sub_t>(normI, RCNST, normI.dims());
auto dRes = arithOp<T, af_div_t>(RCNST, sRes, RCNST.dims());
auto normP = complexNorm<T, CT>(P);
auto denom = arithOp<T, af_add_t>(normP, dRes, normP.dims());
return cast<CT, T>(denom);
}
}
template<typename InputType, typename RealType = float>
af_array invDeconv(const af_array in, const af_array ker, const float gamma,
const af_inverse_deconv_algo algo) {
using T = RealType;
using CT = typename std::conditional<std::is_same<T, double>::value,
cdouble, cfloat>::type;
auto input = castArray<T>(in);
auto psf = castArray<T>(ker);
const dim4& idims = input.dims();
const dim4& fdims = psf.dims();
dim_t nElems = 1;
dim4 inUPad, psfUPad, inLPad, psfLPad, odims(1);
auto index = calcPadInfo(inLPad, psfLPad, inUPad, psfUPad, odims, nElems,
idims, fdims);
auto paddedIn =
padArrayBorders<T>(input, inLPad, inUPad, AF_PAD_CLAMP_TO_EDGE);
auto paddedPsf = padArrayBorders<T>(psf, psfLPad, psfUPad, AF_PAD_ZERO);
const array<int, 4> shiftDims = {-int(fdims[0] / 2), -int(fdims[1] / 2), 0,
0};
auto shiftedPsf = shift(paddedPsf, shiftDims.data());
auto I = fft_r2c<CT, T>(paddedIn, BASE_DIM);
auto P = fft_r2c<CT, T>(shiftedPsf, BASE_DIM);
auto Pc = conj(P);
auto numer = arithOp<CT, af_mul_t>(I, Pc, I.dims());
auto denom = denominator(I, P, gamma, algo);
auto absVal = detail::abs<T, CT>(denom);
auto THRESH = createValueArray(I.dims(), scalar<T>(gamma));
auto cond = logicOp<T, af_ge_t>(absVal, THRESH, absVal.dims());
auto val = arithOp<CT, af_div_t>(numer, denom, numer.dims());
select_scalar<CT, false>(val, cond, val, scalar<CT>(0.0));
auto ival =
fft_c2r<CT, T>(val, 1 / static_cast<double>(nElems), odims, BASE_DIM);
return getHandle(createSubArray<T>(ival, index));
}
af_err af_inverse_deconv(af_array* out, const af_array in, const af_array psf,
const float gamma, const af_inverse_deconv_algo algo) {
try {
const ArrayInfo& inputInfo = getInfo(in);
const dim4& inputDims = inputInfo.dims();
const ArrayInfo& psfInfo = getInfo(psf);
const dim4& psfDims = psfInfo.dims();
DIM_ASSERT(2, (inputDims.ndims() == 2));
DIM_ASSERT(3, (psfDims.ndims() == 2));
ARG_ASSERT(4, std::isfinite(gamma));
ARG_ASSERT(4, (gamma > 0));
ARG_ASSERT(5, (algo == AF_INVERSE_DECONV_DEFAULT ||
algo == AF_INVERSE_DECONV_TIKHONOV));
af_array res = 0;
af_dtype inputType = inputInfo.getType();
switch (inputType) {
case f32: res = invDeconv<float>(in, psf, gamma, algo); break;
case s16: res = invDeconv<short>(in, psf, gamma, algo); break;
case u16: res = invDeconv<ushort>(in, psf, gamma, algo); break;
case s8: res = invDeconv<schar>(in, psf, gamma, algo); break;
case u8: res = invDeconv<uchar>(in, psf, gamma, algo); break;
default: TYPE_ERROR(1, inputType);
}
std::swap(res, *out);
}
CATCHALL;
return AF_SUCCESS;
}