;;; -*- Mode:Lisp; Package:Weyli; Base:10; Lowercase:T; Syntax:Common-Lisp -*-
;;; ===========================================================================
;;;				    Matrices
;;; ===========================================================================
;;; (c) Copyright 1989, 1991 Cornell University

;;; $Id: matrix.lisp,v 2.25 1992/12/07 22:03:14 rz Exp $

(in-package "WEYLI")

;; This is a very general matrix implementation.   At some point it will
;; be worth while implementing some more specialized matrix spaces.

(define-domain-creator matrix-space ((ring ring))
  (make-instance 'matrix-space :coefficient-domain ring)
  :predicate (lambda (d)
	       (and (eql (class-name (class-of d)) 'matrix-space)
		    (eql (coefficient-domain d) ring))))

(defmethod print-object ((domain matrix-space) stream)
  (format stream "Mat(~S)" (coefficient-domain domain)))

(defmethod make-element ((domain matrix-space) (value array) &rest ignore)
  (declare (ignore ignore))
  (make-instance 'matrix-space-element
		 :domain domain
		 :dimension1 (array-dimension value 0)
		 :dimension2 (array-dimension value 1)
		 :value value))

(defmethod weyl::make-element ((domain matrix-space) (value array)
			       &rest ignore)
  (declare (ignore ignore))
  (multiple-value-bind (x-dim y-dim) (array-dimensions value)
    (let ((coef-domain (coefficient-domain domain))
	  (array (make-array (list x-dim y-dim))))
      (loop for i fixnum below x-dim do
	(loop for j fixnum below y-dim do
	  (setf (aref array i j) (coerce (aref value i j) coef-domain))))
      (make-instance 'matrix-space-element
		     :domain domain
		     :dimension1 x-dim 
		     :dimension2 y-dim 
		     :value array))))

(defmethod make-element ((domain matrix-space) (value list) &rest values)
  (setq values (if (null values) value
		   (cons value values)))
  (unless (loop for row in (rest values)
		with n = (length (first values))
		do (unless (eql (length row) n)
		     (return nil))
		finally (return t))
    (error "All rows not the same length: ~S" values))
  (make-element domain
		(make-array (list (length values) (length (first values)))
			    :initial-contents values)))

(defmethod weyl::make-element ((domain matrix-space) (value list) &rest values)
  (setq values (if (null values) value
		   (cons value values)))
  (unless (loop for row in (rest values)
		with n = (length (first values))
		do (unless (eql (length row) n)
		     (return nil))
		finally (return t))
    (error "All rows not the same length: ~S" values))
  (let* ((x-dim (length values))
	 (y-dim (length (first values)))
	 (array (make-array (list x-dim y-dim))))
    (loop for i fixnum  below x-dim
	  for row in values do
	    (loop for j fixnum below y-dim
		  for val in row do
		    (setf (aref array i j) val)))
    (make-element domain array)))

(defmethod matrix-dimensions ((m matrix-space-element))
  (with-slots (dimension1 dimension2) m
    (values dimension1 dimension2)))

(defmethod dimensions ((m matrix-space-element))
  (with-slots (dimension1 dimension2) m
    (list dimension1 dimension2)))

#+Genera
(defmacro with-matrix-dimensions ((dim1 dim2 &optional array) matrix &body body
				  &environment env)
  (scl:once-only (matrix &environment env)
    `(multiple-value-bind (,dim1 ,dim2) (matrix-dimensions ,matrix)
       ,(if array `(let ((,array (matrix-value ,matrix)))
		     ,@body)
	    `(progn ,@body)))))

#-Genera
(defmacro with-matrix-dimensions ((dim1 dim2 &optional array) matrix &body body)
  `(multiple-value-bind (,dim1 ,dim2) (matrix-dimensions ,matrix)
     (declare (fixnum ,dim1 ,dim2))
    ,(if array `(let ((,array (matrix-value ,matrix)))
		 ,@body)
	 `(progn ,@body))))

#-Genera
(defmethod print-object ((matrix matrix-space-element) stream)
  (with-matrix-dimensions (dim1 dim2 array) matrix
    (princ "Mat<" stream)
    (loop for i fixnum below dim1
	  do (princ "<" stream)
	     (loop for j fixnum below dim2
		   do (print-object (aref array i j) stream)
		      (if (< (1+ j) dim2)
			  (princ ",  " stream)
			  (princ ">" stream)))	     
	     (if (< (1+ i) dim1)
		 (princ ",  " stream)
		 (princ ">" stream)))))

#+Genera
(defmethod print-object ((matrix matrix-space-element) stream)
  (with-matrix-dimensions (dim1 dim2 array) matrix
    (dw:formatting-table (stream)
      (loop for i below dim1 do
	(dw:formatting-row (stream)
	  (loop for j below dim2 do
	    (dw:formatting-cell (stream :align-x :center)
	      (princ (aref array i j) stream))))))))

(defmethod ref ((matrix matrix-element) &rest args)
  (let ((x (first args))
	(y (second args)))
    (cond ((numberp x)
	   (cond ((numberp y)
		  (aref (matrix-value matrix) x y))
		 ((eql y :*)
		  (with-matrix-dimensions (rows cols array) matrix
		    (let ((new-array (make-array (list 1 cols))))
		      (loop for j fixnum below cols
			    do (setf (aref new-array 0 j) (aref array x j)))
		      (make-element (domain-of matrix) new-array))))
		 (t (error "Unknown argument to REF(~S ~S)"
			   x y))))
	  ((eql x :*)
	   (cond ((numberp y)		  
		  (with-matrix-dimensions (rows cols array) matrix
		    (let ((new-array (make-array (list rows 1))))
		      (loop for i fixnum below rows
			    do (setf (aref new-array i 0) (aref array i y)))
		      (make-element (domain-of matrix) new-array))))
		 (t (error "Unknown argument to REF(~S ~S)"
			   x y))))
	  (t (error "Unknown argument to REF(~S ~S)"
		    x y)))))

(defmethod set-ref ((matrix matrix-element) new-value &rest args)
  (setf (aref (matrix-value matrix) (first args) (second args)) new-value))

(defmethod zero-matrix ((domain matrix-space) &optional rank)
  (unless (numberp rank)
    (error "Must specify rank to ZERO-MATRIX (~D)" domain))
  (make-element domain
	       (make-array (list rank rank)
			   :initial-element (zero (coefficient-domain domain)))))

(defmethod one-matrix ((domain matrix-space) &optional rank)
  (unless (numberp rank)
    (error "Must specify rank to ONE-MATRIX (~D)" domain))
  (let* ((zero (zero (coefficient-domain domain)))
	 (one (one (coefficient-domain domain)))
	 (array (make-array (list rank rank) :initial-element zero)))
    (loop for i fixnum below rank
	  do (setf (aref array i i) one))
    (make-element domain array)))

(defmethod plus ((m1 matrix-space-element) (m2 matrix-space-element))
  (let ((domain (domain-of m1)))
    (cond ((eql domain (domain-of m2))
	   (with-matrix-dimensions (1dim1 1dim2 1array) m1
	     (with-matrix-dimensions (2dim1 2dim2 2array) m2
	       (unless (and (eql 1dim1 2dim1) (eql 1dim2 2dim2))
		 (error "Trying to add matrices of different dimensions: (~D ~D) and (~D ~D)"
			1dim1 1dim2 2dim1 2dim2))
	       (let ((array (make-array (list 1dim1 1dim2))))
		 (loop for i fixnum below 1dim1 do
		   (loop for j fixnum below 1dim2 do
		     (setf (aref array i j)
			   (+ (aref 1array i j) (aref 2array i j)))))
		 (make-element domain array)))))
	  (t (error "Can't add these matrices")))))

(defmethod difference ((m1 matrix-space-element) (m2 matrix-space-element))
  (let ((domain (domain-of m1)))
    (cond ((eql domain (domain-of m2))
	   (with-matrix-dimensions (1dim1 1dim2 1array) m1
	     (with-matrix-dimensions (2dim1 2dim2 2array) m2
	       (unless (and (eql 1dim1 2dim1) (eql 1dim2 2dim2))
		 (error "Trying to subtract matrices of different dimensions: (~D ~D) and (~D ~D)"
			1dim1 1dim2 2dim1 2dim2))
	       (let ((array (make-array (list 1dim1 1dim2))))
		 (loop for i fixnum below 1dim1 do
		   (loop for j fixnum below 1dim2 do
		     (setf (aref array i j)
			   (- (aref 1array i j) (aref 2array i j)))))
		 (make-element domain array)))))
	  (t (error "Can't subtract these matrices")))))


(defmethod-sd times ((m1 matrix-element) (m2 matrix-element))
  (with-matrix-dimensions (1dim1 1dim2 1array) m1
    (with-matrix-dimensions (2dim1 2dim2 2array) m2
      (unless (eql 1dim2 2dim1)
	(error "Trying to multiply matrices of incompatible dimensions: (~D ~D) and (~D ~D)"
	       1dim1 1dim2 2dim1 2dim2))
      (let ((array (make-array (list 1dim1 2dim2))))
	(loop for i fixnum below 1dim1 do
	  (loop for j fixnum below 2dim2 do
	    (loop for k fixnum below 1dim2
		  for c = (* (aref 1array i k) (aref 2array k j))
			then (+ c (* (aref 1array i k) (aref 2array k j)))
		  finally (setf (aref array i j) c))))	
	(make-element domain array)))))

(defmethod times ((m matrix-space-element) (v free-module-element))
  (matrix-fme-times m v))

(defun matrix-fme-times (m v)  
  (let ((elt-domain (coefficient-domain (domain-of m)))
	(vector-space (domain-of v)))
    (cond ((eql elt-domain (coefficient-domain vector-space))
	   (with-matrix-dimensions (dim1 dim2 array) m
	     (unless (eql dim2 (dimension vector-space))
	       (error "Trying to multiply a matrix and vector of incompatible dimensions: (~D ~D) and ~D"
		      dim1 dim2 (dimension vector-space)))
	     (%apply #'make-element
		     (if (lisp:= dim1 dim2) vector-space
			 (get-free-module elt-domain dim1))
		     (loop for i fixnum below dim1
			   collect
			   (loop for k fixnum below dim2 
				 for c = (* (aref array i k) (ref v k))
				   then (+ c (* (aref array i k) (ref v k)))
				 finally (return c))))))
	  (t (error "Incompatible arguments: ~S and ~S" m v)))))

(defmethod times ((v free-module-element) (m matrix-space-element))
  (fme-matrix-times v m))

(defun fme-matrix-times (v m)  
  (let ((elt-domain (coefficient-domain (domain-of m)))
	(vector-space (domain-of v)))
    (cond ((eql elt-domain (coefficient-domain vector-space))
	   (with-matrix-dimensions (dim1 dim2 array) m
	     (unless (eql (dimension vector-space) dim1)
	       (error "Trying to multiply a vector and matrix of incompatible dimensions:  ~D and (~D ~D)"
		      (dimension vector-space) dim1 dim2))
	     (%apply #'make-element
		     (if (lisp:= dim1 dim2) vector-space
			 (get-free-module elt-domain dim2))
		     (loop for i fixnum below dim2
			   collect
			   (loop for k fixnum below dim1
				 for c = (* (ref v k) (aref array k i))
				   then (+ c (* (ref v k) (aref array k i)))
				 finally (return c))))))
	  (t (error "Incompatible arguments: ~S and ~S" v m)))))

(defmethod transpose ((m matrix-element))
  (let ((domain (domain-of m)))
    (with-matrix-dimensions (dim1 dim2 array) m
      (let ((transpose (make-array (list dim2 dim1))))
	(loop for i fixnum below dim1 do
	  (loop for j fixnum below dim2 do
	    (setf (aref transpose j i) (aref array i j))))
	(make-element domain transpose)))))

(defmethod-sd direct-sum ((x matrix-element) (y matrix-element))
  (with-matrix-dimensions (x-rows x-cols x-array) x
    (with-matrix-dimensions (y-rows y-cols y-array) y
      (cond ((eql x-rows y-rows)
	     (let ((array (make-array (list x-rows (lisp:+ x-cols y-cols)))))
	       (loop for i fixnum below x-rows
		     for j fixnum = 0
		     do (loop for k fixnum below x-cols
			      do (setf (aref array i j) (aref x-array i k ))
				 (incf j))
			(loop for k fixnum below y-cols
			      do (setf (aref array i j) (aref y-array i k))
				 (incf j)))
	       (make-element domain array)))
	    (t (error "Incompatable dimensions (~D, ~D) and (~D, ~D)"
		      x-rows x-cols y-rows y-cols))))))


(defmethod recip ((m matrix-element))
  (let ((domain (domain-of m)))
    (with-matrix-dimensions (dim1 dim2 array) m
      (cond ((eql dim1 dim2)
	     (make-element domain (invert-array (coefficient-domain domain)
						array)))
	    (t (error "Can't invert a non-square matrix"))))))

;; Invert an array of elements of domain
(defmethod invert-array ((domain domain) array &optional into-array)
  (let ((dimension (array-dimensions array)))
    (unless (and (null (rest (rest dimension)))
		 (eql (first dimension) (second dimension)))
      (error "Wrong dimensions for recip: ~S" array))
    (cond (into-array
	   (unless (eql dimension (array-dimensions into-array))
	     (error "Wrong dimensions for ~S, expected ~S"
		    into-array dimension)))
	  (t (setq into-array (make-array dimension))
	     (loop for i fixnum below (first dimension) 
		   with zero = (zero domain) and one = (one domain) do
	       (loop for j fixnum below (second dimension) do
		 (setf (aref into-array i j) (if (eql i j) one zero))))))
    (setq dimension (first dimension))
    (flet ((exchange-rows (j k)
	     (loop for i fixnum below dimension do
	       (rotatef (aref array j i) (aref array k i))
	       (rotatef (aref into-array j i) (aref into-array k i))))
	   (find-pivot (i)
	     (loop for j fixnum upfrom (1+ i) below dimension
		   for elt = (aref array j i)
		   with max = (aref array i i) and row = i do
	       (when (> elt max)
		 (setq max elt
		       row j))
		   finally (return (values row max))))
	   (subtract-rows1 (row1 row2)
	     (unless (0? (aref array row2 row1))
	       (loop for j fixnum upfrom row1 below dimension
		     with mult = (/ (aref array row2 row1) (aref array row1 row1)) do
		 (setf (aref array row2 j)
		       (- (aref array row2 j) (* mult (aref array row1 j))))
		 (setf (aref into-array row2 j)
		       (- (aref into-array row2 j) (* mult (aref into-array row1 j)))))))
	   (subtract-rows2 (row1 row2)
	     (unless (0? (aref array row2 row1))
	       (let ((mult (aref array row2 row1)))
		 (loop for j fixnum upfrom row1 below dimension do
		   (setf (aref array row2 j)
			 (- (aref array row2 j) (* mult (aref array row1 j)))))
		 (loop for j fixnum below dimension do
		   (setf (aref into-array row2 j)
			 (- (aref into-array row2 j) (* mult (aref into-array row1 j)))))))))
      ;; Triangulate
      (loop for i fixnum below dimension do
	(multiple-value-bind (row pivot) (find-pivot i)
	  (unless (eql i row)
	    (exchange-rows i row))
	  (loop for j fixnum upfrom (1+ i) below dimension do
	    (subtract-rows1 i j))
	  (loop for j fixnum upfrom i below dimension do 
	    (setf (aref array i j) (/ (aref array i j) pivot)))
	  (loop for j fixnum below dimension do 
	    (setf (aref into-array i j) (/ (aref into-array i j) pivot)))))
      ;; Backsolve
      (loop for i fixnum downfrom (1- dimension) above -1 do
	(loop for j fixnum downfrom (1- i) above -1 do
	  (subtract-rows2 i j))))
    into-array))

(defmethod substitute ((values list) (variables list) (m matrix-space-element)
		       &rest ignore)
  (declare (ignore ignore))
  (with-matrix-dimensions (dim1 dim2 array) m
     (let ((new-array (make-array (list dim1 dim2))))
       (loop for i fixnum below dim1 do
	 (loop for j fixnum below dim2 do
	   (setf (aref new-array i j)
		 (substitute values variables (aref array i j)))))
       (make-element (get-matrix-space (domain-of (aref new-array 0 0)))
		     new-array))))

(defmethod jacobian ((function-list list) (var-list list))
  (let* ((ring (domain-of (first function-list)))
	 (dim-col (length var-list))
	 (dim-row (length function-list))
	 (array (make-array (list dim-row dim-col))))
    (loop for poly in function-list
	  for i fixnum below dim-row
	  do (loop for var in var-list
		   for j fixnum  below dim-col
		   do (setf (aref array i j) (partial-deriv poly var))))
    (make-element (get-matrix-space ring) array)))

;; Matrix Groups

;;; ==========================================================================
;;; The Groups GL(n), SL(n), PSL(n), O(n), SO(n) with the following
;;; hierarchy:
;;;  
;;;  det<>0  GL(n)
;;;            |
;;;            |
;;;  det=+-1 PSL(n) -------------> O(n) M*M^t = In
;;;            |                    |
;;;            V                    |
;;;  det=1   SL(n)                  |
;;;            \                   /
;;;              \               /
;;;                \           /
;;;                  \       /
;;;                    \   /
;;;                    SO(n)
;;;
;;;
;;; ==========================================================================


;; The coefficient domain of GL-n must be a field otherwise, it will
;; not be a group.  This is not necessary for the other matrix groups
;; because the determinants are required to be units.

(define-domain-creator GL-n ((domain field) dimension)
  (make-instance 'GL-n :coefficient-domain domain :dimension dimension)
  :predicate (lambda (d)
		(and (eql (class-name (class-of d)) 'GL-n)
		     (eql (coefficient-domain d) domain)
		     (eql (dimension d) dimension))))

(defmethod print-object ((domain GL-n) stream)
  (let ((n (dimension domain)))
    (format stream "GL^~D(~S)" n (coefficient-domain domain))))

(defmethod print-object ((matrix GL-n-element) stream)
  (with-matrix-dimensions (dim1 dim2 array) matrix
    (format stream "~A<" (class-name (class-of (domain-of matrix))))
    (loop for i fixnum below dim1
	  do (princ "<" stream)
	     (loop for j fixnum below dim2
		   do (princ (aref array i j) stream)
		      (if (< (1+ j) dim2)
			  (princ ",  " stream)
			  (princ ">" stream)))	     
	     (if (< (1+ i) dim1)
		 (princ ",  " stream)
		 (princ ">" stream)))))

(define-domain-element-classes GL-n GL-n-element)

(defmethod matrix-dimensions ((m GL-n-element))
  (let ((dim (dimension (domain-of m))))
    (values dim dim)))

(defmethod make-element ((domain GL-n) (value array) &rest ignore)
  (declare (ignore ignore))
  (make-instance (first (domain-element-classes domain))
		 :domain domain :value value))

(defmethod weyl::make-element ((domain GL-n) (value array)
			       &rest ignore)
  (declare (ignore ignore))
  (multiple-value-bind (x-dim y-dim) (array-dimensions value)
    (let ((coef-domain (coefficient-domain domain))
	  (array (make-array (list x-dim y-dim))))
      (loop for i fixnum below x-dim do
	(loop for j fixnum below y-dim do
	  (setf (aref array i j) (coerce (aref value i j) coef-domain))))
      (make-element domain value))))

(defmethod make-element ((domain GL-n) (value list) &rest values)
  (setq values (if (null values) value
		   (cons value values)))
  (unless (loop for row in (rest values)
		with n = (length (first values))
		do (unless (eql (length row) n)
		     (return nil))
		finally (return t))
    (error "All rows not the same length: ~S" values))
  (make-element domain
		(make-array (list (length values) (length (first values)))
			    :initial-contents values)))

(defmethod weyl::make-element ((domain GL-n) (value list) &rest values)
  (setq values (if (null values) value
		   (cons value values)))
  (unless (loop for row in (rest values)
		with n = (length (first values))
		do (unless (eql (length row) n)
		     (return nil))
		finally (return t))
    (error "All rows not the same length: ~S" values))
  (let* ((x-dim (length values))
	 (y-dim (length (first values)))
	 (array (make-array (list x-dim y-dim))))
    (loop for i fixnum  below x-dim
	  for row in values do
	    (loop for j fixnum below y-dim
		  for val in row do
		    (setf (aref array i j) val)))
    (make-element domain array)))

(defmethod one-matrix ((domain GL-n) &optional rank)
  (let ((computed-rank (dimension domain)))
    (if rank
	(if (not (eq rank computed-rank))
	    (error "rank argument conflicts with domain dimension")))
    (let* ((zero (zero (coefficient-domain domain)))
	   (one (one (coefficient-domain domain)))
	   (array (make-array (list computed-rank computed-rank)
			      :initial-element zero)))
      (loop for i fixnum below computed-rank do
	(setf (aref array i i) one))
      (make-element domain array))))

(defmethod one ((domain GL-n))
  (one-matrix domain))

(defmethod times ((m GL-n-element) (v free-module-element))
  (matrix-fme-times m v))

(defmethod times ((v free-module-element) (m GL-n-element))
  (fme-matrix-times v m))

;;
;; PSL(n) : group of matrices with determinant +1 or -1
;;

(defmethod print-object ((domain PSL-n) stream)
  (let ((n (dimension domain)))
    (format stream "PSL^~D(~S)" n (coefficient-domain domain))))

(define-domain-element-classes PSL-n PSL-n-element)

(define-domain-creator PSL-n ((domain field) dimension)
  (make-instance 'PSL-n :coefficient-domain domain :dimension dimension)
  :predicate (lambda (d)
		(and (eql (class-name (class-of d)) 'PSL-n)
		     (eql (coefficient-domain d) domain)
		     (eql (dimension d) dimension))))
;;
;; SL(n) : group of matrices with determinant +1 
;;


(defmethod print-object ((domain SL-n) stream)
  (let ((n (dimension domain)))
    (format stream "SL^~D(~S)" n (coefficient-domain domain))))

(define-domain-element-classes SL-n SL-n-element)

(define-domain-creator SL-n ((domain field) dimension)
  (make-instance 'SL-n :coefficient-domain domain :dimension dimension)
  :predicate (lambda (d)
		(and (eql (class-name (class-of d)) 'SL-n)
		     (eql (coefficient-domain d) domain)
		     (eql (dimension d) dimension))))

(defmethod determinant ((m SL-n-element))
  (one (coefficient-domain (domain-of m))))

;;
;; O(n) : group of orthogonal matrices
;;


(defmethod print-object ((domain O-n) stream)
  (let ((n (dimension domain)))
    (format stream "O^~D(~S)" n (coefficient-domain domain))))

(define-domain-element-classes O-n O-n-element)

(define-domain-creator O-n ((domain field) dimension)
  (make-instance 'O-n :coefficient-domain domain :dimension dimension)
  :predicate (lambda (d)
		(and (eql (class-name (class-of d)) 'O-n)
		     (eql (coefficient-domain d) domain)
		     (eql (dimension d) dimension))))
;;
;; SO(n) : orthogonal matrices with unit determinant
;;

(defmethod print-object ((domain SO-n) stream)
  (let ((n (dimension domain)))
    (format stream "SO^~D(~S)" n (coefficient-domain domain))))

(define-domain-element-classes SO-n SO-n-element)

(defmethod recip ((m SO-n-element))
  (transpose m))

(define-domain-creator SO-n ((domain field) dimension)
  (make-instance 'SO-n :coefficient-domain domain :dimension dimension)
  :predicate (lambda (d)
		(and (eql (class-name (class-of d)) 'SO-n)
		     (eql (coefficient-domain d) domain)
		     (eql (dimension d) dimension))))
