読者です 読者をやめる 読者になる 読者になる

C++と色々

主にC++やプログラムに関する記事を投稿します。

式テンプレートっぽい何か

C++

boost::numeric::ublas::vectorのExpression Templatesってどうやって実装してるのか気になり、ヘッダを眺めて大事そうなとこだけ真似(劣化コピー)してみました。 式自体を定義しているヘッダです。

//ファンクタを適用する式
#ifndef NEK_MATH_VECTOR_EXPRESSION_HPP
#define NEK_MATH_VECTOR_EXPRESSION_HPP

#include <type_traits>//is_same, declval, ...

namespace nek
{
    namespace math
    {
        //ベクトルに対する式
        //CRTP使用
        template <class Expr>
        struct vector_expression
        {
            typedef Expr expr_type;

            //本来の型にキャスト
            const expr_type& operator()() const
            {
                return *static_cast<const expr_type*>(this);
            }

            expr_type operator()()
            {
                return *static_cast<expr_type*>(this);
            }
        };

        template<class T>
        struct scalar_identity;

        //単項演算子の式
        template <class T, class F>
        class unary_expression
            : public vector_expression<unary_expression<T, F>>
        {
        public:
            typedef F function_type;
            typedef typename F::result_type value_type;
            typedef value_type const_reference;
            typedef typename std::conditional<std::is_same<F, scalar_identity<typename T::value_type>>::value,
                typename T::reference,
                typename value_type>::type reference;
            typedef typename T::size_type size_type;
            typedef T expr_type;
            typedef typename T::const_closure_type expr_closure_type;
            typedef unary_expression<T, F> self_type;
            typedef const self_type const_closure_type;
            typedef self_type closure_type;

            explicit unary_expression(expr_type& expr)
                : expr_(expr)
            {
            }

            inline const_reference operator[](size_type i) const
            {
                return function_type::apply(expr_[i]);
            }

            inline reference operator[](size_type i)
            {
                return function_type::apply(expr_[i]);
            }

        private:
            expr_closure_type expr_;
        };

        //二項演算子の式
        template <class L, class F, class R>
        class binary_expression
            : public vector_expression<binary_expression<L, F, R>>
        {
        public:
            typedef F function_type;
            typedef typename F::result_type value_type;
            typedef value_type const_reference;
            typedef decltype(std::declval<typename L::size_type>() + std::declval<typename R::size_type>()) size_type;
            typedef L first_expr_type;
            typedef R second_expr_type;
            typedef typename L::const_closure_type first_closure_type;
            typedef typename R::const_closure_type second_closure_type;
            typedef binary_expression<L, F, R> self_type;
            typedef const self_type const_closure_type;

            binary_expression(const first_expr_type& left, const second_expr_type& right)
                : left_(left),
                right_(right)
            {
            }

            inline const_reference operator[](size_type i) const
            {
                return function_type::apply(left_[i], right_[i]);
            }

        private:
            first_closure_type left_;
            second_closure_type right_;
        };

        //ただ参照を返すだけの式
        template <class Expr>
        class vector_reference
            : public vector_expression<Expr>
        {
        public:
            typedef vector_reference<Expr> self_type;
            typedef typename Expr::size_type size_type;
            typedef typename Expr::value_type value_type;
            typedef typename Expr::const_reference const_reference;
            typedef typename std::conditional<std::is_const<Expr>::value,
                typename Expr::const_reference,
                typename Expr::reference>::type reference;
            typedef Expr referred_type;
            typedef const self_type const_closure_type;
            typedef self_type closure_type;

            explicit vector_reference(referred_type& expr)
                : expr_(expr)
            {
            }

            inline const_reference operator[](size_type i) const
            {
                return expr_[i];
            }

            inline reference operator[](size_type i)
            {
                return expr_[i];
            }

        private:
            referred_type expr_;
        };
    }
}

#endif

実際に計算を行うファンクタと演算子をオーバーロードしているヘッダです

