C++と色々

主にC++やプログラムに関する記事を投稿します。

マルチメソッド3 力任せ改良編

マルチメソッド2 力任せ編の続きです。Modern C++ Designに載っているマルチメソッドを自分なりの解釈を加えながら、LokiライブラリではなくBoostライブラリを用いて実装してみようという試みです。 前回書いたように今回はstatic_dispatcherの改良として、自動で継承関係によってソートする機能と、対称性の機能(func(A, B)とfunc(B, A)と2つ定義しなくてもfunc(A, B)だけでも中で引数の順を変えてB, Aの順にも対応してくれる機能)を付けたいと思います。

復習

まず、比較用に前回作成したstatic_dispatcherのソースコードを貼ります。

template
<
    class Func,
    class BaseLhs,
    class TypesLhs,
    class BaseRhs = BaseLhs,
    class TypesRhs = TypesLhs,
    class Result = void
>
class static_dispatcher
{
    template <class SomeLhs>
    struct rhs_dispatcher
    {
        template <class Head, class Tail>
        struct dispatcher
        {
            static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
            {
                if (Head* p2 = dynamic_cast<Head*>(&rhs))
                {
                    return functor(lhs, *p2);
                }
                return Tail::dispatch(lhs, rhs, functor);
            }
        };

        struct dispatcher_error
        {
            static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
            {
                throw std::runtime_error("未定義の型");
            }
        };

        typedef typename mpl::reverse_fold<TypesRhs, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;
    };

    struct lhs_dispatcher
    {
        template <class Head, class Tail>
        struct dispatcher
        {
            static Result dispatch(BaseLhs& lhs, BaseRhs& rhs, Func functor)
            {
                if (Head* p1 = dynamic_cast<Head*>(&lhs))
                {
                    return rhs_dispatcher<Head>::type::dispatch(*p1, rhs, functor);
                }
                return Tail::dispatch(lhs, rhs, functor);
            }
        };

        struct dispatcher_error
        {
            static Result dispatch(BaseLhs& lhs, BaseRhs& rhs, Func functor)
            {
                throw std::runtime_error("未定義の型");
            }
        };

        typedef typename mpl::reverse_fold<TypesLhs, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;
    };

public:
    static Result go(BaseLhs& lhs, BaseRhs& rhs)
    {
        return lhs_dispatcher::type::dispatch(lhs, rhs, Func());
    }
};

型リストを継承関係によりソート

正しくディスパッチするには、より下位のクラスを先に検証しなければなりません。boostやSTLには、あるクラスが、あるクラスの基底クラスか検証できるメタ関数があります。それを述語とした型リストのソートを行えば、正しい順番にすることができます。
以下のように書きます。

typedef typename mpl::sort<TypesLhs, std::is_base_of<mpl::_2, mpl::_1>>::type lhs_type;
typedef typename mpl::sort<TypesRhs, std::is_base_of<mpl::_2, mpl::_1>>::type rhs_type;

そして今まで書いていたtypedef typename mpl::reverse_fold<TypesRhs, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;typedef typename mpl::reverse_fold<TypesLhs, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;の、TypesRhsをrhs_typeに、TypesLhsをlhs_typeに直します。
これで、使用側が、カタリストの方の順序を気にしなくてよくなりました。

対称性

例えば、四角と三角の衝突判定と、三角と四角の衝突判定を区別しない場合、それぞれ別の関数に分けることは無駄です。1つの関数を定義すれば、引数の順序を気にすることなく呼び出せるようにしたいと思います。また、同じ引数でも順序で処理を区別したい場合もあるので、この機能はユーザーが使用するか選べるようにします。
 そして、この機能は左辺の型リストと右辺の型リストが同一である必要があります。

実装のアイデア

引数を交換しなければいけない時というのは、どういう時でしょうか。ここで型リストを数値(インデックス)で置き換えたもので考えてみます。
 list<0, 1, 2>とlist<0, 1, 2>があります。ここで起こりうる引数の順列は

  • 0, 0
  • 0, 1
  • 0, 2
  • 1, 0
  • 1, 1
  • 1, 2
  • 2, 0
  • 2, 1
  • 2, 2

