vec_cmpv_xx
功能说明
逐element比较两个tensor大小,如果为真则对应比特位为1,否则为0,支持多种比较模式。
函数原型
vec_cmpv_xx (dst, src0, src1, repeat_times, src0_rep_stride, src1_rep_stride)
PIPE:Vector
参数说明
参数名称 |
输入/输出 |
含义 |
|---|---|---|
instruction |
输入 |
指令名称,支持以下几种比较:
|
dst |
输出 |
目的操作数,tensor中起始element,支持uint64, uint32, uint16, uint8。 Tensor的scope为Unified Buffer。 |
src0 |
输入 |
源操作数0,tensor中起始element。 Tensor的scope为Unified Buffer。 Atlas 200/300/500 推理产品,支持的数据类型为:Tensor(float16)。 Atlas 训练系列产品,支持的数据类型为:Tensor(float16/float32)。 Atlas推理系列产品AI Core,支持的数据类型为:Tensor(float16/float32)。 Atlas推理系列产品Vector Core,支持的数据类型为:Tensor(float16/float32)。 Atlas A2训练系列产品/Atlas 800I A2推理产品,src0/src1支持的数据类型为:Tensor(float16/float32)。当指令为vec_cmpv_eq,src0/src1可以支持Tensor(float16/float32/int32)。 Atlas 200/500 A2推理产品,src0/src1支持的数据类型为:Tensor(float16/float32)。当指令为vec_cmpv_eq,src0/src1可以支持Tensor(uint8/int8/float16/float32/int32)。 |
src1 |
输入 |
源操作数1,tensor中起始element。 Tensor的scope为Unified Buffer。 数据类型需保证与src0类型一致。 |
repeat_times |
输入 |
重复迭代次数。
|
src0_rep_stride |
输入 |
相邻迭代间,源操作数0相同block地址步长。 |
src1_rep_stride |
输入 |
相邻迭代间,源操作数1相同block地址步长。 |
返回值
无。
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas推理系列产品AI Core
Atlas推理系列产品Vector Core
Atlas A2训练系列产品/Atlas 800I A2推理产品
Atlas 200/500 A2推理产品
注意事项
- 无mask参数。
- dst连续产生。比如当源操作数为float16,目的操作数为uint16时,相邻迭代间dst跳8个elements;当源操作数float32,目的操作数为uint16时,跳4个elements。
- src0_rep_stride/src1_rep_stride
;单位:block_size ;支持的数据类型为:Scalar(int16/int32/int64/uint16/uint32/uint64)、立即数(int)、Expr(int16/int32/int64/uint16/uint32/uint64)。 - 为了节省地址空间,开发者可以定义一个Tensor,供源操作数与目的操作数同时使用(即地址重叠),相关约束如下:
- 对于单次repeat(repeat_times=1),且源操作数与目的操作数之间要求100%完全重叠,不支持部分重叠。
- 对于多次repeat(repeat_times>1),若源操作数与目的操作数之间存在依赖,即第N次迭代的目的操作数是第N+1次的源操作数,这种情况是不支持地址重叠的。
- 操作数地址偏移对齐要求请见通用约束。
调用示例
- 调用示例1
from tbe import tik
tik_instance = tik.Tik()
src0_gm = tik_instance.Tensor("float16", (128,), name="src0_gm", scope=tik.scope_gm)
src1_gm = tik_instance.Tensor("float16", (128,), name="src1_gm", scope=tik.scope_gm)
src0_ub = tik_instance.Tensor("float16", (128,), name="src0_ub", scope=tik.scope_ubuf)
src1_ub = tik_instance.Tensor("float16", (128,), name="src1_ub", scope=tik.scope_ubuf)
dst_gm = tik_instance.Tensor("uint16", (16,), name="dst_gm", scope=tik.scope_gm)
dst_ub = tik_instance.Tensor("uint16", (16,), name="dst_ub", scope=tik.scope_ubuf)
# 拷贝用户输入数据到src ubuf
tik_instance.data_move(src0_ub, src0_gm, 0, 1, 8, 0, 0)
tik_instance.data_move(src1_ub, src1_gm, 0, 1, 8, 0, 0)
# 将dst_ub初始化为全5
tik_instance.vec_dup(16, dst_ub, 5, 1, 1)
tik_instance.vec_cmpv_eq(dst_ub, src0_ub, src1_ub, 1, 8, 8)
# 将计算结果拷贝到目标gm
tik_instance.data_move(dst_gm, dst_ub, 0, 1, 1, 0, 0)
tik_instance.BuildCCE(kernel_name="vec_cmpv_eq", inputs=[src0_gm, src1_gm], outputs=[dst_gm])
结果示例:
输入数据(float16):
src0_gm = {1,2,3,...,128}
src1_gm = {2,2,2,...,2}
输出结果:
dst_gm = {2,0,0,0,0,0,0,0,5,5,5,5,5,5,5,5}
- 调用示例2
"""
将两组各256个源操作数,经过指令经vec_cmpv_gt处理,处理得结果前一半src0数据与src1数据相等,后一半数据src0数据大于src1
"""
from tbe import tik
tik_instance = tik.Tik()
dtype_size = {
"int8": 1,
"uint8": 1,
"int16": 2,
"uint16": 2,
"float16": 2,
"int32": 4,
"uint32": 4,
"float32": 4,
"int64": 8,
}
src_shape = (2, 128)
dst_shape = (16, )
src_dtype = "float16"
dst_dtype = "uint16"
elements = 2 * 128
# 迭代次数,当前示例进行了2次迭代,可根据需要调整对应的迭代次数
repeat_times = 2
# 迭代间目的操作数前一次repeat头与后一次repeat头之间的距离,单位32B, src0 间隔8个block,src1间隔7个block,所以第二次迭代src0数据都大于src1
src0_rep_stride = 8
src1_rep_stride = 7
src0_gm = tik_instance.Tensor(src_dtype, src_shape, name="src0_gm", scope=tik.scope_gm)
src1_gm = tik_instance.Tensor(src_dtype, src_shape, name="src1_gm", scope=tik.scope_gm)
dst_gm = tik_instance.Tensor(dst_dtype, dst_shape, name="dst_gm", scope=tik.scope_gm)
src0_ub = tik_instance.Tensor(src_dtype, src_shape, name="src0_ub", scope=tik.scope_ubuf)
src1_ub = tik_instance.Tensor(src_dtype, src_shape, name="src1_ub", scope=tik.scope_ubuf)
dst_ub = tik_instance.Tensor(dst_dtype, dst_shape, name="dst_ub", scope=tik.scope_ubuf)
# 搬移的片段数
nburst = 1
# 每次搬运的片段长度,单位32B
burst = elements * dtype_size[src_dtype] // 32 // nburst
# 前burst尾与后burst头的距离,单位32B
dst_stride, src_stride = 0, 0
# 拷贝用户输入数据到src ubuf
tik_instance.data_move(src0_ub, src0_gm, 0, nburst, burst, src_stride, dst_stride)
tik_instance.data_move(src1_ub, src1_gm, 0, nburst, burst, src_stride, dst_stride)
tik_instance.vec_cmpv_gt(dst_ub, src0_ub, src1_ub, repeat_times, src0_rep_stride, src1_rep_stride)
# 将计算结果拷贝到目标gm
tik_instance.data_move(dst_gm, dst_ub, 0, nburst, 1, src_stride, dst_stride)
tik_instance.BuildCCE(kernel_name="vec_cmpv_gt", inputs=[src0_gm, src1_gm], outputs=[dst_gm])
示例结果
输入数据(src0_gm):
[[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.
14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.
28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41.
42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.
56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69.
70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83.
84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.
98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.
112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
126. 127.]
[128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141.
142. 143. 144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154. 155.
156. 157. 158. 159. 160. 161. 162. 163. 164. 165. 166. 167. 168. 169.
170. 171. 172. 173. 174. 175. 176. 177. 178. 179. 180. 181. 182. 183.
184. 185. 186. 187. 188. 189. 190. 191. 192. 193. 194. 195. 196. 197.
198. 199. 200. 201. 202. 203. 204. 205. 206. 207. 208. 209. 210. 211.
212. 213. 214. 215. 216. 217. 218. 219. 220. 221. 222. 223. 224. 225.
226. 227. 228. 229. 230. 231. 232. 233. 234. 235. 236. 237. 238. 239.
240. 241. 242. 243. 244. 245. 246. 247. 248. 249. 250. 251. 252. 253.
254. 255.]]
输入数据(src1_gm):
[[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.
14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.
28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41.
42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.
56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69.
70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83.
84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.
98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.
112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
126. 127.]
[128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141.
142. 143. 144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154. 155.
156. 157. 158. 159. 160. 161. 162. 163. 164. 165. 166. 167. 168. 169.
170. 171. 172. 173. 174. 175. 176. 177. 178. 179. 180. 181. 182. 183.
184. 185. 186. 187. 188. 189. 190. 191. 192. 193. 194. 195. 196. 197.
198. 199. 200. 201. 202. 203. 204. 205. 206. 207. 208. 209. 210. 211.
212. 213. 214. 215. 216. 217. 218. 219. 220. 221. 222. 223. 224. 225.
226. 227. 228. 229. 230. 231. 232. 233. 234. 235. 236. 237. 238. 239.
240. 241. 242. 243. 244. 245. 246. 247. 248. 249. 250. 251. 252. 253.
254. 255.]]
输出数据(dst_gm):
[ 0 0 0 0 0 0 0 0 65535 65535 65535 65535
65535 65535 65535 65535]