//演算を行うファンクタを定義
#ifndef NEK_MATH_FUNCTIONAL_HPP
#define NEK_MATH_FUNCTIONAL_HPP

#include "vector_expression.hpp"

namespace nek
{
    namespace math
    {
        //要素に対する単項ファンクタ基底クラス
        template <class T>
        struct scalar_unary_function
        {
            typedef T value_type;
            typedef const T& arg_type;
            typedef value_type result_type;
        };

        //要素そのまま
        template <class T>
        struct scalar_identity
            : public scalar_unary_function<T>
        {
            typedef typename scalar_unary_function<T>::arg_type arg_type;
            typedef typename scalar_unary_function<T>::result_type result_type;

            static result_type apply(arg_type v)
            {
                return v;
            }
        };

        //要素の符号を反転
        template <class T>
        struct scalar_negate
            : public scalar_unary_function<T>
        {
            typedef typename scalar_unary_function<T>::arg_type arg_type;
            typedef typename scalar_unary_function<T>::result_type result_type;

            static result_type apply(arg_type v)
            {
                return -v;
            }
        };

        //要素に対する二項ファンクタ基底クラス
        template <class L, class R>
        struct scalar_binary_function
        {
            typedef const L& first_type;
            typedef const R& second_type;
            typedef decltype(std::declval<first_type>() + std::declval<second_type>()) result_type;
        };

        //要素に対する加算
        template <class L, class R>
        struct scalar_plus
            : public scalar_binary_function<L, R>
        {
            typedef typename scalar_binary_function<L, R>::first_type first_type;
            typedef typename scalar_binary_function<L, R>::second_type second_type;
            typedef typename scalar_binary_function<L, R>::result_type result_type;

            static result_type apply(first_type l, second_type r)
            {
                return l + r;
            }
        };

        //要素に対する減算
        template <class L, class R>
        struct scalar_minus
            : public scalar_binary_function<L, R>
        {
            typedef typename scalar_binary_function<L, R>::first_type first_type;
            typedef typename scalar_binary_function<L, R>::second_type second_type;
            typedef typename scalar_binary_function<L, R>::result_type result_type;

            static result_type apply(first_type l, second_type r)
            {
                return l - r;
            }
        };

        //要素に対する乗算
        template <class L, class R>
        struct scalar_multiply
            : public scalar_binary_function<L, R>
        {
            typedef typename scalar_binary_function<L, R>::first_type first_type;
            typedef typename scalar_binary_function<L, R>::second_type second_type;
            typedef typename scalar_binary_function<L, R>::result_type result_type;

            static result_type apply(first_type l, second_type r)
            {
                return l * r;
            }
        };

        //要素に対する除算
        template <class L, class R>
        struct scalar_devide
            : public scalar_binary_function<L, R>
        {
            typedef typename scalar_binary_function<L, R>::first_type first_type;
            typedef typename scalar_binary_function<L, R>::second_type second_type;
            typedef typename scalar_binary_function<L, R>::result_type result_type;

            static result_type apply(first_type l, second_type r)
            {
                return l / r;
            }
        };

        //要素に対する代入ファンクタ基底クラス
        template <class L, class R>
        struct scalar_binary_assign_funciton
        {
            typedef typename std::remove_reference<
                typename std::remove_const<L>::type
            >::type& first_type;

            typedef const R& second_type;
        };

        //要素に対する代入
        template <class L, class R>
        struct scalar_assign
            : public scalar_binary_assign_funciton<L, R>
        {
            typedef typename scalar_binary_assign_funciton<L, R>::first_type first_type;
            typedef typename scalar_binary_assign_funciton<L, R>::second_type second_type;

            static void apply(first_type l, second_type r)
            {
                l = r;
            }
        };

        //要素に対する加算代入
        template <class L ,class R>
        struct scalar_plus_assign
            : public scalar_binary_assign_funciton<L, R>
        {
            typedef typename scalar_binary_assign_funciton<L, R>::first_type first_type;
            typedef typename scalar_binary_assign_funciton<L, R>::second_type second_type;

            static void apply(first_type l, second_type r)
            {
                l += r;
            }
        };

