use super::mesh::*;
use super::mesh_with_delays::*;
use super::pe::*;
use super::*;
macro_rules! shift_reg {
    ($first: ident, $( $x:ident ),*) => {{
        [ [$first], $(
                [ $x.shift_reg_fwd::<{ ${index()} + 1 }>() ]
        ), *]
    }};
    (($fx:ident, $fy: ident), $(($x:ident, $y:ident)), *) => {{
        [ [($fx, $fy)], $(
                [ ($x.shift_reg_fwd::<{ ${index()} + 1 }>(), $y.shift_reg_fwd::<{ ${index()} + 1 }>()) ]
        ), *]
    }};
}
macro_rules! shift_reg_rev {
    ($($x:ident),* ; $last:ident) => {{
        [ $( [ $x.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>() ]
        ),*, [ $last ] ]
    }};
    ($(($x:ident, $y:ident)),* ; ($lx:ident, $ly:ident)) => {{
        [ $( [ ($x.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>(), $y.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>()) ]
        ),*, [ ($lx, $ly) ] ]
    }};
}
pub fn preprocess_shift((in_row, in_col): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) {
    let [[r0], [r1], [r2], [r3], [r4], [r5], [r6], [r7], [r8], [r9], [r10], [r11], [r12], [r13], [r14], [r15]] = in_row;
    let [[(c0d, c0c)], [(c1d, c1c)], [(c2d, c2c)], [(c3d, c3c)], [(c4d, c4c)], [(c5d, c5c)], [(c6d, c6c)], [(c7d, c7c)], [(c8d, c8c)], [(c9d, c9c)], [(c10d, c10c)], [(c11d, c11c)], [(c12d, c12c)], [(c13d, c13c)], [(c14d, c14c)], [(c15d, c15c)]] =
        in_col;
    (
        shift_reg!(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15),
        shift_reg!(
            (c0d, c0c),
            (c1d, c1c),
            (c2d, c2c),
            (c3d, c3c),
            (c4d, c4c),
            (c5d, c5c),
            (c6d, c6c),
            (c7d, c7c),
            (c8d, c8c),
            (c9d, c9c),
            (c10d, c10c),
            (c11d, c11c),
            (c12d, c12c),
            (c13d, c13c),
            (c14d, c14c),
            (c15d, c15c)
        ),
    )
}
pub fn postprocess_shift((out_row, out_col): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) {
    let [[r0], [r1], [r2], [r3], [r4], [r5], [r6], [r7], [r8], [r9], [r10], [r11], [r12], [r13], [r14], [r15]] =
        out_row;
    let [[(c0d, c0c)], [(c1d, c1c)], [(c2d, c2c)], [(c3d, c3c)], [(c4d, c4c)], [(c5d, c5c)], [(c6d, c6c)], [(c7d, c7c)], [(c8d, c8c)], [(c9d, c9c)], [(c10d, c10c)], [(c11d, c11c)], [(c12d, c12c)], [(c13d, c13c)], [(c14d, c14c)], [(c15d, c15c)]] =
        out_col;
    (
        shift_reg_rev!(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14; r15),
        shift_reg_rev!(
            (c0d, c0c),
            (c1d, c1c),
            (c2d, c2c),
            (c3d, c3c),
            (c4d, c4c),
            (c5d, c5c),
            (c6d, c6c),
            (c7d, c7c),
            (c8d, c8c),
            (c9d, c9c),
            (c10d, c10c),
            (c11d, c11c),
            (c12d, c12c),
            (c13d, c13c),
            (c14d, c14c);
            (c15d, c15c)
        ),
    )
}
#[allow(clippy::type_complexity)]
pub fn preprocess_type((data, req): (Valid<(A, B, D)>, Valid<(ReqExtended, bool)>)) -> (MeshRowData, MeshColData) {
    unsafe {
        (data, req).fsm::<(MeshRowData, MeshColData), ()>((), |(data_in, req_in), _, ()| {
            let default_row = None::<PeRowData>.repeat::<1>().repeat::<MESH_ROWS>();
            let default_col = (None::<PeColData>, None::<PeColControl>).repeat::<1>().repeat::<MESH_COLS>();
            let a_in = data_in.map(|p| p.0);
            let b_in = data_in.map(|p| p.1);
            let d_in = data_in.map(|p| p.2);
            let col_in = match (b_in.zip(d_in), req_in) {
                (Some(bd), Some(req)) => Some((Some(bd), Some(req))),
                (None, Some(req)) => Some((None, Some(req))),
                _ => None,
            };
            let in_left = col_in.map_or(default_row, |_| {
                if let Some(mesh_row) = a_in {
                    mesh_row.map(|tile_row| tile_row.map(|a| Some(PeRowData { a })))
                } else {
                    range::<MESH_ROWS>().map(|_| Some(PeRowData { a: S::from(0.into_u::<INPUT_BITS>()) }).repeat::<1>())
                }
            });
            let in_top = col_in.map_or(default_col, |(bd, mesh_req)| {
                let (bd, (ReqExtended { req, config }, last)) = mesh_req.map(|req| (bd, req)).unwrap();
                let pe_control = PeControl { dataflow: req.dataflow, propagate: config.propagate, shift: req.shift };
                let column_control = Some(PeColControl { control: pe_control, id: config.matmul_id, last });
                if let Some((b, d)) = bd {
                    b.zip(d).map(|(b, d)| {
                        let column_data =
                            Some(PeColData { b: b[0].sext::<OUTPUT_BITS>(), d: d[0].sext::<OUTPUT_BITS>() });
                        column_data.repeat::<1>().zip(column_control.repeat::<1>())
                    })
                } else {
                    range::<MESH_COLS>().map(|_| {
                        let column_data = Some(PeColData {
                            b: S::from(0.into_u::<OUTPUT_BITS>()),
                            d: S::from(0.into_u::<OUTPUT_BITS>()),
                        });
                        column_data.repeat::<1>().zip(column_control.repeat::<1>())
                    })
                }
            });
            ((in_left, in_top), ((), ()), ())
        })
    }
}
pub fn postprocess_type((out_row, out_col): (MeshRowData, MeshColData)) -> Valid<(C, PeColControl)> {
    unsafe {
        (out_row, out_col).fsm::<Valid<(C, PeColControl)>, ()>((), |(_, col_data), _, ()| {
            let out_valid = col_data[0][0].0.is_some();
            let dataflow_os = col_data[0][0].1.is_some_and(|v| matches!(v.control.dataflow, Dataflow::OS));
            let out_b = col_data
                .map(|tile_r| tile_r.map(|(data, _)| data.map_or(0.into_u(), |v| U::from(v.b))).concat())
                .map(|v| S::from(v).repeat::<TILE_COLS>());
            let out_c = col_data
                .map(|tile_r| tile_r.map(|(data, _)| data.map_or(0.into_u(), |v| U::from(v.d))).concat())
                .map(|v| S::from(v).repeat::<TILE_COLS>());
            let matmul_result = if dataflow_os { out_c } else { out_b };
            let ep = if out_valid { Some((matmul_result, col_data[0][0].1.unwrap())) } else { None };
            let ir0 = ().repeat::<1>().repeat::<MESH_COLS>();
            let ir1 = ((), ()).repeat::<1>().repeat::<MESH_COLS>();
            (ep, (ir0, ir1), ())
        })
    }
}