work in progess: experimental nabla

This commit is contained in:
Helmut Merz 2025-05-27 14:20:34 +02:00
parent 8025938b81
commit 7498dce57a

View file

@ -6,6 +6,7 @@
(:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
#:default-deviation #:l2-loss #:trial #:gradient
#:nabla-xp #:diff
#:line
#:*obj* #:*trials* #:try))
@ -69,12 +70,28 @@
(defun gradient (target dataset)
(let ((expect (funcall (l2-loss target) dataset)))
(lambda (theta)
(let* ((loss0 (funcall expect theta))
(loss1 (funcall expect (vary theta))))
(- loss0 loss1)))))
)))
(defun vary (theta)
(mapcar (lambda (x) (- x 0.01)) theta))
(defun nabla-xp (fn args)
"Determine gradients by experiment: vary args and record changes."
(let ((base (apply fn args))
(vargs (apply #'vector args))
(res nil))
(dotimes (ix (length vargs))
(push (diff fn vargs ix base) res))
(reverse res)))
(defun diff (fn vargs ix base)
(let* ((vx 0.01) argsx r+ r-)
(setf argsx (copy-seq vargs))
(setf (svref argsx ix) (+ (svref vargs ix) vx))
(print argsx)
(setf r+ (/ (- (apply fn (map 'list #'identity argsx)) base) vx))
(setf (svref argsx ix) (- (svref vargs ix) vx))
(print argsx)
(setf r- (/ (- base (apply fn (map 'list #'identity argsx))) vx))
(format t "~a ~a ~a~&" base r+ r-)
(/ (+ r+ r-) 2)))
;;;; parameterized target functions