Package deepnetts.net.loss
Class CrossEntropyLoss
java.lang.Object
deepnetts.net.loss.CrossEntropyLoss
- All Implemented Interfaces:
LossFunction,Serializable
Average Cross Entropy Loss function commonly used for multi class classification problems.
E = -1/n * SUM(SUM(t*ln(y)))
Since its 1-of-n classification scheme, all outputs but target are zeros so it comes down to E = -1/n * SUM(ln(y_targetIdx))
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionfloat[]addPatternError(float[] predictedOut, float[] targetOut) Calculates and returns outpurt error vector for specified predicted and target outputs.addPatternError(TensorBase predictedOut, TensorBase targetOut) voidaddRegularizationSum(float regSum) Adds specified regularization sum to total loss.floatfloatgetTotal()Returns the total error calculated by this loss function.voidreset()Resets the total error and pattern counter.Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface deepnetts.net.loss.LossFunction
valueFor
-
Constructor Details
-
CrossEntropyLoss
-
-
Method Details
-
addPatternError
public float[] addPatternError(float[] predictedOut, float[] targetOut) Calculates and returns outpurt error vector for specified predicted and target outputs.- Specified by:
addPatternErrorin interfaceLossFunction- Parameters:
predictedOut- predicted output from the neural networktargetOut- target/desired output of the neural network- Returns:
- error vector for specified actual and target outputs
-
addPatternError
- Specified by:
addPatternErrorin interfaceLossFunction
-
getPatternLoss
public float getPatternLoss()- Specified by:
getPatternLossin interfaceLossFunction
-
addRegularizationSum
public void addRegularizationSum(float regSum) Description copied from interface:LossFunctionAdds specified regularization sum to total loss.- Specified by:
addRegularizationSumin interfaceLossFunction- Parameters:
regSum- regularization sum
-
getTotal
public float getTotal()Description copied from interface:LossFunctionReturns the total error calculated by this loss function.- Specified by:
getTotalin interfaceLossFunction- Returns:
- total error calculated by this loss function
-
reset
public void reset()Description copied from interface:LossFunctionResets the total error and pattern counter.- Specified by:
resetin interfaceLossFunction
-