start experimenting with quadratic functions

This commit is contained in:
Helmut Merz 2025-05-28 11:34:12 +02:00
parent 0768949897
commit 0617177d72
2 changed files with 42 additions and 11 deletions

View file

@ -4,20 +4,27 @@
(:use :common-lisp) (:use :common-lisp)
(:local-nicknames (:shape :scopes/shape) (:local-nicknames (:shape :scopes/shape)
(:util :scopes/util)) (:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv (:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv #:rsqr
#:default-deviation #:l2-loss #:default-deviation #:l2-loss
*revisions* *alpha* #:gradient-descent *revisions* *alpha* #:gradient-descent
#:nabla-xp #:nabla-xp
#:line #:line #:quad
#:*obj*)) #:*obj*))
(in-package :decons) (in-package :decons)
;;;; common (basic) stuff
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
;;;; rapply, rreduce - recursive application of operations ;;;; rapply, rreduce - recursive application of operations
(defun rapply (op arg1 &optional arg2) (defun rapply (op arg1 &optional arg2)
(if arg2 (if arg2
(rcall (rcurry op arg1) arg2) ;(rcall (rcurry op arg1) arg2)
(rcall2 op arg1 arg2)
(rcall op arg1))) (rcall op arg1)))
(defgeneric rcall (op arg) (defgeneric rcall (op arg)
@ -25,6 +32,15 @@
(:method (op (arg list)) (:method (op (arg list))
(mapcar (lambda (i) (rcall op i)) arg))) (mapcar (lambda (i) (rcall op i)) arg)))
(defgeneric rcall2 (op a1 a2)
(:method (op a1 a2) (funcall op a1 a2))
(:method (op (a1 list) a2)
(mapcar (lambda (i) (rapply op i a2)) a1))
(:method (op a1 (a2 list))
(mapcar (lambda (j) (rapply op a1 j)) a2))
(:method (op (a1 list) (a2 list))
(mapcar (lambda (i j) (rapply op i j)) a1 a2)))
(defgeneric rcurry (op arg) (defgeneric rcurry (op arg)
(:method (op arg) (lambda (j) (funcall op arg j))) (:method (op arg) (lambda (j) (funcall op arg j)))
(:method (op (arg list)) (:method (op (arg list))
@ -43,13 +59,10 @@
(defun rmul (a b) (rapply #'* a b)) (defun rmul (a b) (rapply #'* a b))
(defun rsub (a b) (rapply #'- a b)) (defun rsub (a b) (rapply #'- a b))
(defun rdiv (a b) (rapply #'/ a b)) (defun rdiv (a b) (rapply #'/ a b))
(defun rsqr (a) (rapply #'sqr a))
;;;; loss calculation, collect trial data (parameters, resulting loss) ;;;; loss calculation, collect trial data (parameters, resulting loss)
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
(defun default-deviation (observed calculated &key (norm #'sqr)) (defun default-deviation (observed calculated &key (norm #'sqr))
(sum (mapcar (lambda (a b) (funcall norm (- a b))) (sum (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated))) observed calculated)))
@ -101,7 +114,14 @@
;;;; parameterized target functions ;;;; parameterized target functions
(defun line (x) (defun line (x)
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x)))) (lambda (theta)
(radd (rmul (first theta) x) (second theta))))
(defun quad (x)
(lambda (theta)
(radd (rmul (first theta) (rsqr x))
(radd (rmul (second theta) x)
(third theta)))))
;;;; working area ;;;; working area

View file

@ -16,6 +16,7 @@
(test-rapply) (test-rapply)
(test-rreduce) (test-rreduce)
(test-line) (test-line)
(test-quad)
(t:show-result))) (t:show-result)))
(deftest test-basic () (deftest test-basic ()
@ -43,8 +44,9 @@
(== (decons:rapply #'1+ '(2 3)) '(3 4)) (== (decons:rapply #'1+ '(2 3)) '(3 4))
(== (decons:radd 2 3) 5) (== (decons:radd 2 3) 5)
(== (decons:radd 3 '(4 5)) '(7 8)) (== (decons:radd 3 '(4 5)) '(7 8))
(== (decons:radd '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8) (== (decons:radd '(2 3) '(4 5)) '(6 8)) ; not '((6 7) (7 8))
(== (decons:rsub '(6 7) '(4 5)) '((2 3) (1 2))) (== (decons:rsub '(6 7) '(4 5)) '(2 2)) ; not '((2 3) (1 2))
(== (decons:rsqr '(2 3 4)) '(4 9 16))
) )
(deftest test-rreduce () (deftest test-rreduce ()
@ -66,8 +68,17 @@
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0))) (== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
0.20999993) ;0.899999861) 0.20999993) ;0.899999861)
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*)) (setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
(setf decons:*obj* objective) (setf decons:*obj* objective) ; for interactive experiments
(== (funcall objective '(0.0 0.0)) 33.21) (== (funcall objective '(0.0 0.0)) 33.21)
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001)) (== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6)) (== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
)) ))
(defvar *ds2* '((-1.0 0.0 1.0 2.0 3.0)
(2.55 2.1 4.35 10.2 18.25)))
(deftest test-quad ()
(let (ps2 objective)
(setf ps2 (decons:quad (car *ds2*)))
(== (funcall ps2 '(1.0 0.0 0.0)) '(1.0 0.0 1.0 4.0 9.0))
))