Filters
Contents
Filters¶
Fabric filters allow a PE to selectively accept incoming wavelets. This example shows the use of so-called range filters, which specify the wavelets to allow to be forwarded to the CE based on the upper 16 bits of the wavelet contents. Specifically, PE #0 sends all 12 wavelets to the other PEs, while each recipient PE receives and processes only a quarter of the incoming wavelets. See the documentation for other possible filter configurations.
code.csl¶
// resources to route the data between the host and the device.
//
// color map
//
// color var color var color var color var
// 0 9 dataColor 18 27 reserved (memcpy)
// 1 D2H 10 19 28 reserved (memcpy)
// 2 LAUNCH 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 mainColor 17 26 35
//
param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);
const mainColor: color = @get_color(8);
const dataColor: color = @get_color(9);
const memcpy = @import_module( "<memcpy_multi/get_params>", .{
.width = 4,
.height = 1
});
layout {
@set_rectangle(4, 1);
const memcpy_params_0 = memcpy.get_params(0);
const memcpy_params_1 = memcpy.get_params(1);
const memcpy_params_2 = memcpy.get_params(2);
const memcpy_params_3 = memcpy.get_params(3);
@set_tile_code(0, 0, "send.csl", .{
.peId = 0,
.mainColor = mainColor,
.exchColor = dataColor,
.memcpy_params = memcpy_params_0,
.MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
.LAUNCH = LAUNCH
});
const recvStruct = .{
.recvColor = dataColor,
.MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
.LAUNCH = LAUNCH
};
@set_tile_code(1, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 1, .memcpy_params = memcpy_params_1 }));
@set_tile_code(2, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 2, .memcpy_params = memcpy_params_2 }));
@set_tile_code(3, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 3, .memcpy_params = memcpy_params_3 }));
// export symbol name
@export_name("buf", [*]f16, true);
@export_name("f_run", fn()void);
}
send.csl¶
// Not a complete program; the top-level source file is code.csl.
param peId: u16;
param mainColor: color;
param exchColor: color;
param memcpy_params: comptime_struct;
param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;
const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
.MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
.LAUNCH = LAUNCH
}));
/// Helper function to pack 16-bit index and 16-bit float value into one 32-bit
/// wavelet.
fn pack(index: u16, data: f16) u32 {
return (@as(u32, index) << 16) | @as(u32, @bitcast(u16, data));
}
const size = 12;
const data = [size]u32 {
pack(0, 10.0), pack( 1, 11.0), pack( 2, 12.0),
pack(3, 13.0), pack( 4, 14.0), pack( 5, 15.0),
pack(6, 16.0), pack( 7, 17.0), pack( 8, 18.0),
pack(9, 19.0), pack(10, 20.0), pack(11, 21.0),
};
/// Function to send all data values to all east neighbors.
fn sendDataToEastTiles() void {
const inDsd = @get_dsd(mem1d_dsd, .{
.tensor_access = |i|{size} -> data[i]
});
const outDsd = @get_dsd(fabout_dsd, .{
.extent = size,
.fabric_color = exchColor,
});
@mov32(outDsd, inDsd);
}
const num_wvlts: i16 = 3;
var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;
// Function to process (divide by 2) the first three values
fn processAndSendSubset() void {
var idx: u16 = 0;
while (idx < 3) : (idx += 1) {
const payload = @as(u16, data[idx] & 0xffff);
const floatValue = @bitcast(f16, payload);
buf[idx] = floatValue / 2.0;
}
}
task mainTask() void {
// broadcast to all PEs, including itself
sendDataToEastTiles();
// prepare data in "buf"
processAndSendSubset();
// WARNING: the user must unblock cmd color for every PE
sys_mod.unblock_cmd_stream();
}
comptime {
@bind_task(mainTask, mainColor);
@set_local_color_config(exchColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST } } });
}
// only sender triggers the broadcasting
fn f_run() void {
@activate(mainColor);
// terminate when the mainTask is done
}
comptime{
@export_symbol(ptr_buf, "buf");
@export_symbol(f_run);
@rpc(LAUNCH);
}
recv.csl¶
// Not a complete program; the top-level source file is code.csl.
param peId: u16;
param recvColor: color;
param memcpy_params: comptime_struct;
param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;
const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
.MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
.LAUNCH = LAUNCH
}));
const num_wvlts: i16 = 3;
var index: i16 = 0;
var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;
// The recipient simply halves the value in the incoming wavelet
task recvTask(data: f16) void {
buf[index] = data / 2.0;
index += 1;
if (index >= num_wvlts){
// receive all wavelets, proceed next command
// WARNING: the user must unblock cmd color for every PE
sys_mod.unblock_cmd_stream();
}
}
comptime {
@bind_task(recvTask, recvColor);
// f_run() unblocks this color to receive the broadcasting value
@block(recvColor);
const baseRoute = .{
.rx = .{ WEST }
};
const filter = .{
// Each PE should only accept three wavelets starting with the one whose
// index field contains the value peId * 3.
.kind = .{ .range = true },
.min_idx = peId * 3,
.max_idx = peId * 3 + 2,
};
if (peId == 3) {
// This is the last PE, don't forward the wavelet further to the east.
const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP } });
@set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
} else {
// Otherwise, forward incoming wavelets to both CE and to the east neighbor.
const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP, EAST } });
@set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
}
}
// only sender triggers the broadcasting
// receiver unblocks recvColor to receive the data from the sender
fn f_run() void {
// starts to receive the data from the sender
@unblock(recvColor);
// terminates only when all wavelets are received
}
comptime{
@export_symbol(ptr_buf, "buf");
@export_symbol(f_run);
@rpc(LAUNCH);
}
run.py¶
#!/usr/bin/env cs_python
import argparse
import json
import numpy as np
from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module
parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name
# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)
sym_buf = runner.get_id("buf")
runner.load()
runner.run()
print("step 1: call f_run to start broadcasting")
runner.launch("f_run", nonblock=False)
print("step 2: copy D2H")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(4*3, np.uint32)
runner.memcpy_d2h(out_tensors_u32, sym_buf, 0, 0, 4, 1, 3, \
streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result = memcpy_view(out_tensors_u32, np.dtype(np.float16))
runner.stop()
oracle = [5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5]
np.testing.assert_allclose(result, oracle, atol=0.0001, rtol=0)
print("SUCCESS!")
commands.sh¶
#!/usr/bin/env bash
set -e
cslc ./code.csl --fabric-dims=11,3 --fabric-offsets=4,1 -o out \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out