/*  -*-c++-*-
 *  Copyright (C) 2008 Cedric Pinson <cedric.pinson@plopbyte.net>
 *
 * This library is open source and may be redistributed and/or modified under
 * the terms of the OpenSceneGraph Public License (OSGPL) version 0.0 or
 * (at your option) any later version.  The full license is in LICENSE file
 * included with this distribution, and on the openscenegraph.org website.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * OpenSceneGraph Public License for more details.
 */

#ifndef OSGANIMATION_EASE_MOTION
#define OSGANIMATION_EASE_MOTION 1

#include <osg/Referenced>
#include <osg/ref_ptr>
#include <osg/Notify>
#include <osg/Math>
#include <vector>

namespace osgAnimation
{
    struct OutBounceFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            if ((t) < (1/2.75))
            {
                result = 7.5625 * t * t;
            }
            else if (t < (2/2.75))
            {
                t = t - (1.5/2.75);
                result = 7.5625* t * t + .75;
            }
            else if (t < (2.5/2.75))
            {
                t = t - (2.25/2.75);
                result = 7.5625 * t * t + .9375;
            }
            else
            {
                t = t - (2.625/2.75);
                result = 7.5625* t * t + .984375;
            }
        }
    };

    struct InBounceFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            OutBounceFunction::getValueAt(1-t, result);
            result = 1 - result;
        }
    };

    struct InOutBounceFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            if (t < 0.5)
            {
                InBounceFunction::getValueAt(t * 2, result);
                result *= 0.5;
            }
            else
            {
                OutBounceFunction::getValueAt(t * 2 - 1 , result);
                result = result * 0.5 + 0.5;
            }
        }
    };

    /// Linear function
    struct LinearFunction
    {
        inline static void getValueAt(float t, float& result) { result = t;}
    };

    /// Quad function
    struct OutQuadFunction
    {
        inline static void getValueAt(float t, float& result) { result = - (t * (t -2.0));}
    };

    struct InQuadFunction
    {
        inline static void getValueAt(float t, float& result) { result = t*t;}
    };

    struct InOutQuadFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t *= 2.0;
            if (t < 1.0)
                result = 0.5 * t * t;
            else
            {
                t -= 1.0;
                result = - 0.5 * (t * ( t - 2) - 1);
            }
        }
    };

    /// Cubic function
    struct OutCubicFunction
    {
        inline static void getValueAt(float t, float& result) { t = t-1.0; result = t*t*t + 1;}
    };

    struct InCubicFunction
    {
        inline static void getValueAt(float t, float& result) { result = t*t*t;}
    };

    struct InOutCubicFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t *= 2.0f;
            if (t < 1.0f)
                result = 0.5f * t * t * t;
            else {
                t -= 2.0f;
                result = 0.5 * (t * t * t + 2.0f);
            }
        }
    };

    /// Quart function
    struct InQuartFunction
    {
        inline static void getValueAt(float t, float& result) { result = t*t*t*t*t;}
    };

    struct OutQuartFunction
    {
        inline static void getValueAt(float t, float& result) { t = t - 1; result = - (t*t*t*t -1); }
    };

    struct InOutQuartFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t = t * 2.0;
            if ( t < 1)
                result = 0.5*t*t*t*t;
            else
            {
                t -= 2.0;
                result = -0.5 * (t*t*t*t -2);
            }
        }
    };

    /// Elastic function
    struct OutElasticFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = pow(2.0f, -10.0f * t) * sinf((t - 0.3f / 4.0f) * (2.0f * osg::PI) / 0.3f) + 1.0f;
        }
    };

    struct InElasticFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            OutElasticFunction::getValueAt(1.0f - t, result);
            result = 1.0f - result;
        }
    };

    struct InOutElasticFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t *= 2.0f;
            if (t < 1.0f)
            {
                t -= 1.0f;
                result =  -0.5 * (1.0f * pow(2.0f, 10.0f * t) * sinf((t - 0.45f / 4.0f) * (2.0f * osg::PI) / 0.45f));
            }
            else
            {
                t -= 1.0f;
                result = pow(2.0f, -10.0f * t) * sinf((t - 0.45f / 4.0f) * (2.0f * osg::PI) / 0.45f) * 0.5f + 1.0f;
            }
        }
    };

    // Sine function
    struct OutSineFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = sinf(t * (osg::PI / 2.0f));
        }
    };

    struct InSineFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = -cosf(t * (osg::PI / 2.0f)) + 1.0f;
        }
    };

    struct InOutSineFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = -0.5f * (cosf((osg::PI * t)) - 1.0f);
        }
    };

    // Back function
    struct OutBackFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t -= 1.0f;
            result = t * t * ((1.70158 + 1.0f) * t + 1.70158) + 1.0f;
        }
    };

    struct InBackFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = t * t * ((1.70158 + 1.0f) * t - 1.70158);
        }
    };

    struct InOutBackFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            float s = 1.70158 * 1.525f;
            t *= 2.0f;
            if (t < 1.0f)
            {
                result = 0.5f * (t * t * ((s + 1.0f) * t - s));
            }
            else
            {
                float p = t -= 2.0f;
                result = 0.5f * ((p) * t * ((s + 1.0f) * t + s) + 2.0f);
            }
        }
    };

    // Circ function
    struct OutCircFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t -= 1.0f;
            result = sqrt(1.0f - t * t);
        }
    };

    struct InCircFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            result = -(sqrt(1.0f - (t * t)) - 1.0f);
        }
    };

    struct InOutCircFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            t *= 2.0f;
            if (t < 1.0f)
            {
                result = -0.5f * (sqrt(1.0f - t * t) - 1.0f);
            }
            else
            {
                t -= 2.0f;
                result = 0.5f * (sqrt(1 - t * t) + 1.0f);
            }
        }
    };

    // Expo function
    struct OutExpoFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            if(t == 1.0f)
            {
                result = 0.0f;
            }
            else
            {
                result = -powf(2.0f, -10.0f * t) + 1.0f;
            }
        }
    };

    struct InExpoFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            if(t == 0.0f)
            {
                result = 0.0f;
            }
            else
            {
                result = powf(2.0f, 10.0f * (t - 1.0f));
            }
        }
    };

    struct InOutExpoFunction
    {
        inline static void getValueAt(float t, float& result)
        {
            if(t == 0.0f || t == 1.0f)
            {
                result = 0.0f;
            }
            else
            {
                t *= 2.0f;
                if(t < 1.0f)
                {
                    result = 0.5f * powf(2.0f, 10.0f * (t - 1.0f));
                }
                else
                {
                    result = 0.5f * (-powf(2.0f, -10.0f * (t - 1.0f)) + 2.0f);
                }
            }
        }
    };

    class Motion : public osg::Referenced
    {
    public:
        typedef float value_type;
        enum TimeBehaviour
        {
            CLAMP,
            LOOP
        };
        Motion(float startValue = 0, float duration = 1, float changeValue = 1, TimeBehaviour tb = CLAMP) : _time(0), _startValue(startValue), _changeValue(changeValue), _duration(duration), _behaviour(tb) {}
        virtual ~Motion() {}
        void reset() { setTime(0);}
        float getTime() const { return _time; }
        float evaluateTime(float time) const
        {
            switch (_behaviour)
            {
            case CLAMP:
                if (time > _duration)
                    time = _duration;
                else if (time < 0.0)
                    time = 0.0;
                break;
            case LOOP:
                if (time <= 0)
                    time = 0;
                else
                    time = fmodf(time, _duration);
                break;
            }
            return time;
        }

        void update(float dt)
        {
            _time = evaluateTime(_time + dt);
        }

        void setTime(float time) { _time = evaluateTime(time);}
        void getValue(value_type& result) const { getValueAt(_time, result); }
        value_type getValue() const
        {
            value_type result;
            getValueAt(_time, result);
            return result;
        }

        void getValueAt(float time, value_type& result) const
        {
            getValueInNormalizedRange(evaluateTime(time)/_duration, result);
            result = result * _changeValue + _startValue;
        }
        value_type getValueAt(float time) const
        {
            value_type result;
            getValueAt(evaluateTime(time), result);
            return result;
        }

        virtual void getValueInNormalizedRange(float t, value_type& result) const = 0;

        float getDuration() const { return _duration;}
    protected:
        float _time;
        float _startValue;
        float _changeValue;
        float _duration;
        TimeBehaviour _behaviour;
    };



    template <typename T>
    struct MathMotionTemplate : public Motion
    {
        MathMotionTemplate(float startValue = 0, float duration = 1, float changeValue = 1, TimeBehaviour tb = CLAMP) : Motion(startValue, duration, changeValue, tb) {}
        virtual void getValueInNormalizedRange(float t, value_type& result) const { T::getValueAt(t, result); }
    };

    template <class T>
    struct SamplerMotionTemplate : public Motion
    {
        T _sampler;
        SamplerMotionTemplate(float startValue = 0, float duration = 1, float changeValue = 1, TimeBehaviour tb = CLAMP) : Motion(startValue, duration, changeValue, tb) {}
        T& getSampler() { return _sampler;}
        const T& getSampler() const { return _sampler;}
        virtual void getValueInNormalizedRange(float t, value_type& result) const
        {
            if (!_sampler.getKeyframeContainer())
            {
                result = 0;
                return;
            }
            float size = _sampler.getEndTime() - _sampler.getStartTime();
            t = t * size + _sampler.getStartTime();
            _sampler.getValueAt(t, result);
        }
    };

    struct CompositeMotion : public Motion
    {
        typedef std::vector<osg::ref_ptr<Motion> > MotionList;
        MotionList _motions;

        MotionList& getMotionList() { return _motions; }
        const MotionList& getMotionList() const { return _motions; }
        CompositeMotion(float startValue = 0, float duration = 1, float changeValue = 1, TimeBehaviour tb = CLAMP) : Motion(startValue, duration, changeValue, tb) {}

        virtual void getValueInNormalizedRange(float t, value_type& result) const
        {
            if (_motions.empty())
            {
                result = 0;
                osg::notify(osg::WARN) << "CompositeMotion::getValueInNormalizedRange no Motion in the CompositeMotion, add motion to have result" << std::endl;
                return;
            }
            for (MotionList::const_iterator it = _motions.begin(); it != _motions.end(); ++it)
            {
                const Motion* motion = static_cast<const Motion*>(it->get());
                float durationInRange = motion->getDuration() / getDuration();
                if (t < durationInRange)
                {
                    float tInRange = t/durationInRange * motion->getDuration();
                    motion->getValueAt( tInRange, result);
                    return;
                } else
                    t = t - durationInRange;
            }
            osg::notify(osg::WARN) << "CompositeMotion::getValueInNormalizedRange did find the value in range, something wrong" << std::endl;
            result = 0;
        }
    };


    // linear
    typedef MathMotionTemplate<LinearFunction > LinearMotion;

    // quad
    typedef MathMotionTemplate<OutQuadFunction > OutQuadMotion;
    typedef MathMotionTemplate<InQuadFunction> InQuadMotion;
    typedef MathMotionTemplate<InOutQuadFunction> InOutQuadMotion;

    // cubic
    typedef MathMotionTemplate<OutCubicFunction > OutCubicMotion;
    typedef MathMotionTemplate<InCubicFunction> InCubicMotion;
    typedef MathMotionTemplate<InOutCubicFunction> InOutCubicMotion;

    // quart
    typedef MathMotionTemplate<OutQuartFunction > OutQuartMotion;
    typedef MathMotionTemplate<InQuartFunction> InQuartMotion;
    typedef MathMotionTemplate<InOutQuartFunction> InOutQuartMotion;

    // bounce
    typedef MathMotionTemplate<OutBounceFunction > OutBounceMotion;
    typedef MathMotionTemplate<InBounceFunction> InBounceMotion;
    typedef MathMotionTemplate<InOutBounceFunction> InOutBounceMotion;

    // elastic
    typedef MathMotionTemplate<OutElasticFunction > OutElasticMotion;
    typedef MathMotionTemplate<InElasticFunction > InElasticMotion;
    typedef MathMotionTemplate<InOutElasticFunction > InOutElasticMotion;

    // sine
    typedef MathMotionTemplate<OutSineFunction > OutSineMotion;
    typedef MathMotionTemplate<InSineFunction > InSineMotion;
    typedef MathMotionTemplate<InOutSineFunction > InOutSineMotion;

    // back
    typedef MathMotionTemplate<OutBackFunction > OutBackMotion;
    typedef MathMotionTemplate<InBackFunction > InBackMotion;
    typedef MathMotionTemplate<InOutBackFunction > InOutBackMotion;

    // circ
    typedef MathMotionTemplate<OutCircFunction > OutCircMotion;
    typedef MathMotionTemplate<InCircFunction > InCircMotion;
    typedef MathMotionTemplate<InOutCircFunction > InOutCircMotion;

    // expo
    typedef MathMotionTemplate<OutExpoFunction > OutExpoMotion;
    typedef MathMotionTemplate<InExpoFunction > InExpoMotion;
    typedef MathMotionTemplate<InOutExpoFunction > InOutExpoMotion;
}

#endif