        //要素に対する減算代入
        template <class L ,class R>
        struct scalar_minus_assign
            : public scalar_binary_assign_funciton<L, R>
        {
            typedef typename scalar_binary_assign_funciton<L, R>::first_type first_type;
            typedef typename scalar_binary_assign_funciton<L, R>::second_type second_type;

            static void apply(first_type l, second_type r)
            {
                l -= r;
            }
        };

        //要素に対する乗算代入
        template <class L ,class R>
        struct scalar_multiply_assign
            : public scalar_binary_assign_funciton<L, R>
        {
            typedef typename scalar_binary_assign_funciton<L, R>::first_type first_type;
            typedef typename scalar_binary_assign_funciton<L, R>::second_type second_type;

            static void apply(first_type l, second_type r)
            {
                l *= r;
            }
        };

        //加算記号
        template<class L, class R>
        binary_expression<L, scalar_plus<typename L::value_type, typename R::value_type>, R>
            operator+(const vector_expression<L>& left, const vector_expression<R>& right)
        {
            typedef binary_expression<L, scalar_plus<typename L::value_type, typename R::value_type>, R> expr_type;
            return expr_type(left(), right());
        }
        
        //減算記号
        template<class L, class R>
        binary_expression<L, scalar_minus<typename L::value_type, typename R::value_type>, R>
            operator-(const vector_expression<L>& left, const vector_expression<R>& right)
        {
            typedef binary_expression<L, scalar_minus<typename L::value_type, typename R::value_type>, R> expr_type;
            return expr_type(left(), right());
        }

        //乗算記号
        template<class L, class R>
        binary_expression<L, scalar_multiply<typename L::value_type, typename R::value_type>, R>
            operator*(const vector_expression<L>& left, const vector_expression<R>& right)
        {
            typedef binary_expression<L, scalar_multiply<typename L::value_type, typename R::value_type>, R> expr_type;
            return expr_type(left(), right());
        }

        //除算記号
        template<class L, class R>
        binary_expression<L, scalar_devide<typename L::value_type, typename R::value_type>, R>
            operator/(const vector_expression<L>& left, const vector_expression<R>& right)
        {
            typedef binary_expression<L, scalar_devide<typename L::value_type, typename R::value_type>, R> expr_type;
            return expr_type(left(), right());
        }
    }
}

#endif

そしてvector自体のヘッダです。vector自身も式から継承することで、vector自体も式の一部として扱えます。Compositeパターンですかね。

//固定長ベクトルクラス
#ifndef NEK_MATH_VECTOR_HPP
#define NEK_MATH_VECTOR_HPP

#include "vector_expression.hpp"
#include "functional.hpp"
#include <algorithm>//equal lexicographical_compare
#include <array>//array
#include <cassert>//assert
#include <initializer_list>//initializer_list
#include <iterator>//iterator random_access_iterator_tag
#include <stdexcept>//invalid_argument
#include <type_traits>//forward

