HiPipe  0.6.0
C++17 data pipeline with Python bindings.
groups.hpp
1 /****************************************************************************
2  * hipipe library
3  * Copyright (c) 2017, Cognexa Solutions s.r.o.
4  * Copyright (c) 2018, Iterait a.s.
5  * Author(s) Filip Matzner
6  *
7  * This file is distributed under the MIT License.
8  * See the accompanying file LICENSE.txt for the complete license agreement.
9  ****************************************************************************/
11 
12 #ifndef HIPIPE_CORE_GROUPS_HPP
13 #define HIPIPE_CORE_GROUPS_HPP
14 
15 #include <hipipe/core/utility/random.hpp>
16 
17 #include <range/v3/action/insert.hpp>
18 #include <range/v3/action/shuffle.hpp>
19 #include <range/v3/algorithm/all_of.hpp>
20 #include <range/v3/algorithm/copy.hpp>
21 #include <range/v3/numeric/accumulate.hpp>
22 #include <range/v3/view/concat.hpp>
23 #include <range/v3/view/drop.hpp>
24 #include <range/v3/view/filter.hpp>
25 #include <range/v3/view/iota.hpp>
26 #include <range/v3/view/repeat_n.hpp>
27 #include <range/v3/view/take.hpp>
28 
29 #include <vector>
30 
31 namespace hipipe {
32 
48 template<typename Prng = std::mt19937&>
49 std::vector<std::size_t> generate_groups(std::size_t size, std::vector<double> ratio,
50  Prng&& gen = utility::random_generator)
51 {
52  namespace view = ranges::view;
53 
54  // check all ratios non-negative
55  assert(ranges::all_of(ratio, [](double d) { return d >= 0; }));
56 
57  // check positive ratio sum
58  double ratio_sum = ranges::accumulate(ratio, 0.);
59  assert(ratio_sum > 0);
60 
61  // remove trailing zeros
62  ratio.erase(std::find_if(ratio.rbegin(), ratio.rend(), [](double r) { return r > 0; }).base(),
63  ratio.end());
64 
65  // scale to [0, 1]
66  for (double& r : ratio) r /= ratio_sum;
67 
68  std::vector<std::size_t> groups;
69  groups.reserve(size);
70 
71  for (std::size_t i = 0; i < ratio.size(); ++i) {
72  std::size_t count = std::lround(ratio[i] * size);
73  // take all the remaining elements if this is the last non-zero group
74  if (i + 1 == ratio.size()) count = size - groups.size();
75  ranges::action::insert(groups, groups.end(), view::repeat_n(i, count));
76  }
77 
78  ranges::action::shuffle(groups, gen);
79  return groups;
80 }
81 
107 template<typename Prng = std::mt19937&>
108 std::vector<std::vector<std::size_t>>
109 generate_groups(std::size_t n, std::size_t size,
110  const std::vector<double>& volatile_ratio,
111  const std::vector<double>& fixed_ratio,
112  Prng&& gen = utility::random_generator)
113 {
114  namespace view = ranges::view;
115 
116  std::size_t volatile_size = volatile_ratio.size();
117  auto full_ratio = view::concat(volatile_ratio, fixed_ratio);
118 
119  std::vector<std::vector<std::size_t>> all_groups;
120  std::vector<std::size_t> initial_groups = generate_groups(size, full_ratio, gen);
121 
122  for (std::size_t i = 0; i < n; ++i) {
123  auto groups = initial_groups;
124  // select those groups, which are volatile (those will be replaced)
125  auto groups_volatile =
126  view::filter(groups, [volatile_size](std::size_t l) { return l < volatile_size; });
127  // count the number of volatile groups
128  std::size_t volatile_count = ranges::distance(groups_volatile);
129  // generate the replacement
130  auto groups_volatile_new = generate_groups(volatile_count, volatile_ratio, gen);
131  // replace
132  ranges::copy(groups_volatile_new, groups_volatile.begin());
133  // store
134  all_groups.emplace_back(std::move(groups));
135  }
136 
137  return all_groups;
138 }
139 
140 } // end namespace hipipe
141 #endif
std::vector< std::vector< std::size_t > > generate_groups(std::size_t n, std::size_t size, const std::vector< double > &volatile_ratio, const std::vector< double > &fixed_ratio, Prng &&gen=utility::random_generator)
Randomly group data into multiple clusters with a given ratio.
Definition: groups.hpp:109
auto copy(from_t< FromColumns... > from_cols, to_t< ToColumns... > to_cols)
Copy the data from FromColumns to the respective ToColumns.
Definition: copy.hpp:38
auto filter(from_t< FromColumns... > f, by_t< ByColumns... > b, Fun fun, dim_t< Dim > d=dim_t< 1 >{})
Filter stream data.
Definition: filter.hpp:141
static thread_local std::mt19937 random_generator
Thread local pseudo-random number generator seeded by std::random_device.
Definition: random.hpp:21