start experimenting with quadratic functions
This commit is contained in:
parent
0768949897
commit
0617177d72
2 changed files with 42 additions and 11 deletions
36
decons.lisp
36
decons.lisp
|
@ -4,20 +4,27 @@
|
|||
(:use :common-lisp)
|
||||
(:local-nicknames (:shape :scopes/shape)
|
||||
(:util :scopes/util))
|
||||
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
||||
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv #:rsqr
|
||||
#:default-deviation #:l2-loss
|
||||
*revisions* *alpha* #:gradient-descent
|
||||
#:nabla-xp
|
||||
#:line
|
||||
#:line #:quad
|
||||
#:*obj*))
|
||||
|
||||
(in-package :decons)
|
||||
|
||||
;;;; common (basic) stuff
|
||||
|
||||
(defun sqr (x) (* x x))
|
||||
|
||||
(defun sum (data) (reduce #'+ data))
|
||||
|
||||
;;;; rapply, rreduce - recursive application of operations
|
||||
|
||||
(defun rapply (op arg1 &optional arg2)
|
||||
(if arg2
|
||||
(rcall (rcurry op arg1) arg2)
|
||||
;(rcall (rcurry op arg1) arg2)
|
||||
(rcall2 op arg1 arg2)
|
||||
(rcall op arg1)))
|
||||
|
||||
(defgeneric rcall (op arg)
|
||||
|
@ -25,6 +32,15 @@
|
|||
(:method (op (arg list))
|
||||
(mapcar (lambda (i) (rcall op i)) arg)))
|
||||
|
||||
(defgeneric rcall2 (op a1 a2)
|
||||
(:method (op a1 a2) (funcall op a1 a2))
|
||||
(:method (op (a1 list) a2)
|
||||
(mapcar (lambda (i) (rapply op i a2)) a1))
|
||||
(:method (op a1 (a2 list))
|
||||
(mapcar (lambda (j) (rapply op a1 j)) a2))
|
||||
(:method (op (a1 list) (a2 list))
|
||||
(mapcar (lambda (i j) (rapply op i j)) a1 a2)))
|
||||
|
||||
(defgeneric rcurry (op arg)
|
||||
(:method (op arg) (lambda (j) (funcall op arg j)))
|
||||
(:method (op (arg list))
|
||||
|
@ -43,13 +59,10 @@
|
|||
(defun rmul (a b) (rapply #'* a b))
|
||||
(defun rsub (a b) (rapply #'- a b))
|
||||
(defun rdiv (a b) (rapply #'/ a b))
|
||||
(defun rsqr (a) (rapply #'sqr a))
|
||||
|
||||
;;;; loss calculation, collect trial data (parameters, resulting loss)
|
||||
|
||||
(defun sqr (x) (* x x))
|
||||
|
||||
(defun sum (data) (reduce #'+ data))
|
||||
|
||||
(defun default-deviation (observed calculated &key (norm #'sqr))
|
||||
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
|
||||
observed calculated)))
|
||||
|
@ -101,7 +114,14 @@
|
|||
;;;; parameterized target functions
|
||||
|
||||
(defun line (x)
|
||||
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
|
||||
(lambda (theta)
|
||||
(radd (rmul (first theta) x) (second theta))))
|
||||
|
||||
(defun quad (x)
|
||||
(lambda (theta)
|
||||
(radd (rmul (first theta) (rsqr x))
|
||||
(radd (rmul (second theta) x)
|
||||
(third theta)))))
|
||||
|
||||
;;;; working area
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
(test-rapply)
|
||||
(test-rreduce)
|
||||
(test-line)
|
||||
(test-quad)
|
||||
(t:show-result)))
|
||||
|
||||
(deftest test-basic ()
|
||||
|
@ -43,8 +44,9 @@
|
|||
(== (decons:rapply #'1+ '(2 3)) '(3 4))
|
||||
(== (decons:radd 2 3) 5)
|
||||
(== (decons:radd 3 '(4 5)) '(7 8))
|
||||
(== (decons:radd '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8)
|
||||
(== (decons:rsub '(6 7) '(4 5)) '((2 3) (1 2)))
|
||||
(== (decons:radd '(2 3) '(4 5)) '(6 8)) ; not '((6 7) (7 8))
|
||||
(== (decons:rsub '(6 7) '(4 5)) '(2 2)) ; not '((2 3) (1 2))
|
||||
(== (decons:rsqr '(2 3 4)) '(4 9 16))
|
||||
)
|
||||
|
||||
(deftest test-rreduce ()
|
||||
|
@ -66,8 +68,17 @@
|
|||
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
|
||||
0.20999993) ;0.899999861)
|
||||
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
|
||||
(setf decons:*obj* objective)
|
||||
(setf decons:*obj* objective) ; for interactive experiments
|
||||
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
|
||||
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
|
||||
))
|
||||
|
||||
(defvar *ds2* '((-1.0 0.0 1.0 2.0 3.0)
|
||||
(2.55 2.1 4.35 10.2 18.25)))
|
||||
|
||||
(deftest test-quad ()
|
||||
(let (ps2 objective)
|
||||
(setf ps2 (decons:quad (car *ds2*)))
|
||||
(== (funcall ps2 '(1.0 0.0 0.0)) '(1.0 0.0 1.0 4.0 9.0))
|
||||
))
|
||||
|
|
Loading…
Add table
Reference in a new issue