mmm.im

// Matrix-matrix multiply for Iota+

uses io.print, io.printi, conv.stoi, conv.itos

// A is m-by-n, B is n-by-p, result is m-by-p
// Arrays are stored as arrays of rows, i.e., A is an array of m n-vectors.
mmm(A: array[array[int]], B: array[array[int]]): array[array[int]] = (
	i, j, k: int;

	m: int = length(A);
	n: int = length(B);

	if (m <= 0) return null;
	if (n <= 0) return null;

	p: int = length(B[0]);

	// Check that A is the correct size.
	i = 0;
	while (i < m) (
	    if (n != length(A[i])) return null;
	    i++
	);
	
	// Check that B is the correct size.
	while (i < n) (
	    if (p != length(B[i])) return null;
	    i++
	);

	// Allocate a new array.
	C: array[array[int]] = new array[int][m](new int[p](0));

	// MMM
	i = 0;
	while (i < m) (
		j = 0;
		while (j < n) (
			k = 0;
			while (k < p) (
				C[i][k] = C[i][k] + A[i][j] * B[j][k];
				k++
			);
			j++
		);
		i++
	);

	C
)

main(args: array[string]): int = (
    if (length args < 3) (
	print("Usage: mmm <m> <n> <p>\n");
	return 1;
    )

    m: int = stoi(args[0], 1);
    n: int = stoi(args[1], 1);
    p: int = stoi(args[2], 1);

    A: array[array[int]] = new array[int][m](new int[n](0));
    B: array[array[int]] = new array[int][n](new int[p](0));

    i, j: int;

    i = 0;
    while (i < m & i < n) (
	A[i][i] = 1;
	i++
    );

    i = 0;
    while (i < n) (
	j = 0;
	while (j < p) (
		B[i][j] = i * j;
		j++
	);
	i++
    );

    C: array[array[int]] = mmm(A, B);

	print("A:\N");
	printMatrix(A);
	print("B:\N");
	printMatrix(B);
	print("C:\N");
	printMatrix(C);

	0
)

printMatrix(M:array[array[int]]) = (
	i:int = 0;
	while (i < length M) (
		j:int = 0;
		print("\t");
		while (j < length M[i]) (
			print(itos(M[i][j]) + " ");
			j++
		);
		print("\N");
		i++
	);
)