start experimenting with quadratic functions
This commit is contained in:
parent
0768949897
commit
0617177d72
2 changed files with 42 additions and 11 deletions
36
decons.lisp
36
decons.lisp
|
@ -4,20 +4,27 @@
|
||||||
(:use :common-lisp)
|
(:use :common-lisp)
|
||||||
(:local-nicknames (:shape :scopes/shape)
|
(:local-nicknames (:shape :scopes/shape)
|
||||||
(:util :scopes/util))
|
(:util :scopes/util))
|
||||||
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv #:rsqr
|
||||||
#:default-deviation #:l2-loss
|
#:default-deviation #:l2-loss
|
||||||
*revisions* *alpha* #:gradient-descent
|
*revisions* *alpha* #:gradient-descent
|
||||||
#:nabla-xp
|
#:nabla-xp
|
||||||
#:line
|
#:line #:quad
|
||||||
#:*obj*))
|
#:*obj*))
|
||||||
|
|
||||||
(in-package :decons)
|
(in-package :decons)
|
||||||
|
|
||||||
|
;;;; common (basic) stuff
|
||||||
|
|
||||||
|
(defun sqr (x) (* x x))
|
||||||
|
|
||||||
|
(defun sum (data) (reduce #'+ data))
|
||||||
|
|
||||||
;;;; rapply, rreduce - recursive application of operations
|
;;;; rapply, rreduce - recursive application of operations
|
||||||
|
|
||||||
(defun rapply (op arg1 &optional arg2)
|
(defun rapply (op arg1 &optional arg2)
|
||||||
(if arg2
|
(if arg2
|
||||||
(rcall (rcurry op arg1) arg2)
|
;(rcall (rcurry op arg1) arg2)
|
||||||
|
(rcall2 op arg1 arg2)
|
||||||
(rcall op arg1)))
|
(rcall op arg1)))
|
||||||
|
|
||||||
(defgeneric rcall (op arg)
|
(defgeneric rcall (op arg)
|
||||||
|
@ -25,6 +32,15 @@
|
||||||
(:method (op (arg list))
|
(:method (op (arg list))
|
||||||
(mapcar (lambda (i) (rcall op i)) arg)))
|
(mapcar (lambda (i) (rcall op i)) arg)))
|
||||||
|
|
||||||
|
(defgeneric rcall2 (op a1 a2)
|
||||||
|
(:method (op a1 a2) (funcall op a1 a2))
|
||||||
|
(:method (op (a1 list) a2)
|
||||||
|
(mapcar (lambda (i) (rapply op i a2)) a1))
|
||||||
|
(:method (op a1 (a2 list))
|
||||||
|
(mapcar (lambda (j) (rapply op a1 j)) a2))
|
||||||
|
(:method (op (a1 list) (a2 list))
|
||||||
|
(mapcar (lambda (i j) (rapply op i j)) a1 a2)))
|
||||||
|
|
||||||
(defgeneric rcurry (op arg)
|
(defgeneric rcurry (op arg)
|
||||||
(:method (op arg) (lambda (j) (funcall op arg j)))
|
(:method (op arg) (lambda (j) (funcall op arg j)))
|
||||||
(:method (op (arg list))
|
(:method (op (arg list))
|
||||||
|
@ -43,13 +59,10 @@
|
||||||
(defun rmul (a b) (rapply #'* a b))
|
(defun rmul (a b) (rapply #'* a b))
|
||||||
(defun rsub (a b) (rapply #'- a b))
|
(defun rsub (a b) (rapply #'- a b))
|
||||||
(defun rdiv (a b) (rapply #'/ a b))
|
(defun rdiv (a b) (rapply #'/ a b))
|
||||||
|
(defun rsqr (a) (rapply #'sqr a))
|
||||||
|
|
||||||
;;;; loss calculation, collect trial data (parameters, resulting loss)
|
;;;; loss calculation, collect trial data (parameters, resulting loss)
|
||||||
|
|
||||||
(defun sqr (x) (* x x))
|
|
||||||
|
|
||||||
(defun sum (data) (reduce #'+ data))
|
|
||||||
|
|
||||||
(defun default-deviation (observed calculated &key (norm #'sqr))
|
(defun default-deviation (observed calculated &key (norm #'sqr))
|
||||||
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
|
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
|
||||||
observed calculated)))
|
observed calculated)))
|
||||||
|
@ -101,7 +114,14 @@
|
||||||
;;;; parameterized target functions
|
;;;; parameterized target functions
|
||||||
|
|
||||||
(defun line (x)
|
(defun line (x)
|
||||||
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
|
(lambda (theta)
|
||||||
|
(radd (rmul (first theta) x) (second theta))))
|
||||||
|
|
||||||
|
(defun quad (x)
|
||||||
|
(lambda (theta)
|
||||||
|
(radd (rmul (first theta) (rsqr x))
|
||||||
|
(radd (rmul (second theta) x)
|
||||||
|
(third theta)))))
|
||||||
|
|
||||||
;;;; working area
|
;;;; working area
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
(test-rapply)
|
(test-rapply)
|
||||||
(test-rreduce)
|
(test-rreduce)
|
||||||
(test-line)
|
(test-line)
|
||||||
|
(test-quad)
|
||||||
(t:show-result)))
|
(t:show-result)))
|
||||||
|
|
||||||
(deftest test-basic ()
|
(deftest test-basic ()
|
||||||
|
@ -43,8 +44,9 @@
|
||||||
(== (decons:rapply #'1+ '(2 3)) '(3 4))
|
(== (decons:rapply #'1+ '(2 3)) '(3 4))
|
||||||
(== (decons:radd 2 3) 5)
|
(== (decons:radd 2 3) 5)
|
||||||
(== (decons:radd 3 '(4 5)) '(7 8))
|
(== (decons:radd 3 '(4 5)) '(7 8))
|
||||||
(== (decons:radd '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8)
|
(== (decons:radd '(2 3) '(4 5)) '(6 8)) ; not '((6 7) (7 8))
|
||||||
(== (decons:rsub '(6 7) '(4 5)) '((2 3) (1 2)))
|
(== (decons:rsub '(6 7) '(4 5)) '(2 2)) ; not '((2 3) (1 2))
|
||||||
|
(== (decons:rsqr '(2 3 4)) '(4 9 16))
|
||||||
)
|
)
|
||||||
|
|
||||||
(deftest test-rreduce ()
|
(deftest test-rreduce ()
|
||||||
|
@ -66,8 +68,17 @@
|
||||||
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
|
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
|
||||||
0.20999993) ;0.899999861)
|
0.20999993) ;0.899999861)
|
||||||
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
|
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
|
||||||
(setf decons:*obj* objective)
|
(setf decons:*obj* objective) ; for interactive experiments
|
||||||
(== (funcall objective '(0.0 0.0)) 33.21)
|
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||||
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
|
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
|
||||||
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
|
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
|
||||||
))
|
))
|
||||||
|
|
||||||
|
(defvar *ds2* '((-1.0 0.0 1.0 2.0 3.0)
|
||||||
|
(2.55 2.1 4.35 10.2 18.25)))
|
||||||
|
|
||||||
|
(deftest test-quad ()
|
||||||
|
(let (ps2 objective)
|
||||||
|
(setf ps2 (decons:quad (car *ds2*)))
|
||||||
|
(== (funcall ps2 '(1.0 0.0 0.0)) '(1.0 0.0 1.0 4.0 9.0))
|
||||||
|
))
|
||||||
|
|
Loading…
Add table
Reference in a new issue