の9通りです。そして、ここで行いたいことは1, 0の組み合わせは引数を入れ替えて0, 1を 受け取る関数で処理できるようにすることです。つまり6つの関数だけで上記の9通り受け入れられるようにしたいのです。6通りの関数の引数の通りを列挙してみます。

  • 0, 0
  • 0, 1
  • 0, 2
  • 1, 1
  • 1, 2
  • 2, 2

です。左辺の数値は右辺の数値以下であることがわかります。よって引数を交換しなければならない時というのは、 右辺の数値が左辺の数値より小さい時 、と言えます。更に言い換えるなら、 右辺の型のインデックスが左辺の型のインデックスより小さい時 引数を交換しなければならない、ということです。インデックスとは型リストにおけるその型が何番目にあるか、ということです。

実装

まず、引数の交換を行う型を用意します。今回はstatic_dispatcher内にネストした型として用意しました。

template <bool C, class SomeLhs, class SomeRhs>
struct call_traits
{
    static Result dispatch(SomeLhs& lhs, SomeRhs& rhs, Func functor)
    {
        return functor(lhs, rhs);
    }
};

template <class SomeLhs, class SomeRhs>
struct call_traits<true, SomeLhs, SomeRhs>
{
    static Result dispatch(SomeLhs& lhs, SomeRhs& rhs, Func functor)
    {
        return functor(rhs, lhs);
    }
};

今まで

template <class Head, class Tail>
struct dispatcher
{
    static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
    {
        if(Head* p2 = dynamic_cast<Head*>(&rhs))
        {
            return functor(lhs, *p2);
        }
        return Tail::dispatch(lhs, rhs, functor);
    }
};

としていた部分を書き直します。

template <class Head, class Tail>
struct dispatcher
{
    static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
    {
        if (Head* p2 = dynamic_cast<Head*>(&rhs))
        {
            using lhs_index = typename mpl::index_of<lhs_type, SomeLhs>::type;
            using rhs_index = typename mpl::index_of<rhs_type, Head>::type;
            return call_traits<Symmetric && rhs_index::value < lhs_index::value, SomeLhs, Head>::dispatch(lhs, *p2, functor);
        }
        return Tail::dispatch(lhs, rhs, functor);
    }
};

これで対称性が実装できました! 大分static_dispatcherの中が複雑になりましたが、使用側は定義しなければならない関数が最低限でよく、型リストの継承順に気を使う必要もなくなり、だいぶ使い勝手が良くなったと思います。しかし、根本的にこのディスパッチの計算量がO(n * m)*1なので、ここのオーダーを改善したディスパッチャーを次回以降考えて行きたいと思います。
メリット

  • クラスの数が少ない場合高速
  • 新しいクラスや関数が増えてもstatic_dispatcher自体と既存の継承階層のクラスの修正が要らない(非侵入性)
  • コンパイル時にオーバーロードの曖昧さを検出できる

デメリット

  • クラス数が増えると遅い
  • 引数の型リストの為に、全クラスの階層が見えていなければならない(依存性が高い)
  • コンパイル時間、実行時間、プログラムサイズが指数的に増加する。

おまけ

ここまでに書いたコード、VC++12.0で動く完全なコードで貼っておきます。(2013/12/06追記:gcc4.7、clang3.1で動作確認しました。これより新しいバージョンでは動くと思います。)

#include <iostream>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <boost/mpl/index_of.hpp>
#include <boost/mpl/reverse_fold.hpp>
#include <boost/mpl/list.hpp>
#include <boost/mpl/sort.hpp>
namespace mpl = boost::mpl;

struct shape
{
    virtual ~shape() {}
};

struct rectangle : shape {};
struct ellipse : shape {};
struct polygon : shape {};
struct round_rectangle : rectangle {};

void is_hit_rectangle_and_rectangle(rectangle&, rectangle&)
{
    std::cout << "rectangle and rectangle" << std::endl;
}

