Class OptimizerBase¶
Defined in File optimizers.h
Inheritance Relationships¶
Base Types¶
public marian::TrainingObserver(Class TrainingObserver)public marian::ExponentialSmoothing(Class ExponentialSmoothing)
Derived Types¶
public marian::Adagrad(Class Adagrad)public marian::Adam(Class Adam)public marian::Sgd(Class Sgd)
Class Documentation¶
-
class
OptimizerBase: public marian::TrainingObserver, public marian::ExponentialSmoothing¶ Base class for optimizers.
Subclassed by marian::Adagrad, marian::Adam, marian::Sgd
Public Types
-
typedef std::function<void(const io::Item&, const ScatterStateSetFunc&)>
ScatterStateFunc¶
-
typedef std::function<io::Item(const GatherStateGetFunc&)>
GatherStateFunc¶
Public Functions
-
virtual
~OptimizerBase()¶
-
float
update(Ptr<ExpressionGraph> graph, size_t mbSize, float costScaleFactor = 1.f)¶
-
float
update(Tensor params, Tensor grads, size_t mbSize, float costScaleFactor = 1.f)¶
-
virtual void
init(TrainingState &state)¶
-
virtual void
actAfterLoaded(TrainingState &state)¶
-
virtual void
actAfterEpoch(TrainingState &state)¶
-
virtual void
actAfterBatches(TrainingState &state)¶
-
virtual void
actAfterStalled(TrainingState &state)¶
-
void
load(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const std::vector<Ptr<Backend>> &backends, const ScatterStateFunc &scatterFn, bool isMainProcess)¶
-
void
save(std::vector<io::Item> &items, const std::vector<Ptr<OptimizerBase>> &opts, const GatherStateFunc &gatherFn, bool isMainProcess)¶
-
void
swapWithSmoothed(Tensor params)¶
Protected Functions
-
virtual void
updateImpl(Tensor params, Tensor grads, size_t actualMBSize) = 0¶
-
virtual void
resetStats() = 0¶
-
typedef std::function<void(const io::Item&, const ScatterStateSetFunc&)>