<?php

namespace MathPHP\Tests\Statistics\Multivariate;

use MathPHP\Functions\Map\Multi;
use MathPHP\LinearAlgebra\Matrix;
use MathPHP\LinearAlgebra\MatrixFactory;
use MathPHP\SampleData;
use MathPHP\Statistics\Multivariate\PCA;
use MathPHP\Exception;

class PCACenterFalseScaleTrueTest extends \PHPUnit\Framework\TestCase
{
    /** @var PCA */
    private static $pca;

    /** @var Matrix  */
    private static $matrix;

    /**
     * R code for expected values:
     *   library(mdatools)
     *   data = mtcars[,c(1:7,10,11)]
     *   model = pca(data, center=FALSE, scale=TRUE)
     * @throws Exception\MathException
     */
    public static function setUpBeforeClass()
    {
        $mtCars = new SampleData\MtCars();

        // Remove and categorical variables
        self::$matrix = MatrixFactory::create($mtCars->getData())->columnExclude(8)->columnExclude(7);
        self::$pca = new PCA(self::$matrix, false, true);
    }

    /**
     * @test The class returns the correct R-squared values
     *
     * R code for expected values:
     *   model$calres$expvar / 100
     */
    public function testRsq()
    {
        // Given
        $expected = [9.627597e-01, 2.399382e-02, 8.974667e-03, 1.647348e-03, 1.007722e-03, 6.644581e-04, 4.551423e-04, 4.007877e-04, 9.637733e-05];

        // When
        $R2 = self::$pca->getR2();

        // Then
        $this->assertEquals($expected, $R2, '', .00001);
    }

    /**
     * @test The class returns the correct cumulative R-squared values
     *
     * R code for expected values:
     *   model$calres$cumexpvar / 100
     */
    public function testCumRsq()
    {
        // Given
        $expected = [0.9627597, 0.9867535, 0.9957282, 0.9973755, 0.9983832, 0.9990477, 0.9995028, 0.9999036, 1.0000000];

        // When
        $cumR2 = self::$pca->getCumR2();

        // Then
        $this->assertEquals($expected, $cumR2, '', .00001);
    }

    /**
     * @test The class returns the correct loadings
     *
     * R code for expected values:
     *   model$loadings
     *
     * @throws \Exception
     */
    public function testLoadings()
    {
        // Given
        $expected = [
            [-0.2283453, -0.3650884,  0.01991352, -0.3796295,  0.34604431, -0.23172234, 0.42903263, -0.54776506, -0.11178708],
            [-0.2341349,  0.4385442, -0.00311039, -0.4212048,  0.16151512, -0.47125470, 0.14271710,  0.53758994, -0.13378157],
            [-0.1253369,  0.4195141, -0.10415777, -0.3369536, -0.01895432,  0.42132831, 0.19228361, -0.17248911,  0.66280446],
            [-0.1445912,  0.3937394,  0.25409277, -0.1758046,  0.16061996,  0.01235594, -0.66477897, -0.44409230, -0.25197564],
            [-0.4598548, -0.2444309,  0.31302618, -0.2679742, -0.74400176,  0.02312682, -0.04868431,  0.04005242, -0.03072032],
            [-0.2230294,  0.4121847, -0.19849703,  0.2611516, -0.17864380,  0.31401959, 0.43383293, -0.15883735, -0.57340928],
            [-0.6814843, -0.1267230, -0.53408337,  0.3099947,  0.13377984, -0.11701792, -0.26734937,  0.02090626,  0.18621398],
            [-0.3417161, -0.1578309,  0.53025464,  0.1283565,  0.46982102,  0.47378775, 0.06393284,  0.33178008, -0.03608434],
            [-0.1184231,  0.2697307,  0.46944681,  0.5305098, -0.06266656, -0.45743019, 0.22475430, -0.21594317,  0.31793865]
        ];

        // And since each column could be multiplied by -1, we will compare the two and adjust.
        $loadings   = self::$pca->getLoadings();
        $load_array = $loadings->getMatrix();

        // Get an array that's roughly ones and negative ones.
        $quotiant = Multi::divide($expected[1], $load_array[1]);

        // Convert to exactly one or negative one. Cannot be zero.
        $signum = \array_map(
            function ($x) {
                return $x <=> 0;
            },
            $quotiant
        );
        $sign_change = MatrixFactory::diagonal($signum);

        // Multiplying a sign change matrix on the right changes column signs.
        $sign_adjusted = $loadings->multiply($sign_change);

        // Then
        $this->assertEquals($expected, $sign_adjusted->getMatrix(), '', .00001);
    }