void is_hit_rectangle_and_ellipse(rectangle&, ellipse&)
{
    std::cout << "rectangle and ellipse" << std::endl;
}

void is_hit_rectangle_and_polygon(rectangle&, polygon&)
{
    std::cout << "rectangle and polygon" << std::endl;
}

void is_hit_rectangle_and_round_rectangle(rectangle&, round_rectangle&)
{
    std::cout << "rectangle and round_rect" << std::endl;
}

void is_hit_ellipse_and_rectangle(ellipse&, rectangle&)
{
    std::cout << "ellipse and rectangle" << std::endl;
}

void is_hit_ellipse_and_ellipse(ellipse&, ellipse&)
{
    std::cout << "ellipse and ellipse" << std::endl;
}

void is_hit_ellipse_and_polygon(ellipse&, polygon&)
{
    std::cout << "ellipse and polygon" << std::endl;
}

void is_hit_ellipse_and_round_rectangle(ellipse&, round_rectangle&)
{
    std::cout << "ellipse and round_rect" << std::endl;
}

void is_hit_polygon_and_rectangle(polygon&, rectangle&)
{
    std::cout << "polygon and rectangle" << std::endl;
}

void is_hit_polygon_and_ellipse(polygon&, ellipse&)
{
    std::cout << "polygon and ellipse" << std::endl;
}

void is_hit_polygon_and_polygon(polygon&, polygon&)
{
    std::cout << "polygon and polygon" << std::endl;
}

void is_hit_polygon_and_round_rectangle(polygon&, round_rectangle&)
{
    std::cout << "polygon and round_rect" << std::endl;
}

void is_hit_round_rectangle_and_rectangle(round_rectangle& lhs, rectangle& rhs)
{
    std::cout << "round_rect and rectangle" << std::endl;
}

void is_hit_round_rectangle_and_ellipse(round_rectangle& lhs, ellipse& rhs)
{
    std::cout << "round_rect and ellipse" << std::endl;
}

void is_hit_round_rectangle_and_polygon(round_rectangle& lhs, polygon& rhs)
{
    std::cout << "round_rect and polygon" << std::endl;
}

void is_hit_round_rectangle_and_round_rectangle(round_rectangle& lhs, round_rectangle& rhs)
{
    std::cout << "round_rect and round_rect" << std::endl;
}

void double_dispatch(shape& lhs, shape& rhs)
{
    if (round_rectangle* p1 = dynamic_cast<round_rectangle*>(&lhs))
    {
        if (round_rectangle* p2 = dynamic_cast<round_rectangle*>(&rhs))
            is_hit_round_rectangle_and_round_rectangle(*p1, *p2);
        else if (rectangle* p2 = dynamic_cast<rectangle*>(&rhs))
            is_hit_rectangle_and_round_rectangle(*p2, *p1);
        else if (ellipse* p2 = dynamic_cast<ellipse*>(&rhs))
            is_hit_ellipse_and_round_rectangle(*p2, *p1);
        else if (polygon* p2 = dynamic_cast<polygon*>(&rhs))
            is_hit_polygon_and_round_rectangle(*p2, *p1);
        else
            std::runtime_error("未定義の型");
    }
    else if (rectangle* p1 = dynamic_cast<rectangle*>(&lhs))
    {
        if (round_rectangle* p2 = dynamic_cast<round_rectangle*>(&rhs))
            is_hit_rectangle_and_round_rectangle(*p1, *p2);
        else if(rectangle* p2 = dynamic_cast<rectangle*>(&rhs))
            is_hit_rectangle_and_rectangle(*p1, *p2);
        else if (ellipse* p2 = dynamic_cast<ellipse*>(&rhs))
            is_hit_rectangle_and_ellipse(*p1, *p2);
        else if (polygon* p2 = dynamic_cast<polygon*>(&rhs))
            is_hit_rectangle_and_polygon(*p1, *p2);
        else
            std::runtime_error("未定義の型");
    }
    else if (ellipse* p1 = dynamic_cast<ellipse*>(&lhs))
    {
        if (round_rectangle* p2 = dynamic_cast<round_rectangle*>(&rhs))
            is_hit_ellipse_and_round_rectangle(*p1, *p2);
        else if (rectangle* p2 = dynamic_cast<rectangle*>(&rhs))
            is_hit_rectangle_and_ellipse(*p2, *p1);
        else if (ellipse* p2 = dynamic_cast<ellipse*>(&rhs))
            is_hit_ellipse_and_ellipse(*p1, *p2);
        else if (polygon* p2 = dynamic_cast<polygon*>(&rhs))
            is_hit_ellipse_and_polygon(*p1, *p2);
        else
            std::runtime_error("未定義の型");
    }
    else if (polygon* p1 = dynamic_cast<polygon*>(&lhs))
    {
        if (round_rectangle* p2 = dynamic_cast<round_rectangle*>(&rhs))
            is_hit_polygon_and_round_rectangle(*p1, *p2);
        else if (rectangle* p2 = dynamic_cast<rectangle*>(&rhs))
            is_hit_rectangle_and_polygon(*p2, *p1);
        else if (ellipse* p2 = dynamic_cast<ellipse*>(&rhs))
            is_hit_ellipse_and_polygon(*p2, *p1);
        else if (polygon* p2 = dynamic_cast<polygon*>(&rhs))
            is_hit_polygon_and_polygon(*p1, *p2);
        else
            std::runtime_error("未定義の型");
    }
    else
    {
        std::runtime_error("未定義の型");
    }
}

