now with function l2-loss
This commit is contained in:
parent
0172a2fd46
commit
25197a0520
2 changed files with 18 additions and 7 deletions
19
decons.lisp
19
decons.lisp
|
@ -7,7 +7,7 @@
|
||||||
#:absv #:double #:remainder
|
#:absv #:double #:remainder
|
||||||
#:scalar-p #:tensor #:at
|
#:scalar-p #:tensor #:at
|
||||||
#:rapply #:rreduce
|
#:rapply #:rreduce
|
||||||
#:combine #:cost
|
#:combine #:cost #:l2-loss
|
||||||
#:line
|
#:line
|
||||||
#:lgx
|
#:lgx
|
||||||
))
|
))
|
||||||
|
@ -97,12 +97,21 @@
|
||||||
(:method (a b) (list a b))
|
(:method (a b) (list a b))
|
||||||
(:method (a (b list)) (cons a b)))
|
(:method (a (b list)) (cons a b)))
|
||||||
|
|
||||||
;;;; cost calculation
|
;;;; cost / loss calculation
|
||||||
|
|
||||||
(defun cost (have want)
|
(defun cost (measured calculated)
|
||||||
(reduce #'+
|
(reduce #'+
|
||||||
(mapcar (lambda (p) (abs (apply #'- p)))
|
(mapcar (lambda (p) (sqr (apply #'- p)))
|
||||||
(combine have want))))
|
(combine measured calculated))))
|
||||||
|
|
||||||
|
(defun l2-loss (target &key (norm #'sqr))
|
||||||
|
(lambda (dataset)
|
||||||
|
(lambda (theta)
|
||||||
|
(let ((calculated (funcall (funcall target (car dataset)) theta)))
|
||||||
|
(reduce #'+ (mapcar (lambda (p) (funcall norm (apply #'- p)))
|
||||||
|
(combine (cadr dataset) calculated)))))))
|
||||||
|
|
||||||
|
(defun sqr (x) (* x x))
|
||||||
|
|
||||||
;;;; logging (printing) functions for debugging purposes
|
;;;; logging (printing) functions for debugging purposes
|
||||||
|
|
||||||
|
|
|
@ -55,11 +55,13 @@
|
||||||
(p2 (decons:line 1.0))
|
(p2 (decons:line 1.0))
|
||||||
(ds1 '((2.0 1.0 4.0 3.0)
|
(ds1 '((2.0 1.0 4.0 3.0)
|
||||||
(1.8 1.2 4.2 3.3)))
|
(1.8 1.2 4.2 3.3)))
|
||||||
ps1)
|
ps1 objective)
|
||||||
(== (funcall p1 '(0.5 2.0)) 2.0)
|
(== (funcall p1 '(0.5 2.0)) 2.0)
|
||||||
(== (funcall p2 '(0.5 2.0)) 2.5)
|
(== (funcall p2 '(0.5 2.0)) 2.5)
|
||||||
(setf ps1 (decons:line (car ds1)))
|
(setf ps1 (decons:line (car ds1)))
|
||||||
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
|
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
|
||||||
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
|
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
|
||||||
(== (decons:cost (funcall ps1 '(1.0 0.0)) (cadr ds1)) 0.899999861)
|
(== (decons:cost (cadr ds1) (funcall ps1 '(1.0 0.0))) 0.20999993) ;0.899999861)
|
||||||
|
(setf objective (funcall (decons:l2-loss #'decons:line) ds1))
|
||||||
|
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||||
))
|
))
|
||||||
|
|
Loading…
Add table
Reference in a new issue