    /**
     * @test The class returns the correct scores
     *
     * R code for expected values:
     *   model$calres$scores
     *   new = matrix(c(1:9), 1, 9)
     *   result = predict(model, new)
     *   result$scores
     *
     * @throws \Exception
     */
    public function testScores()
    {
        // Given
        $expected = [
            [-14.35066, -0.65965959,  1.2016258,  0.15487538, -0.299362630, -0.45941646,  0.40319925,  0.28569240,  0.159949082],
            [-14.62235, -0.59195135,  0.9825213,  0.32008269, -0.303995255, -0.41425005,  0.43247921,  0.25084881,  0.068866965],
            [-14.55682, -2.29042326, -0.2911804,  0.03077448, -0.007060303,  0.42965952, -0.41191416,  0.17374023, -0.127024712],
            [-14.34867, -0.22505382, -1.9614721, -0.20932669,  0.443716831,  0.02614755, -0.02775485, -0.05880444,  0.143964740],
            [-14.01070,  1.54954215, -0.8094368, -1.02132719,  0.247787154, -0.08881788,  0.06248158,  0.03235842,  0.159602626],
            [-14.25782,  0.02861675, -2.4333452,  0.46215743,  0.706518179,  0.05364997, -0.24431736,  0.26481478,  0.003159298],
            [-13.76947,  2.66307496,  0.3781640, -0.46684763, -0.114018898, -0.35185110, -0.42243004, -0.31894716,  0.175042030],
            [-15.25532, -1.92631374, -0.8282702,  0.78624311,  0.135325909,  0.39218068,  0.39409779, -0.09539406,  0.005222661],
            [-16.55300, -1.98750454, -1.4302883,  1.19558483,  0.026039043,  0.24679166, -0.52147213, -0.09786046,  0.194498654],
            [-15.22340, -0.24443784,  0.5328860,  0.74229318, -0.413210410, -0.21849364,  0.24727610,  0.24441794, -0.103751531],
            [-15.39917, -0.20218107,  0.3489316,  0.93456390, -0.448673945, -0.20395783,  0.05784849,  0.37867798, -0.015259641],
            [-14.14199,  1.87458933, -0.7253186, -0.05774572,  0.126266949, -0.39508414,  0.08855680,  0.08868195, -0.394565209],
            [-14.17486,  1.66265942, -0.7131462, -0.17048646,  0.254990715, -0.55190140, -0.02804934,  0.06441817, -0.191165130],
            [-14.25924,  1.78256558, -0.8497805,  0.04452587,  0.155233971, -0.48130854, -0.21521520,  0.25184178, -0.139833047],
            [-14.60887,  3.73267398, -1.0214334,  0.53684683, -0.205676916,  0.55854838,  0.31175310, -0.12993899,  0.247897004],
            [-14.65667,  3.80212820, -0.9207854,  0.52743129, -0.321566390,  0.58890438,  0.29088998, -0.20288424,  0.024306771],
            [-14.85825,  3.45003695, -0.5639484,  0.06674281, -0.372044072,  0.36908124,  0.42439661, -0.63764289, -0.226133832],
            [-15.33241, -3.34286781, -0.4329398, -0.42313731,  0.251613808, -0.12879000,  0.28498558, -0.43633349, -0.215820012],
            [-15.53280, -3.71292186,  0.7019771, -0.67160778, -1.081438617, -0.43658913,  0.21820869, -0.14587649,  0.248520179],
            [-15.58065, -3.71345479, -0.3978123, -0.58738381,  0.240581290, -0.35172158,  0.15075281, -0.48083802, -0.029942225],
            [-14.50305, -1.90345352, -1.5451798,  0.25228324, -0.423902509, -0.17229477, -0.74262863, -0.21893809,  0.065538523],
            [-13.41987,  1.68028135, -1.0770843, -0.45068058,  0.528758276, -0.09444063,  0.10534783,  0.49961873,  0.046117853],
            [-13.87438,  1.40647490, -0.9492607, -0.53727629,  0.018703256, -0.16906811, -0.07526003,  0.59441255,  0.049027104],
            [-14.06626,  2.59631600,  0.7614411, -0.63982126, -0.974969958, -0.21009585, -0.37243065, -0.22405212, -0.092804106],
            [-14.11365,  1.85513174, -0.9735096, -1.01318751,  0.296084716,  0.15292354,  0.34158820, -0.13939186,  0.134046728],
            [-14.86170, -3.10412644, -0.2259207, -0.27232178, -0.035544826,  0.02059226, -0.10981384,  0.06311741, -0.023721124],
            [-14.95215, -2.70656972,  1.6579182, -0.36675494, -0.149104693,  0.79943145,  0.23329711,  0.24505429, -0.092191003],
            [-14.50548, -2.90865277,  1.4561922, -0.43367260,  1.206753297,  0.30568948,  0.04627864, -0.20758638,  0.036786117],
            [-15.05003,  1.68781721,  2.9714326, -1.08305855, -0.141138310,  0.85068735, -0.40936761,  0.45847020, -0.031808329],
            [-14.46037,  0.12093755,  2.8433751,  0.95556921,  0.630111488, -0.27406595,  0.25773517,  0.12944592,  0.037880250],
            [-14.95663,  3.11502916,  3.9273674,  0.69992324,  0.712383681, -0.31613525, -0.45139048, -0.50859225,  0.057149291],
            [-14.94864, -1.82710616,  0.1050937,  0.36176865, -0.537268693,  0.40688978, -0.32561122, -0.00975516, -0.179051914],
        ];

        // And since each column could be multiplied by -1, we will compare the two and adjust.
        $scores = self::$pca->getScores();
        $score_array = $scores->getMatrix();

        // Get an array that's roughly ones and negative ones.
        $quotiant = Multi::divide($expected[1], $score_array[1]);

        // Convert to exactly one or negative one. Cannot be zero.
        $signum = \array_map(
            function ($x) {
                return $x <=> 0;
            },
            $quotiant
        );
        $signature = MatrixFactory::diagonal($signum);

        // Multiplying a sign change matrix on the right changes column signs.
        $sign_adjusted = $scores->multiply($signature);

        // Then
        $this->assertEquals($expected, $sign_adjusted->getMatrix(), '', .00001);

        // And Given
        $expected = MatrixFactory::create([[-13.01415, 0.0006325086, 7.995322, 4.104522, -2.536586, 3.716645, 3.300201, 2.357735, -1.860735]]);
        $sign_adjusted = $expected->multiply($signature);

        // When
        $scores = self::$pca->getScores(MatrixFactory::create([[1,2,3,4,5,6,7,8,9]]));

        // Then
        $this->assertEquals($sign_adjusted->getMatrix(), $scores->getMatrix(), '', .00001);
    }

