OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
(...skipping 49 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
60 // | 60 // |
61 // If you pass in storage through the ctor, that storage is copied into the | 61 // If you pass in storage through the ctor, that storage is copied into the |
62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced | 62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced |
63 // instead of copied, and owned by the user. | 63 // instead of copied, and owned by the user. |
64 template <typename T> | 64 template <typename T> |
65 class Matrix { | 65 class Matrix { |
66 public: | 66 public: |
67 Matrix() : num_rows_(0), num_columns_(0) {} | 67 Matrix() : num_rows_(0), num_columns_(0) {} |
68 | 68 |
69 // Allocates space for the elements and initializes all values to zero. | 69 // Allocates space for the elements and initializes all values to zero. |
70 Matrix(int num_rows, int num_columns) | 70 Matrix(size_t num_rows, size_t num_columns) |
71 : num_rows_(num_rows), num_columns_(num_columns) { | 71 : num_rows_(num_rows), num_columns_(num_columns) { |
72 Resize(); | 72 Resize(); |
73 scratch_data_.resize(num_rows_ * num_columns_); | 73 scratch_data_.resize(num_rows_ * num_columns_); |
74 scratch_elements_.resize(num_rows_); | 74 scratch_elements_.resize(num_rows_); |
75 } | 75 } |
76 | 76 |
77 // Copies |data| into the new Matrix. | 77 // Copies |data| into the new Matrix. |
78 Matrix(const T* data, int num_rows, int num_columns) | 78 Matrix(const T* data, size_t num_rows, size_t num_columns) |
79 : num_rows_(0), num_columns_(0) { | 79 : num_rows_(0), num_columns_(0) { |
80 CopyFrom(data, num_rows, num_columns); | 80 CopyFrom(data, num_rows, num_columns); |
81 scratch_data_.resize(num_rows_ * num_columns_); | 81 scratch_data_.resize(num_rows_ * num_columns_); |
82 scratch_elements_.resize(num_rows_); | 82 scratch_elements_.resize(num_rows_); |
83 } | 83 } |
84 | 84 |
85 virtual ~Matrix() {} | 85 virtual ~Matrix() {} |
86 | 86 |
87 // Deep copy an existing matrix. | 87 // Deep copy an existing matrix. |
88 void CopyFrom(const Matrix& other) { | 88 void CopyFrom(const Matrix& other) { |
89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_); | 89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_); |
90 } | 90 } |
91 | 91 |
92 // Copy |data| into the Matrix. The current data is lost. | 92 // Copy |data| into the Matrix. The current data is lost. |
93 void CopyFrom(const T* const data, int num_rows, int num_columns) { | 93 void CopyFrom(const T* const data, size_t num_rows, size_t num_columns) { |
94 Resize(num_rows, num_columns); | 94 Resize(num_rows, num_columns); |
95 memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0])); | 95 memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0])); |
96 } | 96 } |
97 | 97 |
98 Matrix& CopyFromColumn(const T* const* src, | 98 Matrix& CopyFromColumn(const T* const* src, |
99 size_t column_index, | 99 size_t column_index, |
100 int num_rows) { | 100 size_t num_rows) { |
101 Resize(1, num_rows); | 101 Resize(1, num_rows); |
102 for (int i = 0; i < num_columns_; ++i) { | 102 for (size_t i = 0; i < num_columns_; ++i) { |
103 data_[i] = src[i][column_index]; | 103 data_[i] = src[i][column_index]; |
104 } | 104 } |
105 | 105 |
106 return *this; | 106 return *this; |
107 } | 107 } |
108 | 108 |
109 void Resize(int num_rows, int num_columns) { | 109 void Resize(size_t num_rows, size_t num_columns) { |
110 if (num_rows != num_rows_ || num_columns != num_columns_) { | 110 if (num_rows != num_rows_ || num_columns != num_columns_) { |
111 num_rows_ = num_rows; | 111 num_rows_ = num_rows; |
112 num_columns_ = num_columns; | 112 num_columns_ = num_columns; |
113 Resize(); | 113 Resize(); |
114 } | 114 } |
115 } | 115 } |
116 | 116 |
117 // Accessors and mutators. | 117 // Accessors and mutators. |
118 int num_rows() const { return num_rows_; } | 118 size_t num_rows() const { return num_rows_; } |
119 int num_columns() const { return num_columns_; } | 119 size_t num_columns() const { return num_columns_; } |
120 T* const* elements() { return &elements_[0]; } | 120 T* const* elements() { return &elements_[0]; } |
121 const T* const* elements() const { return &elements_[0]; } | 121 const T* const* elements() const { return &elements_[0]; } |
122 | 122 |
123 T Trace() { | 123 T Trace() { |
124 CHECK_EQ(num_rows_, num_columns_); | 124 CHECK_EQ(num_rows_, num_columns_); |
125 | 125 |
126 T trace = 0; | 126 T trace = 0; |
127 for (int i = 0; i < num_rows_; ++i) { | 127 for (size_t i = 0; i < num_rows_; ++i) { |
128 trace += elements_[i][i]; | 128 trace += elements_[i][i]; |
129 } | 129 } |
130 return trace; | 130 return trace; |
131 } | 131 } |
132 | 132 |
133 // Matrix Operations. Returns *this to support method chaining. | 133 // Matrix Operations. Returns *this to support method chaining. |
134 Matrix& Transpose() { | 134 Matrix& Transpose() { |
135 CopyDataToScratch(); | 135 CopyDataToScratch(); |
136 Resize(num_columns_, num_rows_); | 136 Resize(num_columns_, num_rows_); |
137 return Transpose(scratch_elements()); | 137 return Transpose(scratch_elements()); |
(...skipping 137 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
275 | 275 |
276 CopyDataToScratch(); | 276 CopyDataToScratch(); |
277 Resize(num_rows_, rhs.num_columns_); | 277 Resize(num_rows_, rhs.num_columns_); |
278 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements()); | 278 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements()); |
279 } | 279 } |
280 | 280 |
281 std::string ToString() const { | 281 std::string ToString() const { |
282 std::ostringstream ss; | 282 std::ostringstream ss; |
283 ss << std::endl << "Matrix" << std::endl; | 283 ss << std::endl << "Matrix" << std::endl; |
284 | 284 |
285 for (int i = 0; i < num_rows_; ++i) { | 285 for (size_t i = 0; i < num_rows_; ++i) { |
286 for (int j = 0; j < num_columns_; ++j) { | 286 for (size_t j = 0; j < num_columns_; ++j) { |
287 ss << elements_[i][j] << " "; | 287 ss << elements_[i][j] << " "; |
288 } | 288 } |
289 ss << std::endl; | 289 ss << std::endl; |
290 } | 290 } |
291 ss << std::endl; | 291 ss << std::endl; |
292 | 292 |
293 return ss.str(); | 293 return ss.str(); |
294 } | 294 } |
295 | 295 |
296 protected: | 296 protected: |
297 void SetNumRows(const int num_rows) { num_rows_ = num_rows; } | 297 void SetNumRows(const size_t num_rows) { num_rows_ = num_rows; } |
298 void SetNumColumns(const int num_columns) { num_columns_ = num_columns; } | 298 void SetNumColumns(const size_t num_columns) { num_columns_ = num_columns; } |
299 T* data() { return &data_[0]; } | 299 T* data() { return &data_[0]; } |
300 const T* data() const { return &data_[0]; } | 300 const T* data() const { return &data_[0]; } |
301 const T* const* scratch_elements() const { return &scratch_elements_[0]; } | 301 const T* const* scratch_elements() const { return &scratch_elements_[0]; } |
302 | 302 |
303 // Resize the matrix. If an increase in capacity is required, the current | 303 // Resize the matrix. If an increase in capacity is required, the current |
304 // data is lost. | 304 // data is lost. |
305 void Resize() { | 305 void Resize() { |
306 size_t size = num_rows_ * num_columns_; | 306 size_t size = num_rows_ * num_columns_; |
307 data_.resize(size); | 307 data_.resize(size); |
308 elements_.resize(num_rows_); | 308 elements_.resize(num_rows_); |
309 | 309 |
310 for (int i = 0; i < num_rows_; ++i) { | 310 for (size_t i = 0; i < num_rows_; ++i) { |
311 elements_[i] = &data_[i * num_columns_]; | 311 elements_[i] = &data_[i * num_columns_]; |
312 } | 312 } |
313 } | 313 } |
314 | 314 |
315 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly. | 315 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly. |
316 void CopyDataToScratch() { | 316 void CopyDataToScratch() { |
317 scratch_data_ = data_; | 317 scratch_data_ = data_; |
318 scratch_elements_.resize(num_rows_); | 318 scratch_elements_.resize(num_rows_); |
319 | 319 |
320 for (int i = 0; i < num_rows_; ++i) { | 320 for (size_t i = 0; i < num_rows_; ++i) { |
321 scratch_elements_[i] = &scratch_data_[i * num_columns_]; | 321 scratch_elements_[i] = &scratch_data_[i * num_columns_]; |
322 } | 322 } |
323 } | 323 } |
324 | 324 |
325 private: | 325 private: |
326 int num_rows_; | 326 size_t num_rows_; |
327 int num_columns_; | 327 size_t num_columns_; |
328 std::vector<T> data_; | 328 std::vector<T> data_; |
329 std::vector<T*> elements_; | 329 std::vector<T*> elements_; |
330 | 330 |
331 // Stores temporary copies of |data_| and |elements_| for in-place operations | 331 // Stores temporary copies of |data_| and |elements_| for in-place operations |
332 // where referring to original data is necessary. | 332 // where referring to original data is necessary. |
333 std::vector<T> scratch_data_; | 333 std::vector<T> scratch_data_; |
334 std::vector<T*> scratch_elements_; | 334 std::vector<T*> scratch_elements_; |
335 | 335 |
336 // Helpers for Transpose and Multiply operations that unify in-place and | 336 // Helpers for Transpose and Multiply operations that unify in-place and |
337 // out-of-place solutions. | 337 // out-of-place solutions. |
338 Matrix& Transpose(const T* const* src) { | 338 Matrix& Transpose(const T* const* src) { |
339 for (int i = 0; i < num_rows_; ++i) { | 339 for (size_t i = 0; i < num_rows_; ++i) { |
340 for (int j = 0; j < num_columns_; ++j) { | 340 for (size_t j = 0; j < num_columns_; ++j) { |
341 elements_[i][j] = src[j][i]; | 341 elements_[i][j] = src[j][i]; |
342 } | 342 } |
343 } | 343 } |
344 | 344 |
345 return *this; | 345 return *this; |
346 } | 346 } |
347 | 347 |
348 Matrix& Multiply(const T* const* lhs, int num_rows_rhs, const T* const* rhs) { | 348 Matrix& Multiply(const T* const* lhs, |
349 for (int row = 0; row < num_rows_; ++row) { | 349 size_t num_rows_rhs, |
350 for (int col = 0; col < num_columns_; ++col) { | 350 const T* const* rhs) { |
| 351 for (size_t row = 0; row < num_rows_; ++row) { |
| 352 for (size_t col = 0; col < num_columns_; ++col) { |
351 T cur_element = 0; | 353 T cur_element = 0; |
352 for (int i = 0; i < num_rows_rhs; ++i) { | 354 for (size_t i = 0; i < num_rows_rhs; ++i) { |
353 cur_element += lhs[row][i] * rhs[i][col]; | 355 cur_element += lhs[row][i] * rhs[i][col]; |
354 } | 356 } |
355 | 357 |
356 elements_[row][col] = cur_element; | 358 elements_[row][col] = cur_element; |
357 } | 359 } |
358 } | 360 } |
359 | 361 |
360 return *this; | 362 return *this; |
361 } | 363 } |
362 | 364 |
363 DISALLOW_COPY_AND_ASSIGN(Matrix); | 365 DISALLOW_COPY_AND_ASSIGN(Matrix); |
364 }; | 366 }; |
365 | 367 |
366 } // namespace webrtc | 368 } // namespace webrtc |
367 | 369 |
368 #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_ | 370 #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_ |
OLD | NEW |