namespace nek
{
    namespace math
    {
        //イテレータ
        template <class T, std::size_t N>
        class vector_const_iterator
            : public std::iterator<std::random_access_iterator_tag, T>
        {
        public:
            typedef vector_const_iterator<T, N> self_type;
            typedef std::size_t size_type;
            typedef std::iterator<std::random_access_iterator_tag, T> base_type;
            typedef typename base_type::iterator_category iterator_category;
            typedef typename base_type::value_type value_type;
            typedef typename base_type::difference_type difference_type;
            typedef const value_type* pointer;
            typedef const value_type& reference;

            vector_const_iterator(pointer ptr)
                : vector_const_iterator(ptr, 0)
            {
            }

            vector_const_iterator(pointer ptr, size_type id)
                : ptr_(ptr),
                id_(id)
            {
            }

            inline reference operator*() const
            {
                check_nullptr();
                check_over(id_);
                return ptr_[id_];
            }

            inline pointer operator->() const
            {
                return &**this;
            }

            inline self_type& operator++()
            {
                check_over(id_);
                ++id_;
                return *this;
            }

            inline self_type operator++(int)
            {
                self_type temp_iter(*this);
                ++*this;
                return temp_iter;
            }

            inline self_type& operator--()
            {
                check_under(id);
                --id_;
                return *this;
            }

            inline self_type operator--(int)
            {
                self_type temp_iter(*this);
                --*this;
                return temp_iter;
            }

            inline self_type& operator+=(difference_type offset)
            {
                check_over(id_ + offset);
                id_ += offset;
                return *this;
            }

            inline self_type operator+(difference_type offset) const
            {
                return self_type(*this += offset);
            }

            inline self_type& operator-=(difference_type offset)
            {
                return *this += -offset;
            }

            inline self_type operator-(difference_type offset) const
            {
                return self_type(*this -= offset);
            }

            inline reference operator[](difference_type offset) const
            {
                check_nullptr();
                check_over(offset);
                return ptr_[offset];
            }

            inline bool operator==(const self_type& right) const
            {
                check_compatible(right);
                return id_ == right.id_;
            }

            inline bool operator!=(const self_type& right) const
            {
                return !(*this == right);
            }

            inline bool operator<(const self_type& right) const
            {
                check_compatible(right);
                return id_ < right.id_;
            }

            inline bool operator>(const self_type& right) const
            {
                return right_ < *this;
            }

            inline bool operator<=(const self_type& right) const
            {
                return !(right_ < *this);
            }

            inline bool operator>=(const self_type& right) const
            {
                return !(*this < right);
            }

        private:
            void check_nullptr() const
            {
                assert(ptr_ != nullptr);
            }

            void check_over(difference_type off) const
            {
                assert(off < N);
            }

            void check_under(difference_type off) const
            {
                assert(0 <= off);
            }

            void check_compatible(const self_type& right) const
            {
                assert(ptr_ == right.ptr_);
            }

            pointer ptr_;
            size_type id_;
        };

        template <class T, std::size_t N>
        class vector_iterator
            : public std::iterator<std::random_access_iterator_tag, T>
        {
        public:
            typedef vector_iterator<T, N> self_type;
            typedef std::size_t size_type;
            typedef std::iterator<std::random_access_iterator_tag, T> base_type;
            typedef typename base_type::iterator_category iterator_category;
            typedef typename base_type::value_type value_type;
            typedef typename base_type::difference_type difference_type;
            typedef typename base_type::pointer pointer;
            typedef typename base_type::reference reference;

            vector_iterator(pointer ptr)
                : vector_iterator(ptr, 0)
            {
            }

            vector_iterator(pointer ptr, size_type id)
                : iter_(ptr, id)
            {
            }

            inline reference operator*() const
            {
                return const_cast<reference>(*iter_);
            }

            inline pointer operator->() const
            {
                return const_cast<pointer>(&*iter_);
            }

            inline self_type& operator++()
            {
                ++iter_;
                return *this;
            }

            inline self_type operator++(int)
            {
                self_type temp_iter(*this);
                ++*this;
                return temp_iter;
            }

            inline self_type& operator--()
            {
                --iter_;
                return *this;
            }

            inline self_type operator--(int)
            {
                self_type temp_iter(*this);
                --*this;
                return temp_iter;
            }

            inline self_type& operator+=(difference_type offset)
            {
                iter_ += offset;
                return *this;
            }

            inline self_type operator+(difference_type offset) const
            {
                return self_type(*this += offset);
            }

            inline self_type& operator-=(difference_type offset)
            {
                return *this += -offset;
            }

            inline self_type operator-(difference_type offset) const
            {
                return self_type(*this -= offset);
            }

            inline reference operator[](difference_type offset) const
            {
                return iter_.[offset];
            }

            inline bool operator==(const self_type& right) const
            {
                return iter_ == right.iter_;
            }

            inline bool operator!=(const self_type& right) const
            {
                return !(*this == right);
            }

            inline bool operator<(const self_type& right) const
            {
                return iter_ < right.iter_;
            }

            inline bool operator>(const self_type& right) const
            {
                return right_ < *this;
            }

            inline bool operator<=(const self_type& right) const
            {
                return !(right_ < *this);
            }

            inline bool operator>=(const self_type& right) const
            {
                return !(*this < right);
            }

        private:
            vector_const_iterator<T, N> iter_;
        };

