gradient calculation basically working

This commit is contained in:
Helmut Merz 2025-05-27 15:33:12 +02:00
parent 7498dce57a
commit 11e8ff59e4

View file

@ -68,13 +68,13 @@
(shape:print-slots tr stream 'theta 'loss)) (shape:print-slots tr stream 'theta 'loss))
(defun gradient (target dataset) (defun gradient (target dataset)
(let ((expect (funcall (l2-loss target) dataset))) (let ((obj (funcall (l2-loss target) dataset)))
(lambda (theta) (lambda (theta)
))) (nabla-xp obj theta))))
(defun nabla-xp (fn args) (defun nabla-xp (fn args)
"Determine gradients by experiment: vary args and record changes." "Determine gradients by experiment: vary args and record changes."
(let ((base (apply fn args)) (let ((base (funcall fn args))
(vargs (apply #'vector args)) (vargs (apply #'vector args))
(res nil)) (res nil))
(dotimes (ix (length vargs)) (dotimes (ix (length vargs))
@ -85,12 +85,10 @@
(let* ((vx 0.01) argsx r+ r-) (let* ((vx 0.01) argsx r+ r-)
(setf argsx (copy-seq vargs)) (setf argsx (copy-seq vargs))
(setf (svref argsx ix) (+ (svref vargs ix) vx)) (setf (svref argsx ix) (+ (svref vargs ix) vx))
(print argsx) (setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vx))
(setf r+ (/ (- (apply fn (map 'list #'identity argsx)) base) vx))
(setf (svref argsx ix) (- (svref vargs ix) vx)) (setf (svref argsx ix) (- (svref vargs ix) vx))
(print argsx) (setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vx))
(setf r- (/ (- base (apply fn (map 'list #'identity argsx))) vx)) ;(util:lgi base r+ r-)
(format t "~a ~a ~a~&" base r+ r-)
(/ (+ r+ r-) 2))) (/ (+ r+ r-) 2)))
;;;; parameterized target functions ;;;; parameterized target functions