gradient-descent working with line
This commit is contained in:
parent
4b20e4ef0c
commit
0768949897
2 changed files with 15 additions and 25 deletions
36
decons.lisp
36
decons.lisp
|
@ -5,10 +5,11 @@
|
|||
(:local-nicknames (:shape :scopes/shape)
|
||||
(:util :scopes/util))
|
||||
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
||||
#:default-deviation #:l2-loss #:trial #:gradient
|
||||
#:nabla-xp #:diff
|
||||
#:default-deviation #:l2-loss
|
||||
*revisions* *alpha* #:gradient-descent
|
||||
#:nabla-xp
|
||||
#:line
|
||||
#:*obj* #:*trials* #:try))
|
||||
#:*obj*))
|
||||
|
||||
(in-package :decons)
|
||||
|
||||
|
@ -60,22 +61,18 @@
|
|||
(calculated (funcall objective theta)))
|
||||
(funcall deviation (cadr dataset) calculated)))))
|
||||
|
||||
(defclass trial ()
|
||||
((theta :reader theta :initarg :theta)
|
||||
(loss :reader loss :initarg :loss)))
|
||||
|
||||
(defmethod print-object ((tr trial) stream)
|
||||
(shape:print-slots tr stream 'theta 'loss))
|
||||
|
||||
(defun gradient (target dataset)
|
||||
(let ((obj (funcall (l2-loss target) dataset)))
|
||||
(lambda (theta)
|
||||
(nabla-xp obj theta))))
|
||||
|
||||
;;;; optimization by revision (= gradient descent)
|
||||
|
||||
(defvar *revisions* 10)
|
||||
(defvar *alpha* 0.01)
|
||||
(defparameter *revisions* 1000)
|
||||
(defparameter *alpha* 0.01)
|
||||
|
||||
(defun gradient-descent (obj theta)
|
||||
(flet ((try-theta (th)
|
||||
(mapcar (lambda (p g) (- p (* *alpha* g)))
|
||||
th (nabla-xp obj th))))
|
||||
(dotimes (ix *revisions*)
|
||||
(setf theta (try-theta theta)))
|
||||
theta))
|
||||
|
||||
;;;; experimental differentiation
|
||||
|
||||
|
@ -109,8 +106,3 @@
|
|||
;;;; working area
|
||||
|
||||
(defvar *obj* nil)
|
||||
|
||||
(defvar *trials* nil)
|
||||
|
||||
(defun try (obj theta)
|
||||
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))
|
||||
|
|
|
@ -68,8 +68,6 @@
|
|||
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
|
||||
(setf decons:*obj* objective)
|
||||
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||
(setf decons:*trials* nil)
|
||||
(decons:try objective '(0.0 0.0))
|
||||
(decons:try objective '(0.0099 0.0))
|
||||
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
|
||||
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
|
||||
))
|
||||
|
|
Loading…
Add table
Reference in a new issue