template
<
    class Func,
    class BaseLhs,
    class TypesLhs,
    bool Symmetric = true,
    class BaseRhs = BaseLhs,
    class TypesRhs = TypesLhs,
    class Result = void
>
class static_dispatcher
{
    typedef typename mpl::sort<TypesLhs, std::is_base_of<mpl::_2, mpl::_1>>::type lhs_type;
    typedef typename mpl::sort<TypesRhs, std::is_base_of<mpl::_2, mpl::_1>>::type rhs_type;

    template <bool C, class SomeLhs, class SomeRhs>
    struct call_traits
    {
        static Result dispatch(SomeLhs& lhs, SomeRhs& rhs, Func functor)
        {
            return functor(lhs, rhs);
        }
    };

    template <class SomeLhs, class SomeRhs>
    struct call_traits<true, SomeLhs, SomeRhs>
    {
        static Result dispatch(SomeLhs& lhs, SomeRhs& rhs, Func functor)
        {
            return functor(rhs, lhs);
        }
    };

    template <class SomeLhs>
    struct rhs_dispatcher
    {
        template <class Head, class Tail>
        struct dispatcher
        {
            static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
            {
                if (Head* p2 = dynamic_cast<Head*>(&rhs))
                {
                    using lhs_index = typename mpl::index_of<lhs_type, SomeLhs>::type;
                    using rhs_index = typename mpl::index_of<rhs_type, Head>::type;
                    return call_traits<Symmetric && rhs_index::value < lhs_index::value, SomeLhs, Head>::dispatch(lhs, *p2, functor);
                }
                return Tail::dispatch(lhs, rhs, functor);
            }
        };

        struct dispatcher_error
        {
            static Result dispatch(SomeLhs& lhs, BaseRhs& rhs, Func functor)
            {
                throw std::runtime_error("未定義の型");
            }
        };

        typedef typename mpl::reverse_fold<rhs_type, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;
    };

    struct lhs_dispatcher
    {
        template <class Head, class Tail>
        struct dispatcher
        {
            static Result dispatch(BaseLhs& lhs, BaseRhs& rhs, Func functor)
            {
                if (Head* p1 = dynamic_cast<Head*>(&lhs))
                {
                    return rhs_dispatcher<Head>::type::dispatch(*p1, rhs, functor);
                }
                return Tail::dispatch(lhs, rhs, functor);
            }
        };

