Document what compiler is doing for diffxi
This commit is contained in:
parent
76f5291131
commit
a33e1d37ba
|
@ -102,7 +102,10 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
|
||||||
prev: &[Float],
|
prev: &[Float],
|
||||||
fut: &mut [Float],
|
fut: &mut [Float],
|
||||||
) {
|
) {
|
||||||
|
use std::convert::TryInto;
|
||||||
#[inline(never)]
|
#[inline(never)]
|
||||||
|
/// This prevents code bloat, both start and end block gives
|
||||||
|
/// a matrix multiplication with the same matrix sizes
|
||||||
fn dedup_matmul<const M: usize, const N: usize>(
|
fn dedup_matmul<const M: usize, const N: usize>(
|
||||||
c: &mut ColVector<Float, M>,
|
c: &mut ColVector<Float, M>,
|
||||||
a: &Matrix<Float, M, N>,
|
a: &Matrix<Float, M, N>,
|
||||||
|
@ -129,7 +132,6 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
|
||||||
let (futb1, fut) = fut.split_at_mut(M);
|
let (futb1, fut) = fut.split_at_mut(M);
|
||||||
let (fut, futb2) = fut.split_at_mut(nx - 2 * M);
|
let (fut, futb2) = fut.split_at_mut(nx - 2 * M);
|
||||||
|
|
||||||
use std::convert::TryInto;
|
|
||||||
{
|
{
|
||||||
let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap());
|
let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap());
|
||||||
let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap());
|
let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap());
|
||||||
|
@ -138,19 +140,39 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
|
||||||
*fut *= idx;
|
*fut *= idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The window needs to be aligned to the diagonal elements,
|
|
||||||
// based on the block size
|
|
||||||
let window_elems_to_skip = M - ((D - 1) / 2);
|
|
||||||
|
|
||||||
for (window, f) in prev[window_elems_to_skip..]
|
|
||||||
.array_windows::<D>()
|
|
||||||
.zip(fut.array_chunks_mut::<1>())
|
|
||||||
{
|
{
|
||||||
let fut = ColVector::<_, 1>::map_to_col_mut(f);
|
// The window needs to be aligned to the diagonal elements,
|
||||||
let prev = ColVector::<_, D>::map_to_col(window);
|
// based on the block size
|
||||||
|
let window_elems_to_skip = M - ((D - 1) / 2);
|
||||||
|
|
||||||
fut.matmul_float_into(&matrix.diag, prev);
|
// The compiler is pretty clever right here. It is able to
|
||||||
*fut *= idx;
|
// inline the entire computation into a very tight loop.
|
||||||
|
// It seems the "optimal" way is to broadcast each element of diag
|
||||||
|
// into separate SIMD vectors.
|
||||||
|
// Then an unroll of SIMD::lanes items from fut and prev:
|
||||||
|
// f[0] = diag[0] * in[0] + diag[1] * in[1] + diag[2] * in[3] + ...
|
||||||
|
// f[1] = diag[0] * in[1] + diag[1] * in[2] + diag[2] * in[4] + ...
|
||||||
|
// \-- fma's along the vertical axis, then horizontal
|
||||||
|
//
|
||||||
|
// The resulting inner loop performs:
|
||||||
|
// one mul, D-1 fma,
|
||||||
|
// one mul (by idx),
|
||||||
|
// one store
|
||||||
|
// two integer adds (input and output offsets)
|
||||||
|
// one test + jump
|
||||||
|
//
|
||||||
|
// The compiler is clever enough to combine two such unrollings and merge
|
||||||
|
// these computations to prevent stalling
|
||||||
|
for (window, f) in prev[window_elems_to_skip..]
|
||||||
|
.array_windows::<D>()
|
||||||
|
.zip(fut.array_chunks_mut::<1>())
|
||||||
|
{
|
||||||
|
let fut = ColVector::<_, 1>::map_to_col_mut(f);
|
||||||
|
let prev = ColVector::<_, D>::map_to_col(window);
|
||||||
|
|
||||||
|
fut.matmul_float_into(&matrix.diag, prev);
|
||||||
|
*fut *= idx;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
|
@ -123,10 +123,15 @@ impl<const M: usize, const P: usize> Matrix<Float, M, P> {
|
||||||
) {
|
) {
|
||||||
for i in 0..M {
|
for i in 0..M {
|
||||||
for j in 0..P {
|
for j in 0..P {
|
||||||
let mut t = 0.0;
|
// Slightly cheaper to do first computation separately
|
||||||
for k in 0..N {
|
// rather than store zero and issue all ops as fma
|
||||||
|
let mut t = if N == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
lhs[(i, 0)] * rhs[(0, j)]
|
||||||
|
};
|
||||||
|
for k in 1..N {
|
||||||
t = Float::mul_add(lhs[(i, k)], rhs[(k, j)], t);
|
t = Float::mul_add(lhs[(i, k)], rhs[(k, j)], t);
|
||||||
// t = t + lhs[(i, k)] * rhs[(k, j)];
|
|
||||||
}
|
}
|
||||||
self[(i, j)] = t;
|
self[(i, j)] = t;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue