erip erip - 2 months ago 12
C++ Question

How can I use templates to deduce the parameter types of a std::function?

I'm working on a problem to rotate an NxN matrix of type

T
by 90 degrees. In the spirit of DRY, I'd like the function signature of my rotate function to look like this:

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, std::function<void(T&, T&, T&, T&)> swap_direction);


This will allow me to swap clockwise and counterclockwise with the same function simply by passing a different
std::function<void(T&, T&, T&, T&)>
.

I currently have the following code:

#include <iostream>
#include <array>
#include <functional>

template <typename T, std::size_t N>
using Matrix = std::array<std::array<T, N>, N>;

template <typename T>
void four_way_swap_clockwise(T& top_left, T& top_right, T& bottom_left, T& bottom_right) {
T temp = top_left;
top_left = top_right;
top_right = bottom_right;
bottom_right = bottom_left;
bottom_left = temp;
}

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, std::function<void(T&, T&, T&, T&)> swap_direction) {
for(std::size_t i = 0; i < N/2; ++i) {
for(std::size_t j = 0; j < (N+1)/2; ++j) {
swap_direction(
m[i][j],
m[N-j-1][i],
m[j][N-i-1],
m[N-i-1][N-j-1]
);
}
}
}

int main() {
constexpr std::size_t N = 5;
Matrix<int, N> m {{
{{1,2,3,4,5}},
{{6,7,8,9,10}},
{{11,12,13,14,15}},
{{16,17,18,19,20}},
{{21,22,23,24,25}}
}};

std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);

rotate_90(m, swap_clockwise);
}


This currently doesn't compile, failing with the following error:

error: no matching function for call to 'std::function<void(int&, int&, int&, int&)>::function(<unresolved overloaded function type>)'
std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);


However, even if it did compile, it also defeats the purpose of template programming to specify the type of the types of the parameters of the swap function (i.e., in the definition of
std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);
).

How can I pass the
std::function
with the template type deduced?

Answer
template<class T> struct tag_t{using type=T;};
template<class T> using block_deduction=typename tag_t<T>::type;

This construct blocks C++ from trying to deduce template arguments from a function argument.

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, block_deduction<std::function<void(T&, T&, T&, T&)>> swap_direction) {

now the type of the 2nd argument is always deduced from the type of the first!

The next problem is that std::function doesn't disambiguate overloaded function names. An overloaded function name isn't a C++ value, it is a set of names which (in the right context) a value is found. std::function construction is not one of those contexts.

We can extend std::function with an additional constructor like this:

template<class Sig, class F=std::function<Sig>>
struct my_func:F {
  using F::F;
  using F::operator=;
  my_func( Sig* ptr ):F(ptr) {}
  my_func& operator=( Sig* ptr ) {
    F::operator=(ptr);
    return *this;
  }
  my_func()=default;
  my_func(my_func&&)=default;
  my_func(my_func const&)=default;
  my_func& operator=(my_func&&)=default;
  my_func& operator=(my_func const&)=default;
}; 

live example.

An alternative approach is to wrap your overload set into a lambda:

auto overloads = [](auto&&...args){ return four_way_swap_clockwise(decltype(args)(args)...); };

then pass the overloads to your function. This lambda represents all over the overloads of four_way_swap_clockwise at once.

We can also manually disambiguate by doing four_way_swap_clockwise<int>.

Both of these still requires the block_deduction technique above.

An alternative to consider would be:

template <typename T, std::size_t N, class F>
void rotate_90(Matrix<T, N>& m, F&& swap_direction)

where we leave swap_direction completely free and let any failures occur within the algorithm. This also gives a slight performance boost. You still have to disambiguate the four_way_swap_clockwise with either <int> or the lambda-wrapper technique.

Another approach would be to make for_way_swap_clockwise a lambda itself:

auto four_way_swap_clockwise = [](auto& top_left, auto& top_right, auto& bottom_left, auto& bottom_right) {
  auto temp = top_left;
  top_left = top_right;
  top_right = bottom_right;
  bottom_right = bottom_left;
  bottom_left = temp;
};

and now it is an object with a template operator() overload. This with block_deduction solves your problem.

In short, there are lots of ways around your problem.