scalanlp/breeze

Type class for derivation of functions

Open

#207 建立於 2014年3月28日

在 GitHub 查看
 (1 留言) (0 反應) (0 負責人)Scala (3,453 star) (690 fork)batch import
enhancementhelp wantedproject

描述

Hi all,

in one of our core algorithms we use derivable functions that can be optimized efficiently using gradient methods. We have computed the derivations using SageMath, but I'm wondering if something similar could be achieved in Breeze directly. Here's a first shot:

trait DerivableFunction[T] extends (T => T) {
  def derivation(t: T): T
}

case class Product[T](factor1: DerivableFunction[T], factor2: DerivableFunction[T])(implicit multImpl: OpMulScalar.Impl2[T, T, T], addImpl: OpAdd.Impl2[T, T, T]) extends DerivableFunction[T] {
  def apply(p: T) = multImpl(factor1(p), factor2(p))
  def derivation(p: T) = addImpl(multImpl(factor1(p), factor2.derivation(p)), multImpl(factor1.derivation(p), factor2(p)))
}

case class Sum[T](summand1: DerivableFunction[T], summand2: DerivableFunction[T])(addImpl: OpAdd.Impl2[T, T, T]) extends DerivableFunction[T] {
  def apply(p: T) = addImpl(summand1(p), summand2(p))
  def derivation(p: T) = addImpl(summand1.derivation(p), summand2.derivation(p))
}

case class Const[T](const: T)(implicit diffImpl: OpSub.Impl2[T, T, T]) extends DerivableFunction[T] {
  def apply(p: T) = const
  def derivation(p: T) = diffImpl(p, p)
}

case class Var[T](implicit divImpl: OpDiv.Impl2[T, T, T]) extends DerivableFunction[T] {
  def apply(p: T) = p
  def derivation(p: T) = divImpl(p, p)
}

val myFunc = Product(Sum(Const(1.0), Var[Double]()), Const(3.0))

As you can see, I have not yet figured out how to correctly create the neutral elements in Const#derivation and Var#derivation. I'm guessing I have to summon a Semiring, but I have no idea how do that for DenseVectors (it would need to know the correct size).

Moreover, this approach requires creating and using a hierarchy of such DerivableFunction classes. I guess type classes would be much preferable, as they would allow pinning derivations onto existing UFuncs. If someone can give me a hint on how to start I could try to create a PR with tests and a more complete implementation :)

Best Martin

貢獻者指南