238 lines
6.9 KiB
Java
238 lines
6.9 KiB
Java
/*
|
|
* ====================================================================
|
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
* contributor license agreements. See the NOTICE file distributed with
|
|
* this work for additional information regarding copyright ownership.
|
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
* (the "License"); you may not use this file except in compliance with
|
|
* the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* ====================================================================
|
|
*/
|
|
|
|
package org.apache.poi.ss.formula.functions;
|
|
|
|
import org.apache.poi.ss.formula.TwoDEval;
|
|
import org.apache.poi.ss.formula.eval.ErrorEval;
|
|
import org.apache.poi.ss.formula.eval.EvaluationException;
|
|
import org.apache.poi.ss.formula.eval.NumberEval;
|
|
import org.apache.poi.ss.formula.eval.RefEval;
|
|
import org.apache.poi.ss.formula.eval.ValueEval;
|
|
import org.apache.poi.ss.formula.functions.LookupUtils.ValueVector;
|
|
|
|
/**
|
|
* Base class for linear regression functions.
|
|
*
|
|
* Calculates the linear regression line that is used to predict y values from x values<br/>
|
|
* (http://introcs.cs.princeton.edu/java/97data/LinearRegression.java.html)
|
|
* <b>Syntax</b>:<br/>
|
|
* <b>INTERCEPT</b>(<b>arrayX</b>, <b>arrayY</b>)<p/>
|
|
* or
|
|
* <b>SLOPE</b>(<b>arrayX</b>, <b>arrayY</b>)<p/>
|
|
*
|
|
*
|
|
* @author Johan Karlsteen
|
|
*/
|
|
public final class LinearRegressionFunction extends Fixed2ArgFunction {
|
|
|
|
private static abstract class ValueArray implements ValueVector {
|
|
private final int _size;
|
|
protected ValueArray(int size) {
|
|
_size = size;
|
|
}
|
|
|
|
public ValueEval getItem(int index) {
|
|
if (index < 0 || index > _size) {
|
|
throw new IllegalArgumentException("Specified index " + index
|
|
+ " is outside range (0.." + (_size - 1) + ")");
|
|
}
|
|
return getItemInternal(index);
|
|
}
|
|
protected abstract ValueEval getItemInternal(int index);
|
|
|
|
public final int getSize() {
|
|
return _size;
|
|
}
|
|
}
|
|
|
|
private static final class SingleCellValueArray extends ValueArray {
|
|
private final ValueEval _value;
|
|
public SingleCellValueArray(ValueEval value) {
|
|
super(1);
|
|
_value = value;
|
|
}
|
|
|
|
protected ValueEval getItemInternal(int index) {
|
|
return _value;
|
|
}
|
|
}
|
|
|
|
private static final class RefValueArray extends ValueArray {
|
|
private final RefEval _ref;
|
|
private final int _width;
|
|
public RefValueArray(RefEval ref) {
|
|
super(ref.getNumberOfSheets());
|
|
_ref = ref;
|
|
_width = ref.getNumberOfSheets();
|
|
}
|
|
|
|
protected ValueEval getItemInternal(int index) {
|
|
int sIx = (index % _width) + _ref.getFirstSheetIndex();
|
|
return _ref.getInnerValueEval(sIx);
|
|
}
|
|
}
|
|
|
|
private static final class AreaValueArray extends ValueArray {
|
|
private final TwoDEval _ae;
|
|
private final int _width;
|
|
|
|
public AreaValueArray(TwoDEval ae) {
|
|
super(ae.getWidth() * ae.getHeight());
|
|
_ae = ae;
|
|
_width = ae.getWidth();
|
|
}
|
|
|
|
protected ValueEval getItemInternal(int index) {
|
|
int rowIx = index / _width;
|
|
int colIx = index % _width;
|
|
return _ae.getValue(rowIx, colIx);
|
|
}
|
|
}
|
|
|
|
public enum FUNCTION {INTERCEPT, SLOPE};
|
|
public FUNCTION function;
|
|
|
|
public LinearRegressionFunction(FUNCTION function) {
|
|
this.function = function;
|
|
}
|
|
|
|
public ValueEval evaluate(int srcRowIndex, int srcColumnIndex,
|
|
ValueEval arg0, ValueEval arg1) {
|
|
double result;
|
|
try {
|
|
ValueVector vvY = createValueVector(arg0);
|
|
ValueVector vvX = createValueVector(arg1);
|
|
int size = vvX.getSize();
|
|
if (size == 0 || vvY.getSize() != size) {
|
|
return ErrorEval.NA;
|
|
}
|
|
result = evaluateInternal(vvX, vvY, size);
|
|
} catch (EvaluationException e) {
|
|
return e.getErrorEval();
|
|
}
|
|
if (Double.isNaN(result) || Double.isInfinite(result)) {
|
|
return ErrorEval.NUM_ERROR;
|
|
}
|
|
return new NumberEval(result);
|
|
}
|
|
|
|
private double evaluateInternal(ValueVector x, ValueVector y, int size)
|
|
throws EvaluationException {
|
|
|
|
// error handling is as if the x is fully evaluated before y
|
|
ErrorEval firstXerr = null;
|
|
ErrorEval firstYerr = null;
|
|
boolean accumlatedSome = false;
|
|
// first pass: read in data, compute xbar and ybar
|
|
double sumx = 0.0, sumy = 0.0;
|
|
|
|
for (int i = 0; i < size; i++) {
|
|
ValueEval vx = x.getItem(i);
|
|
ValueEval vy = y.getItem(i);
|
|
if (vx instanceof ErrorEval) {
|
|
if (firstXerr == null) {
|
|
firstXerr = (ErrorEval) vx;
|
|
continue;
|
|
}
|
|
}
|
|
if (vy instanceof ErrorEval) {
|
|
if (firstYerr == null) {
|
|
firstYerr = (ErrorEval) vy;
|
|
continue;
|
|
}
|
|
}
|
|
// only count pairs if both elements are numbers
|
|
if (vx instanceof NumberEval && vy instanceof NumberEval) {
|
|
accumlatedSome = true;
|
|
NumberEval nx = (NumberEval) vx;
|
|
NumberEval ny = (NumberEval) vy;
|
|
sumx += nx.getNumberValue();
|
|
sumy += ny.getNumberValue();
|
|
} else {
|
|
// all other combinations of value types are silently ignored
|
|
}
|
|
}
|
|
double xbar = sumx / size;
|
|
double ybar = sumy / size;
|
|
|
|
// second pass: compute summary statistics
|
|
double xxbar = 0.0, xybar = 0.0;
|
|
for (int i = 0; i < size; i++) {
|
|
ValueEval vx = x.getItem(i);
|
|
ValueEval vy = y.getItem(i);
|
|
|
|
if (vx instanceof ErrorEval) {
|
|
if (firstXerr == null) {
|
|
firstXerr = (ErrorEval) vx;
|
|
continue;
|
|
}
|
|
}
|
|
if (vy instanceof ErrorEval) {
|
|
if (firstYerr == null) {
|
|
firstYerr = (ErrorEval) vy;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// only count pairs if both elements are numbers
|
|
if (vx instanceof NumberEval && vy instanceof NumberEval) {
|
|
NumberEval nx = (NumberEval) vx;
|
|
NumberEval ny = (NumberEval) vy;
|
|
xxbar += (nx.getNumberValue() - xbar) * (nx.getNumberValue() - xbar);
|
|
xybar += (nx.getNumberValue() - xbar) * (ny.getNumberValue() - ybar);
|
|
} else {
|
|
// all other combinations of value types are silently ignored
|
|
}
|
|
}
|
|
double beta1 = xybar / xxbar;
|
|
double beta0 = ybar - beta1 * xbar;
|
|
|
|
if (firstXerr != null) {
|
|
throw new EvaluationException(firstXerr);
|
|
}
|
|
if (firstYerr != null) {
|
|
throw new EvaluationException(firstYerr);
|
|
}
|
|
if (!accumlatedSome) {
|
|
throw new EvaluationException(ErrorEval.DIV_ZERO);
|
|
}
|
|
|
|
if(function == FUNCTION.INTERCEPT) {
|
|
return beta0;
|
|
} else {
|
|
return beta1;
|
|
}
|
|
}
|
|
|
|
private static ValueVector createValueVector(ValueEval arg) throws EvaluationException {
|
|
if (arg instanceof ErrorEval) {
|
|
throw new EvaluationException((ErrorEval) arg);
|
|
}
|
|
if (arg instanceof TwoDEval) {
|
|
return new AreaValueArray((TwoDEval) arg);
|
|
}
|
|
if (arg instanceof RefEval) {
|
|
return new RefValueArray((RefEval) arg);
|
|
}
|
|
return new SingleCellValueArray(arg);
|
|
}
|
|
}
|
|
|