-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmap.cpp
More file actions
1668 lines (1523 loc) · 91.6 KB
/
map.cpp
File metadata and controls
1668 lines (1523 loc) · 91.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "map.hpp"
#include <algorithm>
#include <chrono>
#include <iostream>
#include <sstream>
#include <fstream>
#include <utility>
#include <vector>
#include "common/tt_backend_api_types.hpp"
#include "detail/tt_metal.hpp"
#include "host_api.hpp"
#include "impl/buffers/buffer.hpp"
#include "impl/buffers/circular_buffer_types.hpp"
#include "common/work_split.hpp"
#include "tt_metal/impl/device/device.hpp"
namespace current {
Map::Map(std::vector<Kernel *> kernels, std::vector<Stream *> streams, uint32_t max_parallelization_factor, uint32_t tiles_per_cb)
: kernels(std::move(kernels)), streams(streams), max_parallelization_factor(max_parallelization_factor), tiles_per_cb(tiles_per_cb) {
// Check that all streams have the same number of elements.
for (size_t i = 1; i < streams.size(); i++) {
// TODO: Eventually we want to support streams of different sizes e.g for reduction kernels,
// but right now we just check that all streams have the same number of elements.
// assert(streams[i]->n_elements == streams[0]->n_elements && "All streams must have the same number of elements!");
}
}
// TODO: Validate that port connections are valid (check that types are the same.)
void Map::add_connection(Kernel *src, std::string src_out, Kernel *dst, std::string dst_in) {
// TODO: Add error handling.
auto src_kernel_idx = get_kernel_index(src);
auto dst_kernel_idx = get_kernel_index(dst);
Endpoint src_endpoint = {Endpoint::EndpointType::Kernel, src_kernel_idx, src_out};
Endpoint dst_endpoint = {Endpoint::EndpointType::Kernel, dst_kernel_idx, dst_in};
tt::log_info("[CURRENT] Adding connection from kernel {} to kernel {}", src_kernel_idx, dst_kernel_idx);
add_connection(src_endpoint, dst_endpoint);
}
void Map::add_connection(Stream *src, Kernel *dst, std::string dst_in) {
auto src_stream_idx = get_stream_index(src);
auto dst_kernel_idx = get_kernel_index(dst);
Endpoint src_endpoint = {Endpoint::EndpointType::Stream, src_stream_idx, ""};
Endpoint dst_endpoint = {Endpoint::EndpointType::Kernel, dst_kernel_idx, dst_in};
tt::log_info("[CURRENT] Adding connection from stream {} to kernel {}", src_stream_idx, dst_kernel_idx);
add_connection(src_endpoint, dst_endpoint);
}
void Map::add_connection(Kernel *src, std::string src_out, Stream *dst) {
auto src_kernel_idx = get_kernel_index(src);
auto dst_stream_idx = get_stream_index(dst);
Endpoint src_endpoint = {Endpoint::EndpointType::Kernel, src_kernel_idx, src_out};
Endpoint dst_endpoint = {Endpoint::EndpointType::Stream, dst_stream_idx, ""};
tt::log_info("[CURRENT] Adding connection from kernel {} to stream {}", src_kernel_idx, dst_stream_idx);
// TODO: Need to do checks whether these are valid port names and that the ports have not already been connected.
add_connection(src_endpoint, dst_endpoint);
}
// void Map::parallelize(std::vector<CoreCoord> &cores) {
// auto total_cores = cores.size();
// for (size_t i = 0; i < kernels.size(); i++) {
// }
// }
std::chrono::steady_clock::duration Map::execute() {
check_connections();
propagate_counts();
// 1. Create device and program.
auto device = tt_metal::CreateDevice(0);
if (!device) {
std::cerr << "Failed to create device!\n";
exit(1);
}
auto program = tt_metal::CreateProgram();
// 2. Core grid setup.
// TODO: Have this configurable by user and dyanmic by runtime scheduling.
// Write now just set to # of kernels we have.
// runtime.num_cores = kernels.size();
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
auto num_cores_x = compute_with_storage_grid_size.x;
auto num_cores_y = compute_with_storage_grid_size.y;
auto core_set = num_cores_to_corerange_set({0, 0}, kernels.size() * max_parallelization_factor, {num_cores_x, num_cores_y});
runtime.emplace(Runtime {
.device = device,
.program = std::move(program),
.num_cores_x = static_cast<uint32_t>(num_cores_x),
.num_cores_y = static_cast<uint32_t>(num_cores_y),
.core_set = std::move(core_set)
});
tt::log_info("[CURRENT] num_cores_x: {}, num_cores_y: {}", runtime->num_cores_x, runtime->num_cores_y);
tt::log_info("[CURRENT] core_set: {}", runtime->core_set.str());
tt::log_info("[CURRENT] Total cores: {}", runtime->core_set.num_cores());
// 3. Input & Output DRAM buffer setup.
for (size_t i = 0; i < streams.size(); i++) {
auto *stream = streams[i];
if (stream->is_gather_stream()) {
auto *gather_stream = dynamic_cast<GatherStream*>(stream);
// Index buffer.
tt_metal::InterleavedBufferConfig config = {
.device = runtime->device,
.size = gather_stream->n_elements * gather_stream->element_size,
// .page_size = stream->element_size * TILE_WIDTH * TILE_HEIGHT * gather_stream->accesses_per_token, // We fetch n indices for every n accesses, so index tile has to be n times larger.
.page_size = gather_stream->n_elements * gather_stream->element_size, // TODO: What should this be?
.buffer_type = tt_metal::BufferType::DRAM
};
std::cout << "STREAM " << i << ": size: " << config.size << std::endl;
gather_stream->device_buffer = tt_metal::CreateBuffer(config);
// TODO: Does this need to be blocking?
// TODO: What if there's a mismatch between the host data size and the device buffer size?
tt_metal::EnqueueWriteBuffer(runtime->device->command_queue(), gather_stream->device_buffer, gather_stream->host_data, true);
gather_stream->device_buffer_address = gather_stream->device_buffer->address();
gather_stream->device_buffer_noc_coordinates = gather_stream->device_buffer->noc_coordinates();
// Set up data buffer as well.
if (!gather_stream->use_sram) {
// Note: assuming underlying data is already 32 byte aligned and accesses are padded.
auto total_size_bytes = gather_stream->data_buffer.size() * 4;
std::cout << "data gather stream page size & total size: " << total_size_bytes << "\n";
tt_metal::InterleavedBufferConfig data_config = {
.device = runtime->device,
.size = total_size_bytes,
.page_size = total_size_bytes, // TODO: Random access doesn't work across pages, so we have page_size=total_size. Is this a problem?
.buffer_type = tt_metal::BufferType::DRAM,
};
gather_stream->data_buffer_device = tt_metal::CreateBuffer(data_config);
std::cout << "Created gather stream!\n";
tt_metal::EnqueueWriteBuffer(runtime->device->command_queue(),
gather_stream->data_buffer_device,
gather_stream->data_buffer,
true);
gather_stream->data_buffer_address = gather_stream->data_buffer_device->address();
gather_stream->data_buffer_noc_coordinates = gather_stream->data_buffer_device->noc_coordinates();
} else {
// We are placing the entire gather stream data buffer in L1.
auto total_size_bytes = gather_stream->data_buffer.size() * 4;
std::cout << "data gather stream page size & total size: " << total_size_bytes << "\n";
tt_metal::InterleavedBufferConfig data_config = {
.device = runtime->device,
.size = total_size_bytes,
.page_size = total_size_bytes, // TODO: Random access doesn't work across pages, so we have page_size=total_size. Is this a problem?
.buffer_type = tt_metal::BufferType::L1,
};
gather_stream->data_buffer_device = tt_metal::CreateBuffer(data_config);
std::cout << "Created gather stream!\n";
gather_stream->data_buffer_address = gather_stream->data_buffer_device->address();
gather_stream->data_buffer_noc_coordinates = gather_stream->data_buffer_device->noc_coordinates();
}
} else {
auto page_size = stream->element_size * TILE_WIDTH * TILE_HEIGHT;
auto total_size_bytes = ((stream->n_elements * stream->element_size + page_size - 1) / page_size) * page_size;
tt_metal::InterleavedBufferConfig config = {
.device = runtime->device,
.size = total_size_bytes,
.page_size = page_size, // TODO: Not sure what is optimal for this.
.buffer_type = tt_metal::BufferType::DRAM
};
std::cout << "STREAM " << i << ": size: " << config.size << std::endl;
stream->device_buffer = tt_metal::CreateBuffer(config);
// TODO: Does this need to be blocking?
// TODO: What if there's a mismatch between the host data size and the device buffer size?
tt_metal::EnqueueWriteBuffer(runtime->device->command_queue(), stream->device_buffer, stream->host_data, true);
stream->device_buffer_address = stream->device_buffer->address();
stream->device_buffer_noc_coordinates = stream->device_buffer->noc_coordinates();
}
}
// 4. Generate device kernels.
generate_device_kernels();
// Vector of cores we have availible to assign to kernels.
std::vector<CoreCoord> cores = corerange_to_cores(runtime->core_set);
// TODO: Bug with this when parallelizatino factor is not a multiple of num tiles??? Some sort of workload split bug.
auto total_cores = cores.size();
std::cout << "Total cores: " << total_cores << std::endl;
size_t cores_used = 0;
for (size_t i = 0; i < kernels.size(); i++) {
// Each kernel gets mapped to a single core.
// We just assign to the next available core.
// TODO: Look into core placement strategies (RaftLib thesis)
// possibly doing automatic parallelization of kernels.
// NOTE: This also requires that we need as many cores as kernels.
size_t cores_availible = total_cores - cores_used;
size_t cores_to_assign = std::min(cores_availible, (size_t)max_parallelization_factor);
std::vector<CoreCoord> kernel_cores;
for (size_t j = 0; j < cores_to_assign; j++) {
kernel_cores.push_back(cores[cores_used + j]);
}
kernels[i]->core_spec = kernel_cores;
cores_used += cores_to_assign;
std::cout << "Kernel " << i << " assigned to cores: ";
for (const auto& core : kernel_cores) {
std::cout << "(" << core.x << ", " << core.y << ") ";
}
std::cout << std::endl;
}
// Print core distribution.
for (size_t i = 0; i < kernels.size(); i++) {
auto kernel = kernels[i];
// Parallelize kernel across multiple cores.
auto parallelization_factor = kernel->core_spec.size(); // # of cores this kernel is parallelized across.
for (size_t j = 0; j < parallelization_factor; j++) {
auto core = kernel->core_spec[j];
// Create semaphores for each kernel.
// TODO: Is overwriting across multiple cores, though we might not even need to store this ID for later.
kernel->sender_semaphore_id = tt_metal::CreateSemaphore(runtime->program, core, INVALID);
kernel->receiver_semaphore_id = tt_metal::CreateSemaphore(runtime->program, core, INVALID);
kernel->l1_valid_value_semaphore_id = tt_metal::CreateSemaphore(runtime->program, core, VALID);
auto incoming_connections = get_incoming_connections(kernel);
auto outgoing_connections = get_outgoing_connections(kernel);
// Create circular buffers for each incoming and outgoing connection.
auto num_input_cbs = 0;
auto num_intermed_cbs = 0;
for (size_t k = 0; k < incoming_connections.size(); k++) {
if (incoming_connections[k].source.is_stream() && streams[incoming_connections[k].source.index]->is_gather_stream()) {
auto *stream = dynamic_cast<GatherStream*>(streams[incoming_connections[k].source.index]);
// Init L1 data buffer.
if (stream->use_sram) {
tt::tt_metal::detail::WriteToDeviceL1(runtime->device, core, stream->data_buffer_device->address(), stream->data_buffer);
}
// Setup intermed CB for gather stream indices.
auto intermed_cb_index = num_intermed_cbs + INTERMED_CB_START;
num_intermed_cbs++;
auto index_tile_size_bytes = TILE_SIZE * sizeof(uint32_t) * stream->accesses_per_token;
// auto index_tile_size_bytes = TILE_SIZE * sizeof(uint32_t);
std::cout << "Index CB tile size bytes: " << index_tile_size_bytes << "\n";
tt_metal::CircularBufferConfig index_cb_config = CircularBufferConfig(
tiles_per_cb * index_tile_size_bytes,
{{intermed_cb_index, tt::DataFormat::UInt32}}
).set_page_size(intermed_cb_index, index_tile_size_bytes); // TODO: Not sure what to set this page size to.
// TODO: Does this handle even need to be stored anywhere?
auto index_cb = tt_metal::CreateCircularBuffer(runtime->program, core, index_cb_config);
std::cout << "Creating circular buffer at index " << intermed_cb_index << "!\n";
// Setup input CBs for streaming to compute.
// For gather stream we have one input CB for every access per token.
for (size_t access = 0; access < stream->accesses_per_token; access++) {
// Each port is typed with a specific data format.
auto cb_index = num_input_cbs + IN_CB_START;
num_input_cbs++;
auto input_port = kernel->get_input_port(incoming_connections[k].dest.port);
auto input_port_index = kernel->get_input_port_index(input_port.name);
auto tile_size_bytes = TILE_WIDTH * TILE_HEIGHT * tt::datum_size(input_port.data_format);
tt_metal::CircularBufferConfig cb_config = CircularBufferConfig(
tiles_per_cb * tile_size_bytes,
{{cb_index, input_port.data_format}}
).set_page_size(cb_index, tile_size_bytes); // TODO: Not sure what to set this page size to.
// TODO: Again, overwriting across multiple cores.
kernel->input_ports[input_port_index].cb = tt_metal::CreateCircularBuffer(runtime->program, core, cb_config);
std::cout << "Creating circular buffer at index " << cb_index << "!\n";
}
} else {
// Each port is typed with a specific data format.
auto cb_index = num_input_cbs + IN_CB_START;
num_input_cbs++;
auto input_port = kernel->get_input_port(incoming_connections[k].dest.port);
auto input_port_index = kernel->get_input_port_index(input_port.name);
auto tile_size_bytes = TILE_WIDTH * TILE_HEIGHT * tt::datum_size(input_port.data_format);
tt_metal::CircularBufferConfig cb_config = CircularBufferConfig(
tiles_per_cb * tile_size_bytes,
{{cb_index, input_port.data_format}}
).set_page_size(cb_index, tile_size_bytes); // TODO: Not sure what to set this page size to.
// TODO: Again, overwriting across multiple cores.
kernel->input_ports[input_port_index].cb = tt_metal::CreateCircularBuffer(runtime->program, core, cb_config);
std::cout << "Creating circular buffer at index " << cb_index << "!\n";
}
}
for (size_t k = 0; k < outgoing_connections.size(); k++) {
auto cb_index = k + OUT_CB_START;
auto output_port = kernel->get_output_port(outgoing_connections[k].source.port);
auto output_port_index = kernel->get_output_port_index(output_port.name);
auto tile_size_bytes = TILE_WIDTH * TILE_HEIGHT * tt::datum_size(output_port.data_format);
tt_metal::CircularBufferConfig cb_config = CircularBufferConfig(
tiles_per_cb * tile_size_bytes,
{{cb_index, output_port.data_format}}
).set_page_size(cb_index, tile_size_bytes); // TODO: Not sure what to set this page size to.
// TODO: Again, overwriting across multiple cores.
kernel->output_ports[output_port_index].cb = tt_metal::CreateCircularBuffer(runtime->program, core, cb_config);
}
// Create device kernels.
std::cout << "Reader kernel path: " << kernel->generated_reader_kernel_path << "\n";
auto reader = tt_metal::CreateKernel(
runtime->program,
kernel->generated_reader_kernel_path,
core,
// TODO: Can also do compile-time args here? I think this might be useful.
DataMovementConfig {
.processor = DataMovementProcessor::RISCV_0,
.noc = NOC::RISCV_0_default,
.compile_args = {},
.defines = {}
} // TODO: What to do for this?
);
kernel->reader_kernel = reader;
// DST is capable of storing 16 32x32 tiles of 2B datum size. In full mode, DST is not double buffered, so the full 16 tiles are available for LLKs to use.
// In half mode, we treat DST as a ping pong buffer, where each half contains 8 tiles. That means that an LLK can only index into 8 different tiles in DST.
// This is useful to overlap the MATH and PACK threads.
// Example: MATH will acquire the first half of DST and populate it with 8 tiles of output from some math LLK.
// MATH releases first half and acquires second half.
// PACK acquires first half and writes tiles from DST to L1.
// Meanwhile, MATH is producing results into the second half of DST.
// We continue ping ponging to keep MATH and PACK busy at the same time.
// I haven’t heard about “TILE” mode, not sure if that’s used anywhere.
// HALF mode is most common and afaik expected to be the default.
std::cout << "Compute kernel path: " << kernel->generated_compute_kernel_path << "\n";
auto compute = tt_metal::CreateKernel(
runtime->program,
kernel->generated_compute_kernel_path,
core,
ComputeConfig{
// TODO: Also need to figure out what the heck to do for this.
.math_fidelity = MathFidelity::LoFi,
.math_approx_mode = false,
.compile_args = {},
.defines = {}
}
);
kernel->compute_kernel = compute;
std::cout << "Writer kernel path: " << kernel->generated_writer_kernel_path << "\n";
auto writer = tt_metal::CreateKernel(
runtime->program,
kernel->generated_writer_kernel_path,
core,
DataMovementConfig {
.processor = DataMovementProcessor::RISCV_1,
.noc = NOC::RISCV_1_default,
.compile_args = {},
.defines = {}
}
);
kernel->writer_kernel = writer;
// Set runtime args.
std::vector<uint32_t> reader_args;
std::vector<uint32_t> compute_args;
for (const auto& connection : incoming_connections) {
uint32_t n_tiles = connection.n_tiles / parallelization_factor; // TODO: Account for when n_tiles % parallelization_factor != 0.
uint32_t tile_offset = n_tiles * j;
std::cout << "n_tiles: " << n_tiles << " tile_offset: " << tile_offset << std::endl;
if (connection.source.is_stream()) {
// For every incoming stream connection, we need to know how many tiles we expect to read and what the DRAM address is.
auto *stream = streams[connection.source.index];
if (!stream->is_gather_stream()) {
reader_args.push_back(n_tiles);
reader_args.push_back(stream->device_buffer_address);
reader_args.push_back(tile_offset);
compute_args.push_back(n_tiles); // Compute also needs to know how many tiles to read in.
} else {
auto *gather_stream = dynamic_cast<GatherStream*>(stream);
// n_tiles *= gather_stream->accesses_per_token; // TODO: I HAVE NO IDEA IF THIS SHIT WORKS LMAO
std::cout << "Gather stream index n_tiles: " << n_tiles << "\n";
// n_tiles = gather_stream->data_n_tiles / parallelization_factor;
// # of index tiles.
reader_args.push_back(n_tiles);
// Address of index buffer.
reader_args.push_back(gather_stream->device_buffer_address);
reader_args.push_back(tile_offset);
reader_args.push_back(gather_stream->device_buffer_noc_coordinates.x);
reader_args.push_back(gather_stream->device_buffer_noc_coordinates.y);
// Address of data buffer.
reader_args.push_back(gather_stream->data_buffer_address);
// Data buffer noc coordinates.
reader_args.push_back(gather_stream->data_buffer_noc_coordinates.x);
reader_args.push_back(gather_stream->data_buffer_noc_coordinates.y);
compute_args.push_back(n_tiles);
}
} else {
// Incoming connection is another kernel.
// TODO: I'm not sure if this is correct with the way I'm doing parallelization. Would make more sense to transform the program graph itself.
auto *sender = kernels[connection.source.index];
CoreCoord sender_core = sender->core_spec[j];
uint32_t sender_x = (uint32_t)runtime->device->worker_core_from_logical_core(sender_core).x;
uint32_t sender_y = (uint32_t)runtime->device->worker_core_from_logical_core(sender_core).y;
reader_args.push_back(n_tiles);
reader_args.push_back(sender_x);
reader_args.push_back(sender_y);
reader_args.push_back(kernel->sender_semaphore_id);
reader_args.push_back(kernel->receiver_semaphore_id);
reader_args.push_back(kernel->l1_valid_value_semaphore_id);
compute_args.push_back(n_tiles);
}
}
SetRuntimeArgs(runtime->program, kernel->reader_kernel, core, reader_args);
SetRuntimeArgs(runtime->program, kernel->compute_kernel, core, compute_args);
std::vector<uint32_t> writer_args;
for (const auto& connection : outgoing_connections) {
uint32_t n_tiles = connection.n_tiles / parallelization_factor; // TODO: Account for when n_tiles % parallelization_factor != 0.
uint32_t tile_offset = n_tiles * j; // TODO: Only works if total work is evenly divisible by parallelization factor.
if (connection.dest.is_stream()) {
// TODO: Here we are explicitly setting the # of tiles we expect to write to be the same as the capacity of the stream.
// This can get a bit tricky if the compute does any sort of reduction and the user does not correctly set the capacity.
// I think if reduction kernels get implemented then we need a way of automatically determining the # of tiles to write at each stage of the program.
auto stream = streams[connection.dest.index];
writer_args.push_back(n_tiles);
writer_args.push_back(stream->device_buffer_address);
writer_args.push_back(tile_offset);
} else {
// Outgoing connection is another kernel.
auto receiver = kernels[connection.dest.index];
CoreCoord receiver_core = receiver->core_spec[j];
uint32_t receiver_x = (uint32_t)runtime->device->worker_core_from_logical_core(receiver_core).x;
uint32_t receiver_y = (uint32_t)runtime->device->worker_core_from_logical_core(receiver_core).y;
uint32_t receiver_mailbox_addr = kernel->get_input_port(connection.dest.port).mailbox->address();
writer_args.push_back(n_tiles);
writer_args.push_back(receiver_x);
writer_args.push_back(receiver_y);
// writer_args.push_back(receiver_mailbox_addr);
writer_args.push_back(kernel->sender_semaphore_id);
writer_args.push_back(kernel->receiver_semaphore_id);
writer_args.push_back(kernel->l1_valid_value_semaphore_id);
}
}
SetRuntimeArgs(runtime->program, kernel->writer_kernel, core, writer_args);
}
}
tt_metal::detail::CompileProgram(runtime->device, runtime->program);
// Collect benchmark metrics.
auto start = std::chrono::steady_clock::now();
tt_metal::EnqueueProgram(runtime->device->command_queue(), runtime->program, false);
tt_metal::Finish(runtime->device->command_queue());
auto end = std::chrono::steady_clock::now();
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
// tt::log_info("[CURRENT] Program execution completed! Time taken: {}us", duration.count());
// auto total_bytes = streams[0]->n_elements * streams[0]->element_size;
// tt::log_info("[CURRENT] Total bytes transferred: {}", total_bytes);
// double total_seconds = duration.count() / 1000.0;
// tt::log_info("[CURRENT] Total throughput: {} GB/s", (total_bytes / total_seconds) / 1e9);
// tt_metal::CloseDevice(runtime.device);
return end - start;
}
Map::~Map() {
if (runtime) {
tt_metal::CloseDevice(runtime->device);
}
}
bool Map::has_incoming_connection(Kernel *kernel) {
// Get the index of our kernel in the kernels vector
size_t kernel_idx = get_kernel_index(kernel);
// Check each connection
for (const Connection& connection : connections) {
// Is the destination endpoint a kernel?
if (connection.dest.is_kernel()) {
// Does its index match our target kernel?
if (connection.dest.index == kernel_idx) {
return true; // Found an incoming connection!
}
}
}
return false;
}
std::vector<Map::Connection> Map::get_incoming_connections(Kernel *kernel) {
// Get the index of our kernel in the kernels vector
size_t kernel_idx = get_kernel_index(kernel);
// Find all connections that have our kernel as the destination
std::vector<Connection> incoming_connections;
for (const Connection& connection : connections) {
if (connection.dest.index == kernel_idx && connection.dest.endpoint_type == Endpoint::EndpointType::Kernel) {
incoming_connections.push_back(connection);
}
}
return incoming_connections;
}
std::vector<Map::Connection> Map::get_outgoing_connections(Kernel *kernel) {
// Get the index of our kernel in the kernels vector
size_t kernel_idx = get_kernel_index(kernel);
// Find all connections that have our kernel as the source
std::vector<Connection> outgoing_connections;
for (const Connection& connection : connections) {
if (connection.source.index == kernel_idx && connection.source.endpoint_type == Endpoint::EndpointType::Kernel) {
outgoing_connections.push_back(connection);
}
}
return outgoing_connections;
}
std::string data_format_to_string(tt::DataFormat data_format) {
// std::cout << "Data format: " << data_format << "\n";
switch (data_format) {
case tt::DataFormat::Float16_b: return "DataFormat::Float16_b";
case tt::DataFormat::Bfp8_b: return "DataFormat::Bfp8_b";
case tt::DataFormat::UInt32: return "DataFormat::UInt32";
default:
std::cerr << "Unsupported data format!\n";
exit(1);
}
}
void Map::generate_reader_device_kernel(
Kernel *kernel,
std::vector<Connection> incoming_connections
) {
// Generate reader kernel.
std::stringstream rs;
rs << "#include <cstdint>\n";
rs << "#include \"dataflow_api.h\"\n";
rs << "#include \"debug/dprint.h\"\n";
rs << "#include \"hostdevcommon/common_values.hpp\"\n";
rs << "#include \"debug/dprint.h\"\n";
rs << "void kernel_main() {\n";
rs << " DPRINT << \"READER0: Starting!\" << ENDL();\n";
// Reader params from kernel args
uint32_t total_args = 0;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
auto port = kernel->get_input_port(connection.dest.port);
if (connection.source.is_stream()) {
if (streams[connection.source.index]->is_gather_stream()) {
auto *gather_stream = dynamic_cast<GatherStream*>(streams[connection.source.index]);
// Kernel args for index data should be the same as a normal stream.
rs << " uint32_t " << port.name << "_ntiles = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_ntiles: \" << " << port.name << "_ntiles << ENDL();\n";
total_args++;
// For every incoming stream connection, we need to get it's address and create an address generator.
rs << " uint32_t " << port.name << "_addr = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_addr: \" << " << port.name << "_addr << ENDL();\n";
total_args++;
// TODO: Handle offsets for gather streams.
rs << " uint32_t " << port.name << "_tile_offset = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_tile_offset: \" << " << port.name << "_tile_offset << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_noc_x = get_arg_val<uint32_t>(" << total_args << ");\n";
total_args++;
rs << " uint32_t " << port.name << "_noc_y = get_arg_val<uint32_t>(" << total_args << ");\n";
total_args++;
rs << " uint64_t " << port.name << "_noc_addr = get_noc_addr(" << port.name << "_noc_x, " << port.name << "_noc_y, " << port.name << "_addr);\n";
rs << "\n";
// Address generator.
// TODO: Do we need this? How does this even work?
// rs << " const InterleavedAddrGenFast<true> " << port.name << "_addr_gen = {\n";
// rs << " .bank_base_address = " << port.name << "_addr, \n";
// rs << " .page_size = " << std::to_string(TILE_WIDTH * TILE_HEIGHT * gather_stream->element_size * gather_stream->accesses_per_token) << ", \n";
// rs << " .data_format = " << data_format_to_string(gather_stream->format) << ", \n";
// rs << " };\n\n";
// Args for data buffer of gather stream. This doesn't use an address generator since we are doing non-contiguous reads.
rs << " uint32_t " << port.name << "_data_addr = get_arg_val<uint32_t>(" << total_args << ");\n";
total_args++;
rs << " uint32_t " << port.name << "_data_dram_noc_x = get_arg_val<uint32_t>(" << total_args << ");\n";
total_args++;
rs << " uint32_t " << port.name << "_data_dram_noc_y = get_arg_val<uint32_t>(" << total_args << ");\n";
total_args++;
rs << " uint64_t " << port.name << "_data_dram_noc_addr = get_noc_addr(" << port.name << "_data_dram_noc_x, " << port.name << "_data_dram_noc_y, " << port.name << "_data_addr);\n";
rs << "\n";
} else {
auto *stream = streams[connection.source.index];
// Total # of tiles this kernel will read from this stream.
// Stream -> Kernel, get the input port index.
rs << " uint32_t " << port.name << "_ntiles = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_ntiles: \" << " << port.name << "_ntiles << ENDL();\n";
total_args++;
// For every incoming stream connection, we need to get it's address and create an address generator.
rs << " uint32_t " << port.name << "_addr = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_addr: \" << " << port.name << "_addr << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_tile_offset = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER0: " << port.name << "_tile_offset: \" << " << port.name << "_tile_offset << ENDL();\n";
total_args++;
// Address generator.
// TODO: Do we need this? How does this even work?
rs << " const InterleavedAddrGenFast<true> " << port.name << "_addr_gen = {\n";
rs << " .bank_base_address = " << port.name << "_addr, \n";
rs << " .page_size = " << TILE_WIDTH * TILE_HEIGHT * stream->element_size << ", \n";
rs << " .data_format = " << data_format_to_string(stream->format) << ", \n";
rs << " };\n\n";
}
} else {
rs << " uint32_t " << port.name << "_ntiles = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER1: " << port.name << "_ntiles: \" << " << port.name << "_ntiles << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_sender_noc_x = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER1: " << port.name << "_sender_noc_x: \" << " << port.name << "_sender_noc_x << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_sender_noc_y = get_arg_val<uint32_t>(" << total_args << ");\n";
rs << " DPRINT << \"READER1: " << port.name << "_sender_noc_y: \" << " << port.name << "_sender_noc_y << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_sender_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(" << total_args << "));\n";
rs << " DPRINT << \"READER1: " << port.name << "_sender_semaphore_addr: \" << " << port.name << "_sender_semaphore_addr << ENDL();\n";
total_args++;
rs << " uint32_t " << port.name << "_receiver_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(" << total_args << "));\n";
rs << " DPRINT << \"READER1: " << port.name << "_receiver_semaphore_addr: \" << " << port.name << "_receiver_semaphore_addr << ENDL();\n";
total_args++;
// rs << " uint32_t " << port.name << "_l1_valid_value_semaphore_id = get_arg_val<uint32_t>(" << total_args << ");\n";
// total_args++;
rs << " volatile tt_l1_ptr uint32_t* " << port.name << "_receiver_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(" << port.name << "_receiver_semaphore_addr);\n";
rs << " volatile tt_l1_ptr uint32_t* " << port.name << "_sender_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(" << port.name << "_sender_semaphore_addr);\n";
rs << " uint64_t " << port.name << "_sender_semaphore_noc_addr = get_noc_addr(" << port.name << "_sender_noc_x, " << port.name << "_sender_noc_y, " << port.name << "_sender_semaphore_addr);\n";
rs << " DPRINT << \"READER1: " << port.name << "_sender_semaphore_noc_addr: \" << " << port.name << "_sender_semaphore_noc_addr << ENDL();\n";
rs << "\n";
}
}
// Circular buffers.
uint32_t num_input_cbs = IN_CB_START;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
auto port = kernel->get_input_port(connection.dest.port);
if (connection.source.is_stream() && streams[connection.source.index]->is_gather_stream()) {
// Input is a gather stream, generate input CB for each access per token.
// Generate one intermed CB.
auto *stream = dynamic_cast<GatherStream*>(streams[connection.source.index]);
// The data will be put in an "input" CB, while the indices will be staged in an intermed CB.
rs << " constexpr uint32_t " << port.name << "_indices = " << INTERMED_CB_START + num_input_cbs << ";\n";
rs << " DPRINT << \"READER0: " << port.name << "_indices tile_size: \" << get_tile_size(" << port.name << "_indices" << ") << ENDL();\n";
// The intermed CB is solely used as a staging buffer for the indices, so we can fetch the write_ptr at the start.
rs << " uint32_t " << port.name << "_indices_write_ptr = get_write_ptr(" << port.name << "_indices);\n";
for (size_t access = 0; access < stream->accesses_per_token; access += 1) {
// Assign CBs to input ports in iteration order.
std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
rs << " constexpr uint32_t " << in_cb_name << " = " << num_input_cbs << ";\n";
rs << " DPRINT << \"READER0: " << in_cb_name << " tile_size: \" << get_tile_size(" << num_input_cbs << ") << ENDL();\n";
num_input_cbs++;
}
} else {
// Assign CBs to input ports in iteration order.
rs << " constexpr uint32_t " << port.name << " = " << num_input_cbs << ";\n";
rs << " DPRINT << \"READER0: " << port.name << " tile_size: \" << get_tile_size(" << port.name << ") << ENDL();\n";
num_input_cbs++;
}
}
rs << "\n";
if (incoming_connections.size() > 0) {
// Input tile stream loop.
for (size_t i = 0; i < incoming_connections.size(); i++) {
// Generate a counter variable for every incoming connection.
// Each incoming connection maps to a specific input port, which is managed by a CB.
auto connection = incoming_connections[i];
auto port = kernel->get_input_port(connection.dest.port);
rs << " uint32_t " << port.name << "_count = 0;\n";
}
// The break condition is when we've read the expected # of tiles from each input port.
// TODO: The only case when the incoming stream sizes would be different is if we are doing some sort of reduction on a stream
// e.g for each record of stream A we are dequining 4 records from stream B.
// In this case, we might be requesting/dequing N tiles at once. This needs to somehow be handled in the compute kernel.
// Right now even though we are acting as if the streams are different sizes, the compute kernel is still acting as if they are the same size.
std::string break_condition = "";
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto port = kernel->get_input_port(incoming_connections[i].dest.port);
break_condition += port.name + "_count < " + port.name + "_ntiles";
if (i != incoming_connections.size() - 1) {
break_condition += " && ";
}
}
// rs << " for(uint32_t i = 0; i < source0_n_tiles; i++) {\n";
rs << " while(" << break_condition << ") {\n";
// Wait for space in CBs
bool do_read_barrier = false;
num_input_cbs = IN_CB_START;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
if (!connection.source.is_stream()) {
continue;
}
auto port = kernel->get_input_port(incoming_connections[i].dest.port);
if (!streams[connection.source.index]->is_gather_stream()) {
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
rs << " cb_reserve_back(" << port.name << ", 1);\n";
rs << " }\n";
do_read_barrier = true;
num_input_cbs++;
} else {
auto *stream = dynamic_cast<GatherStream *>(streams[connection.source.index]);
// For each gather stream, we wait on an input cb per access.
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
for (size_t access = 0; access < stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
rs << " cb_reserve_back(" << in_cb_name << ", 1);\n";
do_read_barrier = true;
num_input_cbs++;
}
rs << " }\n";
}
}
// Read tile into CB from DRAM.
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
if (!connection.source.is_stream()) {
continue;
}
auto *stream = streams[connection.source.index];
auto port = kernel->get_input_port(connection.dest.port);
if (!stream->is_gather_stream()) {
// Not a gather stream, just fetch stream data from DRAM to CB.
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
rs << " uint32_t " << port.name << "_write_ptr = get_write_ptr(" << port.name << ");\n";
rs << " uint32_t id = " << port.name << "_tile_offset + " << port.name << "_count;\n";
rs << " DPRINT << \"READER0: id: \" << id << ENDL();\n";
rs << " noc_async_read_tile(id, " << port.name << "_addr_gen, " << port.name << "_write_ptr);\n";
rs << " }\n";
} else {
auto *gather_stream = dynamic_cast<GatherStream*>(stream);
// Gather stream, fetch index data from DRAM to intermed CB
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
rs << " uint32_t id = "<< port.name << "_tile_offset + " << port.name << "_count;\n";
rs << " DPRINT << \"READER0: id: \" << id << ENDL();\n";
// rs << " noc_async_read_tile(id, " << port.name << "_addr_gen, " << port.name << "_indices_write_ptr);\n";
auto tile_size_bytes = TILE_SIZE * gather_stream->accesses_per_token * 4;
rs << " noc_async_read(" << port.name << "_noc_addr + (id * " << tile_size_bytes << ")" << ", " << port.name << "_indices_write_ptr, " << tile_size_bytes << ");\n"; // TODO: Don't hardcode offset and size values.
rs << " }\n";
}
}
// Wait until tile reads are done.
// TODO: Don't do if not reading from stream.
rs << "\n";
if (do_read_barrier) {
rs << " noc_async_read_barrier();\n";
do_read_barrier = false;
}
rs << "\n";
// Process gather stream read requests from index CBs.
num_input_cbs = IN_CB_START;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
if (!connection.source.is_stream()) {
num_input_cbs++;
continue;
}
auto *stream = streams[connection.source.index];
if (!stream->is_gather_stream()) {
num_input_cbs++;
continue;
}
auto *gather_stream = dynamic_cast<GatherStream*>(stream);
if (!gather_stream->use_sram) {
do_read_barrier = true;
auto port = kernel->get_input_port(connection.dest.port);
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
// auto tmp_n_in_cbs = num_input_cbs;
// for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
// std::string in_cb_name = port.name + "_" + std::to_string(tmp_n_in_cbs);
// rs << " uint32_t " << in_cb_name << "_write_ptr = get_write_ptr(" << in_cb_name << ");\n";
// tmp_n_in_cbs++;
// }
// rs << " uint32_t index;\n";
// rs << " for (int i = 0; i < " << TILE_SIZE << "; i++) {\n";
// for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
// std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
// rs << " index = *(((uint32_t *)" << port.name << "_indices_write_ptr) + i + " << access << ") * 32;\n";
// // rs << " DPRINT << \"[READER 0] index: \" << index << ENDL();\n";
// rs << " uint32_t " << in_cb_name << "_offset = i * " << datum_size(gather_stream->data_format) << ";\n";
// rs << " noc_async_read(" << port.name << "_data_dram_noc_addr + index, " << in_cb_name << "_write_ptr + " << in_cb_name << "_offset, " << datum_size(gather_stream->data_format) << ");\n";
// // rs << " DPRINT << \"[READER 0] data: \" << float(*(((uint16_t *)" << port.name << "_write_ptr) + i)) << ENDL();\n";
// num_input_cbs++;
// }
// rs << " }\n";
// First generate write pointer declarations
auto tmp_n_in_cbs = num_input_cbs;
for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(tmp_n_in_cbs);
rs << " uint32_t " << in_cb_name << "_write_ptr = get_write_ptr(" << in_cb_name << ");\n";
tmp_n_in_cbs++;
}
rs << " uint32_t index;\n";
rs << " for (int i = 0; i < " << TILE_SIZE << "; i++) {\n";
// For each access, read from its interleaved position
for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
// Calculate interleaved index position: i * accesses_per_token + access
rs << " index = *(((uint32_t *)" << port.name << "_indices_write_ptr) + (i * " << std::to_string(gather_stream->accesses_per_token) << " + " << access << ")) * 32;\n";
rs << " uint32_t " << in_cb_name << "_offset = i * " << datum_size(gather_stream->data_format) << ";\n";
rs << " noc_async_read(" << port.name << "_data_dram_noc_addr + index, "
<< in_cb_name << "_write_ptr + " << in_cb_name << "_offset, "
<< datum_size(gather_stream->data_format) << ");\n";
num_input_cbs++;
}
rs << " }\n";
rs << " }\n";
} else {
auto port = kernel->get_input_port(connection.dest.port);
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
auto tmp_n_in_cbs = num_input_cbs;
for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(tmp_n_in_cbs);
rs << " uint16_t *" << in_cb_name << "_write_ptr = (uint16_t *)get_write_ptr(" << in_cb_name << ");\n";
tmp_n_in_cbs++;
}
rs << " uint16_t *" << port.name << "_sram_ptr = (uint16_t *)" << port.name << "_data_addr;\n";
rs << " uint32_t index;\n";
rs << " for (int i = 0; i < " << TILE_SIZE << "; i++) {\n";
// For each access.
for (size_t access = 0; access < gather_stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
// Calculate interleaved index position: i * accesses_per_token + access
rs << " index = *(((uint32_t *)" << port.name << "_indices_write_ptr) + (i * " << std::to_string(gather_stream->accesses_per_token) << " + " << access << "));\n";
// rs << " index = *(((uint32_t *)" << port.name << "_indices_write_ptr) + i);\n";
// rs << " DPRINT << \"[READER 0] index: \" << index << ENDL();\n";
// rs << " uint32_t " << port.name << "_offset = i * " << datum_size(gather_stream->data_format) << ";\n";
rs << " " << in_cb_name << "_write_ptr[i] = " << port.name << "_sram_ptr[index];\n";
// rs << " noc_async_read(" << port.name << "_data_dram_noc_addr + index, " << port.name << "_write_ptr + " << port.name << "_offset, " << datum_size(gather_stream->data_format) << ");\n";
// rs << " DPRINT << \"[READER 0] data: \" << float(*(((uint16_t *)" << port.name << "_write_ptr) + i)) << ENDL();\n";
num_input_cbs++;
}
rs << " }\n";
rs << " }\n";
}
}
rs << "\n";
if (do_read_barrier) {
rs << " noc_async_read_barrier();\n";
do_read_barrier = false;
}
rs << "\n";
// Push tiles into CBs and increment counters.
// Signals to compute engine that a tile is ready to be processed.
num_input_cbs = IN_CB_START;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
if (!connection.source.is_stream()) {
num_input_cbs++;
continue;
}
if (!streams[connection.source.index]->is_gather_stream()) {
auto port = kernel->get_input_port(incoming_connections[i].dest.port);
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
rs << " cb_push_back(" << port.name << ", 1);\n";
rs << " " << port.name << "_count++;\n";
rs << " }\n";
num_input_cbs++;
} else {
auto port = kernel->get_input_port(incoming_connections[i].dest.port);
auto *stream = dynamic_cast<GatherStream*>(streams[connection.source.index]);
rs << " if (" << port.name << "_count < " << port.name << "_ntiles) {\n";
for (size_t access = 0; access < stream->accesses_per_token; access++) {
std::string in_cb_name = port.name + "_" + std::to_string(num_input_cbs);
rs << " cb_push_back(" << in_cb_name << ", 1);\n";
num_input_cbs++;
}
rs << " " << port.name << "_count++;\n";
rs << " }\n";
}
}
// Do receiver stuff.
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto connection = incoming_connections[i];
if (connection.source.is_stream()) {
continue;
}
auto port = kernel->get_input_port(incoming_connections[i].dest.port);
// Wait for space in CB.
rs << " DPRINT << \"READER1: Waiting for space in CB " << port.name << "\" << ENDL();\n";
rs << " cb_reserve_back(" << port.name << ", 1);\n";
// Reset receiver's own semaphore value to INVALID.
rs << " noc_semaphore_set(" << port.name << "_receiver_semaphore_addr_ptr, INVALID);\n";
// Tell sender we're ready -- atomic increment sender's semaphore.
rs << " DPRINT << \"READER1: Telling sender we're ready\" << ENDL();\n";
// rs << " noc_semaphore_inc(" << port.name << "_sender_semaphore_noc_addr, 1);\n";
rs << " *(" << port.name << "_sender_semaphore_addr_ptr) = get_write_ptr(" << port.name << ");\n";
// rs << " DPRINT << \"READER1: in cb ptr: << \" *(" << port.name << "_sender_semaphore_addr_ptr) << ENDL();\n";
rs << " noc_semaphore_set_remote(" << port.name << "_sender_semaphore_addr, " << port.name << "_sender_semaphore_noc_addr);\n";
// Wait on receiver's own semaphore value to become VALID (set by sender after it sends the data).
rs << " DPRINT << \"READER1: Waiting on receiver's semaphore\" << ENDL();\n";
rs << " noc_semaphore_wait(" << port.name << "_receiver_semaphore_addr_ptr, VALID);\n";
rs << " DPRINT << \"READER1: Receiver's semaphore is VALID!\" << ENDL();\n";
// Push tile into CB.
rs << " DPRINT << \"READER1: Pushing tile into CB\" << ENDL();\n";
rs << " cb_push_back(" << port.name << ", 1);\n";
// Increment counter.
rs << " " << port.name << "_count++;\n";
}
rs << " }\n";
// End tile stream loop.
rs << " DPRINT << \"READER0: Done!\" << ENDL();\n";
}
rs << "}\n";
rs << "\n";
std::string filename = "reader" + std::to_string(get_kernel_index(kernel)) + ".cpp";
kernel->generated_reader_kernel_path = GENERATED_KERNELS_PATH / filename;
auto reader_kernel_file = std::ofstream(kernel->generated_reader_kernel_path);
if (!reader_kernel_file.is_open()) {
tt::log_error("[CURRENT] Failed to open file for writing: {}", kernel->generated_reader_kernel_path);
exit(1);
}
reader_kernel_file << rs.str();
reader_kernel_file.close();
}
void Map::generate_compute_device_kernel(
Kernel *kernel,
std::vector<Connection> incoming_connections,
std::vector<Connection> outgoing_connections
) {
std::stringstream cs;
// Includes
cs << "#include \"compute_kernel_api/common.h\"\n";
cs << "#include \"compute_kernel_api/tile_move_copy.h\"\n";
cs << "#include \"compute_kernel_api/eltwise_binary.h\"\n";
cs << "#include \"compute_kernel_api/eltwise_unary/eltwise_unary.h\"\n";
cs << "#include \"compute_kernel_api/matmul.h\"\n";
cs << "#include \"compute_kernel_api.h\"\n";
cs << "#include \"cmath_common.h\"\n";
cs << "#include \"sfpi.h\"\n";
cs << "#include \"debug/dprint.h\"\n";
cs << "\n";
// SFPU computation
cs << "namespace sfpi {\n";
// cs << "template< int ITERATIONS = 16 >\n";
cs << "sfpi_inline void compute() {\n";
// Set destination write address.
cs << " math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(0);\n";
// If we don't have a specifed compute kernel, don't generate anything.
if (!kernel->sfpi_kernel_string.empty()) {
// TODO: Do a better optimization if we don't have a compute kernel.
// Can probably avoid any call to the sfpi function, don't need to do sfpi init? idk
cs << " for (int i = 0; i < 32; i++) {\n";
// Get input variables.
auto total_incoming = 0;
for (size_t i = 0; i < incoming_connections.size(); i++) {
auto conn = incoming_connections[i];
if (conn.source.is_stream() && streams[conn.source.index]->is_gather_stream()) {
auto *stream = dynamic_cast<GatherStream*>(streams[conn.source.index]);
for (size_t access = 0; access < stream->accesses_per_token; access++) {
cs << " vFloat in" << std::to_string(total_incoming) << " = dst_reg[" << std::to_string(total_incoming) << " * 32 + i];\n";
total_incoming++;
}
} else {
cs << " vFloat in" << std::to_string(total_incoming) << " = dst_reg[" << std::to_string(total_incoming) << " * 32 + i];\n";
total_incoming++;
}
}
// Declare output variables.
for (size_t i = 0; i < outgoing_connections.size(); i++) {
cs << " vFloat out" << i << ";\n";
}
cs << kernel->sfpi_kernel_string;
// Assign output variables.
for (size_t i = 0; i < outgoing_connections.size(); i++) {
cs << " dst_reg[" << i << " * 32 + i] = out" << i << ";\n";
}
cs << " }\n";
}
// cs << " for (int i = 0; i < ITERATIONS; i++) {\n";
// cs << " vFloat in = dst_reg[i];\n";
// cs << " vFloat a = in + 1.0f;\n";
// cs << " vFloat out = a;\n";
// cs << " dst_reg[i] = out;\n";
// cs << " }\n";
cs << "}\n";
cs << "}\n";
cs << "\n";
// Main function.
cs << "namespace NAMESPACE {\n";
cs << "void MAIN {\n";