    /**
     * @test The class returns the correct eigenvalues
     *
     * R code for expected values:
     *   model$eigenvals
     */
    public function testEigenvalues()
    {
        // Given
        $expected = [4.030377e-01, 5.538161e+00, 2.072459e+00, 3.803137e-01, 2.326922e-01, 1.534309e-01, 1.051069e-01, 9.254194e-02, 2.225659e-02];

        // When
        $eigenvalues = self::$pca->getEigenvalues()->getVector();

        // Then
        $this->assertEquals($expected, $eigenvalues, '', .00001);
    }

    /**
     * @test The class returns the correct critical T² distances
     *
     * R code for expected values:
     *   model$T2lim
     */
    public function testCriticalT2()
    {
        // Given
        $expected = [4.159615, 6.852714, 9.40913, 12.01948, 14.76453, 17.69939, 20.87304, 24.33584, 28.14389];

        // When
        $criticalT2 = self::$pca->getCriticalT2();

        // Then
        $this->assertEquals($expected, $criticalT2, '', .00001);
    }

    /**
     * @test The class returns the correct critical Q distances
     *
     * R code for expected values:
     *   model$Qlim
     */
    public function testCriticalQ()
    {
        // Given
        $expected = [25.824597, 9.343233, 2.3843107, 1.4845012, 0.9577895, 0.6182419, 0.3849731, 0.08339018, 0];

        // When
        $criticalQ = self::$pca->getCriticalQ();

        // Then
        $this->assertEquals($expected, $criticalQ, '', .00001);
    }