        struct dispatcher_error
        {
            static Result dispatch(BaseLhs& lhs, BaseRhs& rhs, Func functor)
            {
                throw std::runtime_error("未定義の型");
            }
        };

        typedef typename mpl::reverse_fold<lhs_type, dispatcher_error, dispatcher<mpl::_2, mpl::_1>>::type type;
    };

public:
    static Result go(BaseLhs& lhs, BaseRhs& rhs)
    {
        return lhs_dispatcher::type::dispatch(lhs, rhs, Func());
    }
};

class are_hit_shapes
{
public:

    void operator()(rectangle& lhs, rectangle& rhs)
    {
        is_hit_rectangle_and_rectangle(lhs, rhs);
    }

    void operator()(rectangle& lhs, ellipse& rhs)
    {
        is_hit_rectangle_and_ellipse(lhs, rhs);
    }

    void operator()(rectangle& lhs, polygon& rhs)
    {
        is_hit_rectangle_and_polygon(lhs, rhs);
    }

    void operator()(rectangle& lhs, round_rectangle& rhs)
    {
        is_hit_rectangle_and_round_rectangle(lhs, rhs);
    }

    void operator()(ellipse& lhs, rectangle& rhs)
    {
        is_hit_ellipse_and_rectangle(lhs, rhs);
    }

    void operator()(ellipse& lhs, polygon& rhs)
    {
        is_hit_ellipse_and_polygon(lhs, rhs);
    }

    void operator()(ellipse& lhs, ellipse& rhs)
    {
        is_hit_ellipse_and_ellipse(lhs, rhs);
    }

    void operator()(ellipse& lhs, round_rectangle& rhs)
    {
        is_hit_ellipse_and_round_rectangle(lhs, rhs);
    }

    void operator()(polygon& lhs, rectangle& rhs)
    {
        is_hit_polygon_and_rectangle(lhs, rhs);
    }

    void operator()(polygon& lhs, ellipse& rhs)
    {
        is_hit_polygon_and_ellipse(lhs, rhs);
    }

    void operator()(polygon& lhs, polygon& rhs)
    {
        is_hit_polygon_and_polygon(lhs, rhs);
    }

    void operator()(polygon& lhs, round_rectangle& rhs)
    {
        is_hit_polygon_and_round_rectangle(lhs, rhs);
    }

    void operator()(round_rectangle& lhs, round_rectangle& rhs)
    {
        is_hit_round_rectangle_and_round_rectangle(lhs, rhs);
    }

    void operator()(round_rectangle& lhs, rectangle& rhs)
    {
        is_hit_round_rectangle_and_rectangle(lhs, rhs);
    }

    void operator()(round_rectangle& lhs, ellipse& rhs)
    {
        is_hit_round_rectangle_and_ellipse(lhs, rhs);
    }

    void operator()(round_rectangle& lhs, polygon& rhs)
    {
        is_hit_round_rectangle_and_polygon(lhs, rhs);
    }
};

void double_dispatch_test(std::vector<shape*>& v)
{
    for (auto& i : v)
    {
        for (auto& j : v)
        {
            double_dispatch(*i, *j);
        }
    }
}

template <bool S = true>
void static_dispatch_test(std::vector<shape*>& v)
{
    typedef mpl::list<rectangle, ellipse, polygon, round_rectangle> TypeList;
    typedef static_dispatcher<are_hit_shapes, shape, TypeList, S> dispatcher;

    for (auto& i : v)
    {
        for (auto& j : v)
        {
            dispatcher::go(*i, *j);
        }
    }
}

int main()
{
    std::vector<shape*> v = {new round_rectangle(), new rectangle(), new ellipse(), new polygon()};
    
    double_dispatch_test(v);
    std::cout << "-------------------------" << std::endl;
    static_dispatch_test<>(v);
    std::cout << "-------------------------" << std::endl;
    static_dispatch_test<false>(v);
    
    for (auto& it : v)
    {
        delete it;
    }
}

*1:nとmはそれぞれTypesLhsとTypesRhsの長さ