13 #include <hipipe/core/stream/stream_t.hpp>
15 #include <range/v3/core.hpp>
16 #include <range/v3/view/all.hpp>
17 #include <range/v3/view/view.hpp>
21 namespace hipipe::stream {
23 namespace rgv = ranges::views;
25 template <
typename Rng>
26 struct rebatch_view : ranges::view_facade<rebatch_view<Rng>> {
29 friend ranges::range_access;
36 rebatch_view<Rng>* rng_ =
nullptr;
37 ranges::iterator_t<Rng> it_ = {};
41 std::shared_ptr<batch_t> batch_;
44 std::shared_ptr<batch_t> subbatch_;
52 while (subbatch_->batch_size() == 0) {
53 if (it_ == ranges::end(rng_->rng_) || ++it_ == ranges::end(rng_->rng_)) {
56 subbatch_ = std::make_shared<batch_t>(*it_);
65 assert(batch_->batch_size() < rng_->n_);
67 std::min(rng_->n_ - batch_->batch_size(), subbatch_->batch_size());
68 batch_->push_back(subbatch_->take(to_take));
69 }
while (batch_->batch_size() < rng_->n_ && find_next());
73 using single_pass = std::true_type;
77 explicit cursor(rebatch_view<Rng>& rng)
79 , it_{ranges::begin(rng_->rng_)}
82 if (it_ == ranges::end(rng_->rng_)) {
85 subbatch_ = std::make_shared<batch_t>(*it_);
90 batch_t&& read()
const
92 return std::move(*batch_);
95 bool equal(ranges::default_sentinel_t)
const
100 bool equal(
const cursor& that)
const
102 assert(rng_ == that.rng_);
103 return it_ == that.it_ && subbatch_->batch_size() == that.subbatch_->batch_size();
108 batch_ = std::make_shared<batch_t>();
109 if (find_next()) fill_batch();
114 cursor begin_cursor() {
return cursor{*
this}; }
117 rebatch_view() =
default;
118 rebatch_view(Rng rng, std::size_t n)
123 throw std::invalid_argument{
"hipipe::stream::rebatch:"
124 " The new batch size " +
std::to_string(n_) +
" is not strictly positive."};
132 friend rgv::view_access;
135 static auto bind(rebatch_fn
rebatch, std::size_t n)
137 return ranges::make_pipeable(std::bind(
rebatch, std::placeholders::_1, n));
141 CPP_template(
class Rng)(requires ranges::input_range<Rng>)
142 rebatch_view<rgv::all_t<Rng>>
operator()(Rng&& rng, std::size_t n)
const
144 return {rgv::all(std::forward<Rng>(rng)), n};
169 inline rgv::view<rebatch_fn>
rebatch{};