-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathrope.cc
More file actions
84 lines (69 loc) · 3.22 KB
/
rope.cc
File metadata and controls
84 lines (69 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include <aie_api/aie.hpp>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
template <typename T, int N>
void rope_kernel_interleaved(const T *restrict input, const T *restrict lut, T *restrict output, int32_t dims)
{
event0();
for (int v = 0; v < dims; v += N) {
::aie::vector<T, N> x = ::aie::load_v<N>(input + v);
::aie::vector<T, N> cache = ::aie::load_v<N>(lut + v);
// Extract even and odd elements
::aie::vector<T, N / 2> x_even = ::aie::filter_even(x, 1);
::aie::vector<T, N / 2> x_odd = ::aie::filter_odd(x, 1);
::aie::vector<T, N / 2> cos_val = ::aie::filter_even(cache, 1);
::aie::vector<T, N / 2> sin_val = ::aie::filter_odd(cache, 1);
// Perform ROPE calculations
::aie::vector<T, N / 2> even_cos = ::aie::mul(x_even, cos_val);
::aie::vector<T, N / 2> even_sin = ::aie::mul(x_even, sin_val);
::aie::vector<T, N / 2> odd_cos = ::aie::mul(x_odd, cos_val);
::aie::vector<T, N / 2> odd_sin = ::aie::mul(x_odd, sin_val);
::aie::vector<T, N / 2> output_even = ::aie::sub(even_cos, odd_sin);
::aie::vector<T, N / 2> output_odd = ::aie::add(even_sin, odd_cos);
auto [low, high] = ::aie::interleave_zip(output_even, output_odd, 1);
::aie::vector<T, N> y = ::aie::concat(low, high);
::aie::store_v(output + v, y);
}
event1();
}
template <typename T, int N>
void rope_kernel_two_halves(const T *restrict input, const T *restrict lut, T *restrict output, int32_t dims)
{
event0();
auto dims_half = dims / 2;
for (int v = 0, i = 0; v < dims_half; v += N, i += 2 * N) {
// Extract the two halves of inputs
::aie::vector<T, N> x1 = ::aie::load_v<N>(input + v);
::aie::vector<T, N> x2 = ::aie::load_v<N>(input + v + dims_half);
::aie::vector<T, 2 * N> cache = ::aie::load_v<2 * N>(lut + i);
// Extract angles
::aie::vector<T, N> cos_val = ::aie::filter_even(cache, 1);
::aie::vector<T, N> sin_val = ::aie::filter_odd(cache, 1);
// Perform ROPE calculations for the first half: x1*cos-x2*sin
::aie::vector<T, N> x1_cos = ::aie::mul(x1, cos_val);
::aie::vector<T, N> x2_sin = ::aie::mul(x2, sin_val);
::aie::vector<T, N> y_first_half = ::aie::sub(x1_cos, x2_sin);
::aie::store_v(output + v, y_first_half);
// Perform ROPE calculations for the second half: x2*cos+x1*sin
::aie::vector<T, N> x2_cos = ::aie::mul(x2, cos_val);
::aie::vector<T, N> x1_sin = ::aie::mul(x1, sin_val);
::aie::vector<T, N> y_second_half = ::aie::add(x2_cos, x1_sin);
::aie::store_v(output + v + dims_half, y_second_half);
}
event1();
}
extern "C" {
void rope(bfloat16 *input, bfloat16 *lut, bfloat16 *output, int32_t dims)
{
#if defined(TWO_HALVES)
rope_kernel_two_halves<bfloat16, 16>(input, lut, output, dims); // For the two-halves method used in HF transformers
#elif defined(INTERLEAVED)
rope_kernel_interleaved<bfloat16, 16>(
input, lut, output, dims); // For the interleaved method used in the Llama paper
#endif
}
}