diff --git a/decons.lisp b/decons.lisp index 0a443c2..9ce2118 100644 --- a/decons.lisp +++ b/decons.lisp @@ -4,7 +4,7 @@ (:use :common-lisp) (:local-nicknames (:shape :scopes/shape) (:util :scopes/util)) - (:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv #:rsqr + (:export #:rapply #:rreduce #:rreduce-1 #:radd #:rmul #:rsub #:rdiv #:rsqr #:default-deviation #:l2-loss *revisions* *alpha* #:gradient-descent #:nabla-xp @@ -35,11 +35,11 @@ (defgeneric rcall2 (op a1 a2) (:method (op a1 a2) (funcall op a1 a2)) (:method (op (a1 list) a2) - (mapcar (lambda (i) (rapply op i a2)) a1)) + (mapcar (lambda (i) (rcall2 op i a2)) a1)) (:method (op a1 (a2 list)) - (mapcar (lambda (j) (rapply op a1 j)) a2)) + (mapcar (lambda (j) (rcall2 op a1 j)) a2)) (:method (op (a1 list) (a2 list)) - (mapcar (lambda (i j) (rapply op i j)) a1 a2))) + (mapcar (lambda (i j) (rcall2 op i j)) a1 a2))) (defgeneric rcurry (op arg) (:method (op arg) (lambda (j) (funcall op arg j))) @@ -55,6 +55,13 @@ (:method (op (v list) &key (initial-value 0)) (rreduce op v :initial-value initial-value))) +(defgeneric rreduce-1 (op arg &key initial-value) + (:method (op arg &key (initial-value 0)) arg) + (:method (op (arg list) &key (initial-value 0)) + (if (some #'consp arg) + (mapcar (lambda (x) (rreduce-1 op x :initial-value initial-value)) arg) + (reduce op arg :initial-value initial-value)))) + (defun radd (a b) (rapply #'+ a b)) (defun rmul (a b) (rapply #'* a b)) (defun rsub (a b) (rapply #'- a b)) diff --git a/test-decons.lisp b/test-decons.lisp index ebb54c0..4169c00 100644 --- a/test-decons.lisp +++ b/test-decons.lisp @@ -51,6 +51,7 @@ (deftest test-rreduce () (== (decons:rreduce #'+ '(1 2 (3 4))) 10) + (== (decons:rreduce-1 #'+ '(1 2 (3 4))) '(1 2 7)) ) (defvar *ds1* '((2.0 1.0 4.0 3.0)