        //固定長ベクトル
        template <class T, std::size_t N>
        class vector
            : public vector_expression<vector<T, N>>
        {
        public:
            typedef vector<T, N> self_type;
            typedef const vector_reference<const self_type> const_closure_type;
            typedef vector_reference<self_type> closure_type;
            typedef std::array<T, N> array_type;
            typedef typename array_type::value_type value_type;
            typedef typename array_type::size_type size_type;
            typedef typename array_type::difference_type difference_type;
            typedef typename array_type::pointer pointer;
            typedef typename array_type::const_pointer const_pointer;
            typedef typename array_type::reference reference;
            typedef typename array_type::const_reference const_reference;
            typedef vector_const_iterator<T, N> const_iterator;
            typedef vector_iterator<T, N> iterator;
            typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
            typedef std::reverse_iterator<iterator> reverse_iterator;
            typedef value_type expr_type;

            vector()
            {
            }

            vector(const T& val)
            {
                data_.assign(val);
            }

            vector(std::initializer_list<T> list)
            {
                if(N < list.size())
                {
                    throw std::invalid_argument("vector<T, N>::vector : initializer list size must be under N");
                }
                std::copy(list.begin(), list.end(), data_.begin());
            }

            template <class Expr>
            vector(const vector_expression<Expr>& expr)
            {
                assign<scalar_assign>(expr);
            }

            vector(const vector& right)
                : data_(right.data_)
            {
            }

            vector& operator=(vector right)
            {
                swap(right);
                return *this;
            }

            template <class Expr>
            vector& operator=(const vector_expression<Expr>& expr)
            {
                self_type temp(expr);
                swap(temp);
                return *this;
            }

            inline const_reference operator[](size_type i) const
            {
                return data_[i];
            }

            inline reference operator[](size_type i)
            {
                return data_[i];
            }

            inline const_reference at(size_type i) const
            {
                return data_.at(i);
            }

            inline reference at(size_type i)
            {
                return data_.at(i);
            }

            inline size_type size() const throw()
            {
                return data_.size();
            }

            inline bool empty() const throw()
            {
                return data_.empty();
            }

            inline void swap(vector& right) throw()
            {
                using std::swap;
                swap(data_, right.data_);
            }

            inline const_pointer data() const throw()
            {
                return data_.data();
            }

            inline pointer data() throw()
            {
                return data_.data();
            }

            inline iterator begin() throw()
            {
                return iterator(std::addressof(data_[0]), 0);
            }

            inline const_iterator begin() const throw()
            {
                return const_iterator(std::addressof(data_[0]), 0);
            }

            inline iterator end() throw()
            {
                return iterator(std::addressof(data_[0]), N);
            }

            inline const_iterator end() const throw()
            {
                return const_iterator(std::addressof(data_[0]), N);
            }

            inline reverse_iterator rbegin() throw()
            {
                return reverse_iterator(end());
            }

            inline const_reverse_iterator rbegin() const throw()
            {
                return const_reverse_iterator(end());
            }

            inline reverse_iterator rend() throw()
            {
                return reverse_iterator(begin());
            }

            inline const_reverse_iterator rend() const throw()
            {
                return const_reverse_iterator(begin());
            }

            inline const_iterator cbegin() const throw()
            {
                return static_cast<const_pointer>(this)->begin();
            }

            inline const_iterator cend() const throw()
            {
                return static_cast<const_pointer>(this)->end();
            }

            inline const_reverse_iterator crbegin() const throw()
            {
                return static_cast<const_pointer>(this)->rbegin();
            }

            inline const_iterator crend() const throw()
            {
                return static_cast<const_pointer>(this)->rend();
            }

            template <class Expr>
            vector& operator+=(const vector_expression<Expr>& expr)
            {
                self_type temp(*this + expr);
                swap(temp);
                return *this;
            }

            template <class Expr>
            vector& plus_assign(const vector_expression<Expr>& expr)
            {
                assign<scalar_plus_assign>(expr);
                return *this;
            }

            template <class Expr>
            vector& operator-=(const vector_expression<Expr>& expr)
            {
                self_type temp(*this - expr);
                swap(temp);
                return *this;
            }

            template <class Expr>
            vector& minus_assign(const vector_expression<Expr>& expr)
            {
                assign<scalar_minus_assign>(expr);
                return *this;
            }

            template <class Expr>
            vector& operator*=(const vector_expression<Expr>& expr)
            {
                self_type temp(*this * expr);
                swap(temp);
                return *this;
            }

            template <class Expr>
            vector& multiply_assign(const vector_expression<Expr>& expr)
            {
                assign<scalar_multiply_assign>(expr);
                return *this;
            }

            template <class Expr>
            vector& operator/=(const vector_expression<Expr>& expr)
            {
                self_type temp(*this / expr);
                swap(temp);
                return *this;
            }

            template <class Expr>
            vector& devide_assign(const vector_expression<Expr>& expr)
            {
                assign<scalar_devide_assign>(expr);
                return *this;
            }

        private:
            template <template <class, class> class F, class E>
            inline void assign(const vector_expression<E>& expr)
            {
                typedef F<reference, typename E::value_type> functor_type;

                size_type last = data_.size();
                for(size_type i = 0; i < last; ++i)
                {
                    functor_type::apply(data_[i], expr()[i]);
                }
            }

            array_type data_;
        };

