provide a trial class (together with a list of trials) to collect loss values for varied parameter sets
This commit is contained in:
parent
52d1381aba
commit
cff6311103
2 changed files with 27 additions and 9 deletions
20
decons.lisp
20
decons.lisp
|
@ -2,12 +2,14 @@
|
||||||
|
|
||||||
(defpackage :decons
|
(defpackage :decons
|
||||||
(:use :common-lisp)
|
(:use :common-lisp)
|
||||||
(:local-nicknames (:util :scopes/util))
|
(:local-nicknames (:shape :scopes/shape)
|
||||||
|
(:util :scopes/util))
|
||||||
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
||||||
#:combine #:default-deviation #:l2-loss
|
#:combine #:default-deviation #:l2-loss
|
||||||
|
#:*trials* #:trial #:try
|
||||||
#:line
|
#:line
|
||||||
#:lgx
|
#:lgx
|
||||||
#:obj
|
#:*obj*
|
||||||
))
|
))
|
||||||
|
|
||||||
(in-package :decons)
|
(in-package :decons)
|
||||||
|
@ -69,6 +71,18 @@
|
||||||
(calculated (funcall objective theta)))
|
(calculated (funcall objective theta)))
|
||||||
(funcall deviation (cadr dataset) calculated)))))
|
(funcall deviation (cadr dataset) calculated)))))
|
||||||
|
|
||||||
|
(defvar *trials* nil)
|
||||||
|
|
||||||
|
(defclass trial ()
|
||||||
|
((theta :reader theta :initarg :theta)
|
||||||
|
(loss :reader loss :initarg :loss)))
|
||||||
|
|
||||||
|
(defmethod print-object ((tr trial) stream)
|
||||||
|
(shape:print-slots tr stream 'theta 'loss))
|
||||||
|
|
||||||
|
(defun try (obj theta)
|
||||||
|
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))
|
||||||
|
|
||||||
;;;; logging (printing) functions for debugging purposes
|
;;;; logging (printing) functions for debugging purposes
|
||||||
|
|
||||||
(defun lgx (op)
|
(defun lgx (op)
|
||||||
|
@ -84,4 +98,4 @@
|
||||||
|
|
||||||
;;;; working area
|
;;;; working area
|
||||||
|
|
||||||
(defparameter obj nil)
|
(defvar *obj* nil)
|
||||||
|
|
|
@ -51,20 +51,24 @@
|
||||||
(== (decons:rreduce #'+ '(1 2 (3 4))) 10)
|
(== (decons:rreduce #'+ '(1 2 (3 4))) 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(defvar *ds1* '((2.0 1.0 4.0 3.0)
|
||||||
|
(1.8 1.2 4.2 3.3)))
|
||||||
|
|
||||||
(deftest test-line ()
|
(deftest test-line ()
|
||||||
(let ((p1 (decons:line 0.0))
|
(let ((p1 (decons:line 0.0))
|
||||||
(p2 (decons:line 1.0))
|
(p2 (decons:line 1.0))
|
||||||
(ds1 '((2.0 1.0 4.0 3.0)
|
|
||||||
(1.8 1.2 4.2 3.3)))
|
|
||||||
ps1 objective)
|
ps1 objective)
|
||||||
(== (funcall p1 '(0.5 2.0)) 2.0)
|
(== (funcall p1 '(0.5 2.0)) 2.0)
|
||||||
(== (funcall p2 '(0.5 2.0)) 2.5)
|
(== (funcall p2 '(0.5 2.0)) 2.5)
|
||||||
(setf ps1 (decons:line (car ds1)))
|
(setf ps1 (decons:line (car *ds1*)))
|
||||||
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
|
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
|
||||||
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
|
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
|
||||||
(== (decons:default-deviation (cadr ds1) (funcall ps1 '(1.0 0.0)))
|
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
|
||||||
0.20999993) ;0.899999861)
|
0.20999993) ;0.899999861)
|
||||||
(setf objective (funcall (decons:l2-loss #'decons:line) ds1))
|
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
|
||||||
(setf decons:obj objective)
|
(setf decons:*obj* objective)
|
||||||
(== (funcall objective '(0.0 0.0)) 33.21)
|
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||||
|
(setf decons:*trials* nil)
|
||||||
|
(decons:try objective '(0.0 0.0))
|
||||||
|
(decons:try objective '(0.0099 0.0))
|
||||||
))
|
))
|
||||||
|
|
Loading…
Add table
Reference in a new issue