11 #ifndef HIPIPE_TENSORFLOW_RUN_GRAPH_HPP
12 #define HIPIPE_TENSORFLOW_RUN_GRAPH_HPP
14 #include <hipipe/core/utility/tuple.hpp>
15 #include <hipipe/tensorflow/utility/to_tf_type.hpp>
17 #include <tensorflow/core/framework/tensor.h>
18 #include <tensorflow/core/public/session.h>
21 #include <experimental/filesystem>
25 namespace hipipe::tensorflow {
41 template <
typename... OutTs,
typename... InTs>
42 std::tuple<std::tuple<std::vector<OutTs>...>, std::vector<std::vector<long>>>
44 const std::vector<std::string>& input_names,
45 const std::tuple<std::vector<InTs>...>& input_data,
46 const std::vector<std::vector<long>>& input_shapes,
47 const std::vector<std::string>& output_names)
49 assert(ranges::size(input_names) == ranges::size(input_shapes));
50 assert(ranges::size(input_names) ==
sizeof...(InTs));
51 assert(ranges::size(output_names) ==
sizeof...(OutTs));
52 std::tuple<InTs...> in_types;
53 std::tuple<OutTs...> out_types;
56 std::vector<std::pair<std::string, ::tensorflow::Tensor>> feed;
59 ::tensorflow::TensorShape
shape;
60 for (
long val : input_shapes[i])
shape.AddDim(val);
62 auto dtype = to_tf_type(std::tuple_element_t<i, decltype(in_types)>{});
64 ::tensorflow::Tensor tensor{dtype,
shape};
67 tensor.flat<std::tuple_element_t<i, decltype(in_types)>>().data());
69 feed.emplace_back(input_names[i], std::move(tensor));
73 std::vector<::tensorflow::Tensor> outputs;
74 ::tensorflow::Status status = session.Run(feed, output_names, {}, &outputs);
76 auto msg = std::string{
"Failed to run tensorflow graph: "} + status.ToString();
77 throw std::runtime_error(msg);
81 std::tuple<std::vector<OutTs>...> raw_outputs;
82 std::vector<std::vector<long>> output_shapes;
84 ::tensorflow::Tensor& output = outputs[i];
86 raw_output.resize(output.NumElements());
88 std::copy_n(output.flat<std::tuple_element_t<i, decltype(out_types)>>().data(),
92 std::vector<long> output_shape;
93 for (
long d = 0; d < output.dims(); ++d) output_shape.push_back(output.dim_size(d));
94 output_shapes.emplace_back(std::move(output_shape));
99 return {std::move(raw_outputs), std::move(output_shapes)};