/*
 * monotone.cpp: Reusable framework for constructing second-level
 * spigots which compute a continuous and locally monotonic function
 * of a real, given a means of computing that function for
 * rationals.
 *
 * The basic idea is pretty simple. We have a spigot-construction
 * function implementing some function f on rationals, which is
 * monotonic (one way or the other, it doesn't matter) on an
 * interval surrounding the target real x. We compute a pair of
 * successive continued fraction convergents x1,x2 to x; then we
 * construct spigots for f(x1) and f(x2), and extract continued
 * fraction terms from both until they stop agreeing with each
 * other. Since continued fraction convergents alternate above and
 * below the target real, we can therefore be sure that all the
 * output terms are valid for f(x) proper. Then we fetch two more
 * convergents and start again, throwing away the first few output
 * terms on the grounds that we've already output them.
 *
 * The vital criterion is that f should be monotonic in the region
 * around all the forthcoming convergents of x. (Otherwise f(x1) and
 * f(x2) might agree on a continued fraction term which isn't valid
 * for f(x).) If this isn't the case at the point you want to call
 * this function, extract a few more convergents until it is!
 */

#include <assert.h>
#include <stdlib.h>

#include "spigot.h"
#include "funcs.h"
#include "cr.h"

class MonotoneHelper : public BinaryIntervalSource {
    MonotoneConstructor *fcons;
    Spigot *xorig, *xcloned;
    BracketingGenerator x;
    BracketingGenerator *fx1, *fx2;
    int crState;
    unsigned inbits, outbits;
    bigint n1, n2, d, t1, t2;

  public:
    MonotoneHelper(MonotoneConstructor *afcons, Spigot *ax)
        : xorig(ax), xcloned(ax->clone()), x(xcloned)
    {
        fcons = afcons;
        fx1 = fx2 = NULL;
        crState = -1;
        dprint("hello MonotoneHelper %p", xcloned);
    }

    ~MonotoneHelper()
    {
        delete xorig;
        if (fx1)
            delete fx1;
        if (fx2)
            delete fx2;
        delete fcons;
    }

    virtual MonotoneHelper *clone()
    {
        return new MonotoneHelper(fcons->clone(), xorig->clone());
    }

    virtual void gen_bin_interval(bigint *ret_lo, bigint *ret_hi,
                                  unsigned *ret_bits)
    {
        crBegin;

        /*
         * No point starting with really small d. Get at least a few
         * bits to be going on with.
         */
        inbits = 16;

        /*
         * But on the output side, we take whatever we get.
         */
        outbits = 0;

        while (1) {
            x.set_denominator_lower_bound_shift(inbits);
            inbits = inbits * 5 / 4;

            /*
             * Get a pair of bounds on x.
             */
            x.get_bracket(&n1, &n2, &d);
            dprint("input bracket (%b,%b) / %b", &n1, &n2, &d);
            fx1 = new BracketingGenerator(fcons->construct(n1, d));
            fx2 = new BracketingGenerator(fcons->construct(n2, d));
            fx1->set_denominator_lower_bound_shift(outbits);
            fx2->set_denominator_lower_bound_shift(outbits);

            while (1) {
                {
                    bigint nlo1, nhi1, nlo2, nhi2;
                    unsigned dbits1, dbits2, dbits;

                    fx1->get_bracket_shift(&nlo1, &nhi1, &dbits1);
                    fx2->get_bracket_shift(&nlo2, &nhi2, &dbits2);
                    dprint("output bracket 1 (%b,%b) / 2^%d",
                           &nlo1, &nhi1, (int)dbits1);
                    dprint("output bracket 2 (%b,%b) / 2^%d",
                           &nlo2, &nhi2, (int)dbits2);

                    if (dbits1 < dbits2) {
                        nlo1 <<= dbits2-dbits1;
                        nhi1 <<= dbits2-dbits1;
                        dbits = dbits2;
                    } else if (dbits2 < dbits1) {
                        nlo2 <<= dbits1-dbits2;
                        nhi2 <<= dbits1-dbits2;
                        dbits = dbits1;
                    } else {
                        dbits = dbits1;     // equality
                    }

                    dprint("evened output brackets (%b,%b),(%b,%b) / 2^%d",
                           &nlo1, &nhi1, &nlo2, &nhi2, (int)dbits);

                    if (nlo2 > nhi1 || nlo1 > nhi2) {
                        /*
                         * The output intervals have stopped overlapping,
                         * which is as good a moment as any to decide it's
                         * time to go and get more input detail.
                         */
                        dprint("going round again");
                        break;
                    }

                    *ret_lo = (nlo1 < nlo2 ? nlo1 : nlo2);
                    *ret_hi = (nhi1 > nhi2 ? nhi1 : nhi2);
                    *ret_bits = outbits = dbits;

                    dprint("returning (%b,%b) / 2^%d",
                           ret_lo, ret_hi, (int)*ret_bits);
                }

                crReturnV;
            }

            delete fx1;
            delete fx2;
            fx1 = fx2 = NULL;
        }

        crEnd;
    }
};

