gradient-descent working with line

This commit is contained in:
Helmut Merz 2025-05-28 09:36:22 +02:00
parent 4b20e4ef0c
commit 0768949897
2 changed files with 15 additions and 25 deletions

View file

@ -5,10 +5,11 @@
(:local-nicknames (:shape :scopes/shape) (:local-nicknames (:shape :scopes/shape)
(:util :scopes/util)) (:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv (:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
#:default-deviation #:l2-loss #:trial #:gradient #:default-deviation #:l2-loss
#:nabla-xp #:diff *revisions* *alpha* #:gradient-descent
#:nabla-xp
#:line #:line
#:*obj* #:*trials* #:try)) #:*obj*))
(in-package :decons) (in-package :decons)
@ -60,22 +61,18 @@
(calculated (funcall objective theta))) (calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated))))) (funcall deviation (cadr dataset) calculated)))))
(defclass trial ()
((theta :reader theta :initarg :theta)
(loss :reader loss :initarg :loss)))
(defmethod print-object ((tr trial) stream)
(shape:print-slots tr stream 'theta 'loss))
(defun gradient (target dataset)
(let ((obj (funcall (l2-loss target) dataset)))
(lambda (theta)
(nabla-xp obj theta))))
;;;; optimization by revision (= gradient descent) ;;;; optimization by revision (= gradient descent)
(defvar *revisions* 10) (defparameter *revisions* 1000)
(defvar *alpha* 0.01) (defparameter *alpha* 0.01)
(defun gradient-descent (obj theta)
(flet ((try-theta (th)
(mapcar (lambda (p g) (- p (* *alpha* g)))
th (nabla-xp obj th))))
(dotimes (ix *revisions*)
(setf theta (try-theta theta)))
theta))
;;;; experimental differentiation ;;;; experimental differentiation
@ -109,8 +106,3 @@
;;;; working area ;;;; working area
(defvar *obj* nil) (defvar *obj* nil)
(defvar *trials* nil)
(defun try (obj theta)
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))

View file

@ -68,8 +68,6 @@
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*)) (setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
(setf decons:*obj* objective) (setf decons:*obj* objective)
(== (funcall objective '(0.0 0.0)) 33.21) (== (funcall objective '(0.0 0.0)) 33.21)
(setf decons:*trials* nil)
(decons:try objective '(0.0 0.0))
(decons:try objective '(0.0099 0.0))
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001)) (== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
)) ))