    /**
     * @test The class returns the correct T² distances
     *
     * R code for expected values:
     *   model$calres$T2
     *
     * @throws \Exception
     */
    public function testGetT²Distances()
    {
        // Given
        $expected = [
            [0.9262786, 1.0048123, 1.701496, 1.764548,  2.149645,  3.525146,  5.071853,  5.953711,  7.103198],
            [0.9616837, 1.0249232, 1.490704, 1.760015,  2.157123,  3.275461,  5.054965,  5.734834,  5.947924],
            [0.9530838, 1.8998610, 1.940770, 1.943260,  1.943474,  3.146561,  4.760852,  5.086990,  5.811956],
            [0.9260216, 0.9351625, 2.791520, 2.906701,  3.752732,  3.757188,  3.764517,  3.801878,  4.733100],
            [0.8829120, 1.3162464, 1.632375, 4.374328,  4.638164,  4.689574,  4.726716,  4.738029,  5.882543],
            [0.9143322, 0.9144800, 3.771444, 4.332893,  6.477865,  6.496623,  7.064529,  7.822209,  7.822658],
            [0.8527711, 2.1326919, 2.201693, 2.774596,  2.830459,  3.637259,  5.335026,  6.434130,  7.810786],
            [1.0467445, 1.7164291, 2.047439, 3.672405,  3.751098,  4.753450,  6.231117,  6.329437,  6.330663],
            [1.2323981, 1.9453046, 2.932367, 6.689799,  6.692713,  7.089638,  9.676842,  9.780312, 11.480019],
            [1.0423675, 1.0531508, 1.190165, 2.638541,  3.372239,  3.683358,  4.265103,  4.910560,  5.394208],
            [1.0665781, 1.0739554, 1.132701, 3.428577,  4.293618,  4.564717,  4.596556,  6.145877,  6.156339],
            [0.8995364, 1.5337399, 1.787577, 1.796342,  1.864852,  2.882101,  2.956714,  3.041685, 10.036535],
            [0.9037232, 1.4026340, 1.648023, 1.724426,  2.003824,  3.988872,  3.996357,  4.041193,  5.683136],
            [0.9145147, 1.4879803, 1.836407, 1.841618,  1.945168,  3.454883,  3.895554,  4.580817,  5.459354],
            [0.9599115, 3.4744393, 3.977844, 4.735429,  4.917209,  6.950360,  7.875037,  8.057460, 10.818568],
            [0.9662034, 3.5751779, 3.984264, 4.715507,  5.159848,  7.420000,  8.225056,  8.669787,  8.696333],
            [0.9929637, 3.1411098, 3.294563, 3.306273,  3.901063,  4.788815,  6.502426, 10.895377, 13.192963],
            [1.0573498, 3.0741121, 3.164550, 3.635195,  3.907242,  4.015339,  4.788044,  6.845064,  8.937847],
            [1.0851687, 3.5731548, 3.810917, 4.996582, 10.022074, 11.264281, 11.717296, 11.947213, 14.722220],
            [1.0918644, 3.5805647, 3.656923, 4.563854,  4.812567,  5.618773,  5.834995,  8.333031,  8.373313],
            [0.9460550, 1.5999393, 2.751947, 2.919251,  3.691411,  3.884871,  9.131880,  9.649778,  9.842768],
            [0.8100177, 1.3195601, 1.879313, 2.413223,  3.614628,  3.672753,  3.778342,  6.475328,  6.570889],
            [0.8658151, 1.2228251, 1.657604, 2.416401,  2.417904,  2.604186,  2.658075,  6.475559,  6.583556],
            [0.8899282, 2.1064823, 2.386232, 3.462320,  7.546993,  7.834656,  9.154307,  9.696682, 10.083650],
            [0.8959346, 1.5170407, 1.974316, 4.672739,  5.049449,  5.201853,  6.311984,  6.521915,  7.329249],
            [0.9934248, 2.7324070, 2.757034, 2.951971,  2.957400,  2.960164,  3.074895,  3.117938,  3.143220],
            [1.0055527, 2.3276237, 3.653868, 4.007444,  4.102978,  8.267935,  8.785765,  9.434587,  9.816459],
            [0.9463721, 2.4732351, 3.496374, 3.990747, 10.248405, 10.857394, 10.877770, 11.343355, 11.404156],
            [1.0187610, 1.5328841, 5.793076, 8.876507,  8.962105, 13.678259, 15.272652, 17.543686, 17.589145],
            [0.9404946, 0.9431342, 4.844042, 7.244282,  8.950402,  9.439908, 10.071907, 10.252948, 10.317419],
            [1.0061556, 2.7573750,10.199559,11.487308, 13.668042, 14.319362, 16.257895, 19.052631, 19.199375],
            [1.0050816, 1.6075633, 1.612892, 1.956919,  3.197308,  4.276259,  5.284971,  5.285999,  6.726451],
        ];

        // When
        $T²Distances = self::$pca->getT2Distances()->getMatrix();

        // Then
        $this->assertEquals($expected, $T²Distances, '', .00001);
    }

