package com.xiaomi.ai.nlp.optimization;

import com.xiaomi.ai.nlp.utils.MLMath;

/* loaded from: classes4.dex */
public class NewtonDirection {
    private final int DIMENSION;
    private final int M;
    private double[] alpha;
    private double beta;
    private double gamma;
    private double[] q;
    private double[] rho;
    private double[][] s;
    private int size;
    private double[][] y;

    public NewtonDirection(int i, int i2) {
        this.M = i;
        this.s = new double[i];
        this.y = new double[i];
        this.rho = new double[i];
        this.alpha = new double[i];
        this.DIMENSION = i2;
        for (int i3 = 0; i3 < this.M; i3++) {
            double[][] dArr = this.s;
            int i4 = this.DIMENSION;
            dArr[i3] = new double[i4];
            this.y[i3] = new double[i4];
        }
        this.q = new double[this.DIMENSION];
        this.beta = 0.0d;
        this.gamma = 1.0d;
        this.size = 0;
    }

    public double[] computeDirection(double[] dArr) {
        if (dArr == null) {
            throw new IllegalArgumentException("function's negative gradient is null");
        }
        int length = dArr.length;
        int i = this.DIMENSION;
        if (length != i) {
            throw new IllegalArgumentException("function's negative gradient dimension isn't valid");
        }
        System.arraycopy(dArr, 0, this.q, 0, i);
        for (int i2 = 0; i2 < this.size; i2++) {
            this.alpha[i2] = this.rho[i2] * MLMath.dotProd(this.s[i2], this.q);
            double[] dArr2 = this.q;
            MLMath.plusTo(dArr2, 1.0d, this.y[i2], -this.alpha[i2], dArr2);
        }
        double[] dArr3 = this.q;
        MLMath.transformTo(dArr3, this.gamma, dArr3);
        for (int i3 = this.size - 1; i3 >= 0; i3--) {
            double dotProd = this.rho[i3] * MLMath.dotProd(this.y[i3], this.q);
            this.beta = dotProd;
            double[] dArr4 = this.q;
            MLMath.plusTo(dArr4, 1.0d, this.s[i3], this.alpha[i3] - dotProd, dArr4);
        }
        return this.q;
    }

    public int getDimension() {
        return this.DIMENSION;
    }

    public double getGamma() {
        return this.gamma;
    }

    public double[] getRho() {
        return this.rho;
    }

    public double[][] getS() {
        return this.s;
    }

    public int getSize() {
        return this.size;
    }

    public double[][] getY() {
        return this.y;
    }

    public void updateSYRho(double[] dArr, double[] dArr2) {
        if (dArr == null || dArr2 == null) {
            throw new IllegalArgumentException("input params sk1 or yk1 is null");
        }
        int length = dArr.length;
        int i = this.DIMENSION;
        if (length != i || dArr2.length != i) {
            throw new IllegalArgumentException("input params sk1 or yk1 dimension is invalid");
        }
        int i2 = 0;
        while (i2 < this.M - 1) {
            double[][] dArr3 = this.s;
            double[] dArr4 = dArr3[i2];
            i2++;
            System.arraycopy(dArr4, 0, dArr3[i2], 0, this.DIMENSION);
        }
        int i3 = 0;
        while (true) {
            int i4 = this.M;
            if (i3 >= i4 - 1) {
                double[] dArr5 = this.rho;
                System.arraycopy(dArr5, 0, dArr5, 1, i4 - 1);
                System.arraycopy(dArr, 0, this.s[0], 0, this.DIMENSION);
                System.arraycopy(dArr2, 0, this.y[0], 0, this.DIMENSION);
                this.rho[0] = 1.0d / MLMath.dotProd(this.y[0], this.s[0]);
                double dotProd = MLMath.dotProd(this.y[0], this.s[0]);
                double[][] dArr6 = this.y;
                this.gamma = dotProd / MLMath.dotProd(dArr6[0], dArr6[0]);
                this.size = Math.min(this.size + 1, this.M);
                return;
            }
            double[][] dArr7 = this.y;
            double[] dArr8 = dArr7[i3];
            i3++;
            System.arraycopy(dArr8, 0, dArr7[i3], 0, this.DIMENSION);
        }
    }
}