Spigot *spigot_monotone(MonotoneConstructor *f, Spigot *x)
{
    bigint n, d;
    if (x->is_rational(&n, &d)) {
        delete x;
        Spigot *ret = f->construct(n, d);
        delete f;
        return ret;
    } else
        return new MonotoneHelper(f, x);
}

class MonotoneInverter : public Source {
    MonotoneConstructor *fcons;
    bool increasing;
    bigint nlo, nhi, d;
    Spigot *target;
    int crState;
    int slo, shi;

  public:
    MonotoneInverter(MonotoneConstructor *afcons, bool aincreasing,
                     bigint n1, bigint n2, bigint ad,
                     Spigot *atarget)
    {
        fcons = afcons;
        increasing = aincreasing;
        target = atarget;

        if (ad < 0) {
            n1 = -n1;
            n2 = -n2;
            ad = -ad;
        }

        assert(n1 < n2);

        nlo = n1;
        nhi = n2;
        d = ad;

        crState = -1;
    }

    ~MonotoneInverter()
    {
        delete target;
    }

    virtual MonotoneInverter *clone()
    {
        return new MonotoneInverter(fcons->clone(), increasing, nlo, nhi, d,
                                    target->clone());
    }

    virtual bool gen_interval(bigint *low, bigint *high) {
        *low = 0;
        *high = 1;
        return true; // the first matrix will probably expand this interval
    }

    virtual bool gen_matrix(bigint *matrix) {
        crBegin;

        /*
         * Start by converting our starting interval [0,1] into
         * [nlo/d, nhi/d].
         */
        matrix[0] = nhi - nlo;
        matrix[1] = nlo;
        matrix[2] = 0;
        matrix[3] = d;
        crReturn(false);

        /*
         * Work out what signs we expect to see at the two interval
         * ends.
         */
        if (increasing) {
            slo = -1;
            shi = +1;
        } else {
            slo = +1;
            shi = -1;
        }

        /*
         * Now repeatedly narrow the interval. To avoid exactness
         * hazards, we trisect rather than bisecting: pick _two_ trial
         * points inside the existing interval, begin evaluating
         * f(x)-target with x equal to each of those points, and
         * whichever one we find out the sign of first, narrow the
         * interval to exclude either the first or last third.
         */
        while (1) {
            {
                bigint ndiff = nhi - nlo;

                if (ndiff % 3U != 0) {
                    d *= 3;
                    nlo *= 3;
                    nhi *= 3;
                } else {
                    ndiff /= 3;
                }

                bigint n1 = nlo + ndiff;
                bigint n2 = n1 + ndiff;

                int s = parallel_sign_test
                    (spigot_sub(fcons->construct(n1, d), target->clone()),
                     spigot_sub(fcons->construct(n2, d), target->clone()));

                if (s == slo) {
                    /* New interval is [n1, nhi]. */
                    nlo = n1;
                    /* Narrow to the top 2/3 of prior interval */
                    matrix[0] = 2;
                    matrix[1] = 1;
                    matrix[2] = 0;
                    matrix[3] = 3;
                } else if (s == -slo) {
                    /* New interval is [nlo, n1]. */
                    nhi = n1;
                    /* Narrow to the bottom 1/3 of prior interval */
                    matrix[0] = 1;
                    matrix[1] = 0;
                    matrix[2] = 0;
                    matrix[3] = 3;
                } else if (s == 2*slo) {
                    /* New interval is [n2, nhi]. */
                    nlo = n2;
                    /* Narrow to the top 1/3 of prior interval */
                    matrix[0] = 1;
                    matrix[1] = 2;
                    matrix[2] = 0;
                    matrix[3] = 3;
                } else {
                    /* New interval is [nlo, n2]. */
                    nhi = n2;
                    /* Narrow to the bottom 2/3 of prior interval */
                    matrix[0] = 2;
                    matrix[1] = 0;
                    matrix[2] = 0;
                    matrix[3] = 3;
                }
            }

            crReturn(false);
        }

        crEnd;
    }
};

Spigot *spigot_monotone_invert(MonotoneConstructor *f, bool increasing,
                               bigint n1, bigint n2, bigint d, Spigot *x)
{
    return new MonotoneInverter(f, increasing, n1, n2, d, x);
}