    /**
     * @test The class returns the correct T² distances
     *
     * R code for expected values:
     *   new = matrix(c(1:9), 1, 9)
     *   result = predict(model, new)
     *   result$T2
     *
     * @throws \Exception
     */
    public function testT2WithNewData()
    {
        // Given
        $expected = [[0.7617799, 0.76178, 31.60568, 75.89057, 103.5392, 193.5615, 297.1829, 357.2437, 512.808]];
        $newdata  = MatrixFactory::create([[1,2,3,4,5,6,7,8,9]]);

        // When
        $T²Distances = self::$pca->getT2Distances($newdata)->getMatrix();

        // Then
        $this->assertEquals($expected, $T²Distances, '', .0001);
    }

    /**
     * @test The class returns the correct Q residuals
     *
     * R code for expected values:
     *   model$calres$Q
     *
     * @throws \Exception
     */
    public function testGetQResiduals()
    {
        // Given
        $expected = [
            [ 2.473497, 2.0383460, 0.59444134, 0.57045496, 0.48083697, 0.26977349, 0.107203856, 2.558371e-02, 2.672266e-29],
            [ 1.936930, 1.5865233, 0.62117520, 0.51872227, 0.42630916, 0.25470606, 0.067667784, 4.742659e-03, 2.534216e-29],
            [ 5.732423, 0.4863844, 0.40159844, 0.40065137, 0.40060152, 0.21599422, 0.046320946, 1.613528e-02, 1.869847e-29],
            [ 4.164362, 4.1137130, 0.26634012, 0.22252246, 0.02563783, 0.02495414, 0.024183808, 2.072585e-02, 3.112303e-29],
            [ 4.199089, 1.7980082, 1.14282032, 0.09971110, 0.03831263, 0.03042401, 0.026520065, 2.547300e-02, 2.312349e-29],
            [ 6.767451, 6.7666322, 0.84546356, 0.63187407, 0.13270614, 0.12982782, 0.070136846, 9.981166e-06, 4.675233e-29],
            [ 7.900537, 0.8085684, 0.66556037, 0.44761365, 0.43461334, 0.31081415, 0.132367004, 3.063971e-02, 2.484912e-29],
            [ 5.351454, 1.6407690, 0.95473739, 0.33655915, 0.31824605, 0.16444037, 0.009127303, 2.727619e-05, 1.983246e-29],
            [ 7.806246, 3.8560715, 1.81034682, 0.38092374, 0.38024571, 0.31933958, 0.047406396, 3.782973e-02, 1.903127e-29],
            [ 1.244849, 1.1850989, 0.90113146, 0.35013230, 0.17938945, 0.13164998, 0.070504511, 1.076438e-02, 2.233462e-29],
            [ 1.425924, 1.3850464, 1.26329311, 0.38988343, 0.18857512, 0.14697632, 0.143629871, 2.328566e-04, 3.096279e-29],
            [ 4.386930, 0.8728449, 0.34675789, 0.34342332, 0.32747998, 0.17138850, 0.163546193, 1.556817e-01, 2.958228e-29],
            [ 3.713175, 0.9487391, 0.44016163, 0.41109599, 0.34607573, 0.04148057, 0.040693808, 3.654411e-02, 1.794659e-29],
            [ 4.286700, 1.1091602, 0.38703319, 0.38505064, 0.36095305, 0.12929515, 0.082977565, 1.955328e-02, 2.090481e-29],
            [15.794192, 1.8613370, 0.81801086, 0.52980634, 0.48750335, 0.17552706, 0.078337065, 6.145292e-02, 2.647614e-29],
            [16.158791, 1.7026127, 0.85476689, 0.57658313, 0.47317818, 0.12636981, 0.041752833, 5.908191e-04, 2.509564e-29],
            [13.137723, 1.2349676, 0.91692980, 0.91247520, 0.77405841, 0.63783744, 0.457724962, 5.113651e-02, 3.628760e-29],
            [11.939326, 0.7645604, 0.57712353, 0.39807835, 0.33476884, 0.31818197, 0.236965192, 4.657828e-02, 2.505866e-29],
            [16.220394, 2.4346057, 1.94183382, 1.49077681, 0.32126733, 0.13065726, 0.083042230, 6.176228e-02, 2.001735e-29],
            [14.729436, 0.9396899, 0.78143531, 0.43641557, 0.37853621, 0.25482814, 0.232101734, 8.965368e-04, 2.959461e-29],
            [ 6.887468, 3.2643328, 0.87675212, 0.81310529, 0.63341195, 0.60372646, 0.052229183, 4.295298e-03, 4.468157e-29],
            [ 4.737917, 1.9145719, 0.75446122, 0.55134824, 0.27176293, 0.26284389, 0.251745728, 2.126856e-03, 1.735494e-29],
            [ 3.558261, 1.5800895, 0.67899366, 0.39032785, 0.38997804, 0.36139401, 0.355729941, 2.403657e-03, 1.636886e-29],
            [ 8.922244, 2.1813870, 1.60159446, 1.19222322, 0.24165680, 0.19751654, 0.058811953, 8.612602e-03, 2.169367e-29],
            [ 5.680917, 2.2394028, 1.29168181, 0.26513288, 0.17746673, 0.15408112, 0.037398616, 1.796853e-02, 2.267975e-29],
            [ 9.779093, 0.1434924, 0.09245221, 0.01829305, 0.01702962, 0.01660558, 0.004546499, 5.626917e-04, 1.366948e-29],
            [10.993023, 3.6675032, 0.91881036, 0.78430118, 0.76206897, 0.12297833, 0.068550788, 8.499181e-03, 1.899429e-29],
            [12.365115, 3.9048544, 1.78435854, 1.59628661, 0.14003310, 0.04658704, 0.044445325, 1.353218e-03, 3.841999e-29],
            [13.973532,11.1248050, 2.29539336, 1.12237753, 1.10245751, 0.37878854, 0.211206692, 1.011770e-03, 4.279570e-29],
            [ 9.569292, 9.5546657, 1.46988373, 0.55677121, 0.15973073, 0.08461858, 0.018191160, 1.434913e-03, 2.997671e-29],
            [26.690631,16.9872246, 1.56301004, 1.07311749, 0.56562699, 0.46568549, 0.261932119, 3.266041e-03, 6.054507e-29],
            [ 4.072633, 0.7343156, 0.72327092, 0.59239436, 0.30373671, 0.13817742, 0.032154751, 3.205959e-02, 1.464323e-29],
        ];

        // When
        $qResiduals = self::$pca->getQResiduals()->getMatrix();

        // Then
        $this->assertEquals($expected, $qResiduals, '', .00001);
    }

    /**
     * @test The class returns the correct Q residuals
     *
     * library(mdatools)
     *   new = matrix(c(1:9), 1, 9)
     *   result = predict(model, new)
     *   result$Q
     *
     * @throws \Exception
     */
    public function testQWithNewData()
    {
        // Given
        $expected = [[120.9326, 120.9326, 57.0074, 40.16029, 33.72602, 19.91258, 9.021248, 3.462335, 8.846567e-29]];
        $newData  = MatrixFactory::create([[1,2,3,4,5,6,7,8,9]]);

        // When
        $qResiduals = self::$pca->getQResiduals($newData)->getMatrix();

        // Then
        $this->assertEquals($expected, $qResiduals, '', .0001);
    }
}
