Class OptimizerBase¶
Defined in File optimizers.h
Inheritance Relationships¶
Base Types¶
public marian::TrainingObserver
(Class TrainingObserver)public marian::ExponentialSmoothing
(Class ExponentialSmoothing)
Derived Types¶
public marian::Adagrad
(Class Adagrad)public marian::Adam
(Class Adam)public marian::Sgd
(Class Sgd)
Class Documentation¶
-
class
OptimizerBase
: public marian::TrainingObserver, public marian::ExponentialSmoothing¶ Base class for optimizers.
Subclassed by marian::Adagrad, marian::Adam, marian::Sgd
Public Types
-
typedef std::function<void(const io::Item&, const ScatterStateSetFunc&)>
ScatterStateFunc
¶
-
typedef std::function<io::Item(const GatherStateGetFunc&)>
GatherStateFunc
¶
Public Functions
-
virtual
~OptimizerBase
()¶
-
float
update
(Ptr<ExpressionGraph> graph, size_t mbSize, float costScaleFactor = 1.f)¶
-
float
update
(Tensor params, Tensor grads, size_t mbSize, float costScaleFactor = 1.f)¶
-
virtual void
init
(TrainingState &state)¶
-
virtual void
actAfterLoaded
(TrainingState &state)¶
-
virtual void
actAfterEpoch
(TrainingState &state)¶
-
virtual void
actAfterBatches
(TrainingState &state)¶
-
virtual void
actAfterStalled
(TrainingState &state)¶
-
void
load
(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const std::vector<Ptr<Backend>> &backends, const ScatterStateFunc &scatterFn, bool isMainProcess)¶
-
void
save
(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const GatherStateFunc &gatherFn, bool isMainProcess)¶
-
void
swapWithSmoothed
(Tensor params)¶
Protected Functions
-
virtual void
updateImpl
(Tensor params, Tensor grads, size_t actualMBSize) = 0¶
-
virtual void
resetStats
() = 0¶
-
typedef std::function<void(const io::Item&, const ScatterStateSetFunc&)>