        template <class T, std::size_t N>
        bool operator==(const vector<T, N>& left, const vector<T, N>& right)
        {
            return std::equal(left.begin(), left.end(), right.begin());
        }

        template <class T, std::size_t N>
        bool operator!=(const vector<T, N>& left, const vector<T, N>& right)
        {
            return !(left == right);
        }

        template <class T, std::size_t N>
        bool operator<(const vector<T, N>& left, const vector<T, N>& right)
        {
            return std::lexicographical_compare(left.begin(), left.end(), right.begin(), right.end());
        }

        template <class T, std::size_t N>
        bool operator>(const vector<T, N>& left, const vector<T, N>& right)
        {
            return right < left;
        }

        template <class T, std::size_t N>
        bool operator<=(const vector<T, N>& left, const vector<T, N>& right)
        {
            return !(right < left);
        }

        template <class T, std::size_t N>
        bool operator>=(const vector<T, N>& left, const vector<T, N>& right)
        {
            return !(left < right);
        }

        template <class T, std::size_t N>
        void swap(vector<T, N>& left, vector<T, N>& right) throw()
        {
            left.swap(right);
        }

        template <std::size_t I, class T, std::size_t N>
        T& get(vector<T, N>& v) throw()
        {
            static_assert(0 <= I && I < N, "math::get<I, T, N> : out of range");
            return v[I];
        }

        template <std::size_t I, class T, std::size_t N>
        const T& get(const vector<T, N>& v) throw()
        {
            static_assert(0 <= I && I < N, "math::get<I, T, N> : out of range");
            return v[I];
        }

        template <std::size_t I, class T, std::size_t N>
        T&& get(vector<T, N>&& v)
        {
            static_assert(0 <= I && I < N, "math::get<I, T, N> : out of range");
            return std::forward<T&&>(v[I]);
        }
    }
}

#endif

public継承してるわ、病的にinline書いてるわ、自分どうかしてますね