gradient calculation basically working
This commit is contained in:
parent
7498dce57a
commit
11e8ff59e4
1 changed files with 6 additions and 8 deletions
14
decons.lisp
14
decons.lisp
|
@ -68,13 +68,13 @@
|
||||||
(shape:print-slots tr stream 'theta 'loss))
|
(shape:print-slots tr stream 'theta 'loss))
|
||||||
|
|
||||||
(defun gradient (target dataset)
|
(defun gradient (target dataset)
|
||||||
(let ((expect (funcall (l2-loss target) dataset)))
|
(let ((obj (funcall (l2-loss target) dataset)))
|
||||||
(lambda (theta)
|
(lambda (theta)
|
||||||
)))
|
(nabla-xp obj theta))))
|
||||||
|
|
||||||
(defun nabla-xp (fn args)
|
(defun nabla-xp (fn args)
|
||||||
"Determine gradients by experiment: vary args and record changes."
|
"Determine gradients by experiment: vary args and record changes."
|
||||||
(let ((base (apply fn args))
|
(let ((base (funcall fn args))
|
||||||
(vargs (apply #'vector args))
|
(vargs (apply #'vector args))
|
||||||
(res nil))
|
(res nil))
|
||||||
(dotimes (ix (length vargs))
|
(dotimes (ix (length vargs))
|
||||||
|
@ -85,12 +85,10 @@
|
||||||
(let* ((vx 0.01) argsx r+ r-)
|
(let* ((vx 0.01) argsx r+ r-)
|
||||||
(setf argsx (copy-seq vargs))
|
(setf argsx (copy-seq vargs))
|
||||||
(setf (svref argsx ix) (+ (svref vargs ix) vx))
|
(setf (svref argsx ix) (+ (svref vargs ix) vx))
|
||||||
(print argsx)
|
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vx))
|
||||||
(setf r+ (/ (- (apply fn (map 'list #'identity argsx)) base) vx))
|
|
||||||
(setf (svref argsx ix) (- (svref vargs ix) vx))
|
(setf (svref argsx ix) (- (svref vargs ix) vx))
|
||||||
(print argsx)
|
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vx))
|
||||||
(setf r- (/ (- base (apply fn (map 'list #'identity argsx))) vx))
|
;(util:lgi base r+ r-)
|
||||||
(format t "~a ~a ~a~&" base r+ r-)
|
|
||||||
(/ (+ r+ r-) 2)))
|
(/ (+ r+ r-) 2)))
|
||||||
|
|
||||||
;;;; parameterized target functions
|
;;;; parameterized target functions
|
||||||
|
|
Loading…
Add table
Reference in a new issue