From 076894989703bbac47fec55eaf1dd319b33f1572 Mon Sep 17 00:00:00 2001 From: Helmut Merz Date: Wed, 28 May 2025 09:36:22 +0200 Subject: [PATCH] gradient-descent working with line --- decons.lisp | 36 ++++++++++++++---------------------- test-decons.lisp | 4 +--- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/decons.lisp b/decons.lisp index 0915117..e345a76 100644 --- a/decons.lisp +++ b/decons.lisp @@ -5,10 +5,11 @@ (:local-nicknames (:shape :scopes/shape) (:util :scopes/util)) (:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv - #:default-deviation #:l2-loss #:trial #:gradient - #:nabla-xp #:diff + #:default-deviation #:l2-loss + *revisions* *alpha* #:gradient-descent + #:nabla-xp #:line - #:*obj* #:*trials* #:try)) + #:*obj*)) (in-package :decons) @@ -60,22 +61,18 @@ (calculated (funcall objective theta))) (funcall deviation (cadr dataset) calculated))))) -(defclass trial () - ((theta :reader theta :initarg :theta) - (loss :reader loss :initarg :loss))) - -(defmethod print-object ((tr trial) stream) - (shape:print-slots tr stream 'theta 'loss)) - -(defun gradient (target dataset) - (let ((obj (funcall (l2-loss target) dataset))) - (lambda (theta) - (nabla-xp obj theta)))) - ;;;; optimization by revision (= gradient descent) -(defvar *revisions* 10) -(defvar *alpha* 0.01) +(defparameter *revisions* 1000) +(defparameter *alpha* 0.01) + +(defun gradient-descent (obj theta) + (flet ((try-theta (th) + (mapcar (lambda (p g) (- p (* *alpha* g))) + th (nabla-xp obj th)))) + (dotimes (ix *revisions*) + (setf theta (try-theta theta))) + theta)) ;;;; experimental differentiation @@ -109,8 +106,3 @@ ;;;; working area (defvar *obj* nil) - -(defvar *trials* nil) - -(defun try (obj theta) - (push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*)) diff --git a/test-decons.lisp b/test-decons.lisp index ca52b5e..2e39246 100644 --- a/test-decons.lisp +++ b/test-decons.lisp @@ -68,8 +68,6 @@ (setf objective (funcall (decons:l2-loss #'decons:line) *ds1*)) (setf decons:*obj* objective) (== (funcall objective '(0.0 0.0)) 33.21) - (setf decons:*trials* nil) - (decons:try objective '(0.0 0.0)) - (decons:try objective '(0.0099 0.0)) (== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001)